From 5748b0154089fde62ddb473ebaed308c71a87dcb Mon Sep 17 00:00:00 2001 From: Alexander Ignition Date: Tue, 7 Jan 2025 12:17:18 +0300 Subject: [PATCH 1/3] Create Row and RowDecoder --- Playgrounds/README.playground/Contents.swift | 2 +- README.md | 2 +- Sources/SQLyra/DatabaseError.swift | 2 +- Sources/SQLyra/PreparedStatement.swift | 115 ++++++---- Sources/SQLyra/RowDecoder.swift | 195 +++++++++++++++++ Sources/SQLyra/StatementDecoder.swift | 205 ------------------ Tests/SQLyraTests/DatabaseTests.swift | 3 +- .../SQLyraTests/PreparedStatementTests.swift | 10 +- 8 files changed, 273 insertions(+), 261 deletions(-) create mode 100644 Sources/SQLyra/RowDecoder.swift delete mode 100644 Sources/SQLyra/StatementDecoder.swift diff --git a/Playgrounds/README.playground/Contents.swift b/Playgrounds/README.playground/Contents.swift index 4d07064..0ef56fa 100644 --- a/Playgrounds/README.playground/Contents.swift +++ b/Playgrounds/README.playground/Contents.swift @@ -51,4 +51,4 @@ struct Contact: Codable { let name: String } -let contacts = try database.prepare("SELECT * FROM contacts;").array(decoding: Contact.self) +let contacts = try database.prepare("SELECT * FROM contacts;").array(Contact.self) diff --git a/README.md b/README.md index 4cf31df..c0bf5ec 100644 --- a/README.md +++ b/README.md @@ -50,5 +50,5 @@ struct Contact: Codable { let name: String } -let contacts = try database.prepare("SELECT * FROM contacts;").array(decoding: Contact.self) +let contacts = try database.prepare("SELECT * FROM contacts;").array(Contact.self) ``` diff --git a/Sources/SQLyra/DatabaseError.swift b/Sources/SQLyra/DatabaseError.swift index cc03556..e6154b8 100644 --- a/Sources/SQLyra/DatabaseError.swift +++ b/Sources/SQLyra/DatabaseError.swift @@ -7,7 +7,7 @@ public struct DatabaseError: Error, Equatable, Hashable { public let code: Int32 /// A short error description. - public var message: String? + public let message: String? /// A complete sentence (or more) describing why the operation failed. public let details: String? diff --git a/Sources/SQLyra/PreparedStatement.swift b/Sources/SQLyra/PreparedStatement.swift index d62bd9c..9914076 100644 --- a/Sources/SQLyra/PreparedStatement.swift +++ b/Sources/SQLyra/PreparedStatement.swift @@ -13,7 +13,7 @@ public final class PreparedStatement: DatabaseHandle { private(set) lazy var columnIndexByName = [String: Int32]( uniqueKeysWithValues: (0.. Bool { - switch sqlite3_step(stmt) { - case SQLITE_DONE: - return false - case SQLITE_ROW: - return true - case let code: - throw DatabaseError(code: code, db: db) - } - } - /// Reset the prepared statement. /// /// The ``PreparedStatement/reset()`` function is called to reset a prepared statement object back to its initial state, ready to be re-executed. @@ -59,32 +45,6 @@ public final class PreparedStatement: DatabaseHandle { public func reset() throws -> PreparedStatement { try check(sqlite3_reset(stmt)) } - - /// Reset all bindings on a prepared statement. - /// - /// Contrary to the intuition of many, ``PreparedStatement/reset()`` does not reset the bindings on a prepared statement. - /// Use this routine to reset all host parameters to NULL. - /// - /// - Throws: ``DatabaseError`` - @discardableResult - public func clearBindings() throws -> PreparedStatement { - try check(sqlite3_clear_bindings(stmt)) - } - - // MARK: - Decodable - - public func array(decoding type: T.Type) throws -> [T] where T: Decodable { - var array: [T] = [] - while try step() { - let value = try decode(type) - array.append(value) - } - return array - } - - public func decode(_ type: T.Type) throws -> T where T: Decodable { - try StatementDecoder().decode(type, from: self) - } } // MARK: - Retrieving Statement SQL @@ -162,20 +122,83 @@ extension PreparedStatement { } return try check(code) } + + /// Reset all bindings on a prepared statement. + /// + /// Contrary to the intuition of many, ``PreparedStatement/reset()`` does not reset the bindings on a prepared statement. + /// Use this routine to reset all host parameters to NULL. + /// + /// - Throws: ``DatabaseError`` + @discardableResult + public func clearBindings() throws -> PreparedStatement { + try check(sqlite3_clear_bindings(stmt)) + } } -// MARK: - Result values from a Query +// MARK: - Columns extension PreparedStatement { /// Return the number of columns in the result set. public var columnCount: Int32 { sqlite3_column_count(stmt) } - public func column(at index: Int32) -> Column { - Column(index: index, statement: self) + public func columnName(at index: Int32) -> String? { + sqlite3_column_name(stmt, index).string + } +} + +// MARK: - Result values from a Query + +extension PreparedStatement { + /// The new row of data is ready for processing. + /// + /// - Throws: ``DatabaseError`` + public func row() throws -> Row? { + switch sqlite3_step(stmt) { + case SQLITE_DONE: + return nil + case SQLITE_ROW: + return Row(statement: self) + case let code: + throw DatabaseError(code: code, db: db) + } + } + + public func array(_ type: T.Type) throws -> [T] where T: Decodable { + try array(type, using: RowDecoder.default) + } + + public func array(_ type: T.Type, using decoder: RowDecoder) throws -> [T] where T: Decodable { + var array: [T] = [] + while let row = try row() { + let value = try row.decode(type, using: decoder) + array.append(value) + } + return array } - public func column(for name: String) -> Column? { - columnIndexByName[name].map { Column(index: $0, statement: self) } + @dynamicMemberLookup + public struct Row { + let statement: PreparedStatement + + public subscript(dynamicMember name: String) -> Column! { + self[name] + } + + public subscript(name: String) -> Column? { + statement.columnIndexByName[name].map { self[$0] } + } + + public subscript(index: Int32) -> Column { + Column(index: index, statement: statement) + } + + public func decode(_ type: T.Type) throws -> T where T: Decodable { + try decode(type, using: RowDecoder.default) + } + + public func decode(_ type: T.Type, using decoder: RowDecoder) throws -> T where T: Decodable { + try decoder.decode(type, from: self) + } } /// Information about a single column of the current result row of a query. diff --git a/Sources/SQLyra/RowDecoder.swift b/Sources/SQLyra/RowDecoder.swift new file mode 100644 index 0000000..f5cdfaa --- /dev/null +++ b/Sources/SQLyra/RowDecoder.swift @@ -0,0 +1,195 @@ +import Foundation +import SQLite3 + +/// An object that decodes instances of a data type from ``PreparedStatement``. +public final class RowDecoder { + nonisolated(unsafe) static let `default` = RowDecoder() + + /// A dictionary you use to customize the decoding process by providing contextual information. + public var userInfo: [CodingUserInfoKey: Any] = [:] + + /// Creates a new, reusable row decoder. + public init() {} + + public func decode(_ type: T.Type, from row: PreparedStatement.Row) throws -> T where T: Decodable { + let decoder = _RowDecoder(row: row, userInfo: userInfo) + return try type.init(from: decoder) + } +} + +private struct _RowDecoder: Decoder { + let row: PreparedStatement.Row + + // MARK: - Decoder + + let userInfo: [CodingUserInfoKey: Any] + var codingPath: [any CodingKey] { [] } + + func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { + KeyedDecodingContainer(KeyedContainer(decoder: self)) + } + + func unkeyedContainer() throws -> any UnkeyedDecodingContainer { + let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") + throw DecodingError.typeMismatch(PreparedStatement.self, context) + } + + func singleValueContainer() throws -> any SingleValueDecodingContainer { + let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") + throw DecodingError.typeMismatch(PreparedStatement.self, context) + } + + // MARK: - KeyedDecodingContainer + + struct KeyedContainer: KeyedDecodingContainerProtocol { + let decoder: _RowDecoder + var codingPath: [any CodingKey] { decoder.codingPath } + var allKeys: [Key] { decoder.row.statement.columnIndexByName.keys.compactMap { Key(stringValue: $0) } } + + func contains(_ key: Key) -> Bool { decoder.row.statement.columnIndexByName.keys.contains(key.stringValue) } + func decodeNil(forKey key: Key) throws -> Bool { decoder.null(for: key) } + func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { try decoder.bool(forKey: key) } + func decode(_ type: String.Type, forKey key: Key) throws -> String { try decoder.string(forKey: key) } + func decode(_ type: Double.Type, forKey key: Key) throws -> Double { try decoder.floating(type, forKey: key) } + func decode(_ type: Float.Type, forKey key: Key) throws -> Float { try decoder.floating(type, forKey: key) } + func decode(_ type: Int.Type, forKey key: Key) throws -> Int { try decoder.integer(type, forKey: key) } + func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { try decoder.integer(type, forKey: key) } + func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { + return try decoder.decode(type, forKey: key) + } + + func superDecoder() throws -> any Decoder { fatalError() } + func superDecoder(forKey key: Key) throws -> any Decoder { fatalError() } + func nestedUnkeyedContainer(forKey key: Key) throws -> any UnkeyedDecodingContainer { fatalError() } + func nestedContainer( + keyedBy type: NestedKey.Type, + forKey key: Key + ) throws -> KeyedDecodingContainer where NestedKey: CodingKey { + fatalError() + } + } + + // MARK: - Decoding Values + + @inline(__always) + func null(for key: K) -> Bool where K: CodingKey { + row[key.stringValue]?.isNull ?? true + } + + @inline(__always) + func bool(forKey key: K) throws -> Bool where K: CodingKey { + try integer(Int64.self, forKey: key) != 0 + } + + @inline(__always) + func integer(_ type: T.Type, forKey key: K) throws -> T where T: Numeric, K: CodingKey { + let value = try column(forKey: key) + guard let number = type.init(exactly: value.int64) else { + let message = "Parsed SQL integer <\(value)> does not fit in \(type)." + let context = DecodingError.Context(codingPath: [key], debugDescription: message) + throw DecodingError.dataCorrupted(context) + } + return number + } + + @inline(__always) + func floating(_ type: T.Type, forKey key: K) throws -> T where T: BinaryFloatingPoint, K: CodingKey { + let value = try column(forKey: key) + guard let number = type.init(exactly: value.double) else { + let message = "Parsed SQL double <\(value)> does not fit in \(type)." + let context = DecodingError.Context(codingPath: [key], debugDescription: message) + throw DecodingError.dataCorrupted(context) + } + return number + } + + @inline(__always) + func string(forKey key: K) throws -> String where K: CodingKey { + let value = try column(forKey: key) + guard let value = value.string else { + throw DecodingError.valueNotFound(String.self, .codingPath([key])) + } + return value + } + + @inline(__always) + func decode(_ type: T.Type, forKey key: K) throws -> T where T: Decodable, K: CodingKey { + if type == Data.self { + let value = try column(forKey: key) + guard let data = value.blob else { + throw DecodingError.valueNotFound(Data.self, .codingPath([key])) + } + // swift-format-ignore: NeverForceUnwrap + return data as! T + } + let decoder = _ColumnDecoder(key: key, decoder: self) + return try type.init(from: decoder) + } + + @inline(__always) + private func column(forKey key: K) throws -> PreparedStatement.Column where K: CodingKey { + guard let index = row.statement.columnIndexByName[key.stringValue] else { + let message = "Column index not found for key: \(key)" + let context = DecodingError.Context(codingPath: [key], debugDescription: message) + throw DecodingError.keyNotFound(key, context) + } + return row[index] + } +} + +private extension DecodingError.Context { + static func codingPath(_ path: [any CodingKey]) -> DecodingError.Context { + DecodingError.Context(codingPath: path, debugDescription: "") + } +} + +private struct _ColumnDecoder: Decoder, SingleValueDecodingContainer { + let key: any CodingKey + let decoder: _RowDecoder + + // MARK: - Decoder + + var userInfo: [CodingUserInfoKey: Any] { decoder.userInfo } + var codingPath: [any CodingKey] { [key] } + + func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { + let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") + throw DecodingError.typeMismatch(PreparedStatement.Column.self, context) + } + + func unkeyedContainer() throws -> any UnkeyedDecodingContainer { + let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") + throw DecodingError.typeMismatch(PreparedStatement.Column.self, context) + } + + func singleValueContainer() throws -> any SingleValueDecodingContainer { + self + } + + // MARK: - SingleValueDecodingContainer + + func decodeNil() -> Bool { decoder.null(for: key) } + func decode(_ type: Bool.Type) throws -> Bool { try decoder.bool(forKey: key) } + func decode(_ type: String.Type) throws -> String { try decoder.string(forKey: key) } + func decode(_ type: Double.Type) throws -> Double { try decoder.floating(type, forKey: key) } + func decode(_ type: Float.Type) throws -> Float { try decoder.floating(type, forKey: key) } + func decode(_ type: Int.Type) throws -> Int { try decoder.integer(type, forKey: key) } + func decode(_ type: Int8.Type) throws -> Int8 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int16.Type) throws -> Int16 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int32.Type) throws -> Int32 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int64.Type) throws -> Int64 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt.Type) throws -> UInt { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt8.Type) throws -> UInt8 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt16.Type) throws -> UInt16 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt32.Type) throws -> UInt32 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt64.Type) throws -> UInt64 { try decoder.integer(type, forKey: key) } + func decode(_ type: T.Type) throws -> T where T: Decodable { try decoder.decode(type, forKey: key) } +} diff --git a/Sources/SQLyra/StatementDecoder.swift b/Sources/SQLyra/StatementDecoder.swift deleted file mode 100644 index f65a982..0000000 --- a/Sources/SQLyra/StatementDecoder.swift +++ /dev/null @@ -1,205 +0,0 @@ -import Foundation -import SQLite3 - -/// An object that decodes instances of a data type from ``PreparedStatement``. -public struct StatementDecoder { - /// A dictionary you use to customize the decoding process by providing contextual information. - public var userInfo: [CodingUserInfoKey: Any] = [:] - - /// Creates a new, reusable Statement decoder. - public init() {} - - public func decode(_ type: T.Type, from statement: PreparedStatement) throws -> T where T: Decodable { - let decoder = _StatementDecoder( - statement: statement, - userInfo: userInfo - ) - return try type.init(from: decoder) - } -} - -private final class _StatementDecoder { - let statement: PreparedStatement - let userInfo: [CodingUserInfoKey: Any] - private(set) var codingPath: [any CodingKey] = [] - - init(statement: PreparedStatement, userInfo: [CodingUserInfoKey: Any]) { - self.statement = statement - self.userInfo = userInfo - self.codingPath.reserveCapacity(3) - } - - @inline(__always) - func null(for key: K) -> Bool where K: CodingKey { - statement.column(for: key.stringValue)?.isNull ?? true - } - - @inline(__always) - func bool(forKey key: K) throws -> Bool where K: CodingKey { - try integer(Int64.self, forKey: key) != 0 - } - - @inline(__always) - func string(forKey key: K, single: Bool = false) throws -> String where K: CodingKey { - let index = try columnIndex(forKey: key, single: single) - guard let value = statement.column(at: index).string else { - throw DecodingError.valueNotFound(String.self, context(key, single, "")) - } - return value - } - - @inline(__always) - func floating( - _ type: T.Type, - forKey key: K, - single: Bool = false - ) throws -> T where T: BinaryFloatingPoint, K: CodingKey { - let index = try columnIndex(forKey: key, single: single) - let value = statement.column(at: index).double - guard let number = type.init(exactly: value) else { - throw DecodingError.dataCorrupted(context(key, single, numberNotFit(type, value: "\(value)"))) - } - return number - } - - @inline(__always) - func integer(_ type: T.Type, forKey key: K, single: Bool = false) throws -> T where T: Numeric, K: CodingKey { - let index = try columnIndex(forKey: key, single: single) - let value = statement.column(at: index).int64 - guard let number = type.init(exactly: value) else { - throw DecodingError.dataCorrupted(context(key, single, numberNotFit(type, value: "\(value)"))) - } - return number - } - - @inline(__always) - func decode( - _ type: T.Type, - forKey key: K, - single: Bool = false - ) throws -> T where T: Decodable, K: CodingKey { - if type == Data.self { - let index = try columnIndex(forKey: key, single: single) - guard let data = statement.column(at: index).blob else { - throw DecodingError.valueNotFound(Data.self, context(key, single, "")) - } - // swift-format-ignore: NeverForceUnwrap - return data as! T - } - if single { - return try type.init(from: self) - } - codingPath.append(key) - defer { - codingPath.removeLast() - } - return try type.init(from: self) - } - - private func columnIndex(forKey key: K, single: Bool) throws -> Int32 where K: CodingKey { - guard let index = statement.columnIndexByName[key.stringValue] else { - throw DecodingError.keyNotFound(key, context(key, single, "Column index not found for key: \(key)")) - } - return index - } - - private func context(_ key: any CodingKey, _ single: Bool, _ message: String) -> DecodingError.Context { - var path = codingPath - if !single { - path.append(key) - } - return DecodingError.Context(codingPath: path, debugDescription: message) - } -} - -private func numberNotFit(_ type: any Any.Type, value: String) -> String { - "Parsed SQL number <\(value)> does not fit in \(type)." -} - -// MARK: - Decoder - -extension _StatementDecoder: Decoder { - func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { - KeyedDecodingContainer(KeyedContainer(decoder: self)) - } - - func unkeyedContainer() throws -> any UnkeyedDecodingContainer { - let context = DecodingError.Context( - codingPath: codingPath, - debugDescription: "`unkeyedContainer()` not supported" - ) - throw DecodingError.dataCorrupted(context) - } - - func singleValueContainer() throws -> any SingleValueDecodingContainer { - if codingPath.isEmpty { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "key not found") - throw DecodingError.dataCorrupted(context) - } - return self - } -} - -// MARK: - SingleValueDecodingContainer - -extension _StatementDecoder: SingleValueDecodingContainer { - // swift-format-ignore: NeverForceUnwrap - private var key: any CodingKey { codingPath.last! } - - func decodeNil() -> Bool { null(for: key) } - func decode(_ type: Bool.Type) throws -> Bool { try bool(forKey: key) } - func decode(_ type: String.Type) throws -> String { try string(forKey: key, single: true) } - func decode(_ type: Double.Type) throws -> Double { try floating(type, forKey: key, single: true) } - func decode(_ type: Float.Type) throws -> Float { try floating(type, forKey: key, single: true) } - func decode(_ type: Int.Type) throws -> Int { try integer(type, forKey: key, single: true) } - func decode(_ type: Int8.Type) throws -> Int8 { try integer(type, forKey: key, single: true) } - func decode(_ type: Int16.Type) throws -> Int16 { try integer(type, forKey: key, single: true) } - func decode(_ type: Int32.Type) throws -> Int32 { try integer(type, forKey: key, single: true) } - func decode(_ type: Int64.Type) throws -> Int64 { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt.Type) throws -> UInt { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt8.Type) throws -> UInt8 { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt16.Type) throws -> UInt16 { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt32.Type) throws -> UInt32 { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt64.Type) throws -> UInt64 { try integer(type, forKey: key, single: true) } - func decode(_ type: T.Type) throws -> T where T: Decodable { try decode(type, forKey: key, single: true) } -} - -// MARK: - KeyedDecodingContainer - -extension _StatementDecoder { - struct KeyedContainer: KeyedDecodingContainerProtocol { - let decoder: _StatementDecoder - var codingPath: [any CodingKey] { decoder.codingPath } - var allKeys: [Key] { decoder.statement.columnIndexByName.keys.compactMap { Key(stringValue: $0) } } - - func contains(_ key: Key) -> Bool { decoder.statement.columnIndexByName.keys.contains(key.stringValue) } - func decodeNil(forKey key: Key) throws -> Bool { decoder.null(for: key) } - func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { try decoder.bool(forKey: key) } - func decode(_ type: String.Type, forKey key: Key) throws -> String { try decoder.string(forKey: key) } - func decode(_ type: Double.Type, forKey key: Key) throws -> Double { try decoder.floating(type, forKey: key) } - func decode(_ type: Float.Type, forKey key: Key) throws -> Float { try decoder.floating(type, forKey: key) } - func decode(_ type: Int.Type, forKey key: Key) throws -> Int { try decoder.integer(type, forKey: key) } - func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { try decoder.integer(type, forKey: key) } - func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { try decoder.integer(type, forKey: key) } - func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { try decoder.integer(type, forKey: key) } - func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { try decoder.integer(type, forKey: key) } - func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { - try decoder.decode(type, forKey: key) - } - - func superDecoder() throws -> any Decoder { fatalError() } - func superDecoder(forKey key: Key) throws -> any Decoder { fatalError() } - func nestedUnkeyedContainer(forKey key: Key) throws -> any UnkeyedDecodingContainer { fatalError() } - func nestedContainer( - keyedBy type: NestedKey.Type, - forKey key: Key - ) throws -> KeyedDecodingContainer where NestedKey: CodingKey { - fatalError() - } - } -} diff --git a/Tests/SQLyraTests/DatabaseTests.swift b/Tests/SQLyraTests/DatabaseTests.swift index ba2d19b..0573dd8 100644 --- a/Tests/SQLyraTests/DatabaseTests.swift +++ b/Tests/SQLyraTests/DatabaseTests.swift @@ -78,12 +78,11 @@ struct DatabaseTests { let id: Int let name: String } - let contacts = try database.prepare("SELECT * FROM contacts;").array(decoding: Contact.self) + let contacts = try database.prepare("SELECT * FROM contacts;").array(Contact.self) let expected = [ Contact(id: 1, name: "Paul"), Contact(id: 2, name: "John"), ] #expect(contacts == expected) - // try database.execute("SELECT name FROM sqlite_master WHERE type ='table';") } } diff --git a/Tests/SQLyraTests/PreparedStatementTests.swift b/Tests/SQLyraTests/PreparedStatementTests.swift index eefceee..84782a6 100644 --- a/Tests/SQLyraTests/PreparedStatementTests.swift +++ b/Tests/SQLyraTests/PreparedStatementTests.swift @@ -78,12 +78,12 @@ struct PreparedStatementTests { #expect(select.columnCount == 4) var contracts: [Contact] = [] - while try select.step() { + while let row = try select.row() { let contact = Contact( - id: Int(select.column(at: 0).int64), - name: select.column(at: 1).string ?? "", - rating: select.column(at: 2).double, - image: select.column(at: 3).blob + id: Int(row.id.int64), + name: row.name.string ?? "", + rating: row.rating.double, + image: row.image.blob ) contracts.append(contact) } From 4b0a06495852503c1448b0a7d7fc6fea259a4adc Mon Sep 17 00:00:00 2001 From: Alexander Ignition Date: Sat, 11 Jan 2025 20:25:34 +0300 Subject: [PATCH 2/3] RowDecoderTests --- Playgrounds/README.playground/Contents.swift | 15 +- README.md | 15 +- Sources/SQLyra/PreparedStatement.swift | 74 ++-- Sources/SQLyra/RowDecoder.swift | 81 ++--- Sources/SQLyra/SQLParameter.swift | 2 +- .../SQLyraTests/PreparedStatementTests.swift | 8 +- Tests/SQLyraTests/RowDecoderTests.swift | 329 ++++++++++++++++++ Tests/SQLyraTests/SQLParameter+Testing.swift | 32 ++ 8 files changed, 456 insertions(+), 100 deletions(-) create mode 100644 Tests/SQLyraTests/RowDecoderTests.swift create mode 100644 Tests/SQLyraTests/SQLParameter+Testing.swift diff --git a/Playgrounds/README.playground/Contents.swift b/Playgrounds/README.playground/Contents.swift index 0ef56fa..9a13b6f 100644 --- a/Playgrounds/README.playground/Contents.swift +++ b/Playgrounds/README.playground/Contents.swift @@ -9,7 +9,7 @@ [Documentation](https://alexander-ignition.github.io/SQLyra/documentation/sqlyra/) - - Note: this readme file is available as Xcode playground in Playgrounds/README.playground + > this readme file is available as Xcode playground in Playgrounds/README.playground ## Open @@ -26,21 +26,21 @@ let database = try Database.open( Create table for contacts with fields `id` and `name`. */ -try database.execute( - """ +let sql = """ CREATE TABLE contacts( id INT PRIMARY KEY NOT NULL, name TEXT ); """ -) +try database.execute(sql) /*: ## Insert Insert new contacts Paul and John. */ -try database.execute("INSERT INTO contacts (id, name) VALUES (1, 'Paul');") -try database.execute("INSERT INTO contacts (id, name) VALUES (2, 'John');") +let insert = try database.prepare("INSERT INTO contacts (id, name) VALUES (?, ?);") +try insert.bind(parameters: 1, "Paul").execute().reset() +try insert.bind(parameters: 2, "John").execute() /*: ## Select @@ -51,4 +51,5 @@ struct Contact: Codable { let name: String } -let contacts = try database.prepare("SELECT * FROM contacts;").array(Contact.self) +let select = try database.prepare("SELECT * FROM contacts;") +let contacts = try select.array(Contact.self) diff --git a/README.md b/README.md index c0bf5ec..391411a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Swift SQLite wrapper. [Documentation](https://alexander-ignition.github.io/SQLyra/documentation/sqlyra/) -- Note: this readme file is available as Xcode playground in Playgrounds/README.playground +> this readme file is available as Xcode playground in Playgrounds/README.playground ## Open @@ -25,21 +25,21 @@ let database = try Database.open( Create table for contacts with fields `id` and `name`. ```swift -try database.execute( - """ +let sql = """ CREATE TABLE contacts( id INT PRIMARY KEY NOT NULL, name TEXT ); """ -) +try database.execute(sql) ``` ## Insert Insert new contacts Paul and John. ```swift -try database.execute("INSERT INTO contacts (id, name) VALUES (1, 'Paul');") -try database.execute("INSERT INTO contacts (id, name) VALUES (2, 'John');") +let insert = try database.prepare("INSERT INTO contacts (id, name) VALUES (?, ?);") +try insert.bind(parameters: 1, "Paul").execute().reset() +try insert.bind(parameters: 2, "John").execute() ``` ## Select @@ -50,5 +50,6 @@ struct Contact: Codable { let name: String } -let contacts = try database.prepare("SELECT * FROM contacts;").array(Contact.self) +let select = try database.prepare("SELECT * FROM contacts;") +let contacts = try select.array(Contact.self) ``` diff --git a/Sources/SQLyra/PreparedStatement.swift b/Sources/SQLyra/PreparedStatement.swift index 9914076..7dba370 100644 --- a/Sources/SQLyra/PreparedStatement.swift +++ b/Sources/SQLyra/PreparedStatement.swift @@ -11,7 +11,7 @@ public final class PreparedStatement: DatabaseHandle { /// Find the database handle of a prepared statement. var db: OpaquePointer! { sqlite3_db_handle(stmt) } - private(set) lazy var columnIndexByName = [String: Int32]( + private(set) lazy var columnIndexByName = [String: Int]( uniqueKeysWithValues: (0.. String? { - sqlite3_bind_parameter_name(stmt, index).map { String(cString: $0) } + public func parameterName(at index: Int) -> String? { + sqlite3_bind_parameter_name(stmt, Int32(index)).map { String(cString: $0) } } /// Index of a parameter with a given name. - public func parameterIndex(for name: String) -> Int32 { - sqlite3_bind_parameter_index(stmt, name) + public func parameterIndex(for name: String) -> Int { + Int(sqlite3_bind_parameter_index(stmt, name)) } } @@ -98,13 +98,14 @@ extension PreparedStatement { @discardableResult public func bind(parameters: SQLParameter...) throws -> PreparedStatement { for (index, parameter) in parameters.enumerated() { - try bind(index: Int32(index + 1), parameter: parameter) + try bind(index: index + 1, parameter: parameter) } return self } @discardableResult - public func bind(index: Int32, parameter: SQLParameter) throws -> PreparedStatement { + public func bind(index: Int, parameter: SQLParameter) throws -> PreparedStatement { + let index = Int32(index) let code = switch parameter { case .null: @@ -139,10 +140,14 @@ extension PreparedStatement { extension PreparedStatement { /// Return the number of columns in the result set. - public var columnCount: Int32 { sqlite3_column_count(stmt) } + public var columnCount: Int { Int(sqlite3_column_count(stmt)) } - public func columnName(at index: Int32) -> String? { - sqlite3_column_name(stmt, index).string + /// Returns the name assigned to a specific column in the result set of the SELECT statement. + /// + /// The name of a result column is the value of the "AS" clause for that column, if there is an AS clause. + /// If there is no AS clause then the name of the column is unspecified and may change from one release of SQLite to the next. + public func columnName(at index: Int) -> String? { + sqlite3_column_name(stmt, Int32(index)).string } } @@ -154,12 +159,9 @@ extension PreparedStatement { /// - Throws: ``DatabaseError`` public func row() throws -> Row? { switch sqlite3_step(stmt) { - case SQLITE_DONE: - return nil - case SQLITE_ROW: - return Row(statement: self) - case let code: - throw DatabaseError(code: code, db: db) + case SQLITE_DONE: nil + case SQLITE_ROW: Row(statement: self) + case let code: throw DatabaseError(code: code, db: db) } } @@ -180,16 +182,19 @@ extension PreparedStatement { public struct Row { let statement: PreparedStatement - public subscript(dynamicMember name: String) -> Column! { + public subscript(dynamicMember name: String) -> Value? { self[name] } - public subscript(name: String) -> Column? { - statement.columnIndexByName[name].map { self[$0] } + public subscript(name: String) -> Value? { + statement.columnIndexByName[name].flatMap { self[$0] } } - public subscript(index: Int32) -> Column { - Column(index: index, statement: statement) + public subscript(index: Int) -> Value? { + if sqlite3_column_type(statement.stmt, Int32(index)) == SQLITE_NULL { + return nil + } + return Value(index: Int32(index), statement: statement) } public func decode(_ type: T.Type) throws -> T where T: Decodable { @@ -201,26 +206,27 @@ extension PreparedStatement { } } - /// Information about a single column of the current result row of a query. - public struct Column { + /// Result value from a query. + public struct Value { let index: Int32 let statement: PreparedStatement private var stmt: OpaquePointer { statement.stmt } - /// Returns the name assigned to a specific column in the result set of the SELECT statement. - /// - /// The name of a result column is the value of the "AS" clause for that column, if there is an AS clause. - /// If there is no AS clause then the name of the column is unspecified and may change from one release of SQLite to the next. - public var name: String? { sqlite3_column_name(stmt, index).string } - - public var isNull: Bool { sqlite3_column_type(stmt, index) == SQLITE_NULL } - /// 64-bit INTEGER result. public var int64: Int64 { sqlite3_column_int64(stmt, index) } + /// 32-bit INTEGER result. + public var int32: Int32 { sqlite3_column_int(stmt, index) } + + /// A platform-specific integer. + public var int: Int { Int(int64) } + /// 64-bit IEEE floating point number. public var double: Double { sqlite3_column_double(stmt, index) } + /// Size of a BLOB or a UTF-8 TEXT result in bytes. + public var count: Int { Int(sqlite3_column_bytes(stmt, index)) } + /// UTF-8 TEXT result. public var string: String? { sqlite3_column_text(stmt, index).flatMap { String(cString: $0) } @@ -228,9 +234,7 @@ extension PreparedStatement { /// BLOB result. public var blob: Data? { - sqlite3_column_blob(stmt, index).map { bytes in - Data(bytes: bytes, count: Int(sqlite3_column_bytes(stmt, index))) - } + sqlite3_column_blob(stmt, index).map { Data(bytes: $0, count: count) } } } } diff --git a/Sources/SQLyra/RowDecoder.swift b/Sources/SQLyra/RowDecoder.swift index f5cdfaa..4376dd0 100644 --- a/Sources/SQLyra/RowDecoder.swift +++ b/Sources/SQLyra/RowDecoder.swift @@ -30,13 +30,11 @@ private struct _RowDecoder: Decoder { } func unkeyedContainer() throws -> any UnkeyedDecodingContainer { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") - throw DecodingError.typeMismatch(PreparedStatement.self, context) + throw DecodingError.typeMismatch(PreparedStatement.self, .context(codingPath, "")) } func singleValueContainer() throws -> any SingleValueDecodingContainer { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") - throw DecodingError.typeMismatch(PreparedStatement.self, context) + throw DecodingError.typeMismatch(PreparedStatement.self, .context(codingPath, "")) } // MARK: - KeyedDecodingContainer @@ -63,7 +61,7 @@ private struct _RowDecoder: Decoder { func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { try decoder.integer(type, forKey: key) } func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { try decoder.integer(type, forKey: key) } func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { - return try decoder.decode(type, forKey: key) + try decoder.decode(type, forKey: key) } func superDecoder() throws -> any Decoder { fatalError() } @@ -81,7 +79,7 @@ private struct _RowDecoder: Decoder { @inline(__always) func null(for key: K) -> Bool where K: CodingKey { - row[key.stringValue]?.isNull ?? true + row[key.stringValue] == nil } @inline(__always) @@ -90,68 +88,55 @@ private struct _RowDecoder: Decoder { } @inline(__always) - func integer(_ type: T.Type, forKey key: K) throws -> T where T: Numeric, K: CodingKey { - let value = try column(forKey: key) - guard let number = type.init(exactly: value.int64) else { - let message = "Parsed SQL integer <\(value)> does not fit in \(type)." - let context = DecodingError.Context(codingPath: [key], debugDescription: message) - throw DecodingError.dataCorrupted(context) - } - return number + func string(forKey key: K) throws -> String where K: CodingKey { + try columnValue(String.self, forKey: key).string ?? "" } @inline(__always) - func floating(_ type: T.Type, forKey key: K) throws -> T where T: BinaryFloatingPoint, K: CodingKey { - let value = try column(forKey: key) - guard let number = type.init(exactly: value.double) else { - let message = "Parsed SQL double <\(value)> does not fit in \(type)." - let context = DecodingError.Context(codingPath: [key], debugDescription: message) - throw DecodingError.dataCorrupted(context) + func integer(_ type: T.Type, forKey key: K) throws -> T where T: Numeric, K: CodingKey { + let value = try columnValue(type, forKey: key) + let int64 = value.int64 + guard let number = type.init(exactly: int64) else { + throw DecodingError.dataCorrupted(.context([key], "Parsed SQL int64 <\(int64)> does not fit in \(type).")) } return number } @inline(__always) - func string(forKey key: K) throws -> String where K: CodingKey { - let value = try column(forKey: key) - guard let value = value.string else { - throw DecodingError.valueNotFound(String.self, .codingPath([key])) + func floating(_ type: T.Type, forKey key: K) throws -> T where T: BinaryFloatingPoint, K: CodingKey { + let value = try columnValue(type, forKey: key) + let double = value.double + guard let number = type.init(exactly: double) else { + throw DecodingError.dataCorrupted(.context([key], "Parsed SQL double <\(double)> does not fit in \(type).")) } - return value + return number } @inline(__always) func decode(_ type: T.Type, forKey key: K) throws -> T where T: Decodable, K: CodingKey { if type == Data.self { - let value = try column(forKey: key) - guard let data = value.blob else { - throw DecodingError.valueNotFound(Data.self, .codingPath([key])) - } + let value = try columnValue(type, forKey: key) + let data = value.blob ?? Data() // swift-format-ignore: NeverForceUnwrap return data as! T } - let decoder = _ColumnDecoder(key: key, decoder: self) + let decoder = _ValueDecoder(key: key, decoder: self) return try type.init(from: decoder) } @inline(__always) - private func column(forKey key: K) throws -> PreparedStatement.Column where K: CodingKey { + private func columnValue(_ type: T.Type, forKey key: K) throws -> PreparedStatement.Value where K: CodingKey { guard let index = row.statement.columnIndexByName[key.stringValue] else { - let message = "Column index not found for key: \(key)" - let context = DecodingError.Context(codingPath: [key], debugDescription: message) - throw DecodingError.keyNotFound(key, context) + throw DecodingError.keyNotFound(key, .context([key], "Column index not found for key: \(key)")) } - return row[index] - } -} - -private extension DecodingError.Context { - static func codingPath(_ path: [any CodingKey]) -> DecodingError.Context { - DecodingError.Context(codingPath: path, debugDescription: "") + guard let column = row[index] else { + throw DecodingError.valueNotFound(type, .context([key], "Column value not found for key: \(key)")) + } + return column } } -private struct _ColumnDecoder: Decoder, SingleValueDecodingContainer { +private struct _ValueDecoder: Decoder, SingleValueDecodingContainer { let key: any CodingKey let decoder: _RowDecoder @@ -161,13 +146,11 @@ private struct _ColumnDecoder: Decoder, SingleValueDecodingContainer { var codingPath: [any CodingKey] { [key] } func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") - throw DecodingError.typeMismatch(PreparedStatement.Column.self, context) + throw DecodingError.typeMismatch(PreparedStatement.Value.self, .context(codingPath, "")) } func unkeyedContainer() throws -> any UnkeyedDecodingContainer { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "") - throw DecodingError.typeMismatch(PreparedStatement.Column.self, context) + throw DecodingError.typeMismatch(PreparedStatement.Value.self, .context(codingPath, "")) } func singleValueContainer() throws -> any SingleValueDecodingContainer { @@ -193,3 +176,9 @@ private struct _ColumnDecoder: Decoder, SingleValueDecodingContainer { func decode(_ type: UInt64.Type) throws -> UInt64 { try decoder.integer(type, forKey: key) } func decode(_ type: T.Type) throws -> T where T: Decodable { try decoder.decode(type, forKey: key) } } + +private extension DecodingError.Context { + static func context(_ codingPath: [any CodingKey], _ message: String) -> DecodingError.Context { + DecodingError.Context(codingPath: codingPath, debugDescription: message) + } +} diff --git a/Sources/SQLyra/SQLParameter.swift b/Sources/SQLyra/SQLParameter.swift index 764a5f0..fc61051 100644 --- a/Sources/SQLyra/SQLParameter.swift +++ b/Sources/SQLyra/SQLParameter.swift @@ -1,7 +1,7 @@ import Foundation /// SQL parameters. -public enum SQLParameter: Equatable { +public enum SQLParameter: Equatable, Sendable { case null /// 64-bit signed integer. diff --git a/Tests/SQLyraTests/PreparedStatementTests.swift b/Tests/SQLyraTests/PreparedStatementTests.swift index 84782a6..4872345 100644 --- a/Tests/SQLyraTests/PreparedStatementTests.swift +++ b/Tests/SQLyraTests/PreparedStatementTests.swift @@ -80,10 +80,10 @@ struct PreparedStatementTests { var contracts: [Contact] = [] while let row = try select.row() { let contact = Contact( - id: Int(row.id.int64), - name: row.name.string ?? "", - rating: row.rating.double, - image: row.image.blob + id: row.id?.int ?? 0, + name: row.name?.string ?? "", + rating: row.rating?.double ?? 0, + image: row.image?.blob ) contracts.append(contact) } diff --git a/Tests/SQLyraTests/RowDecoderTests.swift b/Tests/SQLyraTests/RowDecoderTests.swift new file mode 100644 index 0000000..d73474b --- /dev/null +++ b/Tests/SQLyraTests/RowDecoderTests.swift @@ -0,0 +1,329 @@ +import Foundation +import SQLyra +import Testing + +struct RowDecoderTests { + struct SignedIntegers { + /// valid parameters for all signed integers + static let arguments: [(SQLParameter, Int)] = [ + (-1, -1), + (0, 0), + (1, 1), + (0.9, 0), + ("2", 2), + (.blob(Data("3".utf8)), 3), + ] + + struct IntTests: DecodableValueSuite { + let value: Int + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int(expected)) + } + } + + struct Int8Tests: DecodableValueSuite { + let value: Int8 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int8(expected)) + } + } + + struct Int16Tests: DecodableValueSuite { + let value: Int16 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int16(expected)) + } + } + + struct Int32Tests: DecodableValueSuite { + let value: Int32 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int32(expected)) + } + } + + struct Int64Tests: DecodableValueSuite { + let value: Int64 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int64(expected)) + } + } + } + + struct UnsignedIntegers { + /// valid parameters for all unsigned integers + static let arguments: [(SQLParameter, UInt)] = [ + (0, 0), + (1, 1), + (0.9, 0), + ("2", 2), + (.blob(Data("3".utf8)), 3), + ] + + struct UIntTests: DecodableValueSuite { + let value: UInt + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt(expected)) + } + } + + struct UInt8Tests: DecodableValueSuite { + let value: UInt8 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt8(expected)) + } + + @Test(arguments: [ + SQLParameter.int64(-1), + SQLParameter.int64(Int64.max), + ]) + static func dataCorrupted(_ parameter: SQLParameter) throws { + try _dataCorrupted(parameter, "Parsed SQL int64 <\(parameter)> does not fit in UInt8.") + } + } + + struct UInt16Tests: DecodableValueSuite { + let value: UInt16 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt16(expected)) + } + } + + struct UInt32Tests: DecodableValueSuite { + let value: UInt32 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt32(expected)) + } + } + + struct UInt64Tests: DecodableValueSuite { + let value: UInt64 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt64(expected)) + } + } + } + + struct FloatingPointNumerics { + static let arguments: [(SQLParameter, Double)] = [ + (-1, -1.0), + (0, 0.0), + (1, 1.0), + (0.5, 0.5), + ("0.5", 0.5), + ("1.0", 1.0), + (.blob(Data("1".utf8)), 1.0), + ] + + struct DoubleTests: DecodableValueSuite { + let value: Double + + @Test(arguments: FloatingPointNumerics.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Double) throws { + try _decode(parameter, Double(expected)) + } + } + + struct FloatTests: DecodableValueSuite { + let value: Float + + @Test(arguments: FloatingPointNumerics.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Double) throws { + try _decode(parameter, Float(expected)) + } + + @Test static func dataCorrupted() throws { + let parameter = SQLParameter.double(Double.greatestFiniteMagnitude) + try _dataCorrupted(parameter, "Parsed SQL double <\(parameter)> does not fit in Float.") + } + } + } + + struct BoolTests: DecodableValueSuite { + let value: Bool + + @Test(arguments: [ + (SQLParameter.int64(0), false), + (SQLParameter.int64(1), true), + (SQLParameter.double(0.9), false), + (SQLParameter.text("abc"), false), + (SQLParameter.text("true"), false), + (SQLParameter.blob(Data("zxc".utf8)), false), + (SQLParameter.blob(Data("1".utf8)), true), + ]) + static func decode(_ parameter: SQLParameter, _ expected: Bool) throws { + try _decode(parameter, expected) + } + } + + struct StringTests: DecodableValueSuite { + let value: String + + @Test(arguments: [ + (SQLParameter.int64(0), "0"), + (SQLParameter.int64(1), "1"), + (SQLParameter.double(0.9), "0.9"), + (SQLParameter.text("abc"), "abc"), + (SQLParameter.blob(Data("zxc".utf8)), "zxc"), + ]) + static func decode(_ parameter: SQLParameter, _ expected: String) throws { + try _decode(parameter, expected) + } + } + + struct DataTests: DecodableValueSuite { + let value: Data + + @Test(arguments: [ + (SQLParameter.int64(1), Data("1".utf8)), + (SQLParameter.double(1.1), Data("1.1".utf8)), + (SQLParameter.text("123"), Data("123".utf8)), + (SQLParameter.blob(Data("zxc".utf8)), Data("zxc".utf8)), + ]) + static func decode(_ parameter: SQLParameter, _ expected: Data) throws { + try _decode(parameter, expected) + } + } + + struct OptionalTests: DecodableValueSuite { + let value: Int? + + @Test static func decode() throws { + try _decode(1, 1) + } + } + + struct DecodingErrorTests: Decodable { + let item: Int // invalid + + @Test static func keyNotFound() throws { + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(.int64(1)).row()) + #expect { + try row.decode(DecodingErrorTests.self) + } throws: { error in + guard case .keyNotFound(let key, let context) = error as? DecodingError else { + return false + } + return key.stringValue == "item" + && context.codingPath.map(\.stringValue) == ["item"] + && context.debugDescription == "Column index not found for key: \(key)" + && context.underlyingError == nil + } + } + + @Test static func typeMismatch() throws { + let errorMatcher = { (error: any Error) -> Bool in + guard case .typeMismatch(_, let context) = error as? DecodingError else { + return false + } + return context.codingPath.isEmpty && context.debugDescription == "" && context.underlyingError == nil + } + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(.int64(1)).row()) + #expect(performing: { try row.decode(Int.self) }, throws: errorMatcher) + #expect(performing: { try row.decode([Int].self) }, throws: errorMatcher) + } + + @Test static func valueNotFound() throws { + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(.null).row()) + #expect { + try row.decode(Single.self) + } throws: { error in + guard case .valueNotFound(let type, let context) = error as? DecodingError else { + return false + } + return type == Int8.self + && context.codingPath.map(\.stringValue) == ["value"] + && context.debugDescription == "Column value not found for key: \(context.codingPath[0])" + && context.underlyingError == nil + } + } + } +} + +// MARK: - Test Suite + +protocol DecodableValueSuite: Decodable { + associatedtype Value: Decodable, Equatable + + var value: Value { get } +} + +extension DecodableValueSuite { + static func _decode( + _ parameter: SQLParameter, + _ expected: Value, + sourceLocation: SourceLocation = #_sourceLocation + ) throws { + let repo = try ItemRepository(datatype: "ANY") + let select = try repo.select(parameter) + let row = try #require(try select.row(), sourceLocation: sourceLocation) + + let keyed = try row.decode(Self.self) + #expect(keyed.value == expected, sourceLocation: sourceLocation) + + let single = try row.decode(Single.self) + #expect(single.value == expected, sourceLocation: sourceLocation) + + #expect(try select.row() == nil, sourceLocation: sourceLocation) + } + + static func _dataCorrupted( + _ parameter: SQLParameter, + _ message: String, + sourceLocation: SourceLocation = #_sourceLocation + ) throws { + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(parameter).row(), sourceLocation: sourceLocation) + #expect(sourceLocation: sourceLocation) { + try row.decode(Self.self) + } throws: { error in + guard case .dataCorrupted(let context) = error as? DecodingError else { + return false + } + return context.codingPath.map(\.stringValue) == ["value"] + && context.debugDescription == message + && context.underlyingError == nil + } + } +} + +struct Single: Decodable { + let value: T +} + +struct ItemRepository { + private let db: Database + + init(datatype: String) throws { + db = try Database.open(at: ":memory:", options: [.readwrite, .memory]) + try db.execute("CREATE TABLE items (value \(datatype));") + } + + func select(_ parameter: SQLParameter) throws -> PreparedStatement { + try db.prepare("INSERT INTO items (value) VALUES (?);").bind(index: 1, parameter: parameter).execute() + return try db.prepare("SELECT value FROM items;") + } +} diff --git a/Tests/SQLyraTests/SQLParameter+Testing.swift b/Tests/SQLyraTests/SQLParameter+Testing.swift new file mode 100644 index 0000000..eb5ab61 --- /dev/null +++ b/Tests/SQLyraTests/SQLParameter+Testing.swift @@ -0,0 +1,32 @@ +import SQLyra +import Testing + +extension SQLParameter: CustomTestStringConvertible { + public var testDescription: String { + switch self { + case .null: "NULL" + case .int64(let value): "INT(\(value))" + case .double(let value): "DOUBLE(\(value))" + case .text(let value): "TEXT(\(value))" + case .blob(let value): "BLOB(\(Array(value))" + } + } +} + +extension SQLParameter: CustomTestArgumentEncodable { + public func encodeTestArgument(to encoder: some Encoder) throws { + switch self { + case .null: + var container = encoder.singleValueContainer() + try container.encodeNil() + case .int64(let value): + try value.encode(to: encoder) + case .double(let value): + try value.encode(to: encoder) + case .text(let value): + try value.encode(to: encoder) + case .blob(let value): + try value.encode(to: encoder) + } + } +} From bbb9518663780f9d2722317f5af5d166732149f5 Mon Sep 17 00:00:00 2001 From: Alexander Ignition Date: Sat, 11 Jan 2025 20:27:35 +0300 Subject: [PATCH 3/3] fix lint --- Tests/SQLyraTests/RowDecoderTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/SQLyraTests/RowDecoderTests.swift b/Tests/SQLyraTests/RowDecoderTests.swift index d73474b..7e3b7fd 100644 --- a/Tests/SQLyraTests/RowDecoderTests.swift +++ b/Tests/SQLyraTests/RowDecoderTests.swift @@ -214,7 +214,7 @@ struct RowDecoderTests { } struct DecodingErrorTests: Decodable { - let item: Int // invalid + let item: Int // invalid @Test static func keyNotFound() throws { let repo = try ItemRepository(datatype: "ANY")