From 4b0a06495852503c1448b0a7d7fc6fea259a4adc Mon Sep 17 00:00:00 2001 From: Alexander Ignition Date: Sat, 11 Jan 2025 20:25:34 +0300 Subject: [PATCH] RowDecoderTests --- Playgrounds/README.playground/Contents.swift | 15 +- README.md | 15 +- Sources/SQLyra/PreparedStatement.swift | 74 ++-- Sources/SQLyra/RowDecoder.swift | 81 ++--- Sources/SQLyra/SQLParameter.swift | 2 +- .../SQLyraTests/PreparedStatementTests.swift | 8 +- Tests/SQLyraTests/RowDecoderTests.swift | 329 ++++++++++++++++++ Tests/SQLyraTests/SQLParameter+Testing.swift | 32 ++ 8 files changed, 456 insertions(+), 100 deletions(-) create mode 100644 Tests/SQLyraTests/RowDecoderTests.swift create mode 100644 Tests/SQLyraTests/SQLParameter+Testing.swift diff --git a/Playgrounds/README.playground/Contents.swift b/Playgrounds/README.playground/Contents.swift index 0ef56fa..9a13b6f 100644 --- a/Playgrounds/README.playground/Contents.swift +++ b/Playgrounds/README.playground/Contents.swift @@ -9,7 +9,7 @@ [Documentation](https://alexander-ignition.github.io/SQLyra/documentation/sqlyra/) - - Note: this readme file is available as Xcode playground in Playgrounds/README.playground + > this readme file is available as Xcode playground in Playgrounds/README.playground ## Open @@ -26,21 +26,21 @@ let database = try Database.open( Create table for contacts with fields `id` and `name`. */ -try database.execute( - """ +let sql = """ CREATE TABLE contacts( id INT PRIMARY KEY NOT NULL, name TEXT ); """ -) +try database.execute(sql) /*: ## Insert Insert new contacts Paul and John. */ -try database.execute("INSERT INTO contacts (id, name) VALUES (1, 'Paul');") -try database.execute("INSERT INTO contacts (id, name) VALUES (2, 'John');") +let insert = try database.prepare("INSERT INTO contacts (id, name) VALUES (?, ?);") +try insert.bind(parameters: 1, "Paul").execute().reset() +try insert.bind(parameters: 2, "John").execute() /*: ## Select @@ -51,4 +51,5 @@ struct Contact: Codable { let name: String } -let contacts = try database.prepare("SELECT * FROM contacts;").array(Contact.self) +let select = try database.prepare("SELECT * FROM contacts;") +let contacts = try select.array(Contact.self) diff --git a/README.md b/README.md index c0bf5ec..391411a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Swift SQLite wrapper. [Documentation](https://alexander-ignition.github.io/SQLyra/documentation/sqlyra/) -- Note: this readme file is available as Xcode playground in Playgrounds/README.playground +> this readme file is available as Xcode playground in Playgrounds/README.playground ## Open @@ -25,21 +25,21 @@ let database = try Database.open( Create table for contacts with fields `id` and `name`. ```swift -try database.execute( - """ +let sql = """ CREATE TABLE contacts( id INT PRIMARY KEY NOT NULL, name TEXT ); """ -) +try database.execute(sql) ``` ## Insert Insert new contacts Paul and John. ```swift -try database.execute("INSERT INTO contacts (id, name) VALUES (1, 'Paul');") -try database.execute("INSERT INTO contacts (id, name) VALUES (2, 'John');") +let insert = try database.prepare("INSERT INTO contacts (id, name) VALUES (?, ?);") +try insert.bind(parameters: 1, "Paul").execute().reset() +try insert.bind(parameters: 2, "John").execute() ``` ## Select @@ -50,5 +50,6 @@ struct Contact: Codable { let name: String } -let contacts = try database.prepare("SELECT * FROM contacts;").array(Contact.self) +let select = try database.prepare("SELECT * FROM contacts;") +let contacts = try select.array(Contact.self) ``` diff --git a/Sources/SQLyra/PreparedStatement.swift b/Sources/SQLyra/PreparedStatement.swift index 9914076..7dba370 100644 --- a/Sources/SQLyra/PreparedStatement.swift +++ b/Sources/SQLyra/PreparedStatement.swift @@ -11,7 +11,7 @@ public final class PreparedStatement: DatabaseHandle { /// Find the database handle of a prepared statement. var db: OpaquePointer! { sqlite3_db_handle(stmt) } - private(set) lazy var columnIndexByName = [String: Int32]( + private(set) lazy var columnIndexByName = [String: Int]( uniqueKeysWithValues: (0.. String? { - sqlite3_bind_parameter_name(stmt, index).map { String(cString: $0) } + public func parameterName(at index: Int) -> String? { + sqlite3_bind_parameter_name(stmt, Int32(index)).map { String(cString: $0) } } /// Index of a parameter with a given name. - public func parameterIndex(for name: String) -> Int32 { - sqlite3_bind_parameter_index(stmt, name) + public func parameterIndex(for name: String) -> Int { + Int(sqlite3_bind_parameter_index(stmt, name)) } } @@ -98,13 +98,14 @@ extension PreparedStatement { @discardableResult public func bind(parameters: SQLParameter...) throws -> PreparedStatement { for (index, parameter) in parameters.enumerated() { - try bind(index: Int32(index + 1), parameter: parameter) + try bind(index: index + 1, parameter: parameter) } return self } @discardableResult - public func bind(index: Int32, parameter: SQLParameter) throws -> PreparedStatement { + public func bind(index: Int, parameter: SQLParameter) throws -> PreparedStatement { + let index = Int32(index) let code = switch parameter { case .null: @@ -139,10 +140,14 @@ extension PreparedStatement { extension PreparedStatement { /// Return the number of columns in the result set. - public var columnCount: Int32 { sqlite3_column_count(stmt) } + public var columnCount: Int { Int(sqlite3_column_count(stmt)) } - public func columnName(at index: Int32) -> String? { - sqlite3_column_name(stmt, index).string + /// Returns the name assigned to a specific column in the result set of the SELECT statement. + /// + /// The name of a result column is the value of the "AS" clause for that column, if there is an AS clause. + /// If there is no AS clause then the name of the column is unspecified and may change from one release of SQLite to the next. + public func columnName(at index: Int) -> String? { + sqlite3_column_name(stmt, Int32(index)).string } } @@ -154,12 +159,9 @@ extension PreparedStatement { /// - Throws: ``DatabaseError`` public func row() throws -> Row? { switch sqlite3_step(stmt) { - case SQLITE_DONE: - return nil - case SQLITE_ROW: - return Row(statement: self) - case let code: - throw DatabaseError(code: code, db: db) + case SQLITE_DONE: nil + case SQLITE_ROW: Row(statement: self) + case let code: throw DatabaseError(code: code, db: db) } } @@ -180,16 +182,19 @@ extension PreparedStatement { public struct Row { let statement: PreparedStatement - public subscript(dynamicMember name: String) -> Column! { + public subscript(dynamicMember name: String) -> Value? { self[name] } - public subscript(name: String) -> Column? { - statement.columnIndexByName[name].map { self[$0] } + public subscript(name: String) -> Value? { + statement.columnIndexByName[name].flatMap { self[$0] } } - public subscript(index: Int32) -> Column { - Column(index: index, statement: statement) + public subscript(index: Int) -> Value? { + if sqlite3_column_type(statement.stmt, Int32(index)) == SQLITE_NULL { + return nil + } + return Value(index: Int32(index), statement: statement) } public func decode(_ type: T.Type) throws -> T where T: Decodable { @@ -201,26 +206,27 @@ extension PreparedStatement { } } - /// Information about a single column of the current result row of a query. - public struct Column { + /// Result value from a query. + public struct Value { let index: Int32 let statement: PreparedStatement private var stmt: OpaquePointer { statement.stmt } - /// Returns the name assigned to a specific column in the result set of the SELECT statement. - /// - /// The name of a result column is the value of the "AS" clause for that column, if there is an AS clause. - /// If there is no AS clause then the name of the column is unspecified and may change from one release of SQLite to the next. - public var name: String? { sqlite3_column_name(stmt, index).string } - - public var isNull: Bool { sqlite3_column_type(stmt, index) == SQLITE_NULL } - /// 64-bit INTEGER result. public var int64: Int64 { sqlite3_column_int64(stmt, index) } + /// 32-bit INTEGER result. + public var int32: Int32 { sqlite3_column_int(stmt, index) } + + /// A platform-specific integer. + public var int: Int { Int(int64) } + /// 64-bit IEEE floating point number. public var double: Double { sqlite3_column_double(stmt, index) } + /// Size of a BLOB or a UTF-8 TEXT result in bytes. + public var count: Int { Int(sqlite3_column_bytes(stmt, index)) } + /// UTF-8 TEXT result. public var string: String? { sqlite3_column_text(stmt, index).flatMap { String(cString: $0) } @@ -228,9 +234,7 @@ extension PreparedStatement { /// BLOB result. public var blob: Data? { - sqlite3_column_blob(stmt, index).map { bytes in - Data(bytes: bytes, count: Int(sqlite3_column_bytes(stmt, index))) - } + sqlite3_column_blob(stmt, index).map { Data(bytes: $0, count: count) } } } } diff --git a/Sources/SQLyra/RowDecoder.swift b/Sources/SQLyra/RowDecoder.swift index f5cdfaa..4376dd0 100644 --- a/Sources/SQLyra/RowDecoder.swift +++ b/Sources/SQLyra/RowDecoder.swift @@ -30,13 +30,11 @@ private struct _RowDecoder: Decoder { } func unkeyedContainer() throws -> any UnkeyedDecodingContainer { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") - throw DecodingError.typeMismatch(PreparedStatement.self, context) + throw DecodingError.typeMismatch(PreparedStatement.self, .context(codingPath, "")) } func singleValueContainer() throws -> any SingleValueDecodingContainer { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") - throw DecodingError.typeMismatch(PreparedStatement.self, context) + throw DecodingError.typeMismatch(PreparedStatement.self, .context(codingPath, "")) } // MARK: - KeyedDecodingContainer @@ -63,7 +61,7 @@ private struct _RowDecoder: Decoder { func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { try decoder.integer(type, forKey: key) } func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { try decoder.integer(type, forKey: key) } func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { - return try decoder.decode(type, forKey: key) + try decoder.decode(type, forKey: key) } func superDecoder() throws -> any Decoder { fatalError() } @@ -81,7 +79,7 @@ private struct _RowDecoder: Decoder { @inline(__always) func null(for key: K) -> Bool where K: CodingKey { - row[key.stringValue]?.isNull ?? true + row[key.stringValue] == nil } @inline(__always) @@ -90,68 +88,55 @@ private struct _RowDecoder: Decoder { } @inline(__always) - func integer(_ type: T.Type, forKey key: K) throws -> T where T: Numeric, K: CodingKey { - let value = try column(forKey: key) - guard let number = type.init(exactly: value.int64) else { - let message = "Parsed SQL integer <\(value)> does not fit in \(type)." - let context = DecodingError.Context(codingPath: [key], debugDescription: message) - throw DecodingError.dataCorrupted(context) - } - return number + func string(forKey key: K) throws -> String where K: CodingKey { + try columnValue(String.self, forKey: key).string ?? "" } @inline(__always) - func floating(_ type: T.Type, forKey key: K) throws -> T where T: BinaryFloatingPoint, K: CodingKey { - let value = try column(forKey: key) - guard let number = type.init(exactly: value.double) else { - let message = "Parsed SQL double <\(value)> does not fit in \(type)." - let context = DecodingError.Context(codingPath: [key], debugDescription: message) - throw DecodingError.dataCorrupted(context) + func integer(_ type: T.Type, forKey key: K) throws -> T where T: Numeric, K: CodingKey { + let value = try columnValue(type, forKey: key) + let int64 = value.int64 + guard let number = type.init(exactly: int64) else { + throw DecodingError.dataCorrupted(.context([key], "Parsed SQL int64 <\(int64)> does not fit in \(type).")) } return number } @inline(__always) - func string(forKey key: K) throws -> String where K: CodingKey { - let value = try column(forKey: key) - guard let value = value.string else { - throw DecodingError.valueNotFound(String.self, .codingPath([key])) + func floating(_ type: T.Type, forKey key: K) throws -> T where T: BinaryFloatingPoint, K: CodingKey { + let value = try columnValue(type, forKey: key) + let double = value.double + guard let number = type.init(exactly: double) else { + throw DecodingError.dataCorrupted(.context([key], "Parsed SQL double <\(double)> does not fit in \(type).")) } - return value + return number } @inline(__always) func decode(_ type: T.Type, forKey key: K) throws -> T where T: Decodable, K: CodingKey { if type == Data.self { - let value = try column(forKey: key) - guard let data = value.blob else { - throw DecodingError.valueNotFound(Data.self, .codingPath([key])) - } + let value = try columnValue(type, forKey: key) + let data = value.blob ?? Data() // swift-format-ignore: NeverForceUnwrap return data as! T } - let decoder = _ColumnDecoder(key: key, decoder: self) + let decoder = _ValueDecoder(key: key, decoder: self) return try type.init(from: decoder) } @inline(__always) - private func column(forKey key: K) throws -> PreparedStatement.Column where K: CodingKey { + private func columnValue(_ type: T.Type, forKey key: K) throws -> PreparedStatement.Value where K: CodingKey { guard let index = row.statement.columnIndexByName[key.stringValue] else { - let message = "Column index not found for key: \(key)" - let context = DecodingError.Context(codingPath: [key], debugDescription: message) - throw DecodingError.keyNotFound(key, context) + throw DecodingError.keyNotFound(key, .context([key], "Column index not found for key: \(key)")) } - return row[index] - } -} - -private extension DecodingError.Context { - static func codingPath(_ path: [any CodingKey]) -> DecodingError.Context { - DecodingError.Context(codingPath: path, debugDescription: "") + guard let column = row[index] else { + throw DecodingError.valueNotFound(type, .context([key], "Column value not found for key: \(key)")) + } + return column } } -private struct _ColumnDecoder: Decoder, SingleValueDecodingContainer { +private struct _ValueDecoder: Decoder, SingleValueDecodingContainer { let key: any CodingKey let decoder: _RowDecoder @@ -161,13 +146,11 @@ private struct _ColumnDecoder: Decoder, SingleValueDecodingContainer { var codingPath: [any CodingKey] { [key] } func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") - throw DecodingError.typeMismatch(PreparedStatement.Column.self, context) + throw DecodingError.typeMismatch(PreparedStatement.Value.self, .context(codingPath, "")) } func unkeyedContainer() throws -> any UnkeyedDecodingContainer { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") - throw DecodingError.typeMismatch(PreparedStatement.Column.self, context) + throw DecodingError.typeMismatch(PreparedStatement.Value.self, .context(codingPath, "")) } func singleValueContainer() throws -> any SingleValueDecodingContainer { @@ -193,3 +176,9 @@ private struct _ColumnDecoder: Decoder, SingleValueDecodingContainer { func decode(_ type: UInt64.Type) throws -> UInt64 { try decoder.integer(type, forKey: key) } func decode(_ type: T.Type) throws -> T where T: Decodable { try decoder.decode(type, forKey: key) } } + +private extension DecodingError.Context { + static func context(_ codingPath: [any CodingKey], _ message: String) -> DecodingError.Context { + DecodingError.Context(codingPath: codingPath, debugDescription: message) + } +} diff --git a/Sources/SQLyra/SQLParameter.swift b/Sources/SQLyra/SQLParameter.swift index 764a5f0..fc61051 100644 --- a/Sources/SQLyra/SQLParameter.swift +++ b/Sources/SQLyra/SQLParameter.swift @@ -1,7 +1,7 @@ import Foundation /// SQL parameters. -public enum SQLParameter: Equatable { +public enum SQLParameter: Equatable, Sendable { case null /// 64-bit signed integer. diff --git a/Tests/SQLyraTests/PreparedStatementTests.swift b/Tests/SQLyraTests/PreparedStatementTests.swift index 84782a6..4872345 100644 --- a/Tests/SQLyraTests/PreparedStatementTests.swift +++ b/Tests/SQLyraTests/PreparedStatementTests.swift @@ -80,10 +80,10 @@ struct PreparedStatementTests { var contracts: [Contact] = [] while let row = try select.row() { let contact = Contact( - id: Int(row.id.int64), - name: row.name.string ?? "", - rating: row.rating.double, - image: row.image.blob + id: row.id?.int ?? 0, + name: row.name?.string ?? "", + rating: row.rating?.double ?? 0, + image: row.image?.blob ) contracts.append(contact) } diff --git a/Tests/SQLyraTests/RowDecoderTests.swift b/Tests/SQLyraTests/RowDecoderTests.swift new file mode 100644 index 0000000..d73474b --- /dev/null +++ b/Tests/SQLyraTests/RowDecoderTests.swift @@ -0,0 +1,329 @@ +import Foundation +import SQLyra +import Testing + +struct RowDecoderTests { + struct SignedIntegers { + /// valid parameters for all signed integers + static let arguments: [(SQLParameter, Int)] = [ + (-1, -1), + (0, 0), + (1, 1), + (0.9, 0), + ("2", 2), + (.blob(Data("3".utf8)), 3), + ] + + struct IntTests: DecodableValueSuite { + let value: Int + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int(expected)) + } + } + + struct Int8Tests: DecodableValueSuite { + let value: Int8 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int8(expected)) + } + } + + struct Int16Tests: DecodableValueSuite { + let value: Int16 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int16(expected)) + } + } + + struct Int32Tests: DecodableValueSuite { + let value: Int32 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int32(expected)) + } + } + + struct Int64Tests: DecodableValueSuite { + let value: Int64 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int64(expected)) + } + } + } + + struct UnsignedIntegers { + /// valid parameters for all unsigned integers + static let arguments: [(SQLParameter, UInt)] = [ + (0, 0), + (1, 1), + (0.9, 0), + ("2", 2), + (.blob(Data("3".utf8)), 3), + ] + + struct UIntTests: DecodableValueSuite { + let value: UInt + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt(expected)) + } + } + + struct UInt8Tests: DecodableValueSuite { + let value: UInt8 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt8(expected)) + } + + @Test(arguments: [ + SQLParameter.int64(-1), + SQLParameter.int64(Int64.max), + ]) + static func dataCorrupted(_ parameter: SQLParameter) throws { + try _dataCorrupted(parameter, "Parsed SQL int64 <\(parameter)> does not fit in UInt8.") + } + } + + struct UInt16Tests: DecodableValueSuite { + let value: UInt16 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt16(expected)) + } + } + + struct UInt32Tests: DecodableValueSuite { + let value: UInt32 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt32(expected)) + } + } + + struct UInt64Tests: DecodableValueSuite { + let value: UInt64 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt64(expected)) + } + } + } + + struct FloatingPointNumerics { + static let arguments: [(SQLParameter, Double)] = [ + (-1, -1.0), + (0, 0.0), + (1, 1.0), + (0.5, 0.5), + ("0.5", 0.5), + ("1.0", 1.0), + (.blob(Data("1".utf8)), 1.0), + ] + + struct DoubleTests: DecodableValueSuite { + let value: Double + + @Test(arguments: FloatingPointNumerics.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Double) throws { + try _decode(parameter, Double(expected)) + } + } + + struct FloatTests: DecodableValueSuite { + let value: Float + + @Test(arguments: FloatingPointNumerics.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Double) throws { + try _decode(parameter, Float(expected)) + } + + @Test static func dataCorrupted() throws { + let parameter = SQLParameter.double(Double.greatestFiniteMagnitude) + try _dataCorrupted(parameter, "Parsed SQL double <\(parameter)> does not fit in Float.") + } + } + } + + struct BoolTests: DecodableValueSuite { + let value: Bool + + @Test(arguments: [ + (SQLParameter.int64(0), false), + (SQLParameter.int64(1), true), + (SQLParameter.double(0.9), false), + (SQLParameter.text("abc"), false), + (SQLParameter.text("true"), false), + (SQLParameter.blob(Data("zxc".utf8)), false), + (SQLParameter.blob(Data("1".utf8)), true), + ]) + static func decode(_ parameter: SQLParameter, _ expected: Bool) throws { + try _decode(parameter, expected) + } + } + + struct StringTests: DecodableValueSuite { + let value: String + + @Test(arguments: [ + (SQLParameter.int64(0), "0"), + (SQLParameter.int64(1), "1"), + (SQLParameter.double(0.9), "0.9"), + (SQLParameter.text("abc"), "abc"), + (SQLParameter.blob(Data("zxc".utf8)), "zxc"), + ]) + static func decode(_ parameter: SQLParameter, _ expected: String) throws { + try _decode(parameter, expected) + } + } + + struct DataTests: DecodableValueSuite { + let value: Data + + @Test(arguments: [ + (SQLParameter.int64(1), Data("1".utf8)), + (SQLParameter.double(1.1), Data("1.1".utf8)), + (SQLParameter.text("123"), Data("123".utf8)), + (SQLParameter.blob(Data("zxc".utf8)), Data("zxc".utf8)), + ]) + static func decode(_ parameter: SQLParameter, _ expected: Data) throws { + try _decode(parameter, expected) + } + } + + struct OptionalTests: DecodableValueSuite { + let value: Int? + + @Test static func decode() throws { + try _decode(1, 1) + } + } + + struct DecodingErrorTests: Decodable { + let item: Int // invalid + + @Test static func keyNotFound() throws { + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(.int64(1)).row()) + #expect { + try row.decode(DecodingErrorTests.self) + } throws: { error in + guard case .keyNotFound(let key, let context) = error as? DecodingError else { + return false + } + return key.stringValue == "item" + && context.codingPath.map(\.stringValue) == ["item"] + && context.debugDescription == "Column index not found for key: \(key)" + && context.underlyingError == nil + } + } + + @Test static func typeMismatch() throws { + let errorMatcher = { (error: any Error) -> Bool in + guard case .typeMismatch(_, let context) = error as? DecodingError else { + return false + } + return context.codingPath.isEmpty && context.debugDescription == "" && context.underlyingError == nil + } + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(.int64(1)).row()) + #expect(performing: { try row.decode(Int.self) }, throws: errorMatcher) + #expect(performing: { try row.decode([Int].self) }, throws: errorMatcher) + } + + @Test static func valueNotFound() throws { + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(.null).row()) + #expect { + try row.decode(Single.self) + } throws: { error in + guard case .valueNotFound(let type, let context) = error as? DecodingError else { + return false + } + return type == Int8.self + && context.codingPath.map(\.stringValue) == ["value"] + && context.debugDescription == "Column value not found for key: \(context.codingPath[0])" + && context.underlyingError == nil + } + } + } +} + +// MARK: - Test Suite + +protocol DecodableValueSuite: Decodable { + associatedtype Value: Decodable, Equatable + + var value: Value { get } +} + +extension DecodableValueSuite { + static func _decode( + _ parameter: SQLParameter, + _ expected: Value, + sourceLocation: SourceLocation = #_sourceLocation + ) throws { + let repo = try ItemRepository(datatype: "ANY") + let select = try repo.select(parameter) + let row = try #require(try select.row(), sourceLocation: sourceLocation) + + let keyed = try row.decode(Self.self) + #expect(keyed.value == expected, sourceLocation: sourceLocation) + + let single = try row.decode(Single.self) + #expect(single.value == expected, sourceLocation: sourceLocation) + + #expect(try select.row() == nil, sourceLocation: sourceLocation) + } + + static func _dataCorrupted( + _ parameter: SQLParameter, + _ message: String, + sourceLocation: SourceLocation = #_sourceLocation + ) throws { + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(parameter).row(), sourceLocation: sourceLocation) + #expect(sourceLocation: sourceLocation) { + try row.decode(Self.self) + } throws: { error in + guard case .dataCorrupted(let context) = error as? DecodingError else { + return false + } + return context.codingPath.map(\.stringValue) == ["value"] + && context.debugDescription == message + && context.underlyingError == nil + } + } +} + +struct Single: Decodable { + let value: T +} + +struct ItemRepository { + private let db: Database + + init(datatype: String) throws { + db = try Database.open(at: ":memory:", options: [.readwrite, .memory]) + try db.execute("CREATE TABLE items (value \(datatype));") + } + + func select(_ parameter: SQLParameter) throws -> PreparedStatement { + try db.prepare("INSERT INTO items (value) VALUES (?);").bind(index: 1, parameter: parameter).execute() + return try db.prepare("SELECT value FROM items;") + } +} diff --git a/Tests/SQLyraTests/SQLParameter+Testing.swift b/Tests/SQLyraTests/SQLParameter+Testing.swift new file mode 100644 index 0000000..eb5ab61 --- /dev/null +++ b/Tests/SQLyraTests/SQLParameter+Testing.swift @@ -0,0 +1,32 @@ +import SQLyra +import Testing + +extension SQLParameter: CustomTestStringConvertible { + public var testDescription: String { + switch self { + case .null: "NULL" + case .int64(let value): "INT(\(value))" + case .double(let value): "DOUBLE(\(value))" + case .text(let value): "TEXT(\(value))" + case .blob(let value): "BLOB(\(Array(value))" + } + } +} + +extension SQLParameter: CustomTestArgumentEncodable { + public func encodeTestArgument(to encoder: some Encoder) throws { + switch self { + case .null: + var container = encoder.singleValueContainer() + try container.encodeNil() + case .int64(let value): + try value.encode(to: encoder) + case .double(let value): + try value.encode(to: encoder) + case .text(let value): + try value.encode(to: encoder) + case .blob(let value): + try value.encode(to: encoder) + } + } +}