diff --git a/CHANGELOG.md b/CHANGELOG.md index e71921cce3..da3ed4a3dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,36 @@ Release Notes ## Next Version +**New** + +- `DatabaseAggregate` is the protocol for custom aggregate functions (fixes [#236](https://github.com/groue/GRDB.swift/issues/236), [documentation](https://github.com/groue/GRDB.swift#custom-aggregates)): + + ```swift + struct MySum : DatabaseAggregate { + var sum: Int = 0 + + mutating func step(_ dbValues: [DatabaseValue]) { + if let int = Int.fromDatabaseValue(dbValues[0]) { + sum += int + } + } + + func finalize() -> DatabaseValueConvertible? { + return sum + } + } + + let dbQueue = DatabaseQueue() + let fn = DatabaseFunction("mysum", argumentCount: 1, aggregate: MySum.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + try db.execute("CREATE TABLE test(i)") + try db.execute("INSERT INTO test(i) VALUES (1)") + try db.execute("INSERT INTO test(i) VALUES (2)") + try Int.fetchOne(db, "SELECT mysum(i) FROM test")! // 3 + } + ``` + **Fixed** - `QueryInterfaceRequest.order(_:)` clears the eventual reversed flag, and better reflects the documentation of this method: "Any previous ordering is replaced." diff --git a/GRDB.xcodeproj/project.pbxproj b/GRDB.xcodeproj/project.pbxproj index 984a459ae8..d8ac9a7bdb 100755 --- a/GRDB.xcodeproj/project.pbxproj +++ b/GRDB.xcodeproj/project.pbxproj @@ -329,6 +329,21 @@ 564448891EF56B1B00DD2861 /* DatabaseAfterNextTransactionCommitTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564448821EF56B1B00DD2861 /* DatabaseAfterNextTransactionCommitTests.swift */; }; 5644488A1EF56B1B00DD2861 /* DatabaseAfterNextTransactionCommitTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564448821EF56B1B00DD2861 /* DatabaseAfterNextTransactionCommitTests.swift */; }; 564A50C81BFF4B7F00B3A3A2 /* DatabaseCollationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564A50C61BFF4B7F00B3A3A2 /* DatabaseCollationTests.swift */; }; + 564F9C1E1F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */; }; + 564F9C1F1F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */; }; + 564F9C201F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */; }; + 564F9C211F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */; }; + 564F9C221F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */; }; + 564F9C231F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */; }; + 564F9C241F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */; }; + 564F9C251F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */; }; + 564F9C2D1F075DD200877A00 /* DatabaseFunction.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C2C1F075DD200877A00 /* DatabaseFunction.swift */; }; + 564F9C2F1F07611400877A00 /* DatabaseFunction.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C2C1F075DD200877A00 /* DatabaseFunction.swift */; }; + 564F9C301F07611500877A00 /* DatabaseFunction.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C2C1F075DD200877A00 /* DatabaseFunction.swift */; }; + 564F9C311F07611600877A00 /* DatabaseFunction.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C2C1F075DD200877A00 /* DatabaseFunction.swift */; }; + 564F9C321F07611700877A00 /* DatabaseFunction.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C2C1F075DD200877A00 /* DatabaseFunction.swift */; }; + 564F9C331F07611800877A00 /* DatabaseFunction.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C2C1F075DD200877A00 /* DatabaseFunction.swift */; }; + 564F9C341F07611900877A00 /* DatabaseFunction.swift in Sources */ = {isa = PBXBuildFile; fileRef = 564F9C2C1F075DD200877A00 /* DatabaseFunction.swift */; }; 565029C81E914DB700615A2C /* TableMappingTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 565029C71E914DB700615A2C /* TableMappingTests.swift */; }; 565029C91E914DB700615A2C /* TableMappingTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 565029C71E914DB700615A2C /* TableMappingTests.swift */; }; 565029CA1E914DB700615A2C /* TableMappingTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 565029C71E914DB700615A2C /* TableMappingTests.swift */; }; @@ -1903,6 +1918,8 @@ 5636E9BB1D22574100B9B05F /* Request.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Request.swift; sourceTree = ""; }; 564448821EF56B1B00DD2861 /* DatabaseAfterNextTransactionCommitTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DatabaseAfterNextTransactionCommitTests.swift; sourceTree = ""; }; 564A50C61BFF4B7F00B3A3A2 /* DatabaseCollationTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DatabaseCollationTests.swift; sourceTree = ""; }; + 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DatabaseAggregateTests.swift; sourceTree = ""; }; + 564F9C2C1F075DD200877A00 /* DatabaseFunction.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DatabaseFunction.swift; sourceTree = ""; }; 565029C71E914DB700615A2C /* TableMappingTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = TableMappingTests.swift; sourceTree = ""; }; 565490A01D5A4798005622CB /* GRDB.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = GRDB.framework; sourceTree = BUILT_PRODUCTS_DIR; }; 565490E51D5AE282005622CB /* libsqlite3.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.tbd; path = Platforms/WatchOS.platform/Developer/SDKs/WatchOS2.2.sdk/usr/lib/libsqlite3.tbd; sourceTree = DEVELOPER_DIR; }; @@ -2636,6 +2653,7 @@ children = ( 56DAA2C41DE99D8D006E10C8 /* Cursor */, 564448821EF56B1B00DD2861 /* DatabaseAfterNextTransactionCommitTests.swift */, + 564F9C1D1F069B4E00877A00 /* DatabaseAggregateTests.swift */, 564A50C61BFF4B7F00B3A3A2 /* DatabaseCollationTests.swift */, 56A238161B9C74A90082EB20 /* DatabaseErrorTests.swift */, 560C97C61BFD0B8400BF8471 /* DatabaseFunctionTests.swift */, @@ -2726,6 +2744,7 @@ 56DAA2DA1DE9C827006E10C8 /* Cursor.swift */, 56A238711B9C75030082EB20 /* Database.swift */, 56A238731B9C75030082EB20 /* DatabaseError.swift */, + 564F9C2C1F075DD200877A00 /* DatabaseFunction.swift */, 560A37A31C8F625000949E71 /* DatabasePool.swift */, 56A238741B9C75030082EB20 /* DatabaseQueue.swift */, 563363BF1C942C04000BE133 /* DatabaseReader.swift */, @@ -3834,6 +3853,7 @@ 560FC5231CB003810014AA8E /* Configuration.swift in Sources */, 566475CD1D981D5E00FF74B8 /* SQLFunctions.swift in Sources */, 5698AC381D9E5A590056AF8C /* FTS3Pattern.swift in Sources */, + 564F9C2F1F07611400877A00 /* DatabaseFunction.swift in Sources */, 5659F4991EA8D989004A4992 /* Pool.swift in Sources */, 56CEB54D1EAA359A00BFAF62 /* SQLExpressible.swift in Sources */, 5664759B1D97D8A000FF74B8 /* SQLCollection.swift in Sources */, @@ -3988,6 +4008,7 @@ 560FC5961CB00B880014AA8E /* DatabaseValueConvertibleFetchTests.swift in Sources */, 560FC5971CB00B880014AA8E /* DatabaseErrorTests.swift in Sources */, 560FC5991CB00B880014AA8E /* RecordMinimalPrimaryKeySingleTests.swift in Sources */, + 564F9C1F1F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */, 562393731DEE104400A6B01F /* MapCursorTests.swift in Sources */, 5698AC9F1DA4B0430056AF8C /* FTS4TableBuilderTests.swift in Sources */, 56176C621EACCCC7000F3F2B /* FTS5TableBuilderTests.swift in Sources */, @@ -4042,6 +4063,7 @@ 56B964B71DA51D010002DA19 /* FTS5TokenizerDescriptor.swift in Sources */, 56F5ABDA1D814330001F60CB /* NSData.swift in Sources */, 56D121601ED34978001347D2 /* Fixits-0.109.0.swift in Sources */, + 564F9C341F07611900877A00 /* DatabaseFunction.swift in Sources */, 566475A81D9810A400FF74B8 /* SQLSelectable+QueryInterface.swift in Sources */, 5659F4961EA8D964004A4992 /* ReadWriteBox.swift in Sources */, 5659F48E1EA8D94E004A4992 /* Utils.swift in Sources */, @@ -4223,6 +4245,7 @@ 565EFAF01D0436CE00A8FA9D /* NumericOverflowTests.swift in Sources */, 567156551CB16729007DC145 /* DatabaseValueConvertibleFetchTests.swift in Sources */, 567156561CB16729007DC145 /* DatabaseErrorTests.swift in Sources */, + 564F9C201F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */, 562393741DEE104400A6B01F /* MapCursorTests.swift in Sources */, 5698ACA01DA4B0430056AF8C /* FTS4TableBuilderTests.swift in Sources */, 56176C681EACCCC8000F3F2B /* FTS5TableBuilderTests.swift in Sources */, @@ -4284,6 +4307,7 @@ 56AFC9F71CB1A8BB00F48B96 /* RawRepresentable.swift in Sources */, 566475D01D981D5E00FF74B8 /* SQLFunctions.swift in Sources */, 5698AC3B1D9E5A590056AF8C /* FTS3Pattern.swift in Sources */, + 564F9C321F07611700877A00 /* DatabaseFunction.swift in Sources */, 5659F49C1EA8D989004A4992 /* Pool.swift in Sources */, 56CEB5501EAA359A00BFAF62 /* SQLExpressible.swift in Sources */, 5664759E1D97D8A000FF74B8 /* SQLCollection.swift in Sources */, @@ -4438,6 +4462,7 @@ 56AFCA641CB1AA9900F48B96 /* DatabaseValueConvertibleSubclassTests.swift in Sources */, 56A8C24A1D1918F10096E9D4 /* FoundationUUIDTests.swift in Sources */, 56AFCA651CB1AA9900F48B96 /* DatabaseErrorTests.swift in Sources */, + 564F9C231F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */, 5690C32B1D23E6D800E59934 /* FoundationDateComponentsTests.swift in Sources */, 5623936E1DEE0CD200A6B01F /* FlattenCursorTests.swift in Sources */, 56176C741EACCCCA000F3F2B /* FTS5TableBuilderTests.swift in Sources */, @@ -4568,6 +4593,7 @@ 56AFCAB91CB1ABC800F48B96 /* RecordPrimaryKeyRowIDTests.swift in Sources */, 56AFCABA1CB1ABC800F48B96 /* RecordEditedTests.swift in Sources */, 56AFCABB1CB1ABC800F48B96 /* DatabasePoolSchemaCacheTests.swift in Sources */, + 564F9C241F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */, 56AFCABC1CB1ABC800F48B96 /* RecordPrimaryKeySingleWithReplaceConflictResolutionTests.swift in Sources */, 5623936F1DEE0CD200A6B01F /* FlattenCursorTests.swift in Sources */, 56176C7A1EACCCCB000F3F2B /* FTS5TableBuilderTests.swift in Sources */, @@ -4684,6 +4710,7 @@ 56A238861B9C75030082EB20 /* DatabaseValue.swift in Sources */, 56CEB55D1EAA359A00BFAF62 /* SQLOrdering.swift in Sources */, 56CEB5641EAA359A00BFAF62 /* SQLSelectable.swift in Sources */, + 564F9C311F07611600877A00 /* DatabaseFunction.swift in Sources */, 56B964A01DA51B4C0002DA19 /* FTS5.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; @@ -4782,6 +4809,7 @@ 5698ACA21DA4B0430056AF8C /* FTS4TableBuilderTests.swift in Sources */, 56A2385C1B9C74A90082EB20 /* RecordPrimaryKeySingleWithReplaceConflictResolutionTests.swift in Sources */, 56A238401B9C74A90082EB20 /* DatabaseValueConvertibleSubclassTests.swift in Sources */, + 564F9C221F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */, 56A8C2481D1918F00096E9D4 /* FoundationUUIDTests.swift in Sources */, 56A2383C1B9C74A90082EB20 /* DatabaseErrorTests.swift in Sources */, 5690C32A1D23E6D800E59934 /* FoundationDateComponentsTests.swift in Sources */, @@ -4912,6 +4940,7 @@ 56D496741D81309E008276D7 /* RecordCopyTests.swift in Sources */, 56D496791D81309E008276D7 /* RecordWithColumnNameManglingTests.swift in Sources */, 56D4966C1D81309E008276D7 /* RecordMinimalPrimaryKeyRowIDTests.swift in Sources */, + 564F9C1E1F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */, 56D496861D813147008276D7 /* UpdateStatementTests.swift in Sources */, 56D4965D1D81304E008276D7 /* FoundationNSNumberTests.swift in Sources */, 56B021C91D8C0D3900B239BB /* MutablePersistablePersistenceConflictPolicyTests.swift in Sources */, @@ -4973,6 +5002,7 @@ 56A2387B1B9C75030082EB20 /* Configuration.swift in Sources */, 566475CC1D981D5E00FF74B8 /* SQLFunctions.swift in Sources */, 5698AC371D9E5A590056AF8C /* FTS3Pattern.swift in Sources */, + 564F9C2D1F075DD200877A00 /* DatabaseFunction.swift in Sources */, 5659F4981EA8D989004A4992 /* Pool.swift in Sources */, 56CEB54C1EAA359A00BFAF62 /* SQLExpressible.swift in Sources */, 5664759A1D97D8A000FF74B8 /* SQLCollection.swift in Sources */, @@ -5058,6 +5088,7 @@ F3BA80341CFB28A4003DC1BA /* Record.swift in Sources */, F3BA80161CFB2876003DC1BA /* Row.swift in Sources */, F3BA80101CFB2876003DC1BA /* DatabaseReader.swift in Sources */, + 564F9C331F07611800877A00 /* DatabaseFunction.swift in Sources */, 5659F49D1EA8D989004A4992 /* Pool.swift in Sources */, 56CEB5511EAA359A00BFAF62 /* SQLExpressible.swift in Sources */, 5698AC3C1D9E5A590056AF8C /* FTS3Pattern.swift in Sources */, @@ -5212,6 +5243,7 @@ 5698ACA51DA4B0430056AF8C /* FTS4TableBuilderTests.swift in Sources */, F3BA80B11CFB2FC4003DC1BA /* DatabaseTests.swift in Sources */, 56A8C24E1D1918F30096E9D4 /* FoundationUUIDTests.swift in Sources */, + 564F9C251F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */, 5657AB3D1D108BA9006283EF /* FoundationDataTests.swift in Sources */, 5690C32D1D23E6D800E59934 /* FoundationDateComponentsTests.swift in Sources */, 562393791DEE104400A6B01F /* MapCursorTests.swift in Sources */, @@ -5273,6 +5305,7 @@ F3BA80901CFB2E7A003DC1BA /* Record.swift in Sources */, F3BA80721CFB2E55003DC1BA /* Row.swift in Sources */, F3BA806C1CFB2E55003DC1BA /* DatabaseReader.swift in Sources */, + 564F9C301F07611500877A00 /* DatabaseFunction.swift in Sources */, 5659F49A1EA8D989004A4992 /* Pool.swift in Sources */, 56CEB54E1EAA359A00BFAF62 /* SQLExpressible.swift in Sources */, 5698AC391D9E5A590056AF8C /* FTS3Pattern.swift in Sources */, @@ -5427,6 +5460,7 @@ 5698AC431DA2BED90056AF8C /* FTS3PatternTests.swift in Sources */, F3BA80E11CFB300F003DC1BA /* DatabaseValueConversionTests.swift in Sources */, 5623931B1DECC02000A6B01F /* RowFetchTests.swift in Sources */, + 564F9C211F069B4E00877A00 /* DatabaseAggregateTests.swift in Sources */, F3BA80ED1CFB3017003DC1BA /* RowFromDictionaryTests.swift in Sources */, 5690C3291D23E6D800E59934 /* FoundationDateComponentsTests.swift in Sources */, 5657AB391D108BA9006283EF /* FoundationDataTests.swift in Sources */, diff --git a/GRDB/Core/Database.swift b/GRDB/Core/Database.swift index b9fee2d41b..cced1be07e 100644 --- a/GRDB/Core/Database.swift +++ b/GRDB/Core/Database.swift @@ -872,116 +872,13 @@ extension Database { /// try Int.fetchOne(db, "SELECT succ(1)")! // 2 public func add(function: DatabaseFunction) { functions.update(with: function) - let functionPointer = unsafeBitCast(function, to: UnsafeMutableRawPointer.self) - let code = sqlite3_create_function_v2( - sqliteConnection, - function.name, - function.argumentCount, - SQLITE_UTF8 | function.eTextRep, - functionPointer, - { (context, argc, argv) in - let function = unsafeBitCast(sqlite3_user_data(context), to: DatabaseFunction.self) - do { - let result = try function.function(argc, argv) - switch result.storage { - case .null: - sqlite3_result_null(context) - case .int64(let int64): - sqlite3_result_int64(context, int64) - case .double(let double): - sqlite3_result_double(context, double) - case .string(let string): - sqlite3_result_text(context, string, -1, SQLITE_TRANSIENT) - case .blob(let data): - data.withUnsafeBytes { bytes in - sqlite3_result_blob(context, bytes, Int32(data.count), SQLITE_TRANSIENT) - } - } - } catch let error as DatabaseError { - if let message = error.message { - sqlite3_result_error(context, message, -1) - } - sqlite3_result_error_code(context, error.extendedResultCode.rawValue) - } catch { - sqlite3_result_error(context, "\(error)", -1) - } - }, nil, nil, nil) - - guard code == SQLITE_OK else { - // Assume a GRDB bug: there is no point throwing any error. - fatalError(DatabaseError(resultCode: code, message: lastErrorMessage).description) - } + function.install(in: self) } /// Remove an SQL function. public func remove(function: DatabaseFunction) { functions.remove(function) - let code = sqlite3_create_function_v2( - sqliteConnection, - function.name, - function.argumentCount, - SQLITE_UTF8 | function.eTextRep, - nil, nil, nil, nil, nil) - guard code == SQLITE_OK else { - // Assume a GRDB bug: there is no point throwing any error. - fatalError(DatabaseError(resultCode: code, message: lastErrorMessage).description) - } - } -} - - -/// An SQL function. -public final class DatabaseFunction { - public let name: String - let argumentCount: Int32 - let pure: Bool - let function: (Int32, UnsafeMutablePointer?) throws -> DatabaseValue - var eTextRep: Int32 { return pure ? SQLITE_DETERMINISTIC : 0 } - - /// Returns an SQL function. - /// - /// let fn = DatabaseFunction("succ", argumentCount: 1) { dbValues in - /// guard let int = Int.fromDatabaseValue(dbValues[0]) else { - /// return nil - /// } - /// return int + 1 - /// } - /// db.add(function: fn) - /// try Int.fetchOne(db, "SELECT succ(1)")! // 2 - /// - /// - parameters: - /// - name: The function name. - /// - argumentCount: The number of arguments of the function. If - /// omitted, or nil, the function accepts any number of arguments. - /// - pure: Whether the function is "pure", which means that its results - /// only depends on its inputs. When a function is pure, SQLite has - /// the opportunity to perform additional optimizations. Default value - /// is false. - /// - function: A function that takes an array of DatabaseValue - /// arguments, and returns an optional DatabaseValueConvertible such - /// as Int, String, NSDate, etc. The array is guaranteed to have - /// exactly *argumentCount* elements, provided *argumentCount* is - /// not nil. - public init(_ name: String, argumentCount: Int32? = nil, pure: Bool = false, function: @escaping ([DatabaseValue]) throws -> DatabaseValueConvertible?) { - self.name = name - self.argumentCount = argumentCount ?? -1 - self.pure = pure - self.function = { (argc, argv) in - let arguments = (0.. Bool { - return lhs.name == rhs.name && lhs.argumentCount == rhs.argumentCount + function.uninstall(in: self) } } diff --git a/GRDB/Core/DatabaseFunction.swift b/GRDB/Core/DatabaseFunction.swift new file mode 100644 index 0000000000..d664beb888 --- /dev/null +++ b/GRDB/Core/DatabaseFunction.swift @@ -0,0 +1,369 @@ +#if SWIFT_PACKAGE + import CSQLite +#endif + +/// An SQL function or aggregate. +public final class DatabaseFunction { + public let name: String + let argumentCount: Int32? + let pure: Bool + private let kind: Kind + fileprivate var nArg: Int32 { return argumentCount ?? -1 } + fileprivate var eTextRep: Int32 { return (SQLITE_UTF8 | (pure ? SQLITE_DETERMINISTIC : 0)) } + + /// Returns an SQL function. + /// + /// let fn = DatabaseFunction("succ", argumentCount: 1) { dbValues in + /// guard let int = Int.fromDatabaseValue(dbValues[0]) else { + /// return nil + /// } + /// return int + 1 + /// } + /// db.add(function: fn) + /// try Int.fetchOne(db, "SELECT succ(1)")! // 2 + /// + /// - parameters: + /// - name: The function name. + /// - argumentCount: The number of arguments of the function. If + /// omitted, or nil, the function accepts any number of arguments. + /// - pure: Whether the function is "pure", which means that its results + /// only depends on its inputs. When a function is pure, SQLite has + /// the opportunity to perform additional optimizations. Default value + /// is false. + /// - function: A function that takes an array of DatabaseValue + /// arguments, and returns an optional DatabaseValueConvertible such + /// as Int, String, NSDate, etc. The array is guaranteed to have + /// exactly *argumentCount* elements, provided *argumentCount* is + /// not nil. + public init(_ name: String, argumentCount: Int32? = nil, pure: Bool = false, function: @escaping ([DatabaseValue]) throws -> DatabaseValueConvertible?) { + self.name = name + self.argumentCount = argumentCount + self.pure = pure + self.kind = .function{ (argc, argv) in + let arguments = (0.. DatabaseValueConvertible? { + /// return sum + /// } + /// } + /// + /// let dbQueue = DatabaseQueue() + /// let fn = DatabaseFunction("mysum", argumentCount: 1, aggregate: MySum.self) + /// dbQueue.add(function: fn) + /// try dbQueue.inDatabase { db in + /// try db.execute("CREATE TABLE test(i)") + /// try db.execute("INSERT INTO test(i) VALUES (1)") + /// try db.execute("INSERT INTO test(i) VALUES (2)") + /// try Int.fetchOne(db, "SELECT mysum(i) FROM test")! // 3 + /// } + /// + /// - parameters: + /// - name: The function name. + /// - argumentCount: The number of arguments of the aggregate. If + /// omitted, or nil, the aggregate accepts any number of arguments. + /// - pure: Whether the aggregate is "pure", which means that its + /// results only depends on its inputs. When an aggregate is pure, + /// SQLite has the opportunity to perform additional optimizations. + /// Default value is false. + /// - aggregate: A type that implements the DatabaseAggregate protocol. + /// For each step of the aggregation, its `step` method is called with + /// an array of DatabaseValue arguments. The array is guaranteed to + /// have exactly *argumentCount* elements, provided *argumentCount* is + /// not nil. + public init(_ name: String, argumentCount: Int32? = nil, pure: Bool = false, aggregate: Aggregate.Type) { + self.name = name + self.argumentCount = argumentCount + self.pure = pure + self.kind = .aggregate { return Aggregate() } + } + + /// Calls sqlite3_create_function_v2 + /// See https://sqlite.org/c3ref/create_function.html + func install(in db: Database) { + // Retain the function definition + let definition = kind.definition + let definitionP = Unmanaged.passRetained(definition).toOpaque() + + let code = sqlite3_create_function_v2( + db.sqliteConnection, + name, + nArg, + eTextRep, + definitionP, + kind.xFunc, + kind.xStep, + kind.xFinal, + { definitionP in + // Release the function definition + Unmanaged.fromOpaque(definitionP!).release() + }) + + guard code == SQLITE_OK else { + // Assume a GRDB bug: there is no point throwing any error. + fatalError(DatabaseError(resultCode: code, message: db.lastErrorMessage).description) + } + } + + /// Calls sqlite3_create_function_v2 + /// See https://sqlite.org/c3ref/create_function.html + func uninstall(in db: Database) { + let code = sqlite3_create_function_v2( + db.sqliteConnection, + name, + nArg, + eTextRep, + nil, nil, nil, nil, nil) + + guard code == SQLITE_OK else { + // Assume a GRDB bug: there is no point throwing any error. + fatalError(DatabaseError(resultCode: code, message: db.lastErrorMessage).description) + } + } + + /// The way to compute the result of a function. + /// Feeds the `pApp` parameter of sqlite3_create_function_v2 + /// http://sqlite.org/capi3ref.html#sqlite3_create_function + private class FunctionDefinition { + let compute: (Int32, UnsafeMutablePointer?) throws -> DatabaseValueConvertible? + init(compute: @escaping (Int32, UnsafeMutablePointer?) throws -> DatabaseValueConvertible?) { + self.compute = compute + } + } + + /// The way to start an aggregate. + /// Feeds the `pApp` parameter of sqlite3_create_function_v2 + /// http://sqlite.org/capi3ref.html#sqlite3_create_function + private class AggregateDefinition { + let makeAggregate: () -> DatabaseAggregate + init(makeAggregate: @escaping () -> DatabaseAggregate) { + self.makeAggregate = makeAggregate + } + } + + /// The current state of an aggregate, storable in SQLite + private class AggregateContext { + var aggregate: DatabaseAggregate + var hasErrored = false + init(aggregate: DatabaseAggregate) { + self.aggregate = aggregate + } + } + + /// A function kind: an "SQL function" or an "aggregate". + /// See http://sqlite.org/capi3ref.html#sqlite3_create_function + private enum Kind { + /// A regular function: SELECT f(1) + case function((Int32, UnsafeMutablePointer?) throws -> DatabaseValueConvertible?) + + /// An aggregate: SELECT f(foo) FROM bar GROUP BY baz + case aggregate(() -> DatabaseAggregate) + + /// Feeds the `pApp` parameter of sqlite3_create_function_v2 + /// http://sqlite.org/capi3ref.html#sqlite3_create_function + var definition: AnyObject { + switch self { + case .function(let compute): + return FunctionDefinition(compute: compute) + case .aggregate(let makeAggregate): + return AggregateDefinition(makeAggregate: makeAggregate) + } + } + + /// Feeds the `xFunc` parameter of sqlite3_create_function_v2 + /// http://sqlite.org/capi3ref.html#sqlite3_create_function + var xFunc: (@convention(c) (OpaquePointer?, Int32, UnsafeMutablePointer?) -> Void)? { + guard case .function = self else { return nil } + return { (sqliteContext, argc, argv) in + let definition = Unmanaged.fromOpaque(sqlite3_user_data(sqliteContext)).takeUnretainedValue() + do { + try DatabaseFunction.report( + result: definition.compute(argc, argv), + in: sqliteContext) + } catch { + DatabaseFunction.report(error: error, in: sqliteContext) + } + } + } + + /// Feeds the `xStep` parameter of sqlite3_create_function_v2 + /// http://sqlite.org/capi3ref.html#sqlite3_create_function + var xStep: (@convention(c) (OpaquePointer?, Int32, UnsafeMutablePointer?) -> Void)? { + guard case .aggregate = self else { return nil } + return { (sqliteContext, argc, argv) in + let aggregateContextU = DatabaseFunction.unmanagedAggregateContext(sqliteContext) + let aggregateContext = aggregateContextU.takeUnretainedValue() + assert(!aggregateContext.hasErrored) + do { + let arguments = (0.. Void)? { + guard case .aggregate = self else { return nil } + return { (sqliteContext) in + let aggregateContextU = DatabaseFunction.unmanagedAggregateContext(sqliteContext) + let aggregateContext = aggregateContextU.takeUnretainedValue() + aggregateContextU.release() + + guard !aggregateContext.hasErrored else { + return + } + + do { + try DatabaseFunction.report( + result: aggregateContext.aggregate.finalize(), + in: sqliteContext) + } catch { + DatabaseFunction.report(error: error, in: sqliteContext) + } + } + } + } + + /// Helper function that extracts the current state of an aggregate from an + /// sqlite function execution context. + /// + /// The result must be released when the aggregate concludes. + /// + /// See https://sqlite.org/c3ref/context.html + /// See https://sqlite.org/c3ref/aggregate_context.html + private static func unmanagedAggregateContext(_ sqliteContext: OpaquePointer?) -> Unmanaged { + // The current aggregate buffer + let stride = MemoryLayout>.stride + let aggregateContextBufferP = UnsafeMutableRawBufferPointer(start: sqlite3_aggregate_context(sqliteContext, Int32(stride))!, count: stride) + + if aggregateContextBufferP.contains(where: { $0 != 0 }) { + // Buffer contains non-null pointer: load aggregate context + let aggregateContextP = aggregateContextBufferP.baseAddress!.assumingMemoryBound(to: Unmanaged.self) + return aggregateContextP.pointee + } else { + // Buffer contains null pointer: create aggregate context... + let aggregate = Unmanaged.fromOpaque(sqlite3_user_data(sqliteContext)) + .takeUnretainedValue() + .makeAggregate() + let aggregateContext = AggregateContext(aggregate: aggregate) + + // retain and store in SQLite's buffer + let aggregateContextU = Unmanaged.passRetained(aggregateContext) + var aggregateContextP = aggregateContextU.toOpaque() + withUnsafeBytes(of: &aggregateContextP) { + aggregateContextBufferP.copyBytes(from: $0) + } + return aggregateContextU + } + } + + private static func report(result: DatabaseValueConvertible?, in sqliteContext: OpaquePointer?) { + switch result?.databaseValue.storage ?? .null { + case .null: + sqlite3_result_null(sqliteContext) + case .int64(let int64): + sqlite3_result_int64(sqliteContext, int64) + case .double(let double): + sqlite3_result_double(sqliteContext, double) + case .string(let string): + sqlite3_result_text(sqliteContext, string, -1, SQLITE_TRANSIENT) + case .blob(let data): + data.withUnsafeBytes { bytes in + sqlite3_result_blob(sqliteContext, bytes, Int32(data.count), SQLITE_TRANSIENT) + } + } + } + + private static func report(error: Error, in sqliteContext: OpaquePointer?) { + if let error = error as? DatabaseError { + if let message = error.message { + sqlite3_result_error(sqliteContext, message, -1) + } + sqlite3_result_error_code(sqliteContext, error.extendedResultCode.rawValue) + } else { + sqlite3_result_error(sqliteContext, "\(error)", -1) + } + } +} + +extension DatabaseFunction : Hashable { + /// The hash value + public var hashValue: Int { + return name.hashValue ^ nArg.hashValue + } + + /// Two functions are equal if they share the same name and arity. + public static func == (lhs: DatabaseFunction, rhs: DatabaseFunction) -> Bool { + return lhs.name == rhs.name && lhs.nArg == rhs.nArg + } +} + +/// The protocol for custom SQLite aggregates. +/// +/// For example: +/// +/// struct MySum : DatabaseAggregate { +/// var sum: Int = 0 +/// +/// mutating func step(_ dbValues: [DatabaseValue]) { +/// if let int = Int.fromDatabaseValue(dbValues[0]) { +/// sum += int +/// } +/// } +/// +/// func finalize() -> DatabaseValueConvertible? { +/// return sum +/// } +/// } +/// +/// let dbQueue = DatabaseQueue() +/// let fn = DatabaseFunction("mysum", argumentCount: 1, aggregate: MySum.self) +/// dbQueue.add(function: fn) +/// try dbQueue.inDatabase { db in +/// try db.execute("CREATE TABLE test(i)") +/// try db.execute("INSERT INTO test(i) VALUES (1)") +/// try db.execute("INSERT INTO test(i) VALUES (2)") +/// try Int.fetchOne(db, "SELECT mysum(i) FROM test")! // 3 +/// } +public protocol DatabaseAggregate { + /// Creates an aggregate. + init() + + /// This method is called at each step of the aggregation. + /// + /// The dbValues argument contains as many values as given to the SQL + /// aggregate function. + /// + /// -- One value + /// SELECT maxLength(name) FROM persons + /// + /// -- Two values + /// SELECT maxFullNameLength(firstName, lastName) FROM persons + /// + /// This method is never called after the finalize() method has been called. + mutating func step(_ dbValues: [DatabaseValue]) throws + + /// Returns the final result + func finalize() throws -> DatabaseValueConvertible? +} diff --git a/README.md b/README.md index cfd5fc84cc..f6ab36ab2d 100644 --- a/README.md +++ b/README.md @@ -479,7 +479,7 @@ Advanced topics: - [Custom Value Types](#custom-value-types) - [Prepared Statements](#prepared-statements) -- [Custom SQL Functions](#custom-sql-functions) +- [Custom SQL Functions and Aggregates](#custom-sql-functions-and-aggregates) - [Database Schema Introspection](#database-schema-introspection) - [Row Adapters](#row-adapters) - [Raw SQLite Pointers](#raw-sqlite-pointers) @@ -1381,14 +1381,25 @@ let selectStatement = try db.cachedSelectStatement(sql) Should a cached prepared statement throw an error, don't reuse it (it is a programmer error). Instead, reload it from the cache. -## Custom SQL Functions +## Custom SQL Functions and Aggregates + +**SQLite lets you define SQL functions and aggregates.** + +A custom SQL function or aggregate extends SQLite: -**SQLite lets you define SQL functions.** +```sql +SELECT reverse(name) FROM persons; -- custom function +SELECT maxLength(name) FROM persons; -- custom aggregate +``` -A custom SQL function extends SQLite. It can be used in raw SQL queries. And when SQLite needs to evaluate it, it calls your custom code. +- [Custom SQL Functions](#custom-sql-functions) +- [Custom Aggregates](#custom-aggregates) + + +### Custom SQL Functions ```swift -let reverseString = DatabaseFunction("reverseString", argumentCount: 1, pure: true) { (values: [DatabaseValue]) in +let reverse = DatabaseFunction("reverse", argumentCount: 1, pure: true) { (values: [DatabaseValue]) in // Extract string value, if any... guard let string = String.fromDatabaseValue(values[0]) else { return nil @@ -1396,11 +1407,11 @@ let reverseString = DatabaseFunction("reverseString", argumentCount: 1, pure: tr // ... and return reversed string: return String(string.characters.reversed()) } -dbQueue.add(function: reverseString) // Or dbPool.add(function: ...) +dbQueue.add(function: reverse) // Or dbPool.add(function: ...) try dbQueue.inDatabase { db in // "oof" - try String.fetchOne(db, "SELECT reverseString('foo')")! + try String.fetchOne(db, "SELECT reverse('foo')")! } ``` @@ -1459,6 +1470,77 @@ Person.select(reverseString.apply(nameColumn)) **GRDB ships with built-in SQL functions that perform unicode-aware string transformations.** See [Unicode](#unicode). +### Custom Aggregates + +Before registering a custom aggregate, you need to define a type that adopts the `DatabaseAggregate` protocol: + +```swift +protocol DatabaseAggregate { + // Initializes an aggregate + init() + + // Called at each step of the aggregation + mutating func step(_ dbValues: [DatabaseValue]) throws + + // Returns the final result + func finalize() throws -> DatabaseValueConvertible? +} +``` + +For example: + +```swift +struct MaxLength : DatabaseAggregate { + var maxLength: Int = 0 + + mutating func step(_ dbValues: [DatabaseValue]) { + // At each step, extract string value, if any... + guard let string = String.fromDatabaseValue(dbValues[0]) else { + return + } + // ... and update the result + let length = string.characters.count + if length > maxLength { + maxLength = length + } + } + + func finalize() -> DatabaseValueConvertible? { + return maxLength + } +} + +let maxLength = DatabaseFunction( + "maxLength", + argumentCount: 1, + pure: true, + aggregate: MaxLength.self) + +dbQueue.add(function: maxLength) // Or dbPool.add(function: ...) + +try dbQueue.inDatabase { db in + // Some Int + try Int.fetchOne(db, "SELECT maxLength(name) FROM persons")! +} +``` + +The `step` method of the aggregate takes an array of [DatabaseValue](#databasevalue). This array contains as many values as the *argumentCount* parameter (or any number of values, when *argumentCount* is omitted). + +The `finalize` method of the aggregate returns the final aggregated [value](#values) (Bool, Int, String, Date, Swift enums, etc.). + +SQLite has the opportunity to perform additional optimizations when aggregates are "pure", which means that their result only depends on their inputs. So make sure to set the *pure* argument to true when possible. + + +**Use custom aggregates in the [query interface](#the-query-interface):** + +```swift +// SELECT maxLength("name") FROM persons +Person.select(maxLength.apply(nameColumn)) + .asRequest(of: Int.self) + .fetchOne(db) // Int? +``` + + ## Database Schema Introspection **SQLite provides database schema introspection tools**, such as the [sqlite_master](https://www.sqlite.org/faq.html#q7) table, and the pragma [table_info](https://www.sqlite.org/pragma.html#pragma_table_info): @@ -1671,6 +1753,7 @@ try dbQueue.inDatabase { db in Before jumping in the low-level wagon, here is the list of all SQLite APIs used by GRDB: +- `sqlite3_aggregate_context`, `sqlite3_create_function_v2`, `sqlite3_result_blob`, `sqlite3_result_double`, `sqlite3_result_error`, `sqlite3_result_error_code`, `sqlite3_result_int64`, `sqlite3_result_null`, `sqlite3_result_text`, `sqlite3_user_data`, `sqlite3_value_blob`, `sqlite3_value_bytes`, `sqlite3_value_double`, `sqlite3_value_int64`, `sqlite3_value_text`, `sqlite3_value_type`: see [Custom SQL Functions and Aggregates](#custom-sql-functions-and-aggregates) - `sqlite3_backup_finish`, `sqlite3_backup_init`, `sqlite3_backup_step`: see [Backup](#backup) - `sqlite3_bind_blob`, `sqlite3_bind_double`, `sqlite3_bind_int64`, `sqlite3_bind_null`, `sqlite3_bind_parameter_count`, `sqlite3_bind_parameter_name`, `sqlite3_bind_text`, `sqlite3_clear_bindings`, `sqlite3_column_blob`, `sqlite3_column_bytes`, `sqlite3_column_count`, `sqlite3_column_double`, `sqlite3_column_int64`, `sqlite3_column_name`, `sqlite3_column_text`, `sqlite3_column_type`, `sqlite3_exec`, `sqlite3_finalize`, `sqlite3_prepare_v2`, `sqlite3_reset`, `sqlite3_step`: see [Executing Updates](#executing-updates), [Fetch Queries](#fetch-queries), [Prepared Statements](#prepared-statements), [Values](#values) - `sqlite3_busy_handler`, `sqlite3_busy_timeout`: see [Configuration.busyMode](http://groue.github.io/GRDB.swift/docs/1.0/Structs/Configuration.html) @@ -1679,7 +1762,6 @@ Before jumping in the low-level wagon, here is the list of all SQLite APIs used - `sqlite3_commit_hook`, `sqlite3_rollback_hook`, `sqlite3_update_hook`: see [TransactionObserver Protocol](#transactionobserver-protocol), [FetchedRecordsController](#fetchedrecordscontroller) - `sqlite3_config`: see [Error Log](#error-log) - `sqlite3_create_collation_v2`: see [String Comparison](#string-comparison) -- `sqlite3_create_function_v2`, `sqlite3_result_blob`, `sqlite3_result_double`, `sqlite3_result_error`, `sqlite3_result_error_code`, `sqlite3_result_int64`, `sqlite3_result_null`, `sqlite3_result_text`, `sqlite3_user_data`, `sqlite3_value_blob`, `sqlite3_value_bytes`, `sqlite3_value_double`, `sqlite3_value_int64`, `sqlite3_value_text`, `sqlite3_value_type`: see [Custom SQL Functions](#custom-sql-functions) - `sqlite3_db_release_memory`: see [Memory Management](#memory-management) - `sqlite3_errcode`, `sqlite3_errmsg`, `sqlite3_errstr`, `sqlite3_extended_result_codes`: see [Error Handling](#error-handling) - `sqlite3_key`, `sqlite3_rekey`: see [Encryption](#encryption) @@ -3039,9 +3121,9 @@ Feed [requests](#requests) with SQL expressions built from your Swift code: > nameColumn.collating(.caseInsensitiveCompare) == name > ``` -- Custom SQL functions +- Custom SQL functions and aggregates - You can apply your own [custom SQL functions](#custom-sql-functions): + You can apply your own [custom SQL functions and aggregates](#custom-functions-): ```swift let f = DatabaseFunction("f", ...) @@ -3276,7 +3358,7 @@ try String.fetchAll(db, request) // [String] try Person.fetchOne(db, request) // Person? ``` -A TypedRequest knows exactly what it has to do when its RowDecoder associated type can decode database rows ([Row](#fetching-rows) itself, [values](#value-queries), or [records](#records)): +On top of that, a TypedRequest knows exactly what it has to do when its RowDecoder associated type can decode database rows ([Row](#fetching-rows) itself, [values](#value-queries), or [records](#records)): ```swift let request = ... // Some TypedRequest that fetches Person @@ -3290,8 +3372,6 @@ try request.fetchOne(db) // Person? **To build custom requests**, you can create your own type that adopts the protocols, or derive requests from other requests, or use one of the built-in concrete types: -- [Request](http://groue.github.io/GRDB.swift/docs/1.0/Protocols/Request.html): the protocol for all requests -- [TypedRequest](http://groue.github.io/GRDB.swift/docs/1.0/Protocols/TypedRequest.html): the protocol for all typed requests - [SQLRequest](http://groue.github.io/GRDB.swift/docs/1.0/Structs/SQLRequest.html): a Request built from raw SQL - [AnyRequest](http://groue.github.io/GRDB.swift/docs/1.0/Structs/AnyRequest.html): a type-erased Request - [AnyTypedRequest](http://groue.github.io/GRDB.swift/docs/1.0/Structs/AnyTypedRequest.html): a type-erased TypedRequest @@ -5078,7 +5158,7 @@ The `UPPER` and `LOWER` built-in SQLite functions are not unicode-aware: try String.fetchOne(db, "SELECT UPPER('Jérôme')") ``` -GRDB extends SQLite with [SQL functions](#custom-sql-functions) that call the Swift built-in string functions `capitalized`, `lowercased`, `uppercased`, `localizedCapitalized`, `localizedLowercased` and `localizedUppercased`: +GRDB extends SQLite with [SQL functions](#custom-sql-functions-and-aggregates) that call the Swift built-in string functions `capitalized`, `lowercased`, `uppercased`, `localizedCapitalized`, `localizedLowercased` and `localizedUppercased`: ```swift // "JÉRÔME" diff --git a/Tests/GRDBTests/DatabaseAggregateTests.swift b/Tests/GRDBTests/DatabaseAggregateTests.swift new file mode 100644 index 0000000000..5ae2a732e9 --- /dev/null +++ b/Tests/GRDBTests/DatabaseAggregateTests.swift @@ -0,0 +1,671 @@ +import XCTest +#if GRDBCIPHER + import GRDBCipher +#elseif GRDBCUSTOMSQLITE + import GRDBCustomSQLite +#else + import GRDB +#endif + +private struct CustomValueType : DatabaseValueConvertible { + var databaseValue: DatabaseValue { + return "CustomValueType".databaseValue + } + static func fromDatabaseValue(_ dbValue: DatabaseValue) -> CustomValueType? { + guard let string = String.fromDatabaseValue(dbValue), string == "CustomValueType" else { + return nil + } + return CustomValueType() + } +} + +class DatabaseAggregateTests: GRDBTestCase { + + // MARK: - Return values + + func testAggregateReturningNull() throws { + struct Aggregate : DatabaseAggregate { + func step(_ values: [DatabaseValue]) { } + func finalize() -> DatabaseValueConvertible? { + return nil + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertTrue(try DatabaseValue.fetchOne(db, "SELECT f()")!.isNull) + } + } + + func testAggregateReturningInt64() throws { + struct Aggregate : DatabaseAggregate { + func step(_ values: [DatabaseValue]) { } + func finalize() -> DatabaseValueConvertible? { + return Int64(1) + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(try Int64.fetchOne(db, "SELECT f()")!, Int64(1)) + } + } + + func testAggregateReturningDouble() throws { + let dbQueue = try makeDatabaseQueue() + struct Aggregate : DatabaseAggregate { + func step(_ values: [DatabaseValue]) { } + func finalize() -> DatabaseValueConvertible? { + return 1e100 + } + } + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(try Double.fetchOne(db, "SELECT f()")!, 1e100) + } + } + + func testAggregateReturningString() throws { + struct Aggregate : DatabaseAggregate { + func step(_ values: [DatabaseValue]) { } + func finalize() -> DatabaseValueConvertible? { + return "foo" + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(try String.fetchOne(db, "SELECT f()")!, "foo") + } + } + + func testAggregateReturningData() throws { + struct Aggregate : DatabaseAggregate { + func step(_ values: [DatabaseValue]) { } + func finalize() -> DatabaseValueConvertible? { + return "foo".data(using: .utf8) + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(try Data.fetchOne(db, "SELECT f()")!, "foo".data(using: .utf8)) + } + } + + func testAggregateReturningCustomValueType() throws { + struct Aggregate : DatabaseAggregate { + func step(_ values: [DatabaseValue]) { } + func finalize() -> DatabaseValueConvertible? { + return CustomValueType() + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertTrue(try CustomValueType.fetchOne(db, "SELECT f()") != nil) + } + } + + // MARK: - Argument values + + func testAggregateArgumentNil() throws { + struct Aggregate : DatabaseAggregate { + var result: DatabaseValueConvertible? + mutating func step(_ dbValues: [DatabaseValue]) { + result = dbValues[0].isNull + } + func finalize() -> DatabaseValueConvertible? { + return result + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 1, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertTrue(try Bool.fetchOne(db, "SELECT f(NULL)")!) + XCTAssertFalse(try Bool.fetchOne(db, "SELECT f(1)")!) + XCTAssertFalse(try Bool.fetchOne(db, "SELECT f(1.1)")!) + XCTAssertFalse(try Bool.fetchOne(db, "SELECT f('foo')")!) + XCTAssertFalse(try Bool.fetchOne(db, "SELECT f(?)", arguments: ["foo".data(using: .utf8)])!) + } + } + + func testAggregateArgumentInt64() throws { + struct Aggregate : DatabaseAggregate { + var result: DatabaseValueConvertible? + mutating func step(_ dbValues: [DatabaseValue]) { + result = Int64.fromDatabaseValue(dbValues[0]) + } + func finalize() -> DatabaseValueConvertible? { + return result + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 1, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertTrue(try Int64.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertEqual(try Int64.fetchOne(db, "SELECT f(1)")!, 1) + XCTAssertEqual(try Int64.fetchOne(db, "SELECT f(1.1)")!, 1) + } + } + + func testAggregateArgumentDouble() throws { + struct Aggregate : DatabaseAggregate { + var result: DatabaseValueConvertible? + mutating func step(_ dbValues: [DatabaseValue]) { + result = Double.fromDatabaseValue(dbValues[0]) + } + func finalize() -> DatabaseValueConvertible? { + return result + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 1, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertTrue(try Double.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertEqual(try Double.fetchOne(db, "SELECT f(1)")!, 1.0) + XCTAssertEqual(try Double.fetchOne(db, "SELECT f(1.1)")!, 1.1) + } + } + + func testAggregateArgumentString() throws { + struct Aggregate : DatabaseAggregate { + var result: DatabaseValueConvertible? + mutating func step(_ dbValues: [DatabaseValue]) { + result = String.fromDatabaseValue(dbValues[0]) + } + func finalize() -> DatabaseValueConvertible? { + return result + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 1, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertTrue(try String.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertEqual(try String.fetchOne(db, "SELECT f('foo')")!, "foo") + } + } + + func testAggregateArgumentBlob() throws { + struct Aggregate : DatabaseAggregate { + var result: DatabaseValueConvertible? + mutating func step(_ dbValues: [DatabaseValue]) { + result = Data.fromDatabaseValue(dbValues[0]) + } + func finalize() -> DatabaseValueConvertible? { + return result + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 1, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertTrue(try Data.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertEqual(try Data.fetchOne(db, "SELECT f(?)", arguments: ["foo".data(using: .utf8)])!, "foo".data(using: .utf8)) + } + } + + func testAggregateArgumentCustomValueType() throws { + struct Aggregate : DatabaseAggregate { + var result: DatabaseValueConvertible? + mutating func step(_ dbValues: [DatabaseValue]) { + result = CustomValueType.fromDatabaseValue(dbValues[0]) + } + func finalize() -> DatabaseValueConvertible? { + return result + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 1, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertTrue(try CustomValueType.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertTrue(try CustomValueType.fetchOne(db, "SELECT f('CustomValueType')") != nil) + } + } + + // MARK: - Argument count + + func testAggregateWithoutArgument() throws { + struct Aggregate : DatabaseAggregate { + func step(_ dbValues: [DatabaseValue]) { } + func finalize() -> DatabaseValueConvertible? { + return "foo" + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(try String.fetchOne(db, "SELECT f()")!, "foo") + do { + try db.execute("SELECT f(1)") + XCTFail("Expected error") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertEqual(error.message!, "wrong number of arguments to function f()") + XCTAssertEqual(error.sql!, "SELECT f(1)") + XCTAssertEqual(error.description, "SQLite error 1 with statement `SELECT f(1)`: wrong number of arguments to function f()") + } + } + } + + func testAggregateOfOneArgument() throws { + struct Aggregate : DatabaseAggregate { + var result: DatabaseValueConvertible? + mutating func step(_ dbValues: [DatabaseValue]) { + result = String.fromDatabaseValue(dbValues[0]).map { $0.uppercased() } + } + func finalize() -> DatabaseValueConvertible? { + return result + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 1, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(try String.fetchOne(db, "SELECT upper(?)", arguments: ["Roué"])!, "ROUé") + XCTAssertEqual(try String.fetchOne(db, "SELECT f(?)", arguments: ["Roué"])!, "ROUÉ") + XCTAssertTrue(try String.fetchOne(db, "SELECT f(NULL)") == nil) + do { + try db.execute("SELECT f()") + XCTFail("Expected error") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertEqual(error.message!, "wrong number of arguments to function f()") + XCTAssertEqual(error.sql!, "SELECT f()") + XCTAssertEqual(error.description, "SQLite error 1 with statement `SELECT f()`: wrong number of arguments to function f()") + } + } + } + + func testAggregateOfTwoArguments() throws { + struct Aggregate : DatabaseAggregate { + var result: DatabaseValueConvertible? + mutating func step(_ dbValues: [DatabaseValue]) { + let ints = dbValues.flatMap { Int.fromDatabaseValue($0) } + result = ints.reduce(0, +) + } + func finalize() -> DatabaseValueConvertible? { + return result + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 2, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(try Int.fetchOne(db, "SELECT f(1, 2)")!, 3) + do { + try db.execute("SELECT f()") + XCTFail("Expected error") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertEqual(error.message!, "wrong number of arguments to function f()") + XCTAssertEqual(error.sql!, "SELECT f()") + XCTAssertEqual(error.description, "SQLite error 1 with statement `SELECT f()`: wrong number of arguments to function f()") + } + } + } + + func testVariadicFunction() throws { + struct Aggregate : DatabaseAggregate { + var result: DatabaseValueConvertible? + mutating func step(_ dbValues: [DatabaseValue]) { + result = dbValues.count + } + func finalize() -> DatabaseValueConvertible? { + return result + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(try Int.fetchOne(db, "SELECT f()")!, 0) + XCTAssertEqual(try Int.fetchOne(db, "SELECT f(1)")!, 1) + XCTAssertEqual(try Int.fetchOne(db, "SELECT f(1, 1)")!, 2) + } + } + + // MARK: - Step Errors + + func testAggregateStepThrowingDatabaseErrorWithMessage() throws { + struct Aggregate : DatabaseAggregate { + func step(_ dbValues: [DatabaseValue]) throws { + throw DatabaseError(message: "custom error message") + } + func finalize() -> DatabaseValueConvertible? { + fatalError() + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + do { + try db.execute("SELECT f()") + XCTFail("Expected DatabaseError") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertEqual(error.message, "custom error message") + } + } + } + + func testAggregateStepThrowingDatabaseErrorWithCode() throws { + struct Aggregate : DatabaseAggregate { + func step(_ dbValues: [DatabaseValue]) throws { + throw DatabaseError(resultCode: ResultCode(rawValue: 123)) + } + func finalize() -> DatabaseValueConvertible? { + fatalError() + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + do { + try db.execute("SELECT f()") + XCTFail("Expected DatabaseError") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode.rawValue, 123) + XCTAssertEqual(error.message, "unknown error") + } + } + } + + func testAggregateStepThrowingDatabaseErrorWithMessageAndCode() throws { + struct Aggregate : DatabaseAggregate { + func step(_ dbValues: [DatabaseValue]) throws { + throw DatabaseError(resultCode: ResultCode(rawValue: 123), message: "custom error message") + } + func finalize() -> DatabaseValueConvertible? { + fatalError() + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + do { + try db.execute("SELECT f()") + XCTFail("Expected DatabaseError") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode.rawValue, 123) + XCTAssertEqual(error.message, "custom error message") + } + } + } + + func testAggregateStepThrowingCustomError() throws { + struct Aggregate : DatabaseAggregate { + func step(_ dbValues: [DatabaseValue]) throws { + throw NSError(domain: "CustomErrorDomain", code: 123, userInfo: [NSString(string: NSLocalizedDescriptionKey): "custom error message"]) + } + func finalize() -> DatabaseValueConvertible? { + fatalError() + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + do { + try db.execute("SELECT f()") + XCTFail("Expected DatabaseError") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertTrue(error.message!.contains("CustomErrorDomain")) + XCTAssertTrue(error.message!.contains("123")) + XCTAssertTrue(error.message!.contains("custom error message")) + } + } + } + + // MARK: - Result Errors + + func testAggregateResultThrowingDatabaseErrorWithMessage() throws { + struct Aggregate : DatabaseAggregate { + func step(_ dbValues: [DatabaseValue]) { } + func finalize() throws -> DatabaseValueConvertible? { + throw DatabaseError(message: "custom error message") + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + do { + try db.execute("SELECT f()") + XCTFail("Expected DatabaseError") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertEqual(error.message, "custom error message") + } + } + } + + func testAggregateResultThrowingDatabaseErrorWithCode() throws { + struct Aggregate : DatabaseAggregate { + func step(_ dbValues: [DatabaseValue]) { } + func finalize() throws -> DatabaseValueConvertible? { + throw DatabaseError(resultCode: ResultCode(rawValue: 123)) + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + do { + try db.execute("SELECT f()") + XCTFail("Expected DatabaseError") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode.rawValue, 123) + XCTAssertEqual(error.message, "unknown error") + } + } + } + + func testAggregateResultThrowingDatabaseErrorWithMessageAndCode() throws { + struct Aggregate : DatabaseAggregate { + func step(_ dbValues: [DatabaseValue]) { } + func finalize() throws -> DatabaseValueConvertible? { + throw DatabaseError(resultCode: ResultCode(rawValue: 123), message: "custom error message") + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + do { + try db.execute("SELECT f()") + XCTFail("Expected DatabaseError") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode.rawValue, 123) + XCTAssertEqual(error.message, "custom error message") + } + } + } + + func testAggregateResultThrowingCustomError() throws { + struct Aggregate : DatabaseAggregate { + func step(_ dbValues: [DatabaseValue]) { } + func finalize() throws -> DatabaseValueConvertible? { + throw NSError(domain: "CustomErrorDomain", code: 123, userInfo: [NSString(string: NSLocalizedDescriptionKey): "custom error message"]) + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + do { + try db.execute("SELECT f()") + XCTFail("Expected DatabaseError") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertTrue(error.message!.contains("CustomErrorDomain")) + XCTAssertTrue(error.message!.contains("123")) + XCTAssertTrue(error.message!.contains("custom error message")) + } + } + } + + // MARK: - Aggregation + + func testAggregation() throws { + struct Aggregate : DatabaseAggregate { + var sum: Int? + mutating func step(_ dbValues: [DatabaseValue]) { + if let int = Int.fromDatabaseValue(dbValues[0]) { + sum = (sum ?? 0) + int + } + } + func finalize() throws -> DatabaseValueConvertible? { + return sum + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 1, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(try Int.fetchOne(db, "SELECT f(a) FROM (SELECT 1 AS a UNION ALL SELECT 2 UNION ALL SELECT 3)")!, 6) + } + } + + func testParallelAggregation() throws { + struct Aggregate : DatabaseAggregate { + var sum: Int? + mutating func step(_ dbValues: [DatabaseValue]) { + if let int = Int.fromDatabaseValue(dbValues[0]) { + sum = (sum ?? 0) + int + } + } + func finalize() throws -> DatabaseValueConvertible? { + return sum + } + } + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 1, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + let row = try Row.fetchOne(db, "SELECT f(a), f(b) FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 2, 4 UNION ALL SELECT 3, 6)")! + XCTAssertEqual(row.value(atIndex: 0), 6) + XCTAssertEqual(row.value(atIndex: 1), 12) + } + } + + // MARK: - Deallocation + + func testDeallocationAfterSuccess() throws { + final class Aggregate : DatabaseAggregate { + static var onInit: (() -> ())? + static var onDeinit: (() -> ())? + init() { Aggregate.onInit?() } + deinit { Aggregate.onDeinit?() } + func step(_ dbValues: [DatabaseValue]) { } + func finalize() -> DatabaseValueConvertible? { + return nil + } + } + var allocationCount = 0 + var aliveCount = 0 + Aggregate.onInit = { + allocationCount += 1 + aliveCount += 1 + } + Aggregate.onDeinit = { + aliveCount -= 1 + } + + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + try dbQueue.inDatabase { db in + XCTAssertEqual(allocationCount, 0) + XCTAssertEqual(aliveCount, 0) + try db.execute("SELECT f()") + XCTAssertEqual(allocationCount, 1) + XCTAssertEqual(aliveCount, 0) + } + } + + func testDeallocationAfterStepError() throws { + final class Aggregate : DatabaseAggregate { + static var onInit: (() -> ())? + static var onDeinit: (() -> ())? + init() { Aggregate.onInit?() } + deinit { Aggregate.onDeinit?() } + func step(_ dbValues: [DatabaseValue]) throws { + throw DatabaseError(message: "boo") + } + func finalize() -> DatabaseValueConvertible? { + fatalError() + } + } + var allocationCount = 0 + var aliveCount = 0 + Aggregate.onInit = { + allocationCount += 1 + aliveCount += 1 + } + Aggregate.onDeinit = { + aliveCount -= 1 + } + + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + dbQueue.inDatabase { db in + XCTAssertEqual(allocationCount, 0) + XCTAssertEqual(aliveCount, 0) + _ = try? db.execute("SELECT f()") + XCTAssertEqual(allocationCount, 1) + XCTAssertEqual(aliveCount, 0) + } + } + + func testDeallocationAfterResultError() throws { + final class Aggregate : DatabaseAggregate { + static var onInit: (() -> ())? + static var onDeinit: (() -> ())? + init() { Aggregate.onInit?() } + deinit { Aggregate.onDeinit?() } + func step(_ dbValues: [DatabaseValue]) { } + func finalize() throws -> DatabaseValueConvertible? { + throw DatabaseError(message: "boo") + } + } + + var allocationCount = 0 + var aliveCount = 0 + Aggregate.onInit = { + allocationCount += 1 + aliveCount += 1 + } + Aggregate.onDeinit = { + aliveCount -= 1 + } + + let dbQueue = try makeDatabaseQueue() + let fn = DatabaseFunction("f", argumentCount: 0, aggregate: Aggregate.self) + dbQueue.add(function: fn) + dbQueue.inDatabase { db in + XCTAssertEqual(allocationCount, 0) + XCTAssertEqual(aliveCount, 0) + _ = try? db.execute("SELECT f()") + XCTAssertEqual(allocationCount, 1) + XCTAssertEqual(aliveCount, 0) + } + } +} diff --git a/Tests/GRDBTests/DatabaseFunctionTests.swift b/Tests/GRDBTests/DatabaseFunctionTests.swift index c39b4defed..cb263ee5e9 100644 --- a/Tests/GRDBTests/DatabaseFunctionTests.swift +++ b/Tests/GRDBTests/DatabaseFunctionTests.swift @@ -7,7 +7,7 @@ import XCTest import GRDB #endif -struct CustomValueType : DatabaseValueConvertible { +private struct CustomValueType : DatabaseValueConvertible { var databaseValue: DatabaseValue { return "CustomValueType".databaseValue } @@ -126,78 +126,78 @@ class DatabaseFunctionTests: GRDBTestCase { func testFunctionArgumentNil() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("isNil", argumentCount: 1) { (dbValues: [DatabaseValue]) in + let fn = DatabaseFunction("f", argumentCount: 1) { (dbValues: [DatabaseValue]) in return dbValues[0].isNull } dbQueue.add(function: fn) try dbQueue.inDatabase { db in - XCTAssertTrue(try Bool.fetchOne(db, "SELECT isNil(NULL)")!) - XCTAssertFalse(try Bool.fetchOne(db, "SELECT isNil(1)")!) - XCTAssertFalse(try Bool.fetchOne(db, "SELECT isNil(1.1)")!) - XCTAssertFalse(try Bool.fetchOne(db, "SELECT isNil('foo')")!) - XCTAssertFalse(try Bool.fetchOne(db, "SELECT isNil(?)", arguments: ["foo".data(using: .utf8)])!) + XCTAssertTrue(try Bool.fetchOne(db, "SELECT f(NULL)")!) + XCTAssertFalse(try Bool.fetchOne(db, "SELECT f(1)")!) + XCTAssertFalse(try Bool.fetchOne(db, "SELECT f(1.1)")!) + XCTAssertFalse(try Bool.fetchOne(db, "SELECT f('foo')")!) + XCTAssertFalse(try Bool.fetchOne(db, "SELECT f(?)", arguments: ["foo".data(using: .utf8)])!) } } func testFunctionArgumentInt64() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("asInt64", argumentCount: 1) { (dbValues: [DatabaseValue]) in + let fn = DatabaseFunction("f", argumentCount: 1) { (dbValues: [DatabaseValue]) in return Int64.fromDatabaseValue(dbValues[0]) } dbQueue.add(function: fn) try dbQueue.inDatabase { db in - XCTAssertTrue(try Int64.fetchOne(db, "SELECT asInt64(NULL)") == nil) - XCTAssertEqual(try Int64.fetchOne(db, "SELECT asInt64(1)")!, 1) - XCTAssertEqual(try Int64.fetchOne(db, "SELECT asInt64(1.1)")!, 1) + XCTAssertTrue(try Int64.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertEqual(try Int64.fetchOne(db, "SELECT f(1)")!, 1) + XCTAssertEqual(try Int64.fetchOne(db, "SELECT f(1.1)")!, 1) } } func testFunctionArgumentDouble() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("asDouble", argumentCount: 1) { (dbValues: [DatabaseValue]) in + let fn = DatabaseFunction("f", argumentCount: 1) { (dbValues: [DatabaseValue]) in return Double.fromDatabaseValue(dbValues[0]) } dbQueue.add(function: fn) try dbQueue.inDatabase { db in - XCTAssertTrue(try Double.fetchOne(db, "SELECT asDouble(NULL)") == nil) - XCTAssertEqual(try Double.fetchOne(db, "SELECT asDouble(1)")!, 1.0) - XCTAssertEqual(try Double.fetchOne(db, "SELECT asDouble(1.1)")!, 1.1) + XCTAssertTrue(try Double.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertEqual(try Double.fetchOne(db, "SELECT f(1)")!, 1.0) + XCTAssertEqual(try Double.fetchOne(db, "SELECT f(1.1)")!, 1.1) } } func testFunctionArgumentString() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("asString", argumentCount: 1) { (dbValues: [DatabaseValue]) in + let fn = DatabaseFunction("f", argumentCount: 1) { (dbValues: [DatabaseValue]) in return String.fromDatabaseValue(dbValues[0]) } dbQueue.add(function: fn) try dbQueue.inDatabase { db in - XCTAssertTrue(try String.fetchOne(db, "SELECT asString(NULL)") == nil) - XCTAssertEqual(try String.fetchOne(db, "SELECT asString('foo')")!, "foo") + XCTAssertTrue(try String.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertEqual(try String.fetchOne(db, "SELECT f('foo')")!, "foo") } } func testFunctionArgumentBlob() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("asData", argumentCount: 1) { (dbValues: [DatabaseValue]) in + let fn = DatabaseFunction("f", argumentCount: 1) { (dbValues: [DatabaseValue]) in return Data.fromDatabaseValue(dbValues[0]) } dbQueue.add(function: fn) try dbQueue.inDatabase { db in - XCTAssertTrue(try Data.fetchOne(db, "SELECT asData(NULL)") == nil) - XCTAssertEqual(try Data.fetchOne(db, "SELECT asData(?)", arguments: ["foo".data(using: .utf8)])!, "foo".data(using: .utf8)) + XCTAssertTrue(try Data.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertEqual(try Data.fetchOne(db, "SELECT f(?)", arguments: ["foo".data(using: .utf8)])!, "foo".data(using: .utf8)) } } func testFunctionArgumentCustomValueType() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("asCustomValueType", argumentCount: 1) { (dbValues: [DatabaseValue]) in + let fn = DatabaseFunction("f", argumentCount: 1) { (dbValues: [DatabaseValue]) in return CustomValueType.fromDatabaseValue(dbValues[0]) } dbQueue.add(function: fn) try dbQueue.inDatabase { db in - XCTAssertTrue(try CustomValueType.fetchOne(db, "SELECT asCustomValueType(NULL)") == nil) - XCTAssertTrue(try CustomValueType.fetchOne(db, "SELECT asCustomValueType('CustomValueType')") != nil) + XCTAssertTrue(try CustomValueType.fetchOne(db, "SELECT f(NULL)") == nil) + XCTAssertTrue(try CustomValueType.fetchOne(db, "SELECT f('CustomValueType')") != nil) } } @@ -211,23 +211,37 @@ class DatabaseFunctionTests: GRDBTestCase { dbQueue.add(function: fn) try dbQueue.inDatabase { db in XCTAssertEqual(try String.fetchOne(db, "SELECT f()")!, "foo") + do { + try db.execute("SELECT f(1)") + XCTFail("Expected error") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertEqual(error.message!, "wrong number of arguments to function f()") + XCTAssertEqual(error.sql!, "SELECT f(1)") + XCTAssertEqual(error.description, "SQLite error 1 with statement `SELECT f(1)`: wrong number of arguments to function f()") + } } } func testFunctionOfOneArgument() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("unicodeUpper", argumentCount: 1) { (dbValues: [DatabaseValue]) in - let dbValue = dbValues[0] - guard let string = String.fromDatabaseValue(dbValue) else { - return nil - } - return string.uppercased() + let fn = DatabaseFunction("f", argumentCount: 1) { (dbValues: [DatabaseValue]) in + String.fromDatabaseValue(dbValues[0]).map { $0.uppercased() } } dbQueue.add(function: fn) try dbQueue.inDatabase { db in XCTAssertEqual(try String.fetchOne(db, "SELECT upper(?)", arguments: ["Roué"])!, "ROUé") - XCTAssertEqual(try String.fetchOne(db, "SELECT unicodeUpper(?)", arguments: ["Roué"])!, "ROUÉ") - XCTAssertTrue(try String.fetchOne(db, "SELECT unicodeUpper(NULL)") == nil) + XCTAssertEqual(try String.fetchOne(db, "SELECT f(?)", arguments: ["Roué"])!, "ROUÉ") + XCTAssertTrue(try String.fetchOne(db, "SELECT f(NULL)") == nil) + do { + try db.execute("SELECT f()") + XCTFail("Expected error") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertEqual(error.message!, "wrong number of arguments to function f()") + XCTAssertEqual(error.sql!, "SELECT f()") + XCTAssertEqual(error.description, "SQLite error 1 with statement `SELECT f()`: wrong number of arguments to function f()") + } } } @@ -240,6 +254,15 @@ class DatabaseFunctionTests: GRDBTestCase { dbQueue.add(function: fn) try dbQueue.inDatabase { db in XCTAssertEqual(try Int.fetchOne(db, "SELECT f(1, 2)")!, 3) + do { + try db.execute("SELECT f()") + XCTFail("Expected error") + } catch let error as DatabaseError { + XCTAssertEqual(error.resultCode, .SQLITE_ERROR) + XCTAssertEqual(error.message!, "wrong number of arguments to function f()") + XCTAssertEqual(error.sql!, "SELECT f()") + XCTAssertEqual(error.description, "SQLite error 1 with statement `SELECT f()`: wrong number of arguments to function f()") + } } } @@ -260,14 +283,13 @@ class DatabaseFunctionTests: GRDBTestCase { func testFunctionThrowingDatabaseErrorWithMessage() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("f", argumentCount: 1) { dbValues in + let fn = DatabaseFunction("f") { dbValues in throw DatabaseError(message: "custom error message") } dbQueue.add(function: fn) try dbQueue.inDatabase { db in do { - try db.execute("CREATE TABLE items (value INT)") - try db.execute("INSERT INTO items VALUES (f(1))") + try db.execute("SELECT f()") XCTFail("Expected DatabaseError") } catch let error as DatabaseError { XCTAssertEqual(error.resultCode, .SQLITE_ERROR) @@ -278,14 +300,13 @@ class DatabaseFunctionTests: GRDBTestCase { func testFunctionThrowingDatabaseErrorWithCode() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("f", argumentCount: 1) { dbValues in + let fn = DatabaseFunction("f") { dbValues in throw DatabaseError(resultCode: ResultCode(rawValue: 123)) } dbQueue.add(function: fn) try dbQueue.inDatabase { db in do { - try db.execute("CREATE TABLE items (value INT)") - try db.execute("INSERT INTO items VALUES (f(1))") + try db.execute("SELECT f()") XCTFail("Expected DatabaseError") } catch let error as DatabaseError { XCTAssertEqual(error.resultCode.rawValue, 123) @@ -296,14 +317,13 @@ class DatabaseFunctionTests: GRDBTestCase { func testFunctionThrowingDatabaseErrorWithMessageAndCode() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("f", argumentCount: 1) { dbValues in + let fn = DatabaseFunction("f") { dbValues in throw DatabaseError(resultCode: ResultCode(rawValue: 123), message: "custom error message") } dbQueue.add(function: fn) try dbQueue.inDatabase { db in do { - try db.execute("CREATE TABLE items (value INT)") - try db.execute("INSERT INTO items VALUES (f(1))") + try db.execute("SELECT f()") XCTFail("Expected DatabaseError") } catch let error as DatabaseError { XCTAssertEqual(error.resultCode.rawValue, 123) @@ -314,14 +334,13 @@ class DatabaseFunctionTests: GRDBTestCase { func testFunctionThrowingCustomError() throws { let dbQueue = try makeDatabaseQueue() - let fn = DatabaseFunction("f", argumentCount: 1) { dbValues in + let fn = DatabaseFunction("f") { dbValues in throw NSError(domain: "CustomErrorDomain", code: 123, userInfo: [NSString(string: NSLocalizedDescriptionKey): "custom error message"]) } dbQueue.add(function: fn) try dbQueue.inDatabase { db in do { - try db.execute("CREATE TABLE items (value INT)") - try db.execute("INSERT INTO items VALUES (f(1))") + try db.execute("SELECT f()") XCTFail("Expected DatabaseError") } catch let error as DatabaseError { XCTAssertEqual(error.resultCode, .SQLITE_ERROR)