Skip to content

Commit

Permalink
Oracle diver develop, new method QuerySuffix() in dialect, better arg…
Browse files Browse the repository at this point in the history
…s string
  • Loading branch information
klaidliadon authored and James Cooper committed May 16, 2014
1 parent d69be84 commit 728a08e
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 6 deletions.
118 changes: 118 additions & 0 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import (
// but this could change in the future
type Dialect interface {

// adds a suffix to any query, usually ";"
QuerySuffix() string

// ToSqlType returns the SQL column type to use when creating a
// table of the given Go Type. maxsize can be used to switch based on
// size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB,
Expand Down Expand Up @@ -88,6 +91,8 @@ type SqliteDialect struct {
suffix string
}

func (d SqliteDialect) QuerySuffix() string { return ";" }

func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
Expand Down Expand Up @@ -172,6 +177,8 @@ type PostgresDialect struct {
suffix string
}

func (d PostgresDialect) QuerySuffix() string { return ";" }

func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
Expand Down Expand Up @@ -285,6 +292,8 @@ type MySQLDialect struct {
Encoding string
}

func (d MySQLDialect) QuerySuffix() string { return ";" }

func (m MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
Expand Down Expand Up @@ -495,3 +504,112 @@ func (d SqlServerDialect) QuotedTableForQuery(schema string, table string) strin
}
return schema + "." + table
}

///////////////////////////////////////////////////////
// Oracle //
///////////

// Implementation of Dialect for Oracle databases.
type OracleDialect struct{}

func (d OracleDialect) QuerySuffix() string { return "" }

func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32:
if isAutoIncr {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if isAutoIncr {
return "bigserial"
}
return "bigint"
case reflect.Float64:
return "double precision"
case reflect.Float32:
return "real"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "bytea"
}
}

switch val.Name() {
case "NullInt64":
return "bigint"
case "NullFloat64":
return "double precision"
case "NullBool":
return "boolean"
case "NullTime", "Time":
return "timestamp with time zone"
}

if maxsize > 0 {
return fmt.Sprintf("varchar(%d)", maxsize)
} else {
return "text"
}

}

// Returns empty string
func (d OracleDialect) AutoIncrStr() string {
return ""
}

func (d OracleDialect) AutoIncrBindValue() string {
return "default"
}

func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return " returning " + col.ColumnName
}

// Returns suffix
func (d OracleDialect) CreateTableSuffix() string {
return ""
}

func (d OracleDialect) TruncateClause() string {
return "truncate"
}

// Returns "$(i+1)"
func (d OracleDialect) BindVar(i int) string {
return fmt.Sprintf(":%d", i+1)
}

func (d OracleDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
rows, err := exec.query(insertSql, params...)
if err != nil {
return 0, err
}
defer rows.Close()

if rows.Next() {
var id int64
err := rows.Scan(&id)
return id, err
}

return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error())
}

func (d OracleDialect) QuoteField(f string) string {
return `"` + strings.ToUpper(f) + `"`
}

func (d OracleDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}

return schema + "." + d.QuoteField(table)
}
82 changes: 76 additions & 6 deletions gorp.go
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,58 @@ package gorp
import (
"bytes"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"time"
)

// Oracle String (empty string is null)
type OracleString struct {
sql.NullString
}

// Scan implements the Scanner interface.
func (os *OracleString) Scan(value interface{}) error {
if value == nil {
os.String, os.Valid = "", false
return nil
}
os.Valid = true
return os.NullString.Scan(value)
}

// Value implements the driver Valuer interface.
func (os OracleString) Value() (driver.Value, error) {
if !os.Valid || os.String == "" {
return nil, nil
}
return os.String, nil
}

// A nullable Time value
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}

// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) error {
nt.Time, nt.Valid = value.(time.Time)
return nil
}

// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}

var zeroVal reflect.Value
var versFieldConst = "[gorp_ver_field]"

Expand Down Expand Up @@ -340,7 +385,7 @@ func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) {
if plan.autoIncrIdx > -1 {
s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.columns[plan.autoIncrIdx]))
}
s.WriteString(";")
s.WriteString(t.dbmap.Dialect.QuerySuffix())

plan.query = s.String()
t.insertPlan = plan
Expand Down Expand Up @@ -398,7 +443,7 @@ func (t *TableMap) bindUpdate(elem reflect.Value) (bindInstance, error) {
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.argFields = append(plan.argFields, plan.versField)
}
s.WriteString(";")
s.WriteString(t.dbmap.Dialect.QuerySuffix())

plan.query = s.String()
t.updatePlan = plan
Expand Down Expand Up @@ -444,7 +489,7 @@ func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) {

plan.argFields = append(plan.argFields, plan.versField)
}
s.WriteString(";")
s.WriteString(t.dbmap.Dialect.QuerySuffix())

plan.query = s.String()
t.deletePlan = plan
Expand Down Expand Up @@ -485,7 +530,7 @@ func (t *TableMap) bindGet() bindPlan {

plan.keyFields = append(plan.keyFields, col.fieldName)
}
s.WriteString(";")
s.WriteString(t.dbmap.Dialect.QuerySuffix())

plan.query = s.String()
t.getPlan = plan
Expand Down Expand Up @@ -813,7 +858,7 @@ func (m *DbMap) createTables(ifNotExists bool) error {
}
s.WriteString(") ")
s.WriteString(m.Dialect.CreateTableSuffix())
s.WriteString(";")
s.WriteString(m.Dialect.QuerySuffix())
_, err = m.Exec(s.String())
if err != nil {
break
Expand Down Expand Up @@ -1086,8 +1131,33 @@ func (m *DbMap) query(query string, args ...interface{}) (*sql.Rows, error) {

func (m *DbMap) trace(query string, args ...interface{}) {
if m.logger != nil {
m.logger.Printf("%s%s %v", m.logPrefix, query, args)
var margs = argsString(args...)
m.logger.Printf("%s%s [%s]", m.logPrefix, query, margs)
}
}

func argsString(args ...interface{}) string {
var margs string
for i, a := range args {
var v interface{} = a
if x, ok := v.(driver.Valuer); ok {
y, err := x.Value()
if err == nil {
v = y
}
}
switch v.(type) {
case string:
v = fmt.Sprintf("%q", v)
default:
v = fmt.Sprintf("%v", v)
}
margs += fmt.Sprintf("%d:%s", i+1, v)
if i+1 < len(args) {
margs += " "
}
}
return margs
}

///////////////
Expand Down

0 comments on commit 728a08e

Please sign in to comment.