diff --git a/README.md b/README.md index e7bb424b..cbef88be 100644 --- a/README.md +++ b/README.md @@ -564,6 +564,15 @@ implemented per database vendor. Dialects are provided for: Each of these three databases pass the test suite. See `gorp_test.go` for example DSNs for these three databases. +Support is also provided for: + +* Oracle (contributed by @klaidliadon) +* SQL Server (contributed by @qrawl) - use driver: github.com/denisenkom/go-mssqldb + +Note that these databases are not covered by CI and I (@coopernurse) have no good way to +test them locally. So please try them and send patches as needed, but expect a bit more +unpredicability. + ## Known Issues ## ### SQL placeholder portability ### diff --git a/dialect.go b/dialect.go index 6b6ef0e8..963a3eab 100644 --- a/dialect.go +++ b/dialect.go @@ -12,6 +12,9 @@ import ( // but this could change in the future type Dialect interface { + // adds a suffix to any query, usually ";" + QuerySuffix() string + // ToSqlType returns the SQL column type to use when creating a // table of the given Go Type. maxsize can be used to switch based on // size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB, @@ -21,6 +24,8 @@ type Dialect interface { // string to append to primary key column definitions AutoIncrStr() string + // string to bind autoincrement columns to. Empty string will + // remove reference to those columns in the INSERT statement. AutoIncrBindValue() string AutoIncrInsertSuffix(col *ColumnMap) string @@ -32,8 +37,6 @@ type Dialect interface { // string to truncate tables TruncateClause() string - InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) - // bind variable string to use when forming SQL statements // in many dbs it is "?", but Postgres appears to use $1 // @@ -53,6 +56,25 @@ type Dialect interface { QuotedTableForQuery(schema string, table string) string } +// IntegerAutoIncrInserter is implemented by dialects that can perform +// inserts with automatically incremented integer primary keys. If +// the dialect can handle automatic assignment of more than just +// integers, see TargetedAutoIncrInserter. +type IntegerAutoIncrInserter interface { + InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) +} + +// TargetedAutoIncrInserter is implemented by dialects that can +// perform automatic assignment of any primary key type (i.e. strings +// for uuids, integers for serials, etc). +type TargetedAutoIncrInserter interface { + // InsertAutoIncrToTarget runs an insert operation and assigns the + // automatically generated primary key directly to the passed in + // target. The target should be a pointer to the primary key + // field of the value being inserted. + InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error +} + func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { res, err := exec.Exec(insertSql, params...) if err != nil { @@ -69,6 +91,8 @@ type SqliteDialect struct { suffix string } +func (d SqliteDialect) QuerySuffix() string { return ";" } + func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { switch val.Kind() { case reflect.Ptr: @@ -153,6 +177,8 @@ type PostgresDialect struct { suffix string } +func (d PostgresDialect) QuerySuffix() string { return ";" } + func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { switch val.Kind() { case reflect.Ptr: @@ -225,20 +251,19 @@ func (d PostgresDialect) BindVar(i int) string { return fmt.Sprintf("$%d", i+1) } -func (d PostgresDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { +func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error { rows, err := exec.query(insertSql, params...) if err != nil { - return 0, err + return err } defer rows.Close() if rows.Next() { - var id int64 - err := rows.Scan(&id) - return id, err + err := rows.Scan(target) + return err } - return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error()) + return errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error()) } func (d PostgresDialect) QuoteField(f string) string { @@ -267,10 +292,12 @@ type MySQLDialect struct { Encoding string } -func (m MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { +func (d MySQLDialect) QuerySuffix() string { return ";" } + +func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { switch val.Kind() { case reflect.Ptr: - return m.ToSqlType(val.Elem(), maxsize, isAutoIncr) + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) case reflect.Bool: return "boolean" case reflect.Int8: @@ -315,49 +342,49 @@ func (m MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) } // Returns auto_increment -func (m MySQLDialect) AutoIncrStr() string { +func (d MySQLDialect) AutoIncrStr() string { return "auto_increment" } -func (m MySQLDialect) AutoIncrBindValue() string { +func (d MySQLDialect) AutoIncrBindValue() string { return "null" } -func (m MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string { +func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string { return "" } // Returns engine=%s charset=%s based on values stored on struct -func (m MySQLDialect) CreateTableSuffix() string { - if m.Engine == "" || m.Encoding == "" { +func (d MySQLDialect) CreateTableSuffix() string { + if d.Engine == "" || d.Encoding == "" { msg := "gorp - undefined" - if m.Engine == "" { + if d.Engine == "" { msg += " MySQLDialect.Engine" } - if m.Engine == "" && m.Encoding == "" { + if d.Engine == "" && d.Encoding == "" { msg += "," } - if m.Encoding == "" { + if d.Encoding == "" { msg += " MySQLDialect.Encoding" } msg += ". Check that your MySQLDialect was correctly initialized when declared." panic(msg) } - return fmt.Sprintf(" engine=%s charset=%s", m.Engine, m.Encoding) + return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding) } -func (m MySQLDialect) TruncateClause() string { +func (d MySQLDialect) TruncateClause() string { return "truncate" } // Returns "?" -func (m MySQLDialect) BindVar(i int) string { +func (d MySQLDialect) BindVar(i int) string { return "?" } -func (m MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { +func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { return standardInsertAutoIncr(exec, insertSql, params...) } @@ -365,7 +392,226 @@ func (d MySQLDialect) QuoteField(f string) string { return "`" + f + "`" } -// MySQL does not have schemas like PostgreSQL does, so just escape it like normal func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { - return d.QuoteField(table) + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) +} + +/////////////////////////////////////////////////////// +// Sql Server // +//////////////// + +// Implementation of Dialect for Microsoft SQL Server databases. +// Tested on SQL Server 2008 with driver: github.com/denisenkom/go-mssqldb +// Presently, it doesn't work with CreateTablesIfNotExists(). + +type SqlServerDialect struct { + suffix string +} + +func (d SqlServerDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "bit" + case reflect.Int8: + return "tinyint" + case reflect.Uint8: + return "smallint" + case reflect.Int16: + return "smallint" + case reflect.Uint16: + return "int" + case reflect.Int, reflect.Int32: + return "int" + case reflect.Uint, reflect.Uint32: + return "bigint" + case reflect.Int64: + return "bigint" + case reflect.Uint64: + return "bigint" + case reflect.Float32: + return "real" + case reflect.Float64: + return "float(53)" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "varbinary" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "float(53)" + case "NullBool": + return "tinyint" + case "Time": + return "datetime" + } + + if maxsize < 1 { + maxsize = 255 + } + return fmt.Sprintf("varchar(%d)", maxsize) +} + +// Returns auto_increment +func (d SqlServerDialect) AutoIncrStr() string { + return "identity(0,1)" +} + +// Empty string removes autoincrement columns from the INSERT statements. +func (d SqlServerDialect) AutoIncrBindValue() string { + return "" +} + +func (d SqlServerDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return "" +} + +// Returns suffix +func (d SqlServerDialect) CreateTableSuffix() string { + + return d.suffix +} + +func (d SqlServerDialect) TruncateClause() string { + return "delete from" +} + +// Returns "?" +func (d SqlServerDialect) BindVar(i int) string { + return "?" +} + +func (d SqlServerDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + return standardInsertAutoIncr(exec, insertSql, params...) +} + +func (d SqlServerDialect) QuoteField(f string) string { + return `"` + f + `"` +} + +func (d SqlServerDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return table + } + return schema + "." + table +} + +func (d SqlServerDialect) QuerySuffix() string { return ";" } + +/////////////////////////////////////////////////////// +// Oracle // +/////////// + +// Implementation of Dialect for Oracle databases. +type OracleDialect struct{} + +func (d OracleDialect) QuerySuffix() string { return "" } + +func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32: + if isAutoIncr { + return "serial" + } + return "integer" + case reflect.Int64, reflect.Uint64: + if isAutoIncr { + return "bigserial" + } + return "bigint" + case reflect.Float64: + return "double precision" + case reflect.Float32: + return "real" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "bytea" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "double precision" + case "NullBool": + return "boolean" + case "NullTime", "Time": + return "timestamp with time zone" + } + + if maxsize > 0 { + return fmt.Sprintf("varchar(%d)", maxsize) + } else { + return "text" + } + +} + +// Returns empty string +func (d OracleDialect) AutoIncrStr() string { + return "" +} + +func (d OracleDialect) AutoIncrBindValue() string { + return "default" +} + +func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return " returning " + col.ColumnName +} + +// Returns suffix +func (d OracleDialect) CreateTableSuffix() string { + return "" +} + +func (d OracleDialect) TruncateClause() string { + return "truncate" +} + +// Returns "$(i+1)" +func (d OracleDialect) BindVar(i int) string { + return fmt.Sprintf(":%d", i+1) +} + +func (d OracleDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + rows, err := exec.query(insertSql, params...) + if err != nil { + return 0, err + } + defer rows.Close() + + if rows.Next() { + var id int64 + err := rows.Scan(&id) + return id, err + } + + return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error()) +} + +func (d OracleDialect) QuoteField(f string) string { + return `"` + strings.ToUpper(f) + `"` +} + +func (d OracleDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) } diff --git a/gorp.go b/gorp.go old mode 100644 new mode 100755 index c4355db3..53be19fa --- a/gorp.go +++ b/gorp.go @@ -14,13 +14,58 @@ package gorp import ( "bytes" "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" "regexp" "strings" + "time" ) +// Oracle String (empty string is null) +type OracleString struct { + sql.NullString +} + +// Scan implements the Scanner interface. +func (os *OracleString) Scan(value interface{}) error { + if value == nil { + os.String, os.Valid = "", false + return nil + } + os.Valid = true + return os.NullString.Scan(value) +} + +// Value implements the driver Valuer interface. +func (os OracleString) Value() (driver.Value, error) { + if !os.Valid || os.String == "" { + return nil, nil + } + return os.String, nil +} + +// A nullable Time value +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +func (nt *NullTime) Scan(value interface{}) error { + nt.Time, nt.Valid = value.(time.Time) + return nil +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + var zeroVal reflect.Value var versFieldConst = "[gorp_ver_field]" @@ -126,7 +171,7 @@ type TableMap struct { TableName string SchemaName string gotype reflect.Type - columns []*ColumnMap + Columns []*ColumnMap keys []*ColumnMap uniqueTogether [][]string version *ColumnMap @@ -212,7 +257,7 @@ func (t *TableMap) ColMap(field string) *ColumnMap { } func colMapOrNil(t *TableMap, field string) *ColumnMap { - for _, col := range t.columns { + for _, col := range t.Columns { if col.fieldName == field || col.ColumnName == field { return col } @@ -305,42 +350,45 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { x := 0 first := true - for y := range t.columns { - col := t.columns[y] - - if !col.Transient { - if !first { - s.WriteString(",") - s2.WriteString(",") - } - s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + for y := range t.Columns { + col := t.Columns[y] + if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") { + if !col.Transient { + if !first { + s.WriteString(",") + s2.WriteString(",") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) - if col.isAutoIncr { - s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue()) - plan.autoIncrIdx = y - plan.autoIncrFieldName = col.fieldName - } else { - s2.WriteString(t.dbmap.Dialect.BindVar(x)) - if col == t.version { - plan.versField = col.fieldName - plan.argFields = append(plan.argFields, versFieldConst) + if col.isAutoIncr { + s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue()) + plan.autoIncrIdx = y + plan.autoIncrFieldName = col.fieldName } else { - plan.argFields = append(plan.argFields, col.fieldName) + s2.WriteString(t.dbmap.Dialect.BindVar(x)) + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName) + } + + x++ } - - x++ + first = false } - - first = false + } else { + plan.autoIncrIdx = y + plan.autoIncrFieldName = col.fieldName } } s.WriteString(") values (") s.WriteString(s2.String()) s.WriteString(")") if plan.autoIncrIdx > -1 { - s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.columns[plan.autoIncrIdx])) + s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx])) } - s.WriteString(";") + s.WriteString(t.dbmap.Dialect.QuerySuffix()) plan.query = s.String() t.insertPlan = plan @@ -357,9 +405,9 @@ func (t *TableMap) bindUpdate(elem reflect.Value) (bindInstance, error) { s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) x := 0 - for y := range t.columns { - col := t.columns[y] - if !col.isPK && !col.Transient { + for y := range t.Columns { + col := t.Columns[y] + if !col.isAutoIncr && !col.Transient { if x > 0 { s.WriteString(", ") } @@ -398,7 +446,7 @@ func (t *TableMap) bindUpdate(elem reflect.Value) (bindInstance, error) { s.WriteString(t.dbmap.Dialect.BindVar(x)) plan.argFields = append(plan.argFields, plan.versField) } - s.WriteString(";") + s.WriteString(t.dbmap.Dialect.QuerySuffix()) plan.query = s.String() t.updatePlan = plan @@ -414,8 +462,8 @@ func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) { s := bytes.Buffer{} s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) - for y := range t.columns { - col := t.columns[y] + for y := range t.Columns { + col := t.Columns[y] if !col.Transient { if col == t.version { plan.versField = col.fieldName @@ -444,7 +492,7 @@ func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) { plan.argFields = append(plan.argFields, plan.versField) } - s.WriteString(";") + s.WriteString(t.dbmap.Dialect.QuerySuffix()) plan.query = s.String() t.deletePlan = plan @@ -461,7 +509,7 @@ func (t *TableMap) bindGet() bindPlan { s.WriteString("select ") x := 0 - for _, col := range t.columns { + for _, col := range t.Columns { if !col.Transient { if x > 0 { s.WriteString(",") @@ -485,7 +533,7 @@ func (t *TableMap) bindGet() bindPlan { plan.keyFields = append(plan.keyFields, col.fieldName) } - s.WriteString(";") + s.WriteString(t.dbmap.Dialect.QuerySuffix()) plan.query = s.String() t.getPlan = plan @@ -664,7 +712,7 @@ func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name str } tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} - tmap.columns, tmap.version = readStructColumns(t) + tmap.Columns, tmap.version = readStructColumns(t) m.tables = append(m.tables, tmap) return tmap @@ -765,7 +813,7 @@ func (m *DbMap) createTables(ifNotExists bool) error { s.WriteString(fmt.Sprintf("%s %s (", create, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) x := 0 - for _, col := range table.columns { + for _, col := range table.Columns { if !col.Transient { if x > 0 { s.WriteString(", ") @@ -813,13 +861,13 @@ func (m *DbMap) createTables(ifNotExists bool) error { } s.WriteString(") ") s.WriteString(m.Dialect.CreateTableSuffix()) - s.WriteString(";") + s.WriteString(m.Dialect.QuerySuffix()) // use the transaction if it's there. otherwise, use the db connection. if m.Tx != nil { _, err = m.Tx.Exec(s.String()) } else { - _, err = m.Exec(s.String()) + _, err = m.Exec(s.String()) } if err != nil { @@ -1046,7 +1094,10 @@ func (m *DbMap) Begin() (*Transaction, error) { return &Transaction{m, tx, false}, nil } -func (m *DbMap) tableFor(t reflect.Type, checkPK bool) (*TableMap, error) { +// TableFor returns the *TableMap corresponding to the given Go Type +// If no table is mapped to that type an error is returned. +// If checkPK is true and the mapped table has no registered PKs, an error is returned. +func (m *DbMap) TableFor(t reflect.Type, checkPK bool) (*TableMap, error) { table := tableOrNil(m, t) if table == nil { return nil, errors.New(fmt.Sprintf("No table found for type: %v", t.Name())) @@ -1080,7 +1131,7 @@ func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, refle } elem := ptrv.Elem() etype := reflect.TypeOf(elem.Interface()) - t, err := m.tableFor(etype, checkPK) + t, err := m.TableFor(etype, checkPK) if err != nil { return nil, reflect.Value{}, err } @@ -1100,8 +1151,33 @@ func (m *DbMap) query(query string, args ...interface{}) (*sql.Rows, error) { func (m *DbMap) trace(query string, args ...interface{}) { if m.logger != nil { - m.logger.Printf("%s%s %v", m.logPrefix, query, args) + var margs = argsString(args...) + m.logger.Printf("%s%s [%s]", m.logPrefix, query, margs) + } +} + +func argsString(args ...interface{}) string { + var margs string + for i, a := range args { + var v interface{} = a + if x, ok := v.(driver.Valuer); ok { + y, err := x.Value() + if err == nil { + v = y + } + } + switch v.(type) { + case string: + v = fmt.Sprintf("%q", v) + default: + v = fmt.Sprintf("%v", v) + } + margs += fmt.Sprintf("%d:%s", i+1, v) + if i+1 < len(args) { + margs += " " + } } + return margs } /////////////// @@ -1387,17 +1463,21 @@ func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string, // Determine where the results are: written to i, or returned in list if t, _ := toSliceType(i); t == nil { for _, v := range list { - err = runHook("PostGet", reflect.ValueOf(v), hookArg(exec)) - if err != nil { - return nil, err + if v, ok := v.(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } } } } else { resultsValue := reflect.Indirect(reflect.ValueOf(i)) for i := 0; i < resultsValue.Len(); i++ { - err = runHook("PostGet", resultsValue.Index(i), hookArg(exec)) - if err != nil { - return nil, err + if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } } } } @@ -1589,14 +1669,16 @@ func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error // a field in the i struct for x := range cols { colName := strings.ToLower(cols[x]) - field, found := t.FieldByNameFunc(func(fieldName string) bool { + var mappedFieldName string field, _ := t.FieldByName(fieldName) - fieldName = field.Tag.Get("db") - - if fieldName == "-" { + lowerFieldName := strings.ToLower(field.Name) + mappedFieldName = field.Tag.Get("db") + if mappedFieldName == "-" && colName != lowerFieldName { return false - } else if fieldName == "" { + } else if mappedFieldName == "-" && colName == lowerFieldName { + return true + } else if mappedFieldName == "" { fieldName = field.Name } if tableMapped { @@ -1605,7 +1687,6 @@ func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error fieldName = colMap.ColumnName } } - return colName == strings.ToLower(fieldName) }) if found { @@ -1682,7 +1763,7 @@ func get(m *DbMap, exec SqlExecutor, i interface{}, return nil, err } - table, err := m.tableFor(t, true) + table, err := m.TableFor(t, true) if err != nil { return nil, err } @@ -1724,16 +1805,17 @@ func get(m *DbMap, exec SqlExecutor, i interface{}, } } - err = runHook("PostGet", v, hookArg(exec)) - if err != nil { - return nil, err + if v, ok := v.Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } } return v.Interface(), nil } func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { - hookarg := hookArg(exec) count := int64(0) for _, ptr := range list { table, elem, err := m.tableForPointer(ptr, true) @@ -1741,10 +1823,12 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { return -1, err } - eptr := elem.Addr() - err = runHook("PreDelete", eptr, hookarg) - if err != nil { - return -1, err + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreDelete); ok { + err = v.PreDelete(exec) + if err != nil { + return -1, err + } } bi, err := table.bindDelete(elem) @@ -1768,9 +1852,11 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { count += rows - err = runHook("PostDelete", eptr, hookarg) - if err != nil { - return -1, err + if v, ok := eval.(HasPostDelete); ok { + err := v.PostDelete(exec) + if err != nil { + return -1, err + } } } @@ -1778,7 +1864,6 @@ func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { } func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { - hookarg := hookArg(exec) count := int64(0) for _, ptr := range list { table, elem, err := m.tableForPointer(ptr, true) @@ -1786,10 +1871,12 @@ func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { return -1, err } - eptr := elem.Addr() - err = runHook("PreUpdate", eptr, hookarg) - if err != nil { - return -1, err + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreUpdate); ok { + err = v.PreUpdate(exec) + if err != nil { + return -1, err + } } bi, err := table.bindUpdate(elem) @@ -1818,26 +1905,29 @@ func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { count += rows - err = runHook("PostUpdate", eptr, hookarg) - if err != nil { - return -1, err + if v, ok := eval.(HasPostUpdate); ok { + err = v.PostUpdate(exec) + if err != nil { + return -1, err + } } } return count, nil } func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { - hookarg := hookArg(exec) for _, ptr := range list { table, elem, err := m.tableForPointer(ptr, false) if err != nil { return err } - eptr := elem.Addr() - err = runHook("PreInsert", eptr, hookarg) - if err != nil { - return err + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreInsert); ok { + err := v.PreInsert(exec) + if err != nil { + return err + } } bi, err := table.bindInsert(elem) @@ -1846,18 +1936,28 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { } if bi.autoIncrIdx > -1 { - id, err := m.Dialect.InsertAutoIncr(exec, bi.query, bi.args...) - if err != nil { - return err - } f := elem.FieldByName(bi.autoIncrFieldName) - k := f.Kind() - if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) { - f.SetInt(id) - } else if (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) { - f.SetUint(uint64(id)) - } else { - return fmt.Errorf("gorp: Cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName) + switch inserter := m.Dialect.(type) { + case IntegerAutoIncrInserter: + id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...) + if err != nil { + return err + } + k := f.Kind() + if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) { + f.SetInt(id) + } else if (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) { + f.SetUint(uint64(id)) + } else { + return fmt.Errorf("gorp: Cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName) + } + case TargetedAutoIncrInserter: + err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...) + if err != nil { + return err + } + default: + return fmt.Errorf("gorp: Cannot use autoincrement fields on dialects that do not implement an autoincrementing interface") } } else { _, err := exec.Exec(bi.query, bi.args...) @@ -1866,25 +1966,11 @@ func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { } } - err = runHook("PostInsert", eptr, hookarg) - if err != nil { - return err - } - } - return nil -} - -func hookArg(exec SqlExecutor) []reflect.Value { - execval := reflect.ValueOf(exec) - return []reflect.Value{execval} -} - -func runHook(name string, eptr reflect.Value, arg []reflect.Value) error { - hook := eptr.MethodByName(name) - if hook != zeroVal { - ret := hook.Call(arg) - if len(ret) > 0 && !ret[0].IsNil() { - return ret[0].Interface().(error) + if v, ok := eval.(HasPostInsert); ok { + err := v.PostInsert(exec) + if err != nil { + return err + } } } return nil @@ -1905,3 +1991,38 @@ func lockError(m *DbMap, exec SqlExecutor, tableName string, } return -1, ole } + +// PostUpdate() will be executed after the GET statement. +type HasPostGet interface { + PostGet(SqlExecutor) error +} + +// PostUpdate() will be executed after the DELETE statement +type HasPostDelete interface { + PostDelete(SqlExecutor) error +} + +// PostUpdate() will be executed after the UPDATE statement +type HasPostUpdate interface { + PostUpdate(SqlExecutor) error +} + +// PostInsert() will be executed after the INSERT statement +type HasPostInsert interface { + PostInsert(SqlExecutor) error +} + +// PreDelete() will be executed before the DELETE statement. +type HasPreDelete interface { + PreDelete(SqlExecutor) error +} + +// PreUpdate() will be executed before UPDATE statement. +type HasPreUpdate interface { + PreUpdate(SqlExecutor) error +} + +// PreInsert() will be executed before INSERT statement. +type HasPreInsert interface { + PreInsert(SqlExecutor) error +} diff --git a/gorp_test.go b/gorp_test.go index 0f1f9588..45587853 100644 --- a/gorp_test.go +++ b/gorp_test.go @@ -18,6 +18,13 @@ import ( "time" ) +// verify interface compliance +var _ Dialect = SqliteDialect{} +var _ Dialect = PostgresDialect{} +var _ Dialect = MySQLDialect{} +var _ Dialect = SqlServerDialect{} +var _ Dialect = OracleDialect{} + type Invoice struct { Id int64 Created int64 @@ -64,6 +71,12 @@ type WithIgnoredColumn struct { Created int64 } +type IgnoredColumnExported struct { + Id int64 + External int64 `db:"-"` + Created int64 +} + type WithStringPk struct { Id string Name string @@ -119,6 +132,10 @@ type UniqueColumns struct { ZipCode int64 } +type SingleColumnTable struct { + SomeId string +} + type testTypeConverter struct{} func (me testTypeConverter) ToDb(val interface{}) (interface{}, error) { @@ -306,7 +323,8 @@ func TestSetUniqueTogether(t *testing.T) { t.Error(err) } // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL - if !strings.Contains(err.Error(), "unique") && !strings.Contains(err.Error(), "Duplicate entry") { + errLower := strings.ToLower(err.Error()) + if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { t.Error(err) } @@ -317,7 +335,8 @@ func TestSetUniqueTogether(t *testing.T) { t.Error(err) } // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL - if !strings.Contains(err.Error(), "unique") && !strings.Contains(err.Error(), "Duplicate entry") { + errLower = strings.ToLower(err.Error()) + if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { t.Error(err) } @@ -1425,6 +1444,40 @@ func TestSelectSingleVal(t *testing.T) { } } +func TestSelectAlias(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + + p1 := &IgnoredColumnExported{Id: 1, External: 2, Created: 3} + _insert(dbmap, p1) + + var p2 IgnoredColumnExported + + err := dbmap.SelectOne(&p2, "select * from ignored_column_exported_test where Id=1") + if err != nil { + t.Error(err) + } + if p2.Id != 1 || p2.Created != 3 || p2.External != 0 { + t.Error("Expected ignorred field defaults to not set") + } + + err = dbmap.SelectOne(&p2, "SELECT *, 1 AS external FROM ignored_column_exported_test") + if err != nil { + t.Error(err) + } + if p2.External != 1 { + t.Error("Expected select as can map to exported field.") + } + + var rows *sql.Rows + var cols []string + rows, err = dbmap.Db.Query("SELECT * FROM ignored_column_exported_test") + cols, err = rows.Columns() + if err != nil || len(cols) != 2 { + t.Error("Expected ignored column not created") + } +} + func TestMysqlPanicIfDialectNotInitialized(t *testing.T) { _, driver := dialectAndDriver() // this test only applies to MySQL @@ -1449,6 +1502,37 @@ func TestMysqlPanicIfDialectNotInitialized(t *testing.T) { db.CreateTables() } +func TestSingleColumnKeyDbReturnsZeroRowsUpdatedOnPKChange(t *testing.T) { + dbmap := initDbMap() + defer dropAndClose(dbmap) + dbmap.AddTableWithName(SingleColumnTable{}, "single_column_table").SetKeys(false, "SomeId") + err := dbmap.DropTablesIfExists() + if err != nil { + t.Error("Drop tables failed") + } + err = dbmap.CreateTablesIfNotExists() + if err != nil { + t.Error("Create tables failed") + } + err = dbmap.TruncateTables() + if err != nil { + t.Error("Truncate tables failed") + } + + sct := SingleColumnTable{ + SomeId: "A Unique Id String", + } + + count, err := dbmap.Update(&sct) + if err != nil { + t.Error(err) + } + if count != 0 { + t.Errorf("Expected 0 updated rows, got %d", count) + } + +} + func BenchmarkNativeCrud(b *testing.B) { b.StopTimer() dbmap := initDbMapBench() @@ -1559,6 +1643,7 @@ func initDbMap() *DbMap { dbmap.AddTableWithName(OverriddenInvoice{}, "invoice_override_test").SetKeys(false, "Id") dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id") dbmap.AddTableWithName(WithIgnoredColumn{}, "ignored_column_test").SetKeys(true, "Id") + dbmap.AddTableWithName(IgnoredColumnExported{}, "ignored_column_exported_test").SetKeys(true, "Id") dbmap.AddTableWithName(TypeConversionExample{}, "type_conv_test").SetKeys(true, "Id") dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "Id") dbmap.AddTableWithName(WithEmbeddedStructBeforeAutoincrField{}, "embedded_struct_before_autoincr_test").SetKeys(true, "Id")