// Copyright © 2022 Rangeproof Pty Ltd. All rights reserved. import Foundation import GRDB public extension Database { func create( table: T.Type, options: TableOptions = [], body: (TypedTableDefinition) throws -> Void ) throws where T: TableRecord, T: ColumnExpressible { try create(table: T.databaseTableName, options: options) { tableDefinition in let typedDefinition: TypedTableDefinition = TypedTableDefinition(definition: tableDefinition) try body(typedDefinition) } } func alter( table: T.Type, body: (TypedTableAlteration) -> Void ) throws where T: TableRecord, T: ColumnExpressible { try alter(table: T.databaseTableName) { tableAlteration in let typedAlteration: TypedTableAlteration = TypedTableAlteration(alteration: tableAlteration) body(typedAlteration) } } func drop(table: T.Type) throws where T: TableRecord { try drop(table: T.databaseTableName) } func createIndex( withCustomName customName: String? = nil, on table: T.Type, columns: [T.Columns], options: IndexOptions = [], condition: (any SQLExpressible)? = nil ) throws where T: TableRecord, T: ColumnExpressible { guard !columns.isEmpty else { throw StorageError.invalidData } let indexName: String = ( customName ?? "\(T.databaseTableName)_on_\(columns.map { $0.name }.joined(separator: "_and_"))" ) try create( index: indexName, on: T.databaseTableName, columns: columns.map { $0.name }, options: options, condition: condition ) } func makeFTS5Pattern(rawPattern: String, forTable table: T.Type) throws -> FTS5Pattern where T: TableRecord, T: ColumnExpressible { return try makeFTS5Pattern(rawPattern: rawPattern, forTable: table.databaseTableName) } func interrupt() { guard sqliteConnection != nil else { return } sqlite3_interrupt(sqliteConnection) } /// This is a custom implementation of the `afterNextTransaction` method which executes the closures within their own /// transactions to allow for nesting of 'afterNextTransaction' actions /// /// **Note:** GRDB doesn't notify read-only transactions to transaction observers func afterNextTransactionNested( onCommit: @escaping (Database) -> Void, onRollback: @escaping (Database) -> Void = { _ in } ) { afterNextTransactionNestedOnce( dedupeId: UUID().uuidString, onCommit: onCommit, onRollback: onRollback ) } func afterNextTransactionNestedOnce( dedupeId: String, onCommit: @escaping (Database) -> Void, onRollback: @escaping (Database) -> Void = { _ in } ) { // Only allow a single observer per `dedupeId` per transaction, this allows us to // schedule an action to run at most once per transaction (eg. auto-scheduling a ConfigSyncJob // when receiving messages) guard !TransactionHandler.registeredHandlers.wrappedValue.contains(dedupeId) else { return } add( transactionObserver: TransactionHandler( identifier: dedupeId, onCommit: onCommit, onRollback: onRollback ), extent: .nextTransaction ) } } fileprivate class TransactionHandler: TransactionObserver { static var registeredHandlers: Atomic> = Atomic([]) let identifier: String let onCommit: (Database) -> Void let onRollback: (Database) -> Void init( identifier: String, onCommit: @escaping (Database) -> Void, onRollback: @escaping (Database) -> Void ) { self.identifier = identifier self.onCommit = onCommit self.onRollback = onRollback TransactionHandler.registeredHandlers.mutate { $0.insert(identifier) } } // Ignore changes func observes(eventsOfKind eventKind: DatabaseEventKind) -> Bool { false } func databaseDidChange(with event: DatabaseEvent) { } func databaseDidCommit(_ db: Database) { TransactionHandler.registeredHandlers.mutate { $0.remove(identifier) } do { try db.inTransaction { onCommit(db) return .commit } } catch { SNLog("[Database] afterNextTransactionNested onCommit failed") } } func databaseDidRollback(_ db: Database) { TransactionHandler.registeredHandlers.mutate { $0.remove(identifier) } do { try db.inTransaction { onRollback(db) return .commit } } catch { SNLog("[Database] afterNextTransactionNested onRollback failed") } } }