From 0f17aca19500ceb122c665cdcd87460369a87f42 Mon Sep 17 00:00:00 2001 From: XIE Long Date: Thu, 30 Sep 2021 17:27:19 +0800 Subject: [PATCH 1/5] add bitwise operation supporting --- dialect/mysql/mysql.go | 8 ++ dialect/mysql/mysql_dialect_test.go | 13 +++ dialect/sqlite3/sqlite3.go | 6 ++ dialect/sqlite3/sqlite3_dialect_test.go | 13 +++ dialect/sqlserver/sqlserver.go | 6 ++ dialect/sqlserver/sqlserver_dialect_test.go | 60 ++++++++++++++ exp/bitwise.go | 89 +++++++++++++++++++++ exp/bitwise_test.go | 84 +++++++++++++++++++ exp/exp.go | 68 ++++++++++++++++ exp/ident.go | 22 +++++ exp/ident_test.go | 7 ++ exp/literal.go | 9 +++ exp/literal_test.go | 7 ++ sqlgen/expression_sql_generator.go | 28 +++++++ sqlgen/expression_sql_generator_test.go | 35 ++++++++ sqlgen/sql_dialect_options.go | 18 +++++ 16 files changed, 473 insertions(+) create mode 100644 dialect/sqlserver/sqlserver_dialect_test.go create mode 100644 exp/bitwise.go create mode 100644 exp/bitwise_test.go diff --git a/dialect/mysql/mysql.go b/dialect/mysql/mysql.go index 2a8934a2..d5566e49 100644 --- a/dialect/mysql/mysql.go +++ b/dialect/mysql/mysql.go @@ -51,6 +51,14 @@ func DialectOptions() *goqu.SQLDialectOptions { exp.RegexpILikeOp: []byte("REGEXP"), exp.RegexpNotILikeOp: []byte("NOT REGEXP"), } + opts.BitwiseOperatorLookup = map[exp.BitwiseOperation][]byte{ + exp.BitwiseInversionOp: []byte("~"), + exp.BitwiseOrOp: []byte("|"), + exp.BitwiseAndOp: []byte("&"), + exp.BitwiseXorOp: []byte("^"), + exp.BitwiseLeftShiftOp: []byte("<<"), + exp.BitwiseRightShiftOp: []byte(">>"), + } opts.EscapedRunes = map[rune][]byte{ '\'': []byte("\\'"), '"': []byte("\\\""), diff --git a/dialect/mysql/mysql_dialect_test.go b/dialect/mysql/mysql_dialect_test.go index 397a0845..b57184d4 100644 --- a/dialect/mysql/mysql_dialect_test.go +++ b/dialect/mysql/mysql_dialect_test.go @@ -112,6 +112,19 @@ func (mds *mysqlDialectSuite) TestBooleanOperations() { ) } +func (mds *mysqlDialectSuite) TestBitwiseOperations() { + col := goqu.C("a") + ds := mds.GetDs("test") + mds.assertSQL( + sqlTestCase{ds: ds.Where(col.BitwiseInversion()), sql: "SELECT * FROM `test` WHERE (~ `a`)"}, + sqlTestCase{ds: ds.Where(col.BitwiseAnd(1)), sql: "SELECT * FROM `test` WHERE (`a` & 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseOr(1)), sql: "SELECT * FROM `test` WHERE (`a` | 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseXor(1)), sql: "SELECT * FROM `test` WHERE (`a` ^ 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseLeftShift(1)), sql: "SELECT * FROM `test` WHERE (`a` << 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseRightShift(1)), sql: "SELECT * FROM `test` WHERE (`a` >> 1)"}, + ) +} + func (mds *mysqlDialectSuite) TestUpdateSQL() { ds := mds.GetDs("test").Update() mds.assertSQL( diff --git a/dialect/sqlite3/sqlite3.go b/dialect/sqlite3/sqlite3.go index e70222fb..08df9284 100644 --- a/dialect/sqlite3/sqlite3.go +++ b/dialect/sqlite3/sqlite3.go @@ -52,6 +52,12 @@ func DialectOptions() *goqu.SQLDialectOptions { exp.RegexpNotILikeOp: []byte("NOT REGEXP"), } opts.UseLiteralIsBools = false + opts.BitwiseOperatorLookup = map[exp.BitwiseOperation][]byte{ + exp.BitwiseOrOp: []byte("|"), + exp.BitwiseAndOp: []byte("&"), + exp.BitwiseLeftShiftOp: []byte("<<"), + exp.BitwiseRightShiftOp: []byte(">>"), + } opts.EscapedRunes = map[rune][]byte{ '\'': []byte("''"), } diff --git a/dialect/sqlite3/sqlite3_dialect_test.go b/dialect/sqlite3/sqlite3_dialect_test.go index 6a83b5f1..8aa291d2 100644 --- a/dialect/sqlite3/sqlite3_dialect_test.go +++ b/dialect/sqlite3/sqlite3_dialect_test.go @@ -132,6 +132,19 @@ func (sds *sqlite3DialectSuite) TestBooleanOperations() { ) } +func (sds *sqlite3DialectSuite) TestBitwiseOperations() { + col := goqu.C("a") + ds := sds.GetDs("test") + sds.assertSQL( + sqlTestCase{ds: ds.Where(col.BitwiseInversion()), err: "goqu: bitwise operator 'Inversion' not supported"}, + sqlTestCase{ds: ds.Where(col.BitwiseAnd(1)), sql: "SELECT * FROM `test` WHERE (`a` & 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseOr(1)), sql: "SELECT * FROM `test` WHERE (`a` | 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseXor(1)), err: "goqu: bitwise operator 'XOR' not supported"}, + sqlTestCase{ds: ds.Where(col.BitwiseLeftShift(1)), sql: "SELECT * FROM `test` WHERE (`a` << 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseRightShift(1)), sql: "SELECT * FROM `test` WHERE (`a` >> 1)"}, + ) +} + func (sds *sqlite3DialectSuite) TestForUpdate() { ds := sds.GetDs("test") sds.assertSQL( diff --git a/dialect/sqlserver/sqlserver.go b/dialect/sqlserver/sqlserver.go index 58f9ad22..9cb20d65 100644 --- a/dialect/sqlserver/sqlserver.go +++ b/dialect/sqlserver/sqlserver.go @@ -53,6 +53,12 @@ func DialectOptions() *goqu.SQLDialectOptions { exp.RegexpILikeOp: []byte("REGEXP"), exp.RegexpNotILikeOp: []byte("NOT REGEXP"), } + opts.BitwiseOperatorLookup = map[exp.BitwiseOperation][]byte{ + exp.BitwiseInversionOp: []byte("~"), + exp.BitwiseOrOp: []byte("|"), + exp.BitwiseAndOp: []byte("&"), + exp.BitwiseXorOp: []byte("^"), + } opts.FetchFragment = []byte(" FETCH FIRST ") diff --git a/dialect/sqlserver/sqlserver_dialect_test.go b/dialect/sqlserver/sqlserver_dialect_test.go new file mode 100644 index 00000000..1a150c36 --- /dev/null +++ b/dialect/sqlserver/sqlserver_dialect_test.go @@ -0,0 +1,60 @@ +package sqlserver_test + +import ( + "testing" + + "github.com/doug-martin/goqu/v9" + "github.com/doug-martin/goqu/v9/exp" + "github.com/stretchr/testify/suite" +) + +type ( + sqlserverDialectSuite struct { + suite.Suite + } + sqlTestCase struct { + ds exp.SQLExpression + sql string + err string + isPrepared bool + args []interface{} + } +) + +func (sds *sqlserverDialectSuite) GetDs(table string) *goqu.SelectDataset { + return goqu.Dialect("sqlserver").From(table) +} + +func (sds *sqlserverDialectSuite) assertSQL(cases ...sqlTestCase) { + for i, c := range cases { + actualSQL, actualArgs, err := c.ds.ToSQL() + if c.err == "" { + sds.NoError(err, "test case %d failed", i) + } else { + sds.EqualError(err, c.err, "test case %d failed", i) + } + sds.Equal(c.sql, actualSQL, "test case %d failed", i) + if c.isPrepared && c.args != nil || len(c.args) > 0 { + sds.Equal(c.args, actualArgs, "test case %d failed", i) + } else { + sds.Empty(actualArgs, "test case %d failed", i) + } + } +} + +func (sds *sqlserverDialectSuite) TestBitwiseOperations() { + col := goqu.C("a") + ds := sds.GetDs("test") + sds.assertSQL( + sqlTestCase{ds: ds.Where(col.BitwiseInversion()), sql: "SELECT * FROM \"test\" WHERE (~ \"a\")"}, + sqlTestCase{ds: ds.Where(col.BitwiseAnd(1)), sql: "SELECT * FROM \"test\" WHERE (\"a\" & 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseOr(1)), sql: "SELECT * FROM \"test\" WHERE (\"a\" | 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseXor(1)), sql: "SELECT * FROM \"test\" WHERE (\"a\" ^ 1)"}, + sqlTestCase{ds: ds.Where(col.BitwiseLeftShift(1)), err: "goqu: bitwise operator 'Left Shift' not supported"}, + sqlTestCase{ds: ds.Where(col.BitwiseRightShift(1)), err: "goqu: bitwise operator 'Right Shift' not supported"}, + ) +} + +func TestDatasetAdapterSuite(t *testing.T) { + suite.Run(t, new(sqlserverDialectSuite)) +} diff --git a/exp/bitwise.go b/exp/bitwise.go new file mode 100644 index 00000000..eede5a2f --- /dev/null +++ b/exp/bitwise.go @@ -0,0 +1,89 @@ +package exp + +type bitwise struct { + lhs Expression + rhs interface{} + op BitwiseOperation +} + +func NewBitwiseExpression(op BitwiseOperation, lhs Expression, rhs interface{}) BitwiseExpression { + return bitwise{op: op, lhs: lhs, rhs: rhs} +} + +func (b bitwise) Clone() Expression { + return NewBitwiseExpression(b.op, b.lhs.Clone(), b.rhs) +} + +func (b bitwise) RHS() interface{} { + return b.rhs +} + +func (b bitwise) LHS() Expression { + return b.lhs +} + +func (b bitwise) Op() BitwiseOperation { + return b.op +} + +func (b bitwise) Expression() Expression { return b } +func (b bitwise) As(val interface{}) AliasedExpression { return NewAliasExpression(b, val) } +func (b bitwise) Eq(val interface{}) BooleanExpression { return eq(b, val) } +func (b bitwise) Neq(val interface{}) BooleanExpression { return neq(b, val) } +func (b bitwise) Gt(val interface{}) BooleanExpression { return gt(b, val) } +func (b bitwise) Gte(val interface{}) BooleanExpression { return gte(b, val) } +func (b bitwise) Lt(val interface{}) BooleanExpression { return lt(b, val) } +func (b bitwise) Lte(val interface{}) BooleanExpression { return lte(b, val) } +func (b bitwise) Asc() OrderedExpression { return asc(b) } +func (b bitwise) Desc() OrderedExpression { return desc(b) } +func (b bitwise) Like(i interface{}) BooleanExpression { return like(b, i) } +func (b bitwise) NotLike(i interface{}) BooleanExpression { return notLike(b, i) } +func (b bitwise) ILike(i interface{}) BooleanExpression { return iLike(b, i) } +func (b bitwise) NotILike(i interface{}) BooleanExpression { return notILike(b, i) } +func (b bitwise) RegexpLike(val interface{}) BooleanExpression { return regexpLike(b, val) } +func (b bitwise) RegexpNotLike(val interface{}) BooleanExpression { return regexpNotLike(b, val) } +func (b bitwise) RegexpILike(val interface{}) BooleanExpression { return regexpILike(b, val) } +func (b bitwise) RegexpNotILike(val interface{}) BooleanExpression { return regexpNotILike(b, val) } +func (b bitwise) In(i ...interface{}) BooleanExpression { return in(b, i...) } +func (b bitwise) NotIn(i ...interface{}) BooleanExpression { return notIn(b, i...) } +func (b bitwise) Is(i interface{}) BooleanExpression { return is(b, i) } +func (b bitwise) IsNot(i interface{}) BooleanExpression { return isNot(b, i) } +func (b bitwise) IsNull() BooleanExpression { return is(b, nil) } +func (b bitwise) IsNotNull() BooleanExpression { return isNot(b, nil) } +func (b bitwise) IsTrue() BooleanExpression { return is(b, true) } +func (b bitwise) IsNotTrue() BooleanExpression { return isNot(b, true) } +func (b bitwise) IsFalse() BooleanExpression { return is(b, false) } +func (b bitwise) IsNotFalse() BooleanExpression { return isNot(b, false) } +func (b bitwise) Distinct() SQLFunctionExpression { return NewSQLFunctionExpression("DISTINCT", b) } +func (b bitwise) Between(val RangeVal) RangeExpression { return between(b, val) } +func (b bitwise) NotBetween(val RangeVal) RangeExpression { return notBetween(b, val) } + +// used internally to create a Bitwise Inversion BitwiseExpression +func bitwiseInversion(rhs Expression) BitwiseExpression { + return NewBitwiseExpression(BitwiseInversionOp, nil, rhs) +} + +// used internally to create a Bitwise OR BitwiseExpression +func bitwiseOr(lhs Expression, rhs interface{}) BitwiseExpression { + return NewBitwiseExpression(BitwiseOrOp, lhs, rhs) +} + +// used internally to create a Bitwise AND BitwiseExpression +func bitwiseAnd(lhs Expression, rhs interface{}) BitwiseExpression { + return NewBitwiseExpression(BitwiseAndOp, lhs, rhs) +} + +// used internally to create a Bitwise XOR BitwiseExpression +func bitwiseXor(lhs Expression, rhs interface{}) BitwiseExpression { + return NewBitwiseExpression(BitwiseXorOp, lhs, rhs) +} + +// used internally to create a Bitwise LEFT SHIFT BitwiseExpression +func bitwiseLeftShift(lhs Expression, rhs interface{}) BitwiseExpression { + return NewBitwiseExpression(BitwiseLeftShiftOp, lhs, rhs) +} + +// used internally to create a Bitwise RIGHT SHIFT BitwiseExpression +func bitwiseRightShift(lhs Expression, rhs interface{}) BitwiseExpression { + return NewBitwiseExpression(BitwiseRightShiftOp, lhs, rhs) +} diff --git a/exp/bitwise_test.go b/exp/bitwise_test.go new file mode 100644 index 00000000..5dcbd86f --- /dev/null +++ b/exp/bitwise_test.go @@ -0,0 +1,84 @@ +package exp_test + +import ( + "testing" + + "github.com/doug-martin/goqu/v9/exp" + "github.com/stretchr/testify/suite" +) + +type bitwiseExpressionSuite struct { + suite.Suite +} + +func TestBitwiseExpressionSuite(t *testing.T) { + suite.Run(t, &bitwiseExpressionSuite{}) +} + +func (bes *bitwiseExpressionSuite) TestClone() { + be := exp.NewBitwiseExpression(exp.BitwiseAndOp, exp.NewIdentifierExpression("", "", "col"), 1) + bes.Equal(be, be.Clone()) +} + +func (bes *bitwiseExpressionSuite) TestExpression() { + be := exp.NewBitwiseExpression(exp.BitwiseAndOp, exp.NewIdentifierExpression("", "", "col"), 1) + bes.Equal(be, be.Expression()) +} + +func (bes *bitwiseExpressionSuite) TestAs() { + be := exp.NewBitwiseExpression(exp.BitwiseInversionOp, exp.NewIdentifierExpression("", "", "col"), 1) + bes.Equal(exp.NewAliasExpression(be, "a"), be.As("a")) +} + +func (bes *bitwiseExpressionSuite) TestAsc() { + be := exp.NewBitwiseExpression(exp.BitwiseAndOp, exp.NewIdentifierExpression("", "", "col"), 1) + bes.Equal(exp.NewOrderedExpression(be, exp.AscDir, exp.NoNullsSortType), be.Asc()) +} + +func (bes *bitwiseExpressionSuite) TestDesc() { + be := exp.NewBitwiseExpression(exp.BitwiseOrOp, exp.NewIdentifierExpression("", "", "col"), 1) + bes.Equal(exp.NewOrderedExpression(be, exp.DescSortDir, exp.NoNullsSortType), be.Desc()) +} + +func (bes *bitwiseExpressionSuite) TestAllOthers() { + be := exp.NewBitwiseExpression(exp.BitwiseRightShiftOp, exp.NewIdentifierExpression("", "", "col"), 1) + rv := exp.NewRangeVal(1, 2) + pattern := "cast like%" + inVals := []interface{}{1, 2} + testCases := []struct { + Ex exp.Expression + Expected exp.Expression + }{ + {Ex: be.Eq(1), Expected: exp.NewBooleanExpression(exp.EqOp, be, 1)}, + {Ex: be.Neq(1), Expected: exp.NewBooleanExpression(exp.NeqOp, be, 1)}, + {Ex: be.Gt(1), Expected: exp.NewBooleanExpression(exp.GtOp, be, 1)}, + {Ex: be.Gte(1), Expected: exp.NewBooleanExpression(exp.GteOp, be, 1)}, + {Ex: be.Lt(1), Expected: exp.NewBooleanExpression(exp.LtOp, be, 1)}, + {Ex: be.Lte(1), Expected: exp.NewBooleanExpression(exp.LteOp, be, 1)}, + {Ex: be.Between(rv), Expected: exp.NewRangeExpression(exp.BetweenOp, be, rv)}, + {Ex: be.NotBetween(rv), Expected: exp.NewRangeExpression(exp.NotBetweenOp, be, rv)}, + {Ex: be.Like(pattern), Expected: exp.NewBooleanExpression(exp.LikeOp, be, pattern)}, + {Ex: be.NotLike(pattern), Expected: exp.NewBooleanExpression(exp.NotLikeOp, be, pattern)}, + {Ex: be.ILike(pattern), Expected: exp.NewBooleanExpression(exp.ILikeOp, be, pattern)}, + {Ex: be.NotILike(pattern), Expected: exp.NewBooleanExpression(exp.NotILikeOp, be, pattern)}, + {Ex: be.RegexpLike(pattern), Expected: exp.NewBooleanExpression(exp.RegexpLikeOp, be, pattern)}, + {Ex: be.RegexpNotLike(pattern), Expected: exp.NewBooleanExpression(exp.RegexpNotLikeOp, be, pattern)}, + {Ex: be.RegexpILike(pattern), Expected: exp.NewBooleanExpression(exp.RegexpILikeOp, be, pattern)}, + {Ex: be.RegexpNotILike(pattern), Expected: exp.NewBooleanExpression(exp.RegexpNotILikeOp, be, pattern)}, + {Ex: be.In(inVals), Expected: exp.NewBooleanExpression(exp.InOp, be, inVals)}, + {Ex: be.NotIn(inVals), Expected: exp.NewBooleanExpression(exp.NotInOp, be, inVals)}, + {Ex: be.Is(true), Expected: exp.NewBooleanExpression(exp.IsOp, be, true)}, + {Ex: be.IsNot(true), Expected: exp.NewBooleanExpression(exp.IsNotOp, be, true)}, + {Ex: be.IsNull(), Expected: exp.NewBooleanExpression(exp.IsOp, be, nil)}, + {Ex: be.IsNotNull(), Expected: exp.NewBooleanExpression(exp.IsNotOp, be, nil)}, + {Ex: be.IsTrue(), Expected: exp.NewBooleanExpression(exp.IsOp, be, true)}, + {Ex: be.IsNotTrue(), Expected: exp.NewBooleanExpression(exp.IsNotOp, be, true)}, + {Ex: be.IsFalse(), Expected: exp.NewBooleanExpression(exp.IsOp, be, false)}, + {Ex: be.IsNotFalse(), Expected: exp.NewBooleanExpression(exp.IsNotOp, be, false)}, + {Ex: be.Distinct(), Expected: exp.NewSQLFunctionExpression("DISTINCT", be)}, + } + + for _, tc := range testCases { + bes.Equal(tc.Expected, tc.Ex) + } +} diff --git a/exp/exp.go b/exp/exp.go index 50e5473b..ec33c143 100644 --- a/exp/exp.go +++ b/exp/exp.go @@ -138,6 +138,27 @@ type ( // Used internally by update sql Set(interface{}) UpdateExpression } + + Bitwiseable interface { + // Creates a Bit Operation Expresion for sql ~ + // I("col").BitiInversion() // (~ "col") + BitwiseInversion() BitwiseExpression + // Creates a Bit Operation Expresion for sql | + // I("col").BitOr(1) // ("col" | 1) + BitwiseOr(interface{}) BitwiseExpression + // Creates a Bit Operation Expresion for sql & + // I("col").BitAnd(1) // ("col" & 1) + BitwiseAnd(interface{}) BitwiseExpression + // Creates a Bit Operation Expresion for sql ^ + // I("col").BitXor(1) // ("col" ^ 1) + BitwiseXor(interface{}) BitwiseExpression + // Creates a Bit Operation Expresion for sql << + // I("col").BitLeftShift(1) // ("col" << 1) + BitwiseLeftShift(interface{}) BitwiseExpression + // Creates a Bit Operation Expresion for sql >> + // I("col").BitRighttShift(1) // ("col" >> 1) + BitwiseRightShift(interface{}) BitwiseExpression + } ) type ( @@ -195,6 +216,26 @@ type ( // The right hand side of the expression could be a primitive value, dataset, or expression RHS() interface{} } + + BitwiseOperation int + BitwiseExpression interface { + Expression + Aliaseable + Comparable + Isable + Inable + Likeable + Rangeable + Orderable + Distinctable + // Returns the operator for the expression + Op() BitwiseOperation + // The left hand side of the expression (e.g. I("a") + LHS() Expression + // The right hand side of the expression could be a primitive value, dataset, or expression + RHS() interface{} + } + // An Expression that represents another Expression casted to a SQL type CastExpression interface { Expression @@ -276,6 +317,7 @@ type ( Updateable Distinctable Castable + Bitwiseable // returns true if this identifier has more more than on part (Schema, Table or Col) // "schema" -> true //cant qualify anymore // "schema.table" -> true @@ -345,6 +387,7 @@ type ( Likeable Rangeable Orderable + Bitwiseable // Returns the literal sql Literal() string // Arguments to be replaced within the sql @@ -547,6 +590,13 @@ const ( RegexpNotILikeOp betweenStr = "between" + + BitwiseInversionOp BitwiseOperation = iota + BitwiseOrOp + BitwiseAndOp + BitwiseXorOp + BitwiseLeftShiftOp + BitwiseRightShiftOp ) var ( @@ -624,6 +674,24 @@ func (bo BooleanOperation) String() string { return fmt.Sprintf("%d", bo) } +func (bi BitwiseOperation) String() string { + switch bi { + case BitwiseInversionOp: + return "Inversion" + case BitwiseOrOp: + return "OR" + case BitwiseAndOp: + return "AND" + case BitwiseXorOp: + return "XOR" + case BitwiseLeftShiftOp: + return "Left Shift" + case BitwiseRightShiftOp: + return "Right Shift" + } + return fmt.Sprintf("%d", bi) +} + func (ro RangeOperation) String() string { switch ro { case BetweenOp: diff --git a/exp/ident.go b/exp/ident.go index 1cb9d1c4..aebbbd5a 100644 --- a/exp/ident.go +++ b/exp/ident.go @@ -160,6 +160,28 @@ func (i identifier) Lt(val interface{}) BooleanExpression { return lt(i, val) } // (e.g "my_col" <= 1) func (i identifier) Lte(val interface{}) BooleanExpression { return lte(i, val) } +// Returns a BooleanExpression for bit inversion (e.g ~ "my_col") +func (i identifier) BitwiseInversion() BitwiseExpression { return bitwiseInversion(i) } + +// Returns a BooleanExpression for bit OR (e.g "my_col" | 1) +func (i identifier) BitwiseOr(val interface{}) BitwiseExpression { return bitwiseOr(i, val) } + +// Returns a BooleanExpression for bit AND (e.g "my_col" & 1) +func (i identifier) BitwiseAnd(val interface{}) BitwiseExpression { return bitwiseAnd(i, val) } + +// Returns a BooleanExpression for bit XOR (e.g "my_col" ^ 1) +func (i identifier) BitwiseXor(val interface{}) BitwiseExpression { return bitwiseXor(i, val) } + +// Returns a BooleanExpression for bit LEFT shift (e.g "my_col" << 1) +func (i identifier) BitwiseLeftShift(val interface{}) BitwiseExpression { + return bitwiseLeftShift(i, val) +} + +// Returns a BooleanExpression for bit RIGHT shift (e.g "my_col" >> 1) +func (i identifier) BitwiseRightShift(val interface{}) BitwiseExpression { + return bitwiseRightShift(i, val) +} + // Returns a BooleanExpression for checking that a identifier is in a list of values or (e.g "my_col" > 1) func (i identifier) In(vals ...interface{}) BooleanExpression { return in(i, vals...) } func (i identifier) NotIn(vals ...interface{}) BooleanExpression { return notIn(i, vals...) } diff --git a/exp/ident_test.go b/exp/ident_test.go index 3e2d7ac2..d74d83ed 100644 --- a/exp/ident_test.go +++ b/exp/ident_test.go @@ -198,6 +198,7 @@ func (ies *identifierExpressionSuite) TestAllOthers() { rv := exp.NewRangeVal(1, 2) pattern := "ident like%" inVals := []interface{}{1, 2} + bitwiseVals := 2 testCases := []struct { Ex exp.Expression Expected exp.Expression @@ -232,6 +233,12 @@ func (ies *identifierExpressionSuite) TestAllOthers() { {Ex: ident.IsFalse(), Expected: exp.NewBooleanExpression(exp.IsOp, ident, false)}, {Ex: ident.IsNotFalse(), Expected: exp.NewBooleanExpression(exp.IsNotOp, ident, false)}, {Ex: ident.Distinct(), Expected: exp.NewSQLFunctionExpression("DISTINCT", ident)}, + {Ex: ident.BitwiseInversion(), Expected: exp.NewBitwiseExpression(exp.BitwiseInversionOp, nil, ident)}, + {Ex: ident.BitwiseOr(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseOrOp, ident, bitwiseVals)}, + {Ex: ident.BitwiseAnd(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseAndOp, ident, bitwiseVals)}, + {Ex: ident.BitwiseXor(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseXorOp, ident, bitwiseVals)}, + {Ex: ident.BitwiseLeftShift(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseLeftShiftOp, ident, bitwiseVals)}, + {Ex: ident.BitwiseRightShift(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseRightShiftOp, ident, bitwiseVals)}, } for _, tc := range testCases { diff --git a/exp/literal.go b/exp/literal.go index f19775d6..da087751 100644 --- a/exp/literal.go +++ b/exp/literal.go @@ -69,3 +69,12 @@ func (l literal) IsTrue() BooleanExpression { return is(l func (l literal) IsNotTrue() BooleanExpression { return isNot(l, true) } func (l literal) IsFalse() BooleanExpression { return is(l, false) } func (l literal) IsNotFalse() BooleanExpression { return isNot(l, false) } + +func (l literal) BitwiseInversion() BitwiseExpression { return bitwiseInversion(l) } +func (l literal) BitwiseOr(val interface{}) BitwiseExpression { return bitwiseOr(l, val) } +func (l literal) BitwiseAnd(val interface{}) BitwiseExpression { return bitwiseAnd(l, val) } +func (l literal) BitwiseXor(val interface{}) BitwiseExpression { return bitwiseXor(l, val) } +func (l literal) BitwiseLeftShift(val interface{}) BitwiseExpression { return bitwiseLeftShift(l, val) } +func (l literal) BitwiseRightShift(val interface{}) BitwiseExpression { + return bitwiseRightShift(l, val) +} diff --git a/exp/literal_test.go b/exp/literal_test.go index e64166b2..42f00a8c 100644 --- a/exp/literal_test.go +++ b/exp/literal_test.go @@ -39,6 +39,7 @@ func (les *literalExpressionSuite) TestAllOthers() { rv := exp.NewRangeVal(1, 2) pattern := "literal like%" inVals := []interface{}{1, 2} + bitwiseVals := 2 testCases := []struct { Ex exp.Expression Expected exp.Expression @@ -72,6 +73,12 @@ func (les *literalExpressionSuite) TestAllOthers() { {Ex: le.IsNotTrue(), Expected: exp.NewBooleanExpression(exp.IsNotOp, le, true)}, {Ex: le.IsFalse(), Expected: exp.NewBooleanExpression(exp.IsOp, le, false)}, {Ex: le.IsNotFalse(), Expected: exp.NewBooleanExpression(exp.IsNotOp, le, false)}, + {Ex: le.BitwiseInversion(), Expected: exp.NewBitwiseExpression(exp.BitwiseInversionOp, nil, le)}, + {Ex: le.BitwiseOr(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseOrOp, le, bitwiseVals)}, + {Ex: le.BitwiseAnd(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseAndOp, le, bitwiseVals)}, + {Ex: le.BitwiseXor(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseXorOp, le, bitwiseVals)}, + {Ex: le.BitwiseLeftShift(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseLeftShiftOp, le, bitwiseVals)}, + {Ex: le.BitwiseRightShift(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseRightShiftOp, le, bitwiseVals)}, } for _, tc := range testCases { diff --git a/sqlgen/expression_sql_generator.go b/sqlgen/expression_sql_generator.go index a5f10261..82ce15c5 100644 --- a/sqlgen/expression_sql_generator.go +++ b/sqlgen/expression_sql_generator.go @@ -53,6 +53,10 @@ func errUnsupportedBooleanExpressionOperator(op exp.BooleanOperation) error { return errors.New("boolean operator '%+v' not supported", op) } +func errUnsupportedBitwiseExpressionOperator(op exp.BitwiseOperation) error { + return errors.New("bitwise operator '%+v' not supported", op) +} + func errUnsupportedRangeExpressionOperator(op exp.RangeOperation) error { return errors.New("range operator %+v not supported", op) } @@ -170,6 +174,8 @@ func (esg *expressionSQLGenerator) expressionSQL(b sb.SQLBuilder, expression exp esg.aliasedExpressionSQL(b, e) case exp.BooleanExpression: esg.booleanExpressionSQL(b, e) + case exp.BitwiseExpression: + esg.bitwiseExpressionSQL(b, e) case exp.RangeExpression: esg.rangeExpressionSQL(b, e) case exp.OrderedExpression: @@ -420,6 +426,28 @@ func (esg *expressionSQLGenerator) booleanExpressionSQL(b sb.SQLBuilder, operato b.WriteRunes(esg.dialectOptions.RightParenRune) } +// Generates SQL for a BitwiseExpresion (e.g. I("a").BitwiseOr(2) - > "a" | 2) +func (esg *expressionSQLGenerator) bitwiseExpressionSQL(b sb.SQLBuilder, operator exp.BitwiseExpression) { + b.WriteRunes(esg.dialectOptions.LeftParenRune) + + if operator.LHS() != nil { + esg.Generate(b, operator.LHS()) + b.WriteRunes(esg.dialectOptions.SpaceRune) + } + + operatorOp := operator.Op() + if val, ok := esg.dialectOptions.BitwiseOperatorLookup[operatorOp]; ok { + b.Write(val) + } else { + b.SetError(errUnsupportedBitwiseExpressionOperator(operatorOp)) + return + } + + b.WriteRunes(esg.dialectOptions.SpaceRune) + esg.Generate(b, operator.RHS()) + b.WriteRunes(esg.dialectOptions.RightParenRune) +} + // Generates SQL for a RangeExpresion (e.g. I("a").Between(RangeVal{Start:2,End:5}) -> "a" BETWEEN 2 AND 5) func (esg *expressionSQLGenerator) rangeExpressionSQL(b sb.SQLBuilder, operator exp.RangeExpression) { b.WriteRunes(esg.dialectOptions.LeftParenRune) diff --git a/sqlgen/expression_sql_generator_test.go b/sqlgen/expression_sql_generator_test.go index 1b88eada..b07dc72d 100644 --- a/sqlgen/expression_sql_generator_test.go +++ b/sqlgen/expression_sql_generator_test.go @@ -616,6 +616,41 @@ func (esgs *expressionSQLGeneratorSuite) TestGenerate_BooleanExpression() { ) } +func (esgs *expressionSQLGeneratorSuite) TestGenerate_BitwiseExpression() { + ident := exp.NewIdentifierExpression("", "", "a") + esgs.assertCases( + sqlgen.NewExpressionSQLGenerator("test", sqlgen.DefaultDialectOptions()), + expressionTestCase{val: ident.BitwiseInversion(), sql: `(~ "a")`}, + expressionTestCase{val: ident.BitwiseInversion(), sql: `(~ "a")`, isPrepared: true}, + + expressionTestCase{val: ident.BitwiseAnd(1), sql: `("a" & 1)`}, + expressionTestCase{val: ident.BitwiseAnd(1), sql: `("a" & ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.BitwiseOr(1), sql: `("a" | 1)`}, + expressionTestCase{val: ident.BitwiseOr(1), sql: `("a" | ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.BitwiseXor(1), sql: `("a" # 1)`}, + expressionTestCase{val: ident.BitwiseXor(1), sql: `("a" # ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.BitwiseLeftShift(1), sql: `("a" << 1)`}, + expressionTestCase{val: ident.BitwiseLeftShift(1), sql: `("a" << ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + + expressionTestCase{val: ident.BitwiseRightShift(1), sql: `("a" >> 1)`}, + expressionTestCase{val: ident.BitwiseRightShift(1), sql: `("a" >> ?)`, isPrepared: true, args: []interface{}{int64(1)}}, + ) + + opts := sqlgen.DefaultDialectOptions() + opts.BitwiseOperatorLookup = map[exp.BitwiseOperation][]byte{} + esgs.assertCases( + sqlgen.NewExpressionSQLGenerator("test", opts), + expressionTestCase{val: ident.BitwiseInversion(), err: "goqu: bitwise operator 'Inversion' not supported"}, + expressionTestCase{val: ident.BitwiseAnd(1), err: "goqu: bitwise operator 'AND' not supported"}, + expressionTestCase{val: ident.BitwiseOr(1), err: "goqu: bitwise operator 'OR' not supported"}, + expressionTestCase{val: ident.BitwiseXor(1), err: "goqu: bitwise operator 'XOR' not supported"}, + expressionTestCase{val: ident.BitwiseLeftShift(1), err: "goqu: bitwise operator 'Left Shift' not supported"}, + expressionTestCase{val: ident.BitwiseRightShift(1), err: "goqu: bitwise operator 'Right Shift' not supported"}, + ) +} func (esgs *expressionSQLGeneratorSuite) TestGenerate_RangeExpression() { betweenNum := exp.NewIdentifierExpression("", "", "a"). Between(exp.NewRangeVal(1, 2)) diff --git a/sqlgen/sql_dialect_options.go b/sqlgen/sql_dialect_options.go index b8292e5e..3e6947d1 100644 --- a/sqlgen/sql_dialect_options.go +++ b/sqlgen/sql_dialect_options.go @@ -215,6 +215,16 @@ type ( // exp.RegexpNotILikeOp: []byte("!~*"), // }) BooleanOperatorLookup map[exp.BooleanOperation][]byte + // A map used to look up BitwiseOperations and their SQL equivalents + // (Default=map[exp.BitwiseOperation][]byte{ + // exp.BitwiseInversionOp: []byte("~"), + // exp.BitwiseOrOp: []byte("|"), + // exp.BitwiseAndOp: []byte("&"), + // exp.BitwiseXorOp: []byte("#"), + // exp.BitwiseLeftShiftOp: []byte("<<"), + // exp.BitwiseRightShiftOp: []byte(">>"), + // }), + BitwiseOperatorLookup map[exp.BitwiseOperation][]byte // A map used to look up RangeOperations and their SQL equivalents // (Default=map[exp.RangeOperation][]byte{ // exp.BetweenOp: []byte("BETWEEN"), @@ -509,6 +519,14 @@ func DefaultDialectOptions() *SQLDialectOptions { exp.RegexpILikeOp: []byte("~*"), exp.RegexpNotILikeOp: []byte("!~*"), }, + BitwiseOperatorLookup: map[exp.BitwiseOperation][]byte{ + exp.BitwiseInversionOp: []byte("~"), + exp.BitwiseOrOp: []byte("|"), + exp.BitwiseAndOp: []byte("&"), + exp.BitwiseXorOp: []byte("#"), + exp.BitwiseLeftShiftOp: []byte("<<"), + exp.BitwiseRightShiftOp: []byte(">>"), + }, RangeOperatorLookup: map[exp.RangeOperation][]byte{ exp.BetweenOp: []byte("BETWEEN"), exp.NotBetweenOp: []byte("NOT BETWEEN"), From a30353d7fe9facf2344b504a096195ca97f549a8 Mon Sep 17 00:00:00 2001 From: Doug Martin Date: Wed, 6 Oct 2021 15:14:33 -0600 Subject: [PATCH 2/5] chore: fix linting issues --- exp/bitwise_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exp/bitwise_test.go b/exp/bitwise_test.go index 5dcbd86f..419728ff 100644 --- a/exp/bitwise_test.go +++ b/exp/bitwise_test.go @@ -43,7 +43,7 @@ func (bes *bitwiseExpressionSuite) TestDesc() { func (bes *bitwiseExpressionSuite) TestAllOthers() { be := exp.NewBitwiseExpression(exp.BitwiseRightShiftOp, exp.NewIdentifierExpression("", "", "col"), 1) rv := exp.NewRangeVal(1, 2) - pattern := "cast like%" + pattern := "bitwiseExp like%" inVals := []interface{}{1, 2} testCases := []struct { Ex exp.Expression From aa6e818e61619e523ea83cab6e70600bf8fd2970 Mon Sep 17 00:00:00 2001 From: Doug Martin Date: Wed, 6 Oct 2021 15:16:36 -0600 Subject: [PATCH 3/5] chore: release v9.17.0 * [Feature] Add support bitwise operations #303 - @XIELongDragon --- HISTORY.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 44fed621..a92b020d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,6 @@ +# v9.17.0 +* [Feature] Add support bitwise operations [#303](https://github.com/doug-martin/goqu/pull/303) - [@XIELongDragon](https://github.com/XIELongDragon) + # v9.16.0 * [FEATURE] Allow ordering by case expression [#282](https://github.com/doug-martin/goqu/issues/282), [#292](https://github.com/doug-martin/goqu/pull/292) From e65902a950772f8b68fa68ea4ed35262e291b189 Mon Sep 17 00:00:00 2001 From: Juraj Bubniak Date: Fri, 24 Sep 2021 13:36:50 +0200 Subject: [PATCH 4/5] feat: add support for specifying tables to be locked in ForUpdate, ForNoKeyUpdate, ForKeyShare, ForShare --- README.md | 4 +-- dialect/sqlite3/sqlite3.go | 1 + dialect/sqlserver/sqlserver.go | 1 + docs/selecting.md | 26 ++++++++++++++++ exp/lock.go | 9 +++++- select_dataset.go | 20 ++++++------ select_dataset_example_test.go | 32 +++++++++++++++++++ select_dataset_test.go | 48 +++++++++++++++++++++++++++++ sqlgen/select_sql_generator.go | 17 +++++++++- sqlgen/select_sql_generator_test.go | 13 ++++++++ sqlgen/sql_dialect_options.go | 3 ++ 11 files changed, 160 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index ee99c476..1f470fdb 100644 --- a/README.md +++ b/README.md @@ -277,7 +277,7 @@ New features and/or enhancements are great and I encourage you to either submit 1. The use case 2. A short example -If you are issuing a PR also also include the following +If you are issuing a PR also include the following 1. Tests - otherwise the PR will not be merged 2. Documentation - otherwise the PR will not be merged @@ -297,7 +297,7 @@ go test -v -race ./... You can also run the tests in a container using [docker-compose](https://docs.docker.com/compose/). ```sh -GO_VERSION=latest docker-compose run goqu +MYSQL_VERSION=8 POSTGRES_VERSION=13.4 SQLSERVER_VERSION=2017-CU8-ubuntu GO_VERSION=latest docker-compose run goqu ``` ## License diff --git a/dialect/sqlite3/sqlite3.go b/dialect/sqlite3/sqlite3.go index 08df9284..40ddb2a5 100644 --- a/dialect/sqlite3/sqlite3.go +++ b/dialect/sqlite3/sqlite3.go @@ -66,6 +66,7 @@ func DialectOptions() *goqu.SQLDialectOptions { opts.ConflictDoUpdateFragment = []byte(" DO UPDATE SET ") opts.ConflictDoNothingFragment = []byte(" DO NOTHING ") opts.ForUpdateFragment = []byte("") + opts.OfFragment = []byte("") opts.NowaitFragment = []byte("") return opts } diff --git a/dialect/sqlserver/sqlserver.go b/dialect/sqlserver/sqlserver.go index 9cb20d65..c1fca904 100644 --- a/dialect/sqlserver/sqlserver.go +++ b/dialect/sqlserver/sqlserver.go @@ -86,6 +86,7 @@ func DialectOptions() *goqu.SQLDialectOptions { 0x1a: []byte("\\x1a"), } + opts.OfFragment = []byte("") opts.ConflictFragment = []byte("") opts.ConflictDoUpdateFragment = []byte("") opts.ConflictDoNothingFragment = []byte("") diff --git a/docs/selecting.md b/docs/selecting.md index ebab06c4..f35f2aa1 100644 --- a/docs/selecting.md +++ b/docs/selecting.md @@ -14,6 +14,7 @@ * [`Window`](#window) * [`With`](#with) * [`SetError`](#seterror) + * [`ForUpdate`](#forupdate) * Executing Queries * [`ScanStructs`](#scan-structs) - Scans rows into a slice of structs * [`ScanStruct`](#scan-struct) - Scans a row into a slice a struct, returns false if a row wasnt found @@ -875,6 +876,31 @@ name is empty name is empty ``` + +**[`ForUpdate`](https://godoc.org/github.com/doug-martin/goqu/#SelectDataset.ForUpdate)** + +```go +sql, _, _ := goqu.From("test").ForUpdate(exp.Wait).ToSQL() +fmt.Println(sql) +``` + +Output: +```sql +SELECT * FROM "test" FOR UPDATE +``` + +If your dialect supports FOR UPDATE OF you provide tables to be locked as variable arguments to the ForUpdate method. + +```go +sql, _, _ := goqu.From("test").ForUpdate(exp.Wait, goqu.T("test")).ToSQL() +fmt.Println(sql) +``` + +Output: +```sql +SELECT * FROM "test" FOR UPDATE OF "test" +``` + ## Executing Queries To execute your query use [`goqu.Database#From`](https://godoc.org/github.com/doug-martin/goqu/#Database.From) to create your dataset diff --git a/exp/lock.go b/exp/lock.go index e4548a22..9b8bf72e 100644 --- a/exp/lock.go +++ b/exp/lock.go @@ -6,10 +6,12 @@ type ( Lock interface { Strength() LockStrength WaitOption() WaitOption + Of() []IdentifierExpression } lock struct { strength LockStrength waitOption WaitOption + of []IdentifierExpression } ) @@ -25,10 +27,11 @@ const ( SkipLocked ) -func NewLock(strength LockStrength, option WaitOption) Lock { +func NewLock(strength LockStrength, option WaitOption, of ...IdentifierExpression) Lock { return lock{ strength: strength, waitOption: option, + of: of, } } @@ -39,3 +42,7 @@ func (l lock) Strength() LockStrength { func (l lock) WaitOption() WaitOption { return l.waitOption } + +func (l lock) Of() []IdentifierExpression { + return l.of +} diff --git a/select_dataset.go b/select_dataset.go index d027a372..775c387d 100644 --- a/select_dataset.go +++ b/select_dataset.go @@ -359,27 +359,27 @@ func (sd *SelectDataset) ClearWhere() *SelectDataset { } // Adds a FOR UPDATE clause. See examples. -func (sd *SelectDataset) ForUpdate(waitOption exp.WaitOption) *SelectDataset { - return sd.withLock(exp.ForUpdate, waitOption) +func (sd *SelectDataset) ForUpdate(waitOption exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.withLock(exp.ForUpdate, waitOption, of...) } // Adds a FOR NO KEY UPDATE clause. See examples. -func (sd *SelectDataset) ForNoKeyUpdate(waitOption exp.WaitOption) *SelectDataset { - return sd.withLock(exp.ForNoKeyUpdate, waitOption) +func (sd *SelectDataset) ForNoKeyUpdate(waitOption exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.withLock(exp.ForNoKeyUpdate, waitOption, of...) } // Adds a FOR KEY SHARE clause. See examples. -func (sd *SelectDataset) ForKeyShare(waitOption exp.WaitOption) *SelectDataset { - return sd.withLock(exp.ForKeyShare, waitOption) +func (sd *SelectDataset) ForKeyShare(waitOption exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.withLock(exp.ForKeyShare, waitOption, of...) } // Adds a FOR SHARE clause. See examples. -func (sd *SelectDataset) ForShare(waitOption exp.WaitOption) *SelectDataset { - return sd.withLock(exp.ForShare, waitOption) +func (sd *SelectDataset) ForShare(waitOption exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.withLock(exp.ForShare, waitOption, of...) } -func (sd *SelectDataset) withLock(strength exp.LockStrength, option exp.WaitOption) *SelectDataset { - return sd.copy(sd.clauses.SetLock(exp.NewLock(strength, option))) +func (sd *SelectDataset) withLock(strength exp.LockStrength, option exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.copy(sd.clauses.SetLock(exp.NewLock(strength, option, of...))) } // Adds a GROUP BY clause. See examples. diff --git a/select_dataset_example_test.go b/select_dataset_example_test.go index 35f62794..e8d7fc2d 100644 --- a/select_dataset_example_test.go +++ b/select_dataset_example_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/doug-martin/goqu/v9" + "github.com/doug-martin/goqu/v9/exp" "github.com/lib/pq" ) @@ -1651,3 +1652,34 @@ func ExampleSelectDataset_Executor_scannerScanVal() { // Sally // Vinita } + +func ExampleForUpdate() { + sql, args, _ := goqu.From("test").ForUpdate(exp.Wait).ToSQL() + fmt.Println(sql, args) + + // Output: + // SELECT * FROM "test" FOR UPDATE [] +} + +func ExampleForUpdate_of() { + sql, args, _ := goqu.From("test").ForUpdate(exp.Wait, goqu.T("test")).ToSQL() + fmt.Println(sql, args) + + // Output: + // SELECT * FROM "test" FOR UPDATE OF "test" [] +} + +func ExampleForUpdate_ofMultiple() { + sql, args, _ := goqu.From("table1").Join( + goqu.T("table2"), + goqu.On(goqu.I("table2.id").Eq(goqu.I("table1.id"))), + ).ForUpdate( + exp.Wait, + goqu.T("table1"), + goqu.T("table2"), + ).ToSQL() + fmt.Println(sql, args) + + // Output: + // SELECT * FROM "table1" INNER JOIN "table2" ON ("table2"."id" = "table1"."id") FOR UPDATE OF "table1", "table2" [] +} diff --git a/select_dataset_test.go b/select_dataset_test.go index 63fdd997..6579cfad 100644 --- a/select_dataset_test.go +++ b/select_dataset_test.go @@ -674,6 +674,18 @@ func (sds *selectDatasetSuite) TestForUpdate() { SetFrom(exp.NewColumnListExpression("test")). SetLock(exp.NewLock(exp.ForUpdate, goqu.NoWait)), }, + selectTestCase{ + ds: bd.ForUpdate(goqu.NoWait, goqu.T("table1")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForUpdate, goqu.NoWait, goqu.T("table1"))), + }, + selectTestCase{ + ds: bd.ForUpdate(goqu.NoWait, goqu.T("table1"), goqu.T("table2")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForUpdate, goqu.NoWait, goqu.T("table1"), goqu.T("table2"))), + }, selectTestCase{ ds: bd, clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), @@ -690,6 +702,18 @@ func (sds *selectDatasetSuite) TestForNoKeyUpdate() { SetFrom(exp.NewColumnListExpression("test")). SetLock(exp.NewLock(exp.ForNoKeyUpdate, goqu.NoWait)), }, + selectTestCase{ + ds: bd.ForNoKeyUpdate(goqu.NoWait, goqu.T("table1")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForNoKeyUpdate, goqu.NoWait, goqu.T("table1"))), + }, + selectTestCase{ + ds: bd.ForNoKeyUpdate(goqu.NoWait, goqu.T("table1"), goqu.T("table2")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForNoKeyUpdate, goqu.NoWait, goqu.T("table1"), goqu.T("table2"))), + }, selectTestCase{ ds: bd, clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), @@ -706,6 +730,18 @@ func (sds *selectDatasetSuite) TestForKeyShare() { SetFrom(exp.NewColumnListExpression("test")). SetLock(exp.NewLock(exp.ForKeyShare, goqu.NoWait)), }, + selectTestCase{ + ds: bd.ForKeyShare(goqu.NoWait, goqu.T("table1")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForKeyShare, goqu.NoWait, goqu.T("table1"))), + }, + selectTestCase{ + ds: bd.ForKeyShare(goqu.NoWait, goqu.T("table1"), goqu.T("table2")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForKeyShare, goqu.NoWait, goqu.T("table1"), goqu.T("table2"))), + }, selectTestCase{ ds: bd, clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), @@ -722,6 +758,18 @@ func (sds *selectDatasetSuite) TestForShare() { SetFrom(exp.NewColumnListExpression("test")). SetLock(exp.NewLock(exp.ForShare, goqu.NoWait)), }, + selectTestCase{ + ds: bd.ForShare(goqu.NoWait, goqu.T("table1")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForShare, goqu.NoWait, goqu.T("table1"))), + }, + selectTestCase{ + ds: bd.ForShare(goqu.NoWait, goqu.T("table1"), goqu.T("table2")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForShare, goqu.NoWait, goqu.T("table1"), goqu.T("table2"))), + }, selectTestCase{ ds: bd, clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), diff --git a/sqlgen/select_sql_generator.go b/sqlgen/select_sql_generator.go index 7bbb4142..de322910 100644 --- a/sqlgen/select_sql_generator.go +++ b/sqlgen/select_sql_generator.go @@ -196,8 +196,23 @@ func (ssg *selectSQLGenerator) ForSQL(b sb.SQLBuilder, lockingClause exp.Lock) { case exp.ForKeyShare: b.Write(ssg.DialectOptions().ForKeyShareFragment) } + + of := lockingClause.Of() + if ofLen := len(of); ofLen > 0 { + if ofFragment := ssg.DialectOptions().OfFragment; len(ofFragment) > 0 { + b.Write(ofFragment) + for i, table := range of { + ssg.ExpressionSQLGenerator().Generate(b, table) + if i < ofLen-1 { + b.WriteRunes(ssg.DialectOptions().CommaRune, ssg.DialectOptions().SpaceRune) + } + } + b.WriteRunes(ssg.DialectOptions().SpaceRune) + } + } + // the WAIT case is the default in Postgres, and is what you get if you don't specify NOWAIT or - // SKIP LOCKED. There's no special syntax for it in PG, so we don't do anything for it here + // SKIP LOCKED. There's no special syntax for it in PG, so we don't do anything for it here switch lockingClause.WaitOption() { case exp.Wait: return diff --git a/sqlgen/select_sql_generator_test.go b/sqlgen/select_sql_generator_test.go index ce048545..90394b18 100644 --- a/sqlgen/select_sql_generator_test.go +++ b/sqlgen/select_sql_generator_test.go @@ -3,6 +3,7 @@ package sqlgen_test import ( "testing" + "github.com/doug-martin/goqu/v9" "github.com/doug-martin/goqu/v9/exp" "github.com/doug-martin/goqu/v9/internal/errors" "github.com/doug-martin/goqu/v9/internal/sb" @@ -506,6 +507,7 @@ func (ssgs *selectSQLGeneratorSuite) TestToSelectSQL_withFor() { opts.ForNoKeyUpdateFragment = []byte(" for no key update ") opts.ForShareFragment = []byte(" for share ") opts.ForKeyShareFragment = []byte(" for key share ") + opts.OfFragment = []byte("of ") opts.NowaitFragment = []byte("nowait") opts.SkipLockedFragment = []byte("skip locked") @@ -513,10 +515,13 @@ func (ssgs *selectSQLGeneratorSuite) TestToSelectSQL_withFor() { scFnW := sc.SetLock(exp.NewLock(exp.ForNolock, exp.Wait)) scFnNw := sc.SetLock(exp.NewLock(exp.ForNolock, exp.NoWait)) scFnSl := sc.SetLock(exp.NewLock(exp.ForNolock, exp.SkipLocked)) + scFnSlOf := sc.SetLock(exp.NewLock(exp.ForNolock, exp.SkipLocked, goqu.T("my_table"))) scFsW := sc.SetLock(exp.NewLock(exp.ForShare, exp.Wait)) scFsNw := sc.SetLock(exp.NewLock(exp.ForShare, exp.NoWait)) scFsSl := sc.SetLock(exp.NewLock(exp.ForShare, exp.SkipLocked)) + scFsSlOf := sc.SetLock(exp.NewLock(exp.ForShare, exp.SkipLocked, goqu.T("my_table"))) + scFsSlOfMulti := sc.SetLock(exp.NewLock(exp.ForShare, exp.SkipLocked, goqu.T("my_table"), goqu.T("table2"))) scFksW := sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.Wait)) scFksNw := sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.NoWait)) @@ -539,6 +544,8 @@ func (ssgs *selectSQLGeneratorSuite) TestToSelectSQL_withFor() { selectTestCase{clause: scFnSl, sql: `SELECT * FROM "test"`}, selectTestCase{clause: scFnSl, sql: `SELECT * FROM "test"`, isPrepared: true}, + selectTestCase{clause: scFnSlOf, sql: `SELECT * FROM "test"`}, + selectTestCase{clause: scFnSlOf, sql: `SELECT * FROM "test"`, isPrepared: true, args: []interface{}{}}, selectTestCase{clause: scFsW, sql: `SELECT * FROM "test" for share `}, selectTestCase{clause: scFsW, sql: `SELECT * FROM "test" for share `, isPrepared: true}, @@ -549,6 +556,12 @@ func (ssgs *selectSQLGeneratorSuite) TestToSelectSQL_withFor() { selectTestCase{clause: scFsSl, sql: `SELECT * FROM "test" for share skip locked`}, selectTestCase{clause: scFsSl, sql: `SELECT * FROM "test" for share skip locked`, isPrepared: true}, + selectTestCase{clause: scFsSlOf, sql: `SELECT * FROM "test" for share of "my_table" skip locked`}, + selectTestCase{clause: scFsSlOf, sql: `SELECT * FROM "test" for share of "my_table" skip locked`, isPrepared: true}, + + selectTestCase{clause: scFsSlOfMulti, sql: `SELECT * FROM "test" for share of "my_table", "table2" skip locked`}, + selectTestCase{clause: scFsSlOfMulti, sql: `SELECT * FROM "test" for share of "my_table", "table2" skip locked`, isPrepared: true}, + selectTestCase{clause: scFksW, sql: `SELECT * FROM "test" for key share `}, selectTestCase{clause: scFksW, sql: `SELECT * FROM "test" for key share `, isPrepared: true}, diff --git a/sqlgen/sql_dialect_options.go b/sqlgen/sql_dialect_options.go index 3e6947d1..a0df394e 100644 --- a/sqlgen/sql_dialect_options.go +++ b/sqlgen/sql_dialect_options.go @@ -119,6 +119,8 @@ type ( ForNoKeyUpdateFragment []byte // The SQL FOR SHARE fragment(DEFAULT=[]byte(" FOR SHARE ")) ForShareFragment []byte + // The SQL OF fragment(DEFAULT=[]byte("OF ")) + OfFragment []byte // The SQL FOR KEY SHARE fragment(DEFAULT=[]byte(" FOR KEY SHARE ")) ForKeyShareFragment []byte // The SQL NOWAIT fragment(DEFAULT=[]byte("NOWAIT")) @@ -460,6 +462,7 @@ func DefaultDialectOptions() *SQLDialectOptions { ForNoKeyUpdateFragment: []byte(" FOR NO KEY UPDATE "), ForShareFragment: []byte(" FOR SHARE "), ForKeyShareFragment: []byte(" FOR KEY SHARE "), + OfFragment: []byte("OF "), NowaitFragment: []byte("NOWAIT"), SkipLockedFragment: []byte("SKIP LOCKED"), LateralFragment: []byte("LATERAL "), From a7630f7155a53a140693e25bb1d73868ee92c54f Mon Sep 17 00:00:00 2001 From: Doug Martin Date: Wed, 6 Oct 2021 15:25:05 -0600 Subject: [PATCH 5/5] chore: update release notes --- HISTORY.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index a92b020d..080379bb 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,6 @@ # v9.17.0 -* [Feature] Add support bitwise operations [#303](https://github.com/doug-martin/goqu/pull/303) - [@XIELongDragon](https://github.com/XIELongDragon) +* [FEATURE] Add support bitwise operations [#303](https://github.com/doug-martin/goqu/pull/303) - [@XIELongDragon](https://github.com/XIELongDragon) +* [FEATURE] Add support for specifying tables to be locked in ForUpdate, ForNoKeyUpdate, ForKeyShare, ForShare [#299](https://github.com/doug-martin/goqu/pull/299) - [@jbub](https://github.com/jbub) # v9.16.0 * [FEATURE] Allow ordering by case expression [#282](https://github.com/doug-martin/goqu/issues/282), [#292](https://github.com/doug-martin/goqu/pull/292)