session-ios/SessionUtilitiesKit/Database/Utilities/PersistableRecord+Utilities.swift
Morgan Pretty 89df1261e3 Added a method to allow safer database inserts
Fixed an issue where the app settings were updating immediately making them seem to
Updated GRDB to version 6.1 and SQLCipher to 4.5.2
Added a method which allows for inserting into the database while omitting columns which exist in the object but not in the database (so allow for old migrations to run with less issues)
Updated all the migrations to use the migration safe insert method
Removed some ObjC support extension functions
2022-10-21 17:32:51 +11:00

278 lines
9.5 KiB
Swift

// Copyright © 2022 Rangeproof Pty Ltd. All rights reserved.
import Foundation
import GRDB
// MARK: - Migration Safe Functions
public extension MutablePersistableRecord where Self: TableRecord & EncodableRecord & Codable {
func migrationSafeInsert(
_ db: Database,
onConflict conflictResolution: Database.ConflictResolution? = nil
) throws {
var record = try MigrationSafeMutableRecord(db, originalRecord: self)
try record.insert(db, onConflict: conflictResolution)
}
func migrationSafeInserted(
_ db: Database,
onConflict conflictResolution: Database.ConflictResolution? = nil
) throws -> Self {
let record = try MigrationSafeMutableRecord(db, originalRecord: self)
let updatedRecord = try record.inserted(db, onConflict: conflictResolution)
return updatedRecord.originalRecord
}
func migrationSafeSave(
_ db: Database,
onConflict conflictResolution: Database.ConflictResolution? = nil
) throws {
var record = try MigrationSafeMutableRecord(db, originalRecord: self)
try record.save(db, onConflict: conflictResolution)
}
func migrationSafeSaved(
_ db: Database,
onConflict conflictResolution: Database.ConflictResolution? = nil
) throws -> Self {
let record = try MigrationSafeMutableRecord(db, originalRecord: self)
let updatedRecord = try record.saved(db, onConflict: conflictResolution)
return updatedRecord.originalRecord
}
func migrationSafeUpsert(_ db: Database) throws {
var record = try MigrationSafeMutableRecord(db, originalRecord: self)
try record.upsert(db)
}
}
// MARK: - MigrationSafeMutableRecord
private class MigrationSafeRecord<T: PersistableRecord & Encodable>: MigrationSafeMutableRecord<T> {}
private class MigrationSafeMutableRecord<T: MutablePersistableRecord & Encodable>: MutablePersistableRecord & Encodable {
public static var databaseTableName: String { T.databaseTableName }
fileprivate var originalRecord: T
private let availableColumnNames: [String]
init(_ db: Database, originalRecord: T) throws {
// Check the current columns in the database and filter out any properties on the object which
// don't exist in the dictionary
self.originalRecord = originalRecord
self.availableColumnNames = try db.columns(in: Self.databaseTableName).map(\.name)
}
func encode(to encoder: Encoder) throws {
let filteredEncoder: FilteredEncoder = FilteredEncoder(
originalEncoder: encoder,
availableKeys: availableColumnNames
)
try originalRecord.encode(to: filteredEncoder)
}
// MARK: - Persistence Callbacks
func willInsert(_ db: Database) throws {
try originalRecord.willInsert(db)
}
func aroundInsert(_ db: Database, insert: () throws -> InsertionSuccess) throws {
try originalRecord.aroundInsert(db, insert: insert)
}
func didInsert(_ inserted: InsertionSuccess) {
originalRecord.didInsert(inserted)
}
func willUpdate(_ db: Database, columns: Set<String>) throws {
try originalRecord.willUpdate(db, columns: columns)
}
func aroundUpdate(_ db: Database, columns: Set<String>, update: () throws -> PersistenceSuccess) throws {
try originalRecord.aroundUpdate(db, columns: columns, update: update)
}
func didUpdate(_ updated: PersistenceSuccess) {
originalRecord.didUpdate(updated)
}
func willSave(_ db: Database) throws {
try originalRecord.willSave(db)
}
func aroundSave(_ db: Database, save: () throws -> PersistenceSuccess) throws {
try originalRecord.aroundSave(db, save: save)
}
func didSave(_ saved: PersistenceSuccess) {
originalRecord.didSave(saved)
}
func willDelete(_ db: Database) throws {
try originalRecord.willDelete(db)
}
func aroundDelete(_ db: Database, delete: () throws -> Bool) throws {
try originalRecord.aroundDelete(db, delete: delete)
}
func didDelete(deleted: Bool) {
originalRecord.didDelete(deleted: deleted)
}
}
// MARK: - FilteredEncoder
private class FilteredEncoder: Encoder {
let originalEncoder: Encoder
let availableKeys: [String]
init(originalEncoder: Encoder, availableKeys: [String]) {
self.originalEncoder = originalEncoder
self.availableKeys = availableKeys
}
var codingPath: [CodingKey] { originalEncoder.codingPath }
var userInfo: [CodingUserInfoKey: Any] { originalEncoder.userInfo }
func container<Key>(keyedBy type: Key.Type) -> KeyedEncodingContainer<Key> where Key: CodingKey {
let container = originalEncoder.container(keyedBy: type)
let filteredContainer = FilteredKeyedEncodingContainer(
availableKeys: availableKeys,
originalContainer: container
)
return KeyedEncodingContainer(filteredContainer)
}
func unkeyedContainer() -> UnkeyedEncodingContainer { originalEncoder.unkeyedContainer() }
func singleValueContainer() -> SingleValueEncodingContainer { originalEncoder.singleValueContainer() }
}
// MARK: - FilteredKeyedEncodingContainer
private class FilteredKeyedEncodingContainer<Key: CodingKey>: KeyedEncodingContainerProtocol {
let codingPath: [CodingKey]
let availableKeys: [String]
var originalContainer: KeyedEncodingContainer<Key>
init(availableKeys: [String], originalContainer: KeyedEncodingContainer<Key>) {
self.availableKeys = availableKeys
self.codingPath = originalContainer.codingPath
self.originalContainer = originalContainer
}
func encodeNil(forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encodeNil(forKey: key)
}
func encode(_ value: Bool, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: String, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: Double, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: Float, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: Int, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: Int8, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: Int16, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: Int32, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: Int64, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: UInt, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: UInt8, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: UInt16, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: UInt32, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode(_ value: UInt64, forKey key: Key) throws {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func encode<T>(_ value: T, forKey key: Key) throws where T: Encodable {
guard availableKeys.contains(key.stringValue) else { return }
try originalContainer.encode(value, forKey: key)
}
func nestedContainer<NestedKey>(keyedBy keyType: NestedKey.Type, forKey key: Key) -> KeyedEncodingContainer<NestedKey> where NestedKey: CodingKey {
return originalContainer.nestedContainer(keyedBy: keyType, forKey: key)
}
func nestedUnkeyedContainer(forKey key: Key) -> UnkeyedEncodingContainer {
return originalContainer.nestedUnkeyedContainer(forKey: key)
}
func superEncoder() -> Encoder {
return originalContainer.superEncoder()
}
func superEncoder(forKey key: Key) -> Encoder {
return originalContainer.superEncoder(forKey: key)
}
}