diff --git a/pkg/sql/logictest/testdata/logic_test/select_for_update b/pkg/sql/logictest/testdata/logic_test/select_for_update index 27c22fc2793a..73db3e4f31fc 100644 --- a/pkg/sql/logictest/testdata/logic_test/select_for_update +++ b/pkg/sql/logictest/testdata/logic_test/select_for_update @@ -44,6 +44,15 @@ SELECT 1 FOR UPDATE OF a FOR SHARE OF b, c FOR NO KEY UPDATE OF d FOR KEY SHARE ---- 1 +# However, we do mirror Postgres in that we require FOR UPDATE targets to be +# unqualified names and reject anything else. + +query error pgcode 42601 FOR UPDATE must specify unqualified relation names +SELECT 1 FOR UPDATE OF public.a + +query error pgcode 42601 FOR UPDATE must specify unqualified relation names +SELECT 1 FOR UPDATE OF db.public.a + # We can't support SKIP LOCKED or NOWAIT, since they would actually behave # differently - NOWAIT returns an error to the client instead of blocking, # and SKIP LOCKED returns an inconsistent view. @@ -110,30 +119,99 @@ SELECT 1 FOR READ ONLY statement error pgcode 0A000 FOR UPDATE is not allowed with UNION/INTERSECT/EXCEPT SELECT 1 UNION SELECT 1 FOR UPDATE +statement error pgcode 0A000 FOR UPDATE is not allowed with UNION/INTERSECT/EXCEPT +SELECT * FROM (SELECT 1 UNION SELECT 1) a FOR UPDATE + statement error pgcode 0A000 FOR UPDATE is not allowed with VALUES VALUES (1) FOR UPDATE +statement error pgcode 0A000 FOR UPDATE is not allowed with VALUES +SELECT * FROM (VALUES (1)) a FOR UPDATE + statement error pgcode 0A000 FOR UPDATE is not allowed with DISTINCT clause SELECT DISTINCT 1 FOR UPDATE +statement error pgcode 0A000 FOR UPDATE is not allowed with DISTINCT clause +SELECT * FROM (SELECT DISTINCT 1) a FOR UPDATE + statement error pgcode 0A000 FOR UPDATE is not allowed with GROUP BY clause SELECT 1 GROUP BY 1 FOR UPDATE +statement error pgcode 0A000 FOR UPDATE is not allowed with GROUP BY clause +SELECT * FROM (SELECT 1 GROUP BY 1) a FOR UPDATE + statement error pgcode 0A000 FOR UPDATE is not allowed with HAVING clause SELECT 1 HAVING TRUE FOR UPDATE +statement error pgcode 0A000 FOR UPDATE is not allowed with HAVING clause +SELECT * FROM (SELECT 1 HAVING TRUE) a FOR UPDATE + statement error pgcode 0A000 FOR UPDATE is not allowed with aggregate functions SELECT count(1) FOR UPDATE +statement error pgcode 0A000 FOR UPDATE is not allowed with aggregate functions +SELECT * FROM (SELECT count(1)) a FOR UPDATE + statement error pgcode 0A000 FOR UPDATE is not allowed with window functions SELECT count(1) OVER () FOR UPDATE +statement error pgcode 0A000 FOR UPDATE is not allowed with window functions +SELECT * FROM (SELECT count(1) OVER ()) a FOR UPDATE + statement error pgcode 0A000 FOR UPDATE is not allowed with set-returning functions in the target list SELECT generate_series(1, 2) FOR UPDATE +statement error pgcode 0A000 FOR UPDATE is not allowed with set-returning functions in the target list +SELECT * FROM (SELECT generate_series(1, 2)) a FOR UPDATE + # Set-returning functions in the from list are allowed. query I SELECT * FROM generate_series(1, 2) FOR UPDATE ---- 1 2 + +query I +SELECT * FROM (SELECT * FROM generate_series(1, 2)) a FOR UPDATE +---- +1 +2 + +# Use of SELECT FOR UPDATE/SHARE requires UPDATE privileges, not just SELECT privileges. + +statement ok +CREATE TABLE t (k INT PRIMARY KEY, v int) + +statement ok +GRANT GRANT on t to testuser + +user testuser + +statement error pgcode 42501 user testuser does not have SELECT privilege on relation t +SELECT * FROM t + +statement ok +GRANT SELECT ON t TO testuser + +statement ok +SELECT * FROM t + +statement error pgcode 42501 user testuser does not have UPDATE privilege on relation t +SELECT * FROM t FOR UPDATE + +statement error pgcode 42501 user testuser does not have UPDATE privilege on relation t +SELECT * FROM t FOR SHARE + +statement ok +GRANT UPDATE ON t TO testuser + +statement ok +SELECT * FROM t FOR UPDATE + +statement ok +SELECT * FROM t FOR SHARE + +user root + +statement ok +DROP TABLE t diff --git a/pkg/sql/opt/memo/expr_format.go b/pkg/sql/opt/memo/expr_format.go index 67db3fbe043e..63ed893e51df 100644 --- a/pkg/sql/opt/memo/expr_format.go +++ b/pkg/sql/opt/memo/expr_format.go @@ -360,6 +360,29 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) { tp.Childf("flags: force-index=%s%s", idx.Name(), dir) } } + if t.Locking != nil { + strength := "" + switch t.Locking.Strength { + case tree.ForNone: + case tree.ForKeyShare: + strength = "for-key-share" + case tree.ForShare: + strength = "for-share" + case tree.ForNoKeyUpdate: + strength = "for-no-key-update" + case tree.ForUpdate: + strength = "for-update" + } + wait := "" + switch t.Locking.WaitPolicy { + case tree.LockWaitBlock: + case tree.LockWaitSkip: + wait = ",skip-locked" + case tree.LockWaitError: + wait = ",nowait" + } + tp.Childf("locking: %s%s", strength, wait) + } case *LookupJoinExpr: if !t.Flags.Empty() { diff --git a/pkg/sql/opt/memo/interner.go b/pkg/sql/opt/memo/interner.go index 80fc5abc5294..771b57decd92 100644 --- a/pkg/sql/opt/memo/interner.go +++ b/pkg/sql/opt/memo/interner.go @@ -311,17 +311,24 @@ func (h *hasher) HashFloat64(val float64) { h.hash *= prime64 } +func (h *hasher) HashRune(val rune) { + h.hash ^= internHash(val) + h.hash *= prime64 +} + func (h *hasher) HashString(val string) { for _, c := range val { - h.hash ^= internHash(c) - h.hash *= prime64 + h.HashRune(c) } } +func (h *hasher) HashByte(val byte) { + h.HashRune(rune(val)) +} + func (h *hasher) HashBytes(val []byte) { for _, c := range val { - h.hash ^= internHash(c) - h.hash *= prime64 + h.HashByte(c) } } @@ -540,6 +547,13 @@ func (h *hasher) HashPhysProps(val *physical.Required) { h.HashFloat64(val.LimitHint) } +func (h *hasher) HashLockingItem(val *tree.LockingItem) { + if val != nil { + h.HashByte(byte(val.Strength)) + h.HashByte(byte(val.WaitPolicy)) + } +} + func (h *hasher) HashRelExpr(val RelExpr) { h.HashUint64(uint64(reflect.ValueOf(val).Pointer())) } @@ -646,10 +660,18 @@ func (h *hasher) IsFloat64Equal(l, r float64) bool { return math.Float64bits(l) == math.Float64bits(r) } +func (h *hasher) IsRuneEqual(l, r rune) bool { + return l == r +} + func (h *hasher) IsStringEqual(l, r string) bool { return l == r } +func (h *hasher) IsByteEqual(l, r byte) bool { + return l == r +} + func (h *hasher) IsBytesEqual(l, r []byte) bool { return bytes.Equal(l, r) } @@ -854,6 +876,13 @@ func (h *hasher) IsPhysPropsEqual(l, r *physical.Required) bool { return l.Equals(r) } +func (h *hasher) IsLockingItemEqual(l, r *tree.LockingItem) bool { + if l == nil || r == nil { + return l == r + } + return l.Strength == r.Strength && l.WaitPolicy == r.WaitPolicy +} + func (h *hasher) IsPointerEqual(l, r unsafe.Pointer) bool { return l == r } diff --git a/pkg/sql/opt/memo/interner_test.go b/pkg/sql/opt/memo/interner_test.go index 5a00e30d1428..abc8370e012a 100644 --- a/pkg/sql/opt/memo/interner_test.go +++ b/pkg/sql/opt/memo/interner_test.go @@ -177,6 +177,13 @@ func TestInterner(t *testing.T) { {val1: float64(0), val2: math.Copysign(0, -1), equal: false}, }}, + {hashFn: in.hasher.HashRune, eqFn: in.hasher.IsRuneEqual, variations: []testVariation{ + {val1: rune(0), val2: rune(0), equal: true}, + {val1: rune('a'), val2: rune('b'), equal: false}, + {val1: rune('a'), val2: rune('A'), equal: false}, + {val1: rune('🐛'), val2: rune('🐛'), equal: true}, + }}, + {hashFn: in.hasher.HashString, eqFn: in.hasher.IsStringEqual, variations: []testVariation{ {val1: "", val2: "", equal: true}, {val1: "abc", val2: "abcd", equal: false}, @@ -184,6 +191,13 @@ func TestInterner(t *testing.T) { {val1: "the quick brown fox", val2: "the quick brown fox", equal: true}, }}, + {hashFn: in.hasher.HashByte, eqFn: in.hasher.IsByteEqual, variations: []testVariation{ + {val1: byte(0), val2: byte(0), equal: true}, + {val1: byte('a'), val2: byte('b'), equal: false}, + {val1: byte('a'), val2: byte('A'), equal: false}, + {val1: byte('z'), val2: byte('z'), equal: true}, + }}, + {hashFn: in.hasher.HashBytes, eqFn: in.hasher.IsBytesEqual, variations: []testVariation{ {val1: []byte{}, val2: []byte{}, equal: true}, {val1: []byte{}, val2: []byte{0}, equal: false}, @@ -412,6 +426,30 @@ func TestInterner(t *testing.T) { // PhysProps hash/isEqual methods are tested in TestInternerPhysProps. + {hashFn: in.hasher.HashLockingItem, eqFn: in.hasher.IsLockingItemEqual, variations: []testVariation{ + {val1: (*tree.LockingItem)(nil), val2: (*tree.LockingItem)(nil), equal: true}, + { + val1: (*tree.LockingItem)(nil), + val2: &tree.LockingItem{Strength: tree.ForUpdate}, + equal: false, + }, + { + val1: &tree.LockingItem{Strength: tree.ForShare}, + val2: &tree.LockingItem{Strength: tree.ForUpdate}, + equal: false, + }, + { + val1: &tree.LockingItem{WaitPolicy: tree.LockWaitSkip}, + val2: &tree.LockingItem{WaitPolicy: tree.LockWaitError}, + equal: false, + }, + { + val1: &tree.LockingItem{Strength: tree.ForUpdate, WaitPolicy: tree.LockWaitError}, + val2: &tree.LockingItem{Strength: tree.ForUpdate, WaitPolicy: tree.LockWaitError}, + equal: true, + }, + }}, + {hashFn: in.hasher.HashRelExpr, eqFn: in.hasher.IsRelExprEqual, variations: []testVariation{ {val1: (*ScanExpr)(nil), val2: (*ScanExpr)(nil), equal: true}, {val1: scanNode, val2: scanNode, equal: true}, diff --git a/pkg/sql/opt/ops/relational.opt b/pkg/sql/opt/ops/relational.opt index 50f43e5272d6..38aea6e10363 100644 --- a/pkg/sql/opt/ops/relational.opt +++ b/pkg/sql/opt/ops/relational.opt @@ -65,6 +65,14 @@ define ScanPrivate { # Flags modify how the table is scanned, such as which index is used to scan. Flags ScanFlags + # Locking represents the row-level locking mode of the Scan. Most scans + # leave this unset (Strength = ForNone), which indicates that no row-level + # locking will be performed while scanning the table. Stronger locking modes + # are used by SELECT .. FOR [KEY] UPDATE/SHARE statements and by the initial + # row retrieval of DELETE and UPDATE statements. The locking item's Targets + # list will always be empty when part of a ScanPrivate. + Locking LockingItem + # PartitionConstrainedScan records whether or not we were able to use partitions # to constrain the lookup spans further. This flag is used to record telemetry # about how often this optimization is getting applied. diff --git a/pkg/sql/opt/optbuilder/builder.go b/pkg/sql/opt/optbuilder/builder.go index 23ad7414f8f7..34034c7ceef3 100644 --- a/pkg/sql/opt/optbuilder/builder.go +++ b/pkg/sql/opt/optbuilder/builder.go @@ -250,10 +250,10 @@ func (b *Builder) buildStmt( switch stmt := stmt.(type) { case *tree.Select: - return b.buildSelect(stmt, desiredTypes, inScope) + return b.buildSelect(stmt, noRowLocking, desiredTypes, inScope) case *tree.ParenSelect: - return b.buildSelect(stmt.Select, desiredTypes, inScope) + return b.buildSelect(stmt.Select, noRowLocking, desiredTypes, inScope) case *tree.Delete: return b.processWiths(stmt.With, inScope, func(inScope *scope) *scope { diff --git a/pkg/sql/opt/optbuilder/insert.go b/pkg/sql/opt/optbuilder/insert.go index 48fe4e52fc03..c2ff3a02afb3 100644 --- a/pkg/sql/opt/optbuilder/insert.go +++ b/pkg/sql/opt/optbuilder/insert.go @@ -664,6 +664,7 @@ func (mb *mutationBuilder) buildInputForDoNothing(inScope *scope, onConflict *tr mb.b.addTable(mb.tab, &mb.alias), nil, /* ordinals */ nil, /* indexFlags */ + noRowLocking, excludeMutations, inScope, ) @@ -746,6 +747,7 @@ func (mb *mutationBuilder) buildInputForUpsert( mb.b.addTable(mb.tab, &mb.alias), nil, /* ordinals */ nil, /* indexFlags */ + noRowLocking, includeMutations, inScope, ) diff --git a/pkg/sql/opt/optbuilder/join.go b/pkg/sql/opt/optbuilder/join.go index 8c824b062fee..956607c35598 100644 --- a/pkg/sql/opt/optbuilder/join.go +++ b/pkg/sql/opt/optbuilder/join.go @@ -29,8 +29,10 @@ import ( // // See Builder.buildStmt for a description of the remaining input and // return values. -func (b *Builder) buildJoin(join *tree.JoinTableExpr, inScope *scope) (outScope *scope) { - leftScope := b.buildDataSource(join.Left, nil /* indexFlags */, inScope) +func (b *Builder) buildJoin( + join *tree.JoinTableExpr, locking lockingSpec, inScope *scope, +) (outScope *scope) { + leftScope := b.buildDataSource(join.Left, nil /* indexFlags */, locking, inScope) isLateral := false inScopeRight := inScope @@ -43,7 +45,7 @@ func (b *Builder) buildJoin(join *tree.JoinTableExpr, inScope *scope) (outScope inScopeRight = leftScope } - rightScope := b.buildDataSource(join.Right, nil /* indexFlags */, inScopeRight) + rightScope := b.buildDataSource(join.Right, nil /* indexFlags */, locking, inScopeRight) // Check that the same table name is not used on both sides. b.validateJoinTableNames(leftScope, rightScope) diff --git a/pkg/sql/opt/optbuilder/locking.go b/pkg/sql/opt/optbuilder/locking.go new file mode 100644 index 000000000000..1e82067bfb94 --- /dev/null +++ b/pkg/sql/opt/optbuilder/locking.go @@ -0,0 +1,167 @@ +// Copyright 2020 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package optbuilder + +import "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + +// lockingSpec maintains a collection of FOR [KEY] UPDATE/SHARE items that apply +// to a given scope. Locking clauses can be applied to the lockingSpec as they +// come into scope in the AST. The lockingSpec can then be consolidated down to +// a single row-level locking specification for different tables to determine +// how scans over those tables should perform row-level locking, if at all. +// +// A SELECT statement may contain zero, one, or more than one row-level locking +// clause. Each of these clauses consist of two different properties. +// +// The first property is locking strength (see tree.LockingStrength). Locking +// strength represents the degree of protection that a row-level lock provides. +// The stronger the lock, the more protection it provides for the lock holder +// but the more restrictive it is to concurrent transactions attempting to +// access the same row. In order from weakest to strongest, the lock strength +// variants are: +// +// FOR KEY SHARE +// FOR SHARE +// FOR NO KEY UPDATE +// FOR UPDATE +// +// The second property is the locking wait policy (see tree.LockingWaitPolicy). +// A locking wait policy represents the policy a table scan uses to interact +// with row-level locks held by other transactions. Unlike locking strength, +// locking wait policy is optional to specify in a locking clause. If not +// specified, the policy defaults to blocking and waiting for locks to become +// available. The non-standard policies instruct scans to take other approaches +// to handling locks held by other transactions. These non-standard policies +// are: +// +// SKIP LOCKED +// NOWAIT +// +// In addition to these two properties, locking clauses can contain an optional +// list of target relations. When provided, the locking clause applies only to +// those relations in the target list. When not provided, the locking clause +// applies to all relations in the current scope. +// +// Put together, a complex locking spec might look like: +// +// SELECT ... FROM ... FOR SHARE NOWAIT FOR UPDATE OF t1, t2 +// +// which would be represented as: +// +// [ {ForShare, LockWaitError, []}, {ForUpdate, LockWaitBlock, [t1, t2]} ] +// +type lockingSpec []*tree.LockingItem + +// noRowLocking indicates that no row-level locking has been specified. +var noRowLocking lockingSpec + +// isSet returns whether the spec contains any row-level locking modes. +func (lm lockingSpec) isSet() bool { + return len(lm) != 0 +} + +// get returns the first row-level locking mode in the spec. If the spec was the +// outcome of filter operation, this will be the only locking mode in the spec. +func (lm lockingSpec) get() *tree.LockingItem { + if lm.isSet() { + return lm[0] + } + return nil +} + +// apply merges the locking clause into the current locking spec. The effect of +// applying new locking clauses to an existing spec is always to strengthen the +// locking approaches it represents, either through increasing locking strength +// or using more aggressive wait policies. +func (lm *lockingSpec) apply(locking tree.LockingClause) { + // TODO(nvanbenschoten): If we wanted to eagerly prune superfluous locking + // items so that they don't need to get merged away in each call to filter, + // this would be the place to do it. We don't expect to see multiple FOR + // UPDATE clauses very often, so it's probably not worth it. + if len(*lm) == 0 { + // NB: avoid allocation. + *lm = lockingSpec(locking) + return + } + *lm = append(*lm, locking...) +} + +// filter returns the desired row-level locking mode for the specifies table as +// a new consolidated lockingSpec. If no matching locking mode is found then the +// resulting spec will remain un-set. If a matching locking mode for the table +// is found then the resulting spec will contain exclusively that locking mode +// and will no longer be restricted to specific target relations. +func (lm lockingSpec) filter(alias tree.Name) lockingSpec { + var ret lockingSpec + var copied bool + updateRet := func(li *tree.LockingItem, len1 []*tree.LockingItem) { + if ret == nil && len(li.Targets) == 0 { + // Fast-path. We don't want the resulting spec to include targets, + // so we only allow this if the item we want to copy has none. + ret = len1 + return + } + if !copied { + retCpy := make(lockingSpec, 1) + retCpy[0] = new(tree.LockingItem) + if len(ret) == 1 { + *retCpy[0] = *ret[0] + } + ret = retCpy + copied = true + } + // From https://www.postgresql.org/docs/12/sql-select.html#SQL-FOR-UPDATE-SHARE + // > If the same table is mentioned (or implicitly affected) by more + // > than one locking clause, then it is processed as if it was only + // > specified by the strongest one. + ret[0].Strength = ret[0].Strength.Max(li.Strength) + // > Similarly, a table is processed as NOWAIT if that is specified in + // > any of the clauses affecting it. Otherwise, it is processed as SKIP + // > LOCKED if that is specified in any of the clauses affecting it. + ret[0].WaitPolicy = ret[0].WaitPolicy.Max(li.WaitPolicy) + } + + for i, li := range lm { + len1 := lm[i : i+1 : i+1] + if len(li.Targets) == 0 { + // If no targets are specified, the clause affects all tables. + updateRet(li, len1) + } else { + // If targets are specified, the clause affects only those tables. + for _, target := range li.Targets { + if target.TableName == alias { + updateRet(li, len1) + break + } + } + } + } + return ret +} + +// withoutTargets returns a new lockingSpec with all locking clauses that apply +// only to a subset of tables removed. +func (lm lockingSpec) withoutTargets() lockingSpec { + return lm.filter("") +} + +// ignoreLockingForCTE is a placeholder for the following comment: +// +// We intentionally do not propate any row-level locking information from the +// current scope to the CTE. This mirrors Postgres' behavior. It also avoids a +// number of awkward questions like how row-level locking would interact with +// mutating commong table expressions. +// +// From https://www.postgresql.org/docs/12/sql-select.html#SQL-FOR-UPDATE-SHARE +// > these clauses do not apply to WITH queries referenced by the primary query. +// > If you want row locking to occur within a WITH query, specify a locking +// > clause within the WITH query. +func (lm lockingSpec) ignoreLockingForCTE() {} diff --git a/pkg/sql/opt/optbuilder/mutation_builder.go b/pkg/sql/opt/optbuilder/mutation_builder.go index db279775fead..86f5cd86fed0 100644 --- a/pkg/sql/opt/optbuilder/mutation_builder.go +++ b/pkg/sql/opt/optbuilder/mutation_builder.go @@ -234,6 +234,7 @@ func (mb *mutationBuilder) buildInputForUpdate( mb.b.addTable(mb.tab, &mb.alias), nil, /* ordinals */ indexFlags, + noRowLocking, includeMutations, inScope, ) @@ -244,7 +245,7 @@ func (mb *mutationBuilder) buildInputForUpdate( // If there is a FROM clause present, we must join all the tables // together with the table being updated. if fromClausePresent { - fromScope := mb.b.buildFromTables(from, inScope) + fromScope := mb.b.buildFromTables(from, noRowLocking, inScope) // Check that the same table name is not used multiple times. mb.b.validateJoinTableNames(mb.outScope, fromScope) @@ -335,6 +336,7 @@ func (mb *mutationBuilder) buildInputForDelete( mb.b.addTable(mb.tab, &mb.alias), nil, /* ordinals */ indexFlags, + noRowLocking, includeMutations, inScope, ) @@ -1160,6 +1162,7 @@ func (mb *mutationBuilder) addInsertionCheck(fkOrdinal int, insertCols opt.ColLi refTabMeta, refOrdinals, &tree.IndexFlags{IgnoreForeignKeys: true}, + noRowLocking, includeMutations, mb.b.allocScope(), ) @@ -1303,6 +1306,7 @@ func (mb *mutationBuilder) addDeletionCheck( origTabMeta, origOrdinals, &tree.IndexFlags{IgnoreForeignKeys: true}, + noRowLocking, includeMutations, mb.b.allocScope(), ) diff --git a/pkg/sql/opt/optbuilder/select.go b/pkg/sql/opt/optbuilder/select.go index 2f53de77ca79..986213972653 100644 --- a/pkg/sql/opt/optbuilder/select.go +++ b/pkg/sql/opt/optbuilder/select.go @@ -42,7 +42,7 @@ const ( // See Builder.buildStmt for a description of the remaining input and // return values. func (b *Builder) buildDataSource( - texpr tree.TableExpr, indexFlags *tree.IndexFlags, inScope *scope, + texpr tree.TableExpr, indexFlags *tree.IndexFlags, locking lockingSpec, inScope *scope, ) (outScope *scope) { defer func(prevAtRoot bool) { inScope.atRoot = prevAtRoot @@ -55,8 +55,11 @@ func (b *Builder) buildDataSource( telemetry.Inc(sqltelemetry.IndexHintUseCounter) indexFlags = source.IndexFlags } + if source.As.Alias != "" { + locking = locking.filter(source.As.Alias) + } - outScope = b.buildDataSource(source.Expr, indexFlags, inScope) + outScope = b.buildDataSource(source.Expr, indexFlags, locking, inScope) if source.Ordinality { outScope = b.buildWithOrdinality("ordinality", outScope) @@ -68,13 +71,14 @@ func (b *Builder) buildDataSource( return outScope case *tree.JoinTableExpr: - return b.buildJoin(source, inScope) + return b.buildJoin(source, locking, inScope) case *tree.TableName: tn := source // CTEs take precedence over other data sources. if cte := inScope.resolveCTE(tn); cte != nil { + locking.ignoreLockingForCTE() outScope = inScope.push() inCols := make(opt.ColList, len(cte.cols)) outCols := make(opt.ColList, len(cte.cols)) @@ -100,29 +104,43 @@ func (b *Builder) buildDataSource( return outScope } - ds, resName := b.resolveDataSource(tn, privilege.SELECT) + priv := privilege.SELECT + locking = locking.filter(tn.TableName) + if locking.isSet() { + // SELECT ... FOR [KEY] UPDATE/SHARE requires UPDATE privileges. + priv = privilege.UPDATE + } + + ds, resName := b.resolveDataSource(tn, priv) switch t := ds.(type) { case cat.Table: tabMeta := b.addTable(t, &resName) - return b.buildScan(tabMeta, nil /* ordinals */, indexFlags, excludeMutations, inScope) + return b.buildScan(tabMeta, nil /* ordinals */, indexFlags, locking, excludeMutations, inScope) case cat.Sequence: return b.buildSequenceSelect(t, &resName, inScope) case cat.View: - return b.buildView(t, &resName, inScope) + return b.buildView(t, &resName, locking, inScope) + default: panic(errors.AssertionFailedf("unknown DataSource type %T", ds)) } case *tree.ParenTableExpr: - return b.buildDataSource(source.Expr, indexFlags, inScope) + return b.buildDataSource(source.Expr, indexFlags, locking, inScope) case *tree.RowsFromExpr: return b.buildZip(source.Items, inScope) case *tree.Subquery: - outScope = b.buildSelectStmt(source.Select, nil /* desiredTypes */, inScope) + // Remove any target relations from the current scope's locking spec, as + // those only apply to relations in this statements. Interestingly, this + // would not be necessary if we required all subqueries to have aliases + // like Postgres does. + locking = locking.withoutTargets() + + outScope = b.buildSelectStmt(source.Select, locking, nil /* desiredTypes */, inScope) // Treat the subquery result as an anonymous data source (i.e. column names // are not qualified). Remove hidden columns, as they are not accessible @@ -163,6 +181,7 @@ func (b *Builder) buildDataSource( outCols[i] = b.factory.Metadata().AddColumn(col.Alias, c.Type) } + locking.ignoreLockingForCTE() outScope = inScope.push() // Similar to appendColumnsFromScope, but with re-numbering the column IDs. for i, col := range innerScope.cols { @@ -183,10 +202,17 @@ func (b *Builder) buildDataSource( return outScope case *tree.TableRef: - ds := b.resolveDataSourceRef(source, privilege.SELECT) + priv := privilege.SELECT + locking = locking.filter(source.As.Alias) + if locking.isSet() { + // SELECT ... FOR [KEY] UPDATE/SHARE requires UPDATE privileges. + priv = privilege.UPDATE + } + + ds := b.resolveDataSourceRef(source, priv) switch t := ds.(type) { case cat.Table: - outScope = b.buildScanFromTableRef(t, source, indexFlags, inScope) + outScope = b.buildScanFromTableRef(t, source, indexFlags, locking, inScope) case cat.View: if source.Columns != nil { panic(pgerror.Newf(pgcode.FeatureNotSupported, @@ -194,7 +220,7 @@ func (b *Builder) buildDataSource( } tn := tree.MakeUnqualifiedTableName(t.Name()) - outScope = b.buildView(t, &tn, inScope) + outScope = b.buildView(t, &tn, locking, inScope) case cat.Sequence: tn := tree.MakeUnqualifiedTableName(t.Name()) // Any explicitly listed columns are ignored. @@ -212,7 +238,7 @@ func (b *Builder) buildDataSource( // buildView parses the view query text and builds it as a Select expression. func (b *Builder) buildView( - view cat.View, viewName *tree.TableName, inScope *scope, + view cat.View, viewName *tree.TableName, locking lockingSpec, inScope *scope, ) (outScope *scope) { // Cache the AST so that multiple references won't need to reparse. if b.views == nil { @@ -257,7 +283,7 @@ func (b *Builder) buildView( defer func() { b.trackViewDeps = true }() } - outScope = b.buildSelect(sel, nil /* desiredTypes */, &scope{builder: b}) + outScope = b.buildSelect(sel, locking, nil /* desiredTypes */, &scope{builder: b}) // Update data source name to be the name of the view. And if view columns // are specified, then update names of output columns. @@ -342,7 +368,11 @@ func (b *Builder) renameSource(as tree.AliasClause, scope *scope) { // Note, the query SELECT * FROM [53() as t] is unsupported. Column lists must // be non-empty func (b *Builder) buildScanFromTableRef( - tab cat.Table, ref *tree.TableRef, indexFlags *tree.IndexFlags, inScope *scope, + tab cat.Table, + ref *tree.TableRef, + indexFlags *tree.IndexFlags, + locking lockingSpec, + inScope *scope, ) (outScope *scope) { var ordinals []int if ref.Columns != nil { @@ -358,7 +388,7 @@ func (b *Builder) buildScanFromTableRef( tn := tree.MakeUnqualifiedTableName(tab.Name()) tabMeta := b.addTable(tab, &tn) - return b.buildScan(tabMeta, ordinals, indexFlags, excludeMutations, inScope) + return b.buildScan(tabMeta, ordinals, indexFlags, locking, excludeMutations, inScope) } // addTable adds a table to the metadata and returns the TableMeta. The table @@ -383,6 +413,7 @@ func (b *Builder) buildScan( tabMeta *opt.TableMeta, ordinals []int, indexFlags *tree.IndexFlags, + locking lockingSpec, scanMutationCols bool, inScope *scope, ) (outScope *scope) { @@ -436,13 +467,16 @@ func (b *Builder) buildScan( panic(pgerror.Newf(pgcode.Syntax, "index flags not allowed with virtual tables")) } + if locking.isSet() { + panic(pgerror.Newf(pgcode.Syntax, + "%s not allowed with virtual tables", locking.get().Strength)) + } private := memo.VirtualScanPrivate{Table: tabID, Cols: tabColIDs} outScope.expr = b.factory.ConstructVirtualScan(&private) // Virtual tables should not be collected as view dependencies. } else { private := memo.ScanPrivate{Table: tabID, Cols: tabColIDs} - if indexFlags != nil { private.Flags.NoIndexJoin = indexFlags.NoIndexJoin if indexFlags.Index != "" || indexFlags.IndexID != 0 { @@ -468,6 +502,9 @@ func (b *Builder) buildScan( private.Flags.Direction = indexFlags.Direction } } + if locking.isSet() { + private.Locking = locking.get() + } outScope.expr = b.factory.ConstructScan(&private) b.addCheckConstraintsForTable(tabMeta) @@ -712,15 +749,15 @@ func (b *Builder) flushCTEs(expr memo.RelExpr) memo.RelExpr { // See Builder.buildStmt for a description of the remaining input and // return values. func (b *Builder) buildSelectStmt( - stmt tree.SelectStatement, desiredTypes []*types.T, inScope *scope, + stmt tree.SelectStatement, locking lockingSpec, desiredTypes []*types.T, inScope *scope, ) (outScope *scope) { // NB: The case statements are sorted lexicographically. switch stmt := stmt.(type) { case *tree.ParenSelect: - return b.buildSelect(stmt.Select, desiredTypes, inScope) + return b.buildSelect(stmt.Select, locking, desiredTypes, inScope) case *tree.SelectClause: - return b.buildSelectClause(stmt, nil /* orderBy */, nil /* locking */, desiredTypes, inScope) + return b.buildSelectClause(stmt, nil /* orderBy */, locking, desiredTypes, inScope) case *tree.UnionClause: return b.buildUnionClause(stmt, desiredTypes, inScope) @@ -739,13 +776,13 @@ func (b *Builder) buildSelectStmt( // See Builder.buildStmt for a description of the remaining input and // return values. func (b *Builder) buildSelect( - stmt *tree.Select, desiredTypes []*types.T, inScope *scope, + stmt *tree.Select, locking lockingSpec, desiredTypes []*types.T, inScope *scope, ) (outScope *scope) { wrapped := stmt.Select with := stmt.With orderBy := stmt.OrderBy limit := stmt.Limit - locking := stmt.Locking + locking.apply(stmt.Locking) for s, ok := wrapped.(*tree.ParenSelect); ok; s, ok = wrapped.(*tree.ParenSelect) { stmt = s.Select @@ -778,7 +815,7 @@ func (b *Builder) buildSelect( limit = stmt.Limit } if stmt.Locking != nil { - locking = append(locking, stmt.Locking...) + locking.apply(stmt.Locking) } } @@ -799,7 +836,7 @@ func (b *Builder) buildSelectStmtWithoutParens( wrapped tree.SelectStatement, orderBy tree.OrderBy, limit *tree.Limit, - locking []*tree.LockingItem, + locking lockingSpec, desiredTypes []*types.T, inScope *scope, ) (outScope *scope) { @@ -858,11 +895,11 @@ func (b *Builder) buildSelectStmtWithoutParens( func (b *Builder) buildSelectClause( sel *tree.SelectClause, orderBy tree.OrderBy, - locking []*tree.LockingItem, + locking lockingSpec, desiredTypes []*types.T, inScope *scope, ) (outScope *scope) { - fromScope := b.buildFrom(sel.From, inScope) + fromScope := b.buildFrom(sel.From, locking, inScope) b.processWindowDefs(sel, fromScope) b.buildWhere(sel.Where, fromScope) @@ -905,6 +942,7 @@ func (b *Builder) buildSelectClause( } b.buildWindow(outScope, fromScope) + b.validateLockingInFrom(sel, locking, fromScope) // Construct the projection. b.constructProjectForScope(outScope, projectionsScope) @@ -917,8 +955,6 @@ func (b *Builder) buildSelectClause( outScope = b.buildDistinctOn(projectionsScope.distinctOnCols, outScope) } } - - b.validateLockingForSelectClause(sel, locking, fromScope) return outScope } @@ -926,7 +962,7 @@ func (b *Builder) buildSelectClause( // // See Builder.buildStmt for a description of the remaining input and return // values. -func (b *Builder) buildFrom(from tree.From, inScope *scope) (outScope *scope) { +func (b *Builder) buildFrom(from tree.From, locking lockingSpec, inScope *scope) (outScope *scope) { // The root AS OF clause is recognized and handled by the executor. The only // thing that must be done at this point is to ensure that if any timestamps // are specified, the root SELECT was an AS OF SYSTEM TIME and that the time @@ -936,7 +972,7 @@ func (b *Builder) buildFrom(from tree.From, inScope *scope) (outScope *scope) { } if len(from.Tables) > 0 { - outScope = b.buildFromTables(from.Tables, inScope) + outScope = b.buildFromTables(from.Tables, locking, inScope) } else { outScope = inScope.push() outScope.expr = b.factory.ConstructValues(memo.ScalarListWithEmptyTuple, &memo.ValuesPrivate{ @@ -992,16 +1028,18 @@ func (b *Builder) buildWhere(where *tree.Where, inScope *scope) { // // See Builder.buildStmt for a description of the remaining input and // return values. -func (b *Builder) buildFromTables(tables tree.TableExprs, inScope *scope) (outScope *scope) { +func (b *Builder) buildFromTables( + tables tree.TableExprs, locking lockingSpec, inScope *scope, +) (outScope *scope) { // If there are any lateral data sources, we need to build the join tree // left-deep instead of right-deep. for i := range tables { if b.exprIsLateral(tables[i]) { telemetry.Inc(sqltelemetry.LateralJoinUseCounter) - return b.buildFromWithLateral(tables, inScope) + return b.buildFromWithLateral(tables, locking, inScope) } } - return b.buildFromTablesRightDeep(tables, inScope) + return b.buildFromTablesRightDeep(tables, locking, inScope) } // buildFromTablesRightDeep recursively builds a series of InnerJoin @@ -1022,16 +1060,16 @@ func (b *Builder) buildFromTables(tables tree.TableExprs, inScope *scope) (outSc // See Builder.buildStmt for a description of the remaining input and // return values. func (b *Builder) buildFromTablesRightDeep( - tables tree.TableExprs, inScope *scope, + tables tree.TableExprs, locking lockingSpec, inScope *scope, ) (outScope *scope) { - outScope = b.buildDataSource(tables[0], nil /* indexFlags */, inScope) + outScope = b.buildDataSource(tables[0], nil /* indexFlags */, locking, inScope) // Recursively build table join. tables = tables[1:] if len(tables) == 0 { return outScope } - tableScope := b.buildFromTablesRightDeep(tables, inScope) + tableScope := b.buildFromTablesRightDeep(tables, locking, inScope) // Check that the same table name is not used multiple times. b.validateJoinTableNames(outScope, tableScope) @@ -1070,8 +1108,10 @@ func (b *Builder) exprIsLateral(t tree.TableExpr) bool { // // buildFromTablesRightDeep: a JOIN (b JOIN c) // buildFromWithLateral: (a JOIN b) JOIN c -func (b *Builder) buildFromWithLateral(tables tree.TableExprs, inScope *scope) (outScope *scope) { - outScope = b.buildDataSource(tables[0], nil /* indexFlags */, inScope) +func (b *Builder) buildFromWithLateral( + tables tree.TableExprs, locking lockingSpec, inScope *scope, +) (outScope *scope) { + outScope = b.buildDataSource(tables[0], nil /* indexFlags */, locking, inScope) for i := 1; i < len(tables); i++ { scope := inScope // Lateral expressions need to be able to refer to the expressions that @@ -1079,7 +1119,7 @@ func (b *Builder) buildFromWithLateral(tables tree.TableExprs, inScope *scope) ( if b.exprIsLateral(tables[i]) { scope = outScope } - tableScope := b.buildDataSource(tables[i], nil /* indexFlags */, scope) + tableScope := b.buildDataSource(tables[i], nil /* indexFlags */, locking, scope) // Check that the same table name is not used multiple times. b.validateJoinTableNames(outScope, tableScope) @@ -1113,36 +1153,35 @@ func (b *Builder) validateAsOf(asOf tree.AsOfClause) { } } -// validateLockingForSelectClause checks for operations that are not supported -// with FOR [KEY] UPDATE/SHARE. If a locking clause was specified with the -// select and an incompatible operation is in use, a locking error is raised. -// The method also validates that only supported locking modes are used. -func (b *Builder) validateLockingForSelectClause( - sel *tree.SelectClause, locking []*tree.LockingItem, scope *scope, +// validateLockingInFrom checks for operations that are not supported with FOR +// [KEY] UPDATE/SHARE. If a locking clause was specified with the select and an +// incompatible operation is in use, a locking error is raised. +func (b *Builder) validateLockingInFrom( + sel *tree.SelectClause, locking lockingSpec, fromScope *scope, ) { - if len(locking) == 0 { + if !locking.isSet() { + // No FOR [KEY] UPDATE/SHARE locking modes in scope. return } - first := locking[0] switch { case sel.Distinct: - b.raiseLockingError(first, "DISTINCT clause") + b.raiseLockingContextError(locking, "DISTINCT clause") case sel.GroupBy != nil: - b.raiseLockingError(first, "GROUP BY clause") + b.raiseLockingContextError(locking, "GROUP BY clause") case sel.Having != nil: - b.raiseLockingError(first, "HAVING clause") + b.raiseLockingContextError(locking, "HAVING clause") - case scope.groupby != nil && scope.groupby.hasAggregates(): - b.raiseLockingError(first, "aggregate functions") + case fromScope.groupby != nil && fromScope.groupby.hasAggregates(): + b.raiseLockingContextError(locking, "aggregate functions") - case len(scope.windows) != 0: - b.raiseLockingError(first, "window functions") + case len(fromScope.windows) != 0: + b.raiseLockingContextError(locking, "window functions") - case len(scope.srfs) != 0: - b.raiseLockingError(first, "set-returning functions in the target list") + case len(fromScope.srfs) != 0: + b.raiseLockingContextError(locking, "set-returning functions in the target list") } for _, li := range locking { @@ -1174,18 +1213,36 @@ func (b *Builder) validateLockingForSelectClause( default: panic(errors.AssertionFailedf("unknown locking wait policy: %s", li.WaitPolicy)) } + + // Validate locking targets. + for _, target := range li.Targets { + // Insist on unqualified alias names here. We could probably do + // something smarter, but it's better to just mirror Postgres + // exactly. See transformLockingClause in Postgres' source. + if target.CatalogName != "" || target.SchemaName != "" { + panic(pgerror.Newf(pgcode.Syntax, + "%s must specify unqualified relation names", li.Strength)) + } + } } + + // TODO(nvanbenschoten): Postgres verifies that all locking targets + // point to real relations in the FROM clause. We should verify the + // same thing here. } // rejectIfLocking raises a locking error if a locking clause was specified. -func (b *Builder) rejectIfLocking(locking []*tree.LockingItem, context string) { - if len(locking) == 0 { +func (b *Builder) rejectIfLocking(locking lockingSpec, context string) { + if !locking.isSet() { + // No FOR [KEY] UPDATE/SHARE locking modes in scope. return } - b.raiseLockingError(locking[0], context) + b.raiseLockingContextError(locking, context) } -func (b *Builder) raiseLockingError(first *tree.LockingItem, context string) { +// raiseLockingContextError raises an error indicating that a row-level locking +// clause is not permitted in the specified context. locking.set must be true. +func (b *Builder) raiseLockingContextError(locking lockingSpec, context string) { panic(pgerror.Newf(pgcode.FeatureNotSupported, - "%s is not allowed with %s", first.Strength, context)) + "%s is not allowed with %s", locking.get().Strength, context)) } diff --git a/pkg/sql/opt/optbuilder/testdata/select_for_update b/pkg/sql/opt/optbuilder/testdata/select_for_update new file mode 100644 index 000000000000..4641b42dac54 --- /dev/null +++ b/pkg/sql/opt/optbuilder/testdata/select_for_update @@ -0,0 +1,743 @@ +exec-ddl +CREATE TABLE t (a INT PRIMARY KEY, b INT) +---- + +exec-ddl +CREATE TABLE u (a INT PRIMARY KEY, c INT) +---- + +exec-ddl +CREATE VIEW v AS SELECT a FROM t +---- + +# ------------------------------------------------------------------------------ +# Basic tests. +# ------------------------------------------------------------------------------ + +build +SELECT * FROM t FOR UPDATE +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM t FOR NO KEY UPDATE +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-no-key-update + +build +SELECT * FROM t FOR SHARE +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-share + +build +SELECT * FROM t FOR KEY SHARE +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-key-share + +build +SELECT * FROM t FOR KEY SHARE FOR SHARE +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-share + +build +SELECT * FROM t FOR KEY SHARE FOR SHARE FOR NO KEY UPDATE +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-no-key-update + +build +SELECT * FROM t FOR KEY SHARE FOR SHARE FOR NO KEY UPDATE FOR UPDATE +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM t FOR UPDATE OF t +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM t FOR UPDATE OF t2 +---- +scan t + └── columns: a:1(int!null) b:2(int) + +# ------------------------------------------------------------------------------ +# Tests with table aliases. +# ------------------------------------------------------------------------------ + +build +SELECT * FROM t AS t2 FOR UPDATE +---- +scan t2 + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM t AS t2 FOR UPDATE OF t +---- +scan t2 + └── columns: a:1(int!null) b:2(int) + +build +SELECT * FROM t AS t2 FOR UPDATE OF t2 +---- +scan t2 + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +# ------------------------------------------------------------------------------ +# Tests with numeric table references. +# Cockroach numeric references start after 53 for user tables. +# ------------------------------------------------------------------------------ + +build +SELECT * FROM [53 AS t] FOR UPDATE +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM [53 AS t] FOR UPDATE OF t +---- +scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM [53 AS t] FOR UPDATE OF t2 +---- +scan t + └── columns: a:1(int!null) b:2(int) + +# ------------------------------------------------------------------------------ +# Tests with views. +# ------------------------------------------------------------------------------ + +build +SELECT * FROM v FOR UPDATE +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM v FOR UPDATE OF v +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM v FOR UPDATE OF v2 +---- +project + ├── columns: a:1(int!null) + └── scan t + └── columns: a:1(int!null) b:2(int) + +build +SELECT * FROM v FOR UPDATE OF t +---- +project + ├── columns: a:1(int!null) + └── scan t + └── columns: a:1(int!null) b:2(int) + +# ------------------------------------------------------------------------------ +# Tests with aliased views. +# ------------------------------------------------------------------------------ + +build +SELECT * FROM v AS v2 FOR UPDATE +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM v AS v2 FOR UPDATE OF v +---- +project + ├── columns: a:1(int!null) + └── scan t + └── columns: a:1(int!null) b:2(int) + +build +SELECT * FROM v AS v2 FOR UPDATE OF v2 +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +# ------------------------------------------------------------------------------ +# Tests with subqueries. +# +# Row-level locking clauses only apply to subqueries in the FROM clause of a +# SELECT statement. They don't apply to subqueries in the projection or in +# the filter. +# ------------------------------------------------------------------------------ + +build +SELECT * FROM (SELECT a FROM t) FOR UPDATE +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM (SELECT a FROM t FOR UPDATE) +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM (SELECT a FROM t FOR NO KEY UPDATE) FOR KEY SHARE +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-no-key-update + +build +SELECT * FROM (SELECT a FROM t FOR KEY SHARE) FOR NO KEY UPDATE +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-no-key-update + +# TODO(nvanbenschoten): To match Postgres perfectly, this would throw an error. +# It's not clear that it's worth going out of our way to mirror that behavior. +build +SELECT * FROM (SELECT a FROM t) FOR UPDATE OF t +---- +project + ├── columns: a:1(int!null) + └── scan t + └── columns: a:1(int!null) b:2(int) + +build +SELECT * FROM (SELECT a FROM t FOR UPDATE OF t) +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM (SELECT a FROM t) AS r FOR UPDATE +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM (SELECT a FROM t FOR UPDATE) AS r +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM (SELECT a FROM t) AS r FOR UPDATE OF t +---- +project + ├── columns: a:1(int!null) + └── scan t + └── columns: a:1(int!null) b:2(int) + +build +SELECT * FROM (SELECT a FROM t FOR UPDATE OF t) AS r +---- +project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT (SELECT a FROM t) FOR UPDATE +---- +project + ├── columns: a:3(int) + ├── values + │ └── tuple [type=tuple] + └── projections + └── subquery [type=int] + └── max1-row + ├── columns: t.a:1(int!null) + └── project + ├── columns: t.a:1(int!null) + └── scan t + └── columns: t.a:1(int!null) b:2(int) + +build +SELECT (SELECT a FROM t FOR UPDATE) +---- +project + ├── columns: a:3(int) + ├── values + │ └── tuple [type=tuple] + └── projections + └── subquery [type=int] + └── max1-row + ├── columns: t.a:1(int!null) + └── project + ├── columns: t.a:1(int!null) + └── scan t + ├── columns: t.a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT (SELECT a FROM t) FOR UPDATE OF t +---- +project + ├── columns: a:3(int) + ├── values + │ └── tuple [type=tuple] + └── projections + └── subquery [type=int] + └── max1-row + ├── columns: t.a:1(int!null) + └── project + ├── columns: t.a:1(int!null) + └── scan t + └── columns: t.a:1(int!null) b:2(int) + +build +SELECT (SELECT a FROM t FOR UPDATE OF t) +---- +project + ├── columns: a:3(int) + ├── values + │ └── tuple [type=tuple] + └── projections + └── subquery [type=int] + └── max1-row + ├── columns: t.a:1(int!null) + └── project + ├── columns: t.a:1(int!null) + └── scan t + ├── columns: t.a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT (SELECT a FROM t) AS r FOR UPDATE +---- +project + ├── columns: r:3(int) + ├── values + │ └── tuple [type=tuple] + └── projections + └── subquery [type=int] + └── max1-row + ├── columns: a:1(int!null) + └── project + ├── columns: a:1(int!null) + └── scan t + └── columns: a:1(int!null) b:2(int) + +build +SELECT (SELECT a FROM t FOR UPDATE) AS r +---- +project + ├── columns: r:3(int) + ├── values + │ └── tuple [type=tuple] + └── projections + └── subquery [type=int] + └── max1-row + ├── columns: a:1(int!null) + └── project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT (SELECT a FROM t) AS r FOR UPDATE OF t +---- +project + ├── columns: r:3(int) + ├── values + │ └── tuple [type=tuple] + └── projections + └── subquery [type=int] + └── max1-row + ├── columns: a:1(int!null) + └── project + ├── columns: a:1(int!null) + └── scan t + └── columns: a:1(int!null) b:2(int) + +build +SELECT (SELECT a FROM t FOR UPDATE OF t) AS r +---- +project + ├── columns: r:3(int) + ├── values + │ └── tuple [type=tuple] + └── projections + └── subquery [type=int] + └── max1-row + ├── columns: a:1(int!null) + └── project + ├── columns: a:1(int!null) + └── scan t + ├── columns: a:1(int!null) b:2(int) + └── locking: for-update + +build +SELECT * FROM t WHERE a IN (SELECT a FROM t) FOR UPDATE +---- +select + ├── columns: a:1(int!null) b:2(int) + ├── scan t + │ ├── columns: a:1(int!null) b:2(int) + │ └── locking: for-update + └── filters + └── any: eq [type=bool] + ├── project + │ ├── columns: a:3(int!null) + │ └── scan t + │ └── columns: a:3(int!null) b:4(int) + └── variable: a [type=int] + +build +SELECT * FROM t WHERE a IN (SELECT a FROM t FOR UPDATE) +---- +select + ├── columns: a:1(int!null) b:2(int) + ├── scan t + │ └── columns: a:1(int!null) b:2(int) + └── filters + └── any: eq [type=bool] + ├── project + │ ├── columns: a:3(int!null) + │ └── scan t + │ ├── columns: a:3(int!null) b:4(int) + │ └── locking: for-update + └── variable: a [type=int] + +build +SELECT * FROM t WHERE a IN (SELECT a FROM t) FOR UPDATE OF t +---- +select + ├── columns: a:1(int!null) b:2(int) + ├── scan t + │ ├── columns: a:1(int!null) b:2(int) + │ └── locking: for-update + └── filters + └── any: eq [type=bool] + ├── project + │ ├── columns: a:3(int!null) + │ └── scan t + │ └── columns: a:3(int!null) b:4(int) + └── variable: a [type=int] + +build +SELECT * FROM t WHERE a IN (SELECT a FROM t FOR UPDATE OF t) +---- +select + ├── columns: a:1(int!null) b:2(int) + ├── scan t + │ └── columns: a:1(int!null) b:2(int) + └── filters + └── any: eq [type=bool] + ├── project + │ ├── columns: a:3(int!null) + │ └── scan t + │ ├── columns: a:3(int!null) b:4(int) + │ └── locking: for-update + └── variable: a [type=int] + +# ------------------------------------------------------------------------------ +# Tests with common-table expressions. +# +# Unlike with subqueries, row-level locking clauses do not apply to WITH queries +# referenced by the primary query. To achieve row locking within a WITH query, a +# locking clause should be specified within the WITH query. +# ------------------------------------------------------------------------------ + +build +SELECT * FROM [SELECT a FROM t] FOR UPDATE +---- +with &1 + ├── columns: a:3(int!null) + ├── project + │ ├── columns: t.a:1(int!null) + │ └── scan t + │ └── columns: t.a:1(int!null) b:2(int) + └── with-scan &1 + ├── columns: a:3(int!null) + └── mapping: + └── t.a:1(int) => a:3(int) + +build +WITH cte AS (SELECT a FROM t) SELECT * FROM cte FOR UPDATE +---- +with &1 (cte) + ├── columns: a:3(int!null) + ├── project + │ ├── columns: t.a:1(int!null) + │ └── scan t + │ └── columns: t.a:1(int!null) b:2(int) + └── with-scan &1 (cte) + ├── columns: a:3(int!null) + └── mapping: + └── t.a:1(int) => a:3(int) + +build +SELECT * FROM [SELECT a FROM t FOR UPDATE] +---- +with &1 + ├── columns: a:3(int!null) + ├── project + │ ├── columns: t.a:1(int!null) + │ └── scan t + │ ├── columns: t.a:1(int!null) b:2(int) + │ └── locking: for-update + └── with-scan &1 + ├── columns: a:3(int!null) + └── mapping: + └── t.a:1(int) => a:3(int) + +build +WITH cte AS (SELECT a FROM t FOR UPDATE) SELECT * FROM cte +---- +with &1 (cte) + ├── columns: a:3(int!null) + ├── project + │ ├── columns: t.a:1(int!null) + │ └── scan t + │ ├── columns: t.a:1(int!null) b:2(int) + │ └── locking: for-update + └── with-scan &1 (cte) + ├── columns: a:3(int!null) + └── mapping: + └── t.a:1(int) => a:3(int) + +# ------------------------------------------------------------------------------ +# Tests with joins. +# ------------------------------------------------------------------------------ + +build +SELECT * FROM t JOIN u USING (a) FOR UPDATE +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t.a:1(int!null) b:2(int) u.a:3(int!null) c:4(int) + ├── scan t + │ ├── columns: t.a:1(int!null) b:2(int) + │ └── locking: for-update + ├── scan u + │ ├── columns: u.a:3(int!null) c:4(int) + │ └── locking: for-update + └── filters + └── eq [type=bool] + ├── variable: t.a [type=int] + └── variable: u.a [type=int] + +build +SELECT * FROM t JOIN u USING (a) FOR UPDATE OF t +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t.a:1(int!null) b:2(int) u.a:3(int!null) c:4(int) + ├── scan t + │ ├── columns: t.a:1(int!null) b:2(int) + │ └── locking: for-update + ├── scan u + │ └── columns: u.a:3(int!null) c:4(int) + └── filters + └── eq [type=bool] + ├── variable: t.a [type=int] + └── variable: u.a [type=int] + +build +SELECT * FROM t JOIN u USING (a) FOR UPDATE OF u +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t.a:1(int!null) b:2(int) u.a:3(int!null) c:4(int) + ├── scan t + │ └── columns: t.a:1(int!null) b:2(int) + ├── scan u + │ ├── columns: u.a:3(int!null) c:4(int) + │ └── locking: for-update + └── filters + └── eq [type=bool] + ├── variable: t.a [type=int] + └── variable: u.a [type=int] + +build +SELECT * FROM t JOIN u USING (a) FOR UPDATE OF t, u +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t.a:1(int!null) b:2(int) u.a:3(int!null) c:4(int) + ├── scan t + │ ├── columns: t.a:1(int!null) b:2(int) + │ └── locking: for-update + ├── scan u + │ ├── columns: u.a:3(int!null) c:4(int) + │ └── locking: for-update + └── filters + └── eq [type=bool] + ├── variable: t.a [type=int] + └── variable: u.a [type=int] + +build +SELECT * FROM t JOIN u USING (a) FOR UPDATE OF t FOR SHARE OF u +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t.a:1(int!null) b:2(int) u.a:3(int!null) c:4(int) + ├── scan t + │ ├── columns: t.a:1(int!null) b:2(int) + │ └── locking: for-update + ├── scan u + │ ├── columns: u.a:3(int!null) c:4(int) + │ └── locking: for-share + └── filters + └── eq [type=bool] + ├── variable: t.a [type=int] + └── variable: u.a [type=int] + +build +SELECT * FROM t JOIN u USING (a) FOR UPDATE OF t2 FOR SHARE OF u2 +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t.a:1(int!null) b:2(int) u.a:3(int!null) c:4(int) + ├── scan t + │ └── columns: t.a:1(int!null) b:2(int) + ├── scan u + │ └── columns: u.a:3(int!null) c:4(int) + └── filters + └── eq [type=bool] + ├── variable: t.a [type=int] + └── variable: u.a [type=int] + +build +SELECT * FROM t AS t2 JOIN u AS u2 USING (a) FOR UPDATE OF t2 FOR SHARE OF u2 +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t2.a:1(int!null) b:2(int) u2.a:3(int!null) c:4(int) + ├── scan t2 + │ ├── columns: t2.a:1(int!null) b:2(int) + │ └── locking: for-update + ├── scan u2 + │ ├── columns: u2.a:3(int!null) c:4(int) + │ └── locking: for-share + └── filters + └── eq [type=bool] + ├── variable: t2.a [type=int] + └── variable: u2.a [type=int] + +build +SELECT * FROM t JOIN u USING (a) FOR KEY SHARE FOR UPDATE +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t.a:1(int!null) b:2(int) u.a:3(int!null) c:4(int) + ├── scan t + │ ├── columns: t.a:1(int!null) b:2(int) + │ └── locking: for-update + ├── scan u + │ ├── columns: u.a:3(int!null) c:4(int) + │ └── locking: for-update + └── filters + └── eq [type=bool] + ├── variable: t.a [type=int] + └── variable: u.a [type=int] + +build +SELECT * FROM t JOIN u USING (a) FOR KEY SHARE FOR NO KEY UPDATE OF t +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t.a:1(int!null) b:2(int) u.a:3(int!null) c:4(int) + ├── scan t + │ ├── columns: t.a:1(int!null) b:2(int) + │ └── locking: for-no-key-update + ├── scan u + │ ├── columns: u.a:3(int!null) c:4(int) + │ └── locking: for-key-share + └── filters + └── eq [type=bool] + ├── variable: t.a [type=int] + └── variable: u.a [type=int] + +build +SELECT * FROM t JOIN u USING (a) FOR SHARE FOR NO KEY UPDATE OF t FOR UPDATE OF u +---- +project + ├── columns: a:1(int!null) b:2(int) c:4(int) + └── inner-join (hash) + ├── columns: t.a:1(int!null) b:2(int) u.a:3(int!null) c:4(int) + ├── scan t + │ ├── columns: t.a:1(int!null) b:2(int) + │ └── locking: for-no-key-update + ├── scan u + │ ├── columns: u.a:3(int!null) c:4(int) + │ └── locking: for-update + └── filters + └── eq [type=bool] + ├── variable: t.a [type=int] + └── variable: u.a [type=int] + +# ------------------------------------------------------------------------------ +# Tests with virtual tables. +# ------------------------------------------------------------------------------ + +build +SELECT * FROM information_schema.columns FOR UPDATE +---- +error (42601): FOR UPDATE not allowed with virtual tables diff --git a/pkg/sql/opt/optbuilder/update.go b/pkg/sql/opt/optbuilder/update.go index 6d4316c10d99..52bd180e476b 100644 --- a/pkg/sql/opt/optbuilder/update.go +++ b/pkg/sql/opt/optbuilder/update.go @@ -143,7 +143,7 @@ func (mb *mutationBuilder) addTargetColsForUpdate(exprs tree.UpdateExprs) { for i := range desiredTypes { desiredTypes[i] = mb.md.ColumnMeta(mb.targetColList[targetIdx+i]).Type } - outScope := mb.b.buildSelectStmt(t.Select, desiredTypes, mb.outScope) + outScope := mb.b.buildSelectStmt(t.Select, noRowLocking, desiredTypes, mb.outScope) mb.subqueries = append(mb.subqueries, outScope) n = len(outScope.cols) diff --git a/pkg/sql/opt/optgen/cmd/optgen/metadata.go b/pkg/sql/opt/optgen/cmd/optgen/metadata.go index 038db821bb44..ddc0b3de6052 100644 --- a/pkg/sql/opt/optgen/cmd/optgen/metadata.go +++ b/pkg/sql/opt/optgen/cmd/optgen/metadata.go @@ -183,6 +183,7 @@ func newMetadata(compiled *lang.CompiledExpr, pkg string) *metadata { "JobCommand": {fullName: "tree.JobCommand", passByVal: true}, "IndexOrdinal": {fullName: "cat.IndexOrdinal", passByVal: true}, "ViewDeps": {fullName: "opt.ViewDeps", passByVal: true}, + "LockingItem": {fullName: "*tree.LockingItem", isPointer: true}, } // Add types of generated op and private structs. diff --git a/pkg/sql/opt/xform/testdata/rules/groupby b/pkg/sql/opt/xform/testdata/rules/groupby index df5ef2ee8716..0b7518f1ef6f 100644 --- a/pkg/sql/opt/xform/testdata/rules/groupby +++ b/pkg/sql/opt/xform/testdata/rules/groupby @@ -802,7 +802,7 @@ memo (optimized, ~5KB, required=[presentation: array_agg:5]) memo SELECT array_agg(k) FROM (SELECT * FROM kuvw WHERE u=v ORDER BY u) GROUP BY w ---- -memo (optimized, ~9KB, required=[presentation: array_agg:5]) +memo (optimized, ~10KB, required=[presentation: array_agg:5]) ├── G1: (project G2 G3 array_agg) │ └── [presentation: array_agg:5] │ ├── best: (project G2 G3 array_agg) diff --git a/pkg/sql/opt/xform/testdata/rules/join b/pkg/sql/opt/xform/testdata/rules/join index 6bdd9bb13939..858ad95e605a 100644 --- a/pkg/sql/opt/xform/testdata/rules/join +++ b/pkg/sql/opt/xform/testdata/rules/join @@ -190,7 +190,7 @@ full-join (hash) memo SELECT * FROM abc INNER LOOKUP JOIN xyz ON a=x ---- -memo (optimized, ~9KB, required=[presentation: a:1,b:2,c:3,x:5,y:6,z:7]) +memo (optimized, ~10KB, required=[presentation: a:1,b:2,c:3,x:5,y:6,z:7]) ├── G1: (inner-join G2 G3 G4) (inner-join G3 G2 G4) (lookup-join G2 G5 xyz@xy,keyCols=[1],outCols=(1-3,5-7)) │ └── [presentation: a:1,b:2,c:3,x:5,y:6,z:7] │ ├── best: (lookup-join G2 G5 xyz@xy,keyCols=[1],outCols=(1-3,5-7)) @@ -325,7 +325,7 @@ inner-join (merge) memo SELECT * FROM abc JOIN xyz ON a=x ---- -memo (optimized, ~11KB, required=[presentation: a:1,b:2,c:3,x:5,y:6,z:7]) +memo (optimized, ~12KB, required=[presentation: a:1,b:2,c:3,x:5,y:6,z:7]) ├── G1: (inner-join G2 G3 G4) (inner-join G3 G2 G4) (merge-join G2 G3 G5 inner-join,+1,+5) (lookup-join G2 G5 xyz@xy,keyCols=[1],outCols=(1-3,5-7)) (merge-join G3 G2 G5 inner-join,+5,+1) (lookup-join G3 G5 abc@ab,keyCols=[5],outCols=(1-3,5-7)) │ └── [presentation: a:1,b:2,c:3,x:5,y:6,z:7] │ ├── best: (merge-join G2="[ordering: +1]" G3="[ordering: +5]" G5 inner-join,+1,+5) diff --git a/pkg/sql/sem/tree/select.go b/pkg/sql/sem/tree/select.go index 2fead3a088d3..bd4ba9d0a70a 100644 --- a/pkg/sql/sem/tree/select.go +++ b/pkg/sql/sem/tree/select.go @@ -983,26 +983,28 @@ func (f *LockingItem) Format(ctx *FmtCtx) { // statement. type LockingStrength byte +// The ordering of the variants is important, because the highest numerical +// value takes precedence when row-level locking is specified multiple ways. const ( // ForNone represents the default - no for statement at all. // LockingItem AST nodes are never created with this strength. ForNone LockingStrength = iota - // ForUpdate represents FOR UPDATE. - ForUpdate - // ForNoKeyUpdate represents FOR NO KEY UPDATE. - ForNoKeyUpdate - // ForShare represents FOR SHARE. - ForShare // ForKeyShare represents FOR KEY SHARE. ForKeyShare + // ForShare represents FOR SHARE. + ForShare + // ForNoKeyUpdate represents FOR NO KEY UPDATE. + ForNoKeyUpdate + // ForUpdate represents FOR UPDATE. + ForUpdate ) var lockingStrengthName = [...]string{ ForNone: "", - ForUpdate: "FOR UPDATE", - ForNoKeyUpdate: "FOR NO KEY UPDATE", - ForShare: "FOR SHARE", ForKeyShare: "FOR KEY SHARE", + ForShare: "FOR SHARE", + ForNoKeyUpdate: "FOR NO KEY UPDATE", + ForUpdate: "FOR UPDATE", } func (s LockingStrength) String() string { @@ -1017,11 +1019,18 @@ func (s LockingStrength) Format(ctx *FmtCtx) { } } +// Max returns the maximum of the two locking strengths. +func (s LockingStrength) Max(s2 LockingStrength) LockingStrength { + return LockingStrength(max(byte(s), byte(s2))) +} + // LockingWaitPolicy represents the possible policies for dealing with rows // being locked by FOR UPDATE/SHARE clauses (i.e., it represents the NOWAIT // and SKIP LOCKED options). type LockingWaitPolicy byte +// The ordering of the variants is important, because the highest numerical +// value takes precedence when row-level locking is specified multiple ways. const ( // LockWaitBlock represents the default - wait for the lock to become // available. @@ -1050,3 +1059,15 @@ func (p LockingWaitPolicy) Format(ctx *FmtCtx) { ctx.WriteString(p.String()) } } + +// Max returns the maximum of the two locking wait policies. +func (p LockingWaitPolicy) Max(p2 LockingWaitPolicy) LockingWaitPolicy { + return LockingWaitPolicy(max(byte(p), byte(p2))) +} + +func max(a, b byte) byte { + if a > b { + return a + } + return b +}