Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CreateTableQuery: use correct VARCHAR length #738

Merged
merged 7 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions dialect/mssqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ func (*Dialect) AppendBool(b []byte, v bool) []byte {
return strconv.AppendUint(b, uint64(num), 10)
}

func (d *Dialect) DefaultVarcharLen() int {
return 255
}

func sqlType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.VarChar:
return field.DiscoveredSQLType + "(255)"
case sqltype.Timestamp:
return datetimeType
case sqltype.Boolean:
Expand Down
9 changes: 5 additions & 4 deletions dialect/mysqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,12 @@ func (*Dialect) AppendJSON(b, jsonb []byte) []byte {
return b
}

func (d *Dialect) DefaultVarcharLen() int {
return 255
}

func sqlType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.VarChar:
return field.DiscoveredSQLType + "(255)"
case sqltype.Timestamp:
if field.DiscoveredSQLType == sqltype.Timestamp {
return datetimeType
}
return field.DiscoveredSQLType
Expand Down
4 changes: 4 additions & 0 deletions dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ var (
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
)

func (d *Dialect) DefaultVarcharLen() int {
return 0
}

func fieldSQLType(field *schema.Field) string {
if field.UserSQLType != "" {
return field.UserSQLType
Expand Down
4 changes: 4 additions & 0 deletions dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ func (d *Dialect) AppendBytes(b []byte, bs []byte) []byte {
return b
}

func (d *Dialect) DefaultVarcharLen() int {
return 0
}

func fieldSQLType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.SmallInt, sqltype.BigInt:
Expand Down
5 changes: 5 additions & 0 deletions go.work.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/urfave/cli v1.22.1 h1:+mkCCcOFKPnCmVYVcURKps1Xe+3zP90gSYGNfRkjoIY=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
Expand Down
17 changes: 17 additions & 0 deletions internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,23 @@ func TestQuery(t *testing.T) {
When("NOT MATCHED THEN INSERT (name, value) VALUES (_data.name, _data.value)").
Returning("$action")
},
func(db *bun.DB) schema.QueryAppender {
// Note: not all dialects require specifying VARCHAR length
type Model struct {
// ID has the reflection-based type (DiscoveredSQLType) with default length
ID string
// Name has specific type and length defined (UserSQLType)
Name string `bun:",type:varchar(50)"`
// Title has user-defined type (UserSQLType) with default length
Title string `bun:",type:varchar"`
}
// Set default VARCHAR length to 10
return db.NewCreateTable().Model((*Model)(nil)).Varchar(10)
},
func(db *bun.DB) schema.QueryAppender {
// Non-positive VARCHAR length is illegal
return db.NewCreateTable().Model((*Model)(nil)).Varchar(-20)
},
}

timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`)
Expand Down
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mariadb-156
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE `models` (`id` VARCHAR(10), `name` varchar(50), `title` VARCHAR(10))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mariadb-157
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bun: illegal VARCHAR length: -20
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mssql2019-156
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "models" ("id" VARCHAR(10), "name" varchar(50), "title" VARCHAR(10))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mssql2019-157
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bun: illegal VARCHAR length: -20
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-156
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE `models` (`id` VARCHAR(10), `name` varchar(50), `title` VARCHAR(10))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-157
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bun: illegal VARCHAR length: -20
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-156
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE `models` (`id` VARCHAR(10), `name` varchar(50), `title` VARCHAR(10))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-157
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bun: illegal VARCHAR length: -20
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-156
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "models" ("id" VARCHAR(10), "name" varchar(50), "title" VARCHAR(10))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-157
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bun: illegal VARCHAR length: -20
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-156
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "models" ("id" VARCHAR(10), "name" varchar(50), "title" VARCHAR(10))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-157
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bun: illegal VARCHAR length: -20
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-156
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "models" ("id" VARCHAR(10), "name" varchar(50), "title" VARCHAR(10))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-157
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bun: illegal VARCHAR length: -20
36 changes: 26 additions & 10 deletions query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package bun
import (
"context"
"database/sql"
"fmt"
"sort"
"strconv"
"strings"

"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype"
Expand All @@ -17,7 +19,12 @@ type CreateTableQuery struct {

temp bool
ifNotExists bool
varchar int

// varchar changes the default length for VARCHAR columns.
// Because some dialects require that length is always specified for VARCHAR type,
// we will use the exact user-defined type if length is set explicitly, as in `bun:",type:varchar(5)"`,
// but assume the new default length when it's omitted, e.g. `bun:",type:varchar"`.
varchar int

fks []schema.QueryWithArgs
partitionBy schema.QueryWithArgs
Expand All @@ -32,6 +39,7 @@ func NewCreateTableQuery(db *DB) *CreateTableQuery {
db: db,
conn: db.DB,
},
varchar: db.Dialect().DefaultVarcharLen(),
}
return q
}
Expand Down Expand Up @@ -82,7 +90,11 @@ func (q *CreateTableQuery) IfNotExists() *CreateTableQuery {
return q
}

// Varchar sets default length for VARCHAR columns.
func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery {
if n <= 0 {
q.setErr(fmt.Errorf("bun: illegal VARCHAR length: %d", n))
}
q.varchar = n
return q
}
Expand Down Expand Up @@ -120,7 +132,7 @@ func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery {
return q
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

func (q *CreateTableQuery) Operation() string {
return "CREATE TABLE"
Expand Down Expand Up @@ -221,19 +233,23 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by
}

func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte {
if field.CreateTableSQLType != field.DiscoveredSQLType {
// Most of the time these two will match, but for the cases where DiscoveredSQLType is dialect-specific,
// e.g. pgdialect would change sqltype.SmallInt to pgTypeSmallSerial for columns that have `bun:",autoincrement"`
if !strings.EqualFold(field.CreateTableSQLType, field.DiscoveredSQLType) {
return append(b, field.CreateTableSQLType...)
}

if q.varchar > 0 &&
field.CreateTableSQLType == sqltype.VarChar {
b = append(b, "varchar("...)
b = strconv.AppendInt(b, int64(q.varchar), 10)
b = append(b, ")"...)
return b
// For all common SQL types except VARCHAR, both UserDefinedSQLType and DiscoveredSQLType specify the correct type,
// and we needn't modify it. For VARCHAR columns, we will stop to check if a valid length has been set in .Varchar(int).
if !strings.EqualFold(field.CreateTableSQLType, sqltype.VarChar) || q.varchar <= 0 {
return append(b, field.CreateTableSQLType...)
}

return append(b, field.CreateTableSQLType...)
b = append(b, sqltype.VarChar...)
b = append(b, "("...)
b = strconv.AppendInt(b, int64(q.varchar), 10)
b = append(b, ")"...)
return b
}

func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte {
Expand Down
13 changes: 11 additions & 2 deletions schema/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ type Dialect interface {
AppendBytes(b []byte, bs []byte) []byte
AppendJSON(b, jsonb []byte) []byte
AppendBool(b []byte, v bool) []byte

// DefaultVarcharLen should be returned for dialects in which specifying VARCHAR length
// is mandatory in queries that modify the schema (CREATE TABLE / ADD COLUMN, etc).
// Dialects that do not have such requirement may return 0, which should be interpreted so by the caller.
DefaultVarcharLen() int
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

type BaseDialect struct{}

Expand Down Expand Up @@ -131,7 +136,7 @@ func (BaseDialect) AppendBool(b []byte, v bool) []byte {
return dialect.AppendBool(b, v)
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

type nopDialect struct {
BaseDialect
Expand Down Expand Up @@ -168,3 +173,7 @@ func (d *nopDialect) OnTable(table *Table) {}
func (d *nopDialect) IdentQuote() byte {
return '"'
}

func (d *nopDialect) DefaultVarcharLen() int {
return 0
}