diff --git a/dialect/mssqldialect/dialect.go b/dialect/mssqldialect/dialect.go index e14824b8f..a56eb0c29 100755 --- a/dialect/mssqldialect/dialect.go +++ b/dialect/mssqldialect/dialect.go @@ -127,10 +127,12 @@ func (*Dialect) AppendBool(b []byte, v bool) []byte { return strconv.AppendUint(b, uint64(num), 10) } +func (d *Dialect) DefaultVarcharLen() int { + return 255 +} + func sqlType(field *schema.Field) string { switch field.DiscoveredSQLType { - case sqltype.VarChar: - return field.DiscoveredSQLType + "(255)" case sqltype.Timestamp: return datetimeType case sqltype.Boolean: diff --git a/dialect/mysqldialect/dialect.go b/dialect/mysqldialect/dialect.go index 4b16b4a22..9e9032e2c 100644 --- a/dialect/mysqldialect/dialect.go +++ b/dialect/mysqldialect/dialect.go @@ -172,11 +172,12 @@ func (*Dialect) AppendJSON(b, jsonb []byte) []byte { return b } +func (d *Dialect) DefaultVarcharLen() int { + return 255 +} + func sqlType(field *schema.Field) string { - switch field.DiscoveredSQLType { - case sqltype.VarChar: - return field.DiscoveredSQLType + "(255)" - case sqltype.Timestamp: + if field.DiscoveredSQLType == sqltype.Timestamp { return datetimeType } return field.DiscoveredSQLType diff --git a/dialect/pgdialect/sqltype.go b/dialect/pgdialect/sqltype.go index 6c6294d71..dadea5c1c 100644 --- a/dialect/pgdialect/sqltype.go +++ b/dialect/pgdialect/sqltype.go @@ -45,6 +45,10 @@ var ( jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() ) +func (d *Dialect) DefaultVarcharLen() int { + return 0 +} + func fieldSQLType(field *schema.Field) string { if field.UserSQLType != "" { return field.UserSQLType diff --git a/dialect/sqlitedialect/dialect.go b/dialect/sqlitedialect/dialect.go index 720e979f5..3c809e7a7 100644 --- a/dialect/sqlitedialect/dialect.go +++ b/dialect/sqlitedialect/dialect.go @@ -87,6 +87,10 @@ func (d *Dialect) AppendBytes(b []byte, bs []byte) []byte { return b } +func (d *Dialect) DefaultVarcharLen() int { + return 0 +} + func fieldSQLType(field *schema.Field) string { switch field.DiscoveredSQLType { case sqltype.SmallInt, sqltype.BigInt: diff --git a/go.work.sum b/go.work.sum index 05eb5544f..1e0326e1e 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,7 +1,12 @@ +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/urfave/cli v1.22.1 h1:+mkCCcOFKPnCmVYVcURKps1Xe+3zP90gSYGNfRkjoIY= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index 0c8448568..24cfd9195 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -969,6 +969,23 @@ func TestQuery(t *testing.T) { When("NOT MATCHED THEN INSERT (name, value) VALUES (_data.name, _data.value)"). Returning("$action") }, + func(db *bun.DB) schema.QueryAppender { + // Note: not all dialects require specifying VARCHAR length + type Model struct { + // ID has the reflection-based type (DiscoveredSQLType) with default length + ID string + // Name has specific type and length defined (UserSQLType) + Name string `bun:",type:varchar(50)"` + // Title has user-defined type (UserSQLType) with default length + Title string `bun:",type:varchar"` + } + // Set default VARCHAR length to 10 + return db.NewCreateTable().Model((*Model)(nil)).Varchar(10) + }, + func(db *bun.DB) schema.QueryAppender { + // Non-positive VARCHAR length is illegal + return db.NewCreateTable().Model((*Model)(nil)).Varchar(-20) + }, } timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mariadb-156 b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-156 new file mode 100644 index 000000000..117474bd2 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-156 @@ -0,0 +1 @@ +CREATE TABLE `models` (`id` VARCHAR(10), `name` varchar(50), `title` VARCHAR(10)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mariadb-157 b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-157 new file mode 100644 index 000000000..7860dc22e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-157 @@ -0,0 +1 @@ +bun: illegal VARCHAR length: -20 diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-156 b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-156 new file mode 100644 index 000000000..04ad92c73 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-156 @@ -0,0 +1 @@ +CREATE TABLE "models" ("id" VARCHAR(10), "name" varchar(50), "title" VARCHAR(10)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-157 b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-157 new file mode 100644 index 000000000..7860dc22e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-157 @@ -0,0 +1 @@ +bun: illegal VARCHAR length: -20 diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql5-156 b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-156 new file mode 100644 index 000000000..117474bd2 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-156 @@ -0,0 +1 @@ +CREATE TABLE `models` (`id` VARCHAR(10), `name` varchar(50), `title` VARCHAR(10)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql5-157 b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-157 new file mode 100644 index 000000000..7860dc22e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-157 @@ -0,0 +1 @@ +bun: illegal VARCHAR length: -20 diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql8-156 b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-156 new file mode 100644 index 000000000..117474bd2 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-156 @@ -0,0 +1 @@ +CREATE TABLE `models` (`id` VARCHAR(10), `name` varchar(50), `title` VARCHAR(10)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql8-157 b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-157 new file mode 100644 index 000000000..7860dc22e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-157 @@ -0,0 +1 @@ +bun: illegal VARCHAR length: -20 diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pg-156 b/internal/dbtest/testdata/snapshots/TestQuery-pg-156 new file mode 100644 index 000000000..04ad92c73 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pg-156 @@ -0,0 +1 @@ +CREATE TABLE "models" ("id" VARCHAR(10), "name" varchar(50), "title" VARCHAR(10)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pg-157 b/internal/dbtest/testdata/snapshots/TestQuery-pg-157 new file mode 100644 index 000000000..7860dc22e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pg-157 @@ -0,0 +1 @@ +bun: illegal VARCHAR length: -20 diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pgx-156 b/internal/dbtest/testdata/snapshots/TestQuery-pgx-156 new file mode 100644 index 000000000..04ad92c73 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pgx-156 @@ -0,0 +1 @@ +CREATE TABLE "models" ("id" VARCHAR(10), "name" varchar(50), "title" VARCHAR(10)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pgx-157 b/internal/dbtest/testdata/snapshots/TestQuery-pgx-157 new file mode 100644 index 000000000..7860dc22e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pgx-157 @@ -0,0 +1 @@ +bun: illegal VARCHAR length: -20 diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-156 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-156 new file mode 100644 index 000000000..04ad92c73 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-156 @@ -0,0 +1 @@ +CREATE TABLE "models" ("id" VARCHAR(10), "name" varchar(50), "title" VARCHAR(10)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-157 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-157 new file mode 100644 index 000000000..7860dc22e --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-157 @@ -0,0 +1 @@ +bun: illegal VARCHAR length: -20 diff --git a/query_table_create.go b/query_table_create.go index 002250bc1..0fe3013c7 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -3,8 +3,10 @@ package bun import ( "context" "database/sql" + "fmt" "sort" "strconv" + "strings" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/dialect/sqltype" @@ -17,7 +19,12 @@ type CreateTableQuery struct { temp bool ifNotExists bool - varchar int + + // varchar changes the default length for VARCHAR columns. + // Because some dialects require that length is always specified for VARCHAR type, + // we will use the exact user-defined type if length is set explicitly, as in `bun:",type:varchar(5)"`, + // but assume the new default length when it's omitted, e.g. `bun:",type:varchar"`. + varchar int fks []schema.QueryWithArgs partitionBy schema.QueryWithArgs @@ -32,6 +39,7 @@ func NewCreateTableQuery(db *DB) *CreateTableQuery { db: db, conn: db.DB, }, + varchar: db.Dialect().DefaultVarcharLen(), } return q } @@ -82,7 +90,11 @@ func (q *CreateTableQuery) IfNotExists() *CreateTableQuery { return q } +// Varchar sets default length for VARCHAR columns. func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery { + if n <= 0 { + q.setErr(fmt.Errorf("bun: illegal VARCHAR length: %d", n)) + } q.varchar = n return q } @@ -120,7 +132,7 @@ func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery { return q } -//------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ func (q *CreateTableQuery) Operation() string { return "CREATE TABLE" @@ -221,19 +233,23 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by } func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte { - if field.CreateTableSQLType != field.DiscoveredSQLType { + // Most of the time these two will match, but for the cases where DiscoveredSQLType is dialect-specific, + // e.g. pgdialect would change sqltype.SmallInt to pgTypeSmallSerial for columns that have `bun:",autoincrement"` + if !strings.EqualFold(field.CreateTableSQLType, field.DiscoveredSQLType) { return append(b, field.CreateTableSQLType...) } - if q.varchar > 0 && - field.CreateTableSQLType == sqltype.VarChar { - b = append(b, "varchar("...) - b = strconv.AppendInt(b, int64(q.varchar), 10) - b = append(b, ")"...) - return b + // For all common SQL types except VARCHAR, both UserDefinedSQLType and DiscoveredSQLType specify the correct type, + // and we needn't modify it. For VARCHAR columns, we will stop to check if a valid length has been set in .Varchar(int). + if !strings.EqualFold(field.CreateTableSQLType, sqltype.VarChar) || q.varchar <= 0 { + return append(b, field.CreateTableSQLType...) } - return append(b, field.CreateTableSQLType...) + b = append(b, sqltype.VarChar...) + b = append(b, "("...) + b = strconv.AppendInt(b, int64(q.varchar), 10) + b = append(b, ")"...) + return b } func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte { diff --git a/schema/dialect.go b/schema/dialect.go index b73d89bd0..fea8238dc 100644 --- a/schema/dialect.go +++ b/schema/dialect.go @@ -30,9 +30,14 @@ type Dialect interface { AppendBytes(b []byte, bs []byte) []byte AppendJSON(b, jsonb []byte) []byte AppendBool(b []byte, v bool) []byte + + // DefaultVarcharLen should be returned for dialects in which specifying VARCHAR length + // is mandatory in queries that modify the schema (CREATE TABLE / ADD COLUMN, etc). + // Dialects that do not have such requirement may return 0, which should be interpreted so by the caller. + DefaultVarcharLen() int } -//------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ type BaseDialect struct{} @@ -131,7 +136,7 @@ func (BaseDialect) AppendBool(b []byte, v bool) []byte { return dialect.AppendBool(b, v) } -//------------------------------------------------------------------------------ +// ------------------------------------------------------------------------------ type nopDialect struct { BaseDialect @@ -168,3 +173,7 @@ func (d *nopDialect) OnTable(table *Table) {} func (d *nopDialect) IdentQuote() byte { return '"' } + +func (d *nopDialect) DefaultVarcharLen() int { + return 0 +}