diff --git a/pkg/sql/delete.go b/pkg/sql/delete.go index 1572e80e76a6..7b10fd47cbd6 100644 --- a/pkg/sql/delete.go +++ b/pkg/sql/delete.go @@ -125,6 +125,9 @@ func (p *planner) Delete( requestedCols = desc.Columns } + // Since all columns are being returned, use the 1:1 mapping. See todo above. + rowIdxToRetIdx := mutationRowIdxToReturnIdx(requestedCols, requestedCols) + // Create the table deleter, which does the bulk of the work. rd, err := row.MakeDeleter( p.txn, desc, fkTables, requestedCols, row.CheckFKs, p.EvalContext(), &p.alloc, @@ -175,6 +178,7 @@ func (p *planner) Delete( td: tableDeleter{rd: rd, alloc: &p.alloc}, rowsNeeded: rowsNeeded, fastPathInterleaved: canDeleteFastInterleaved(desc, fkTables), + rowIdxToRetIdx: rowIdxToRetIdx, }, } @@ -208,6 +212,13 @@ type deleteRun struct { // traceKV caches the current KV tracing flag. traceKV bool + + // rowIdxToRetIdx is the mapping from the columns returned by the deleter + // to the columns in the resultRowBuffer. A value of -1 is used to indicate + // that the column at that index is not part of the resultRowBuffer + // of the mutation. Otherwise, the value at the i-th index refers to the + // index of the resultRowBuffer where the i-th column is to be returned. + rowIdxToRetIdx []int } // maxDeleteBatchSize is the max number of entries in the KV batch for @@ -331,9 +342,15 @@ func (d *deleteNode) processSourceRow(params runParams, sourceVals tree.Datums) // contain additional columns for every newly dropped column not // visible. We do not want them to be available for RETURNING. // - // d.columns is guaranteed to only contain the requested + // d.run.rows.NumCols() is guaranteed to only contain the requested // public columns. - resultValues := sourceVals[:len(d.columns)] + resultValues := make(tree.Datums, d.run.rows.NumCols()) + for i, retIdx := range d.run.rowIdxToRetIdx { + if retIdx >= 0 { + resultValues[retIdx] = sourceVals[i] + } + } + if _, err := d.run.rows.AddRow(params.ctx, resultValues); err != nil { return err } diff --git a/pkg/sql/insert.go b/pkg/sql/insert.go index f72dcbb99234..9e566d75db73 100644 --- a/pkg/sql/insert.go +++ b/pkg/sql/insert.go @@ -286,6 +286,12 @@ func (p *planner) Insert( columns = sqlbase.ResultColumnsFromColDescs(desc.Columns) } + // Since all columns are being returned, use the 1:1 mapping. + tabColIdxToRetIdx := make([]int, len(desc.Columns)) + for i := range tabColIdxToRetIdx { + tabColIdxToRetIdx[i] = i + } + // At this point, everything is ready for either an insertNode or an upserNode. var node batchedPlanNode @@ -315,8 +321,9 @@ func (p *planner) Insert( Cols: desc.Columns, Mapping: ri.InsertColIDtoRowIndex, }, - defaultExprs: defaultExprs, - insertCols: ri.InsertCols, + defaultExprs: defaultExprs, + insertCols: ri.InsertCols, + tabColIdxToRetIdx: tabColIdxToRetIdx, }, } node = in @@ -368,12 +375,21 @@ type insertRun struct { // into the row container above, when rowsNeeded is set. resultRowBuffer tree.Datums - // rowIdxToRetIdx is the mapping from the ordering of rows in - // insertCols to the ordering in the result rows, used when + // rowIdxToTabColIdx is the mapping from the ordering of rows in + // insertCols to the ordering in the rows in the table, used when // rowsNeeded is set to populate resultRowBuffer and the row // container. The return index is -1 if the column for the row - // index is not public. - rowIdxToRetIdx []int + // index is not public. This is used in conjunction with tabIdxToRetIdx + // to populate the resultRowBuffer. + rowIdxToTabColIdx []int + + // tabColIdxToRetIdx is the mapping from the columns in the table to the + // columns in the resultRowBuffer. A value of -1 is used to indicate + // that the table column at that index is not part of the resultRowBuffer + // of the mutation. Otherwise, the value at the i-th index refers to the + // index of the resultRowBuffer where the i-th column of the table is + // to be returned. + tabColIdxToRetIdx []int // traceKV caches the current KV tracing flag. traceKV bool @@ -405,8 +421,8 @@ func (n *insertNode) startExec(params runParams) error { // re-ordering the data into resultRowBuffer. // // Also we need to re-order the values in the source, ordered by - // insertCols, when writing them to resultRowBuffer, ordered by - // n.columns. This uses the rowIdxToRetIdx mapping. + // insertCols, when writing them to resultRowBuffer, according to + // the rowIdxToTabColIdx mapping. n.run.resultRowBuffer = make(tree.Datums, len(n.columns)) for i := range n.run.resultRowBuffer { @@ -419,13 +435,13 @@ func (n *insertNode) startExec(params runParams) error { colIDToRetIndex[cols[i].ID] = i } - n.run.rowIdxToRetIdx = make([]int, len(n.run.insertCols)) + n.run.rowIdxToTabColIdx = make([]int, len(n.run.insertCols)) for i, col := range n.run.insertCols { if idx, ok := colIDToRetIndex[col.ID]; !ok { // Column must be write only and not public. - n.run.rowIdxToRetIdx[i] = -1 + n.run.rowIdxToTabColIdx[i] = -1 } else { - n.run.rowIdxToRetIdx[i] = idx + n.run.rowIdxToTabColIdx[i] = idx } } } @@ -567,10 +583,13 @@ func (n *insertNode) processSourceRow(params runParams, sourceVals tree.Datums) // The downstream consumer will want the rows in the order of // the table descriptor, not that of insertCols. Reorder them // and ignore non-public columns. - if idx := n.run.rowIdxToRetIdx[i]; idx >= 0 { - n.run.resultRowBuffer[idx] = val + if tabIdx := n.run.rowIdxToTabColIdx[i]; tabIdx >= 0 { + if retIdx := n.run.tabColIdxToRetIdx[tabIdx]; retIdx >= 0 { + n.run.resultRowBuffer[retIdx] = val + } } } + if _, err := n.run.rows.AddRow(params.ctx, n.run.resultRowBuffer); err != nil { return err } diff --git a/pkg/sql/opt/bench/stub_factory.go b/pkg/sql/opt/bench/stub_factory.go index 7eb7b9ad500f..10aad2b9fb90 100644 --- a/pkg/sql/opt/bench/stub_factory.go +++ b/pkg/sql/opt/bench/stub_factory.go @@ -222,8 +222,8 @@ func (f *stubFactory) ConstructInsert( input exec.Node, table cat.Table, insertCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, skipFKChecks bool, ) (exec.Node, error) { return struct{}{}, nil @@ -234,8 +234,8 @@ func (f *stubFactory) ConstructUpdate( table cat.Table, fetchCols exec.ColumnOrdinalSet, updateCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, ) (exec.Node, error) { return struct{}{}, nil } @@ -247,14 +247,17 @@ func (f *stubFactory) ConstructUpsert( insertCols exec.ColumnOrdinalSet, fetchCols exec.ColumnOrdinalSet, updateCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, ) (exec.Node, error) { return struct{}{}, nil } func (f *stubFactory) ConstructDelete( - input exec.Node, table cat.Table, fetchCols exec.ColumnOrdinalSet, rowsNeeded bool, + input exec.Node, + table cat.Table, + fetchCols exec.ColumnOrdinalSet, + returnCols exec.ColumnOrdinalSet, ) (exec.Node, error) { return struct{}{}, nil } diff --git a/pkg/sql/opt/exec/execbuilder/mutation.go b/pkg/sql/opt/exec/execbuilder/mutation.go index 1664bbdd828a..0d52d6898fab 100644 --- a/pkg/sql/opt/exec/execbuilder/mutation.go +++ b/pkg/sql/opt/exec/execbuilder/mutation.go @@ -47,14 +47,15 @@ func (b *Builder) buildInsert(ins *memo.InsertExpr) (execPlan, error) { tab := b.mem.Metadata().Table(ins.Table) insertOrds := ordinalSetFromColList(ins.InsertCols) checkOrds := ordinalSetFromColList(ins.CheckCols) + returnOrds := ordinalSetFromColList(ins.ReturnCols) // If we planned FK checks, disable the execution code for FK checks. disableExecFKs := len(ins.Checks) > 0 node, err := b.factory.ConstructInsert( input.root, tab, insertOrds, + returnOrds, checkOrds, - ins.NeedResults(), disableExecFKs, ) if err != nil { @@ -106,14 +107,15 @@ func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (execPlan, error) { tab := md.Table(upd.Table) fetchColOrds := ordinalSetFromColList(upd.FetchCols) updateColOrds := ordinalSetFromColList(upd.UpdateCols) + returnColOrds := ordinalSetFromColList(upd.ReturnCols) checkOrds := ordinalSetFromColList(upd.CheckCols) node, err := b.factory.ConstructUpdate( input.root, tab, fetchColOrds, updateColOrds, + returnColOrds, checkOrds, - upd.NeedResults(), ) if err != nil { return execPlan{}, err @@ -177,6 +179,7 @@ func (b *Builder) buildUpsert(ups *memo.UpsertExpr) (execPlan, error) { insertColOrds := ordinalSetFromColList(ups.InsertCols) fetchColOrds := ordinalSetFromColList(ups.FetchCols) updateColOrds := ordinalSetFromColList(ups.UpdateCols) + returnColOrds := ordinalSetFromColList(ups.ReturnCols) checkOrds := ordinalSetFromColList(ups.CheckCols) node, err := b.factory.ConstructUpsert( input.root, @@ -185,8 +188,8 @@ func (b *Builder) buildUpsert(ups *memo.UpsertExpr) (execPlan, error) { insertColOrds, fetchColOrds, updateColOrds, + returnColOrds, checkOrds, - ups.NeedResults(), ) if err != nil { return execPlan{}, err @@ -230,7 +233,8 @@ func (b *Builder) buildDelete(del *memo.DeleteExpr) (execPlan, error) { md := b.mem.Metadata() tab := md.Table(del.Table) fetchColOrds := ordinalSetFromColList(del.FetchCols) - node, err := b.factory.ConstructDelete(input.root, tab, fetchColOrds, del.NeedResults()) + returnColOrds := ordinalSetFromColList(del.ReturnCols) + node, err := b.factory.ConstructDelete(input.root, tab, fetchColOrds, returnColOrds) if err != nil { return execPlan{}, err } @@ -310,6 +314,9 @@ func appendColsWhenPresent(dst, src opt.ColList) opt.ColList { // indicating columns that are not involved in the mutation. func ordinalSetFromColList(colList opt.ColList) exec.ColumnOrdinalSet { var res util.FastIntSet + if colList == nil { + return res + } for i, col := range colList { if col != 0 { res.Add(i) diff --git a/pkg/sql/opt/exec/execbuilder/testdata/ddl b/pkg/sql/opt/exec/execbuilder/testdata/ddl index 6f6bbfd91f7e..5b75e361dbe6 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/ddl +++ b/pkg/sql/opt/exec/execbuilder/testdata/ddl @@ -233,19 +233,17 @@ COMMIT query TTTTT colnames EXPLAIN (VERBOSE) SELECT * FROM v ---- -tree field description columns ordering -render · · (k) · - │ render 0 k · · - └── run · · (k, v, z) · - └── update · · (k, v, z) · - │ table kv · · - │ set v · · - │ strategy updater · · - └── render · · (k, v, z, column7) · - │ render 0 k · · - │ render 1 v · · - │ render 2 z · · - │ render 3 444 · · - └── scan · · (k, v, z) · -· table kv@primary · · -· spans /1- · · +tree field description columns ordering +run · · (k) · + └── update · · (k) · + │ table kv · · + │ set v · · + │ strategy updater · · + └── render · · (k, v, z, column7) · + │ render 0 k · · + │ render 1 v · · + │ render 2 z · · + │ render 3 444 · · + └── scan · · (k, v, z) · +· table kv@primary · · +· spans /1- · · diff --git a/pkg/sql/opt/exec/execbuilder/testdata/delete b/pkg/sql/opt/exec/execbuilder/testdata/delete index a496b8a4df5e..0df0b0564a73 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/delete +++ b/pkg/sql/opt/exec/execbuilder/testdata/delete @@ -146,14 +146,11 @@ count · · query TTT EXPLAIN DELETE FROM indexed WHERE value = 5 LIMIT 10 RETURNING id ---- -render · · - └── run · · - └── delete · · - │ from indexed - │ strategy deleter - └── index-join · · - │ table indexed@primary - └── scan · · -· table indexed@indexed_value_idx -· spans /5-/6 -· limit 10 +run · · + └── delete · · + │ from indexed + │ strategy deleter + └── scan · · +· table indexed@indexed_value_idx +· spans /5-/6 +· limit 10 diff --git a/pkg/sql/opt/exec/execbuilder/testdata/insert b/pkg/sql/opt/exec/execbuilder/testdata/insert index 16539a03e2b1..e802c9a4bd79 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/insert +++ b/pkg/sql/opt/exec/execbuilder/testdata/insert @@ -480,20 +480,20 @@ CREATE TABLE xyz (x INT, y INT, z INT) query TTTTT EXPLAIN (VERBOSE) SELECT * FROM [INSERT INTO xyz SELECT a, b, c FROM abc RETURNING z] ORDER BY z ---- -render · · (z) +z - │ render 0 z · · - └── run · · (x, y, z, rowid[hidden]) · - └── insert · · (x, y, z, rowid[hidden]) · - │ into xyz(x, y, z, rowid) · · - │ strategy inserter · · - └── render · · (a, b, c, column9) +c - │ render 0 a · · - │ render 1 b · · - │ render 2 c · · - │ render 3 unique_rowid() · · - └── scan · · (a, b, c) +c -· table abc@abc_c_idx · · -· spans ALL · · +render · · (z) +z + │ render 0 z · · + └── run · · (z, rowid[hidden]) · + └── insert · · (z, rowid[hidden]) · + │ into xyz(x, y, z, rowid) · · + │ strategy inserter · · + └── render · · (a, b, c, column9) +c + │ render 0 a · · + │ render 1 b · · + │ render 2 c · · + │ render 3 unique_rowid() · · + └── scan · · (a, b, c) +c +· table abc@abc_c_idx · · +· spans ALL · · # ------------------------------------------------------------------------------ # Regression for #35364. This tests behavior that is different between the CBO diff --git a/pkg/sql/opt/exec/execbuilder/testdata/orderby b/pkg/sql/opt/exec/execbuilder/testdata/orderby index 1a370145d790..aac322bf3908 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/orderby +++ b/pkg/sql/opt/exec/execbuilder/testdata/orderby @@ -483,8 +483,8 @@ EXPLAIN (VERBOSE) INSERT INTO t(a, b) SELECT * FROM (SELECT 1 AS x, 2 AS y) ORDE ---- render · · (b) · │ render 0 b · · - └── run · · (a, b, c) · - └── insert · · (a, b, c) · + └── run · · (a, b) · + └── insert · · (a, b) · │ into t(a, b, c) · · │ strategy inserter · · └── values · · (x, y, column6) · @@ -496,23 +496,23 @@ render · · (b) · query TTTTT EXPLAIN (VERBOSE) DELETE FROM t WHERE a = 3 RETURNING b ---- -render · · (b) · - │ render 0 b · · - └── run · · (a, b, c) · - └── delete · · (a, b, c) · - │ from t · · - │ strategy deleter · · - └── scan · · (a, b, c) · -· table t@primary · · -· spans /3-/3/# · · +render · · (b) · + │ render 0 b · · + └── run · · (a, b) · + └── delete · · (a, b) · + │ from t · · + │ strategy deleter · · + └── scan · · (a, b) · +· table t@primary · · +· spans /3-/3/# · · query TTTTT EXPLAIN (VERBOSE) UPDATE t SET c = TRUE RETURNING b ---- render · · (b) · │ render 0 b · · - └── run · · (a, b, c) · - └── update · · (a, b, c) · + └── run · · (a, b) · + └── update · · (a, b) · │ table t · · │ set c · · │ strategy updater · · diff --git a/pkg/sql/opt/exec/execbuilder/testdata/update b/pkg/sql/opt/exec/execbuilder/testdata/update index a55117f4709a..d11f3696a773 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/update +++ b/pkg/sql/opt/exec/execbuilder/testdata/update @@ -308,8 +308,8 @@ EXPLAIN (VERBOSE) SELECT * FROM [ UPDATE abc SET a=c RETURNING a ] ORDER BY a ---- render · · (a) +a │ render 0 a · · - └── run · · (a, b, c, rowid[hidden]) · - └── update · · (a, b, c, rowid[hidden]) · + └── run · · (a, rowid[hidden]) · + └── update · · (a, rowid[hidden]) · │ table abc · · │ set a · · │ strategy updater · · diff --git a/pkg/sql/opt/exec/execbuilder/testdata/upsert b/pkg/sql/opt/exec/execbuilder/testdata/upsert index 41c44a9e95f6..ebe20303b4fe 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/upsert +++ b/pkg/sql/opt/exec/execbuilder/testdata/upsert @@ -327,8 +327,8 @@ EXPLAIN (VERBOSE) SELECT * FROM [UPSERT INTO xyz SELECT a, b, c FROM abc RETURNI ---- render · · (z) +z │ render 0 z · · - └── run · · (x, y, z, rowid[hidden]) · - └── upsert · · (x, y, z, rowid[hidden]) · + └── run · · (z, rowid[hidden]) · + └── upsert · · (z, rowid[hidden]) · │ into xyz(x, y, z, rowid) · · │ strategy opt upserter · · └── render · · (a, b, c, column9, a, b, c) +c diff --git a/pkg/sql/opt/exec/factory.go b/pkg/sql/opt/exec/factory.go index f8344ad064be..be7e6a8a596f 100644 --- a/pkg/sql/opt/exec/factory.go +++ b/pkg/sql/opt/exec/factory.go @@ -304,8 +304,8 @@ type Factory interface { input Node, table cat.Table, insertCols ColumnOrdinalSet, + returnCols ColumnOrdinalSet, checks CheckOrdinalSet, - rowsNeeded bool, skipFKChecks bool, ) (Node, error) @@ -326,8 +326,8 @@ type Factory interface { table cat.Table, fetchCols ColumnOrdinalSet, updateCols ColumnOrdinalSet, + returnCols ColumnOrdinalSet, checks CheckOrdinalSet, - rowsNeeded bool, ) (Node, error) // ConstructUpsert creates a node that implements an INSERT..ON CONFLICT or @@ -360,8 +360,8 @@ type Factory interface { insertCols ColumnOrdinalSet, fetchCols ColumnOrdinalSet, updateCols ColumnOrdinalSet, + returnCols ColumnOrdinalSet, checks CheckOrdinalSet, - rowsNeeded bool, ) (Node, error) // ConstructDelete creates a node that implements a DELETE statement. The @@ -373,7 +373,7 @@ type Factory interface { // as they appear in the table schema. The rowsNeeded parameter is true if a // RETURNING clause needs the deleted row(s) as output. ConstructDelete( - input Node, table cat.Table, fetchCols ColumnOrdinalSet, rowsNeeded bool, + input Node, table cat.Table, fetchCols ColumnOrdinalSet, returnCols ColumnOrdinalSet, ) (Node, error) // ConstructDeleteRange creates a node that efficiently deletes contiguous diff --git a/pkg/sql/opt/memo/expr.go b/pkg/sql/opt/memo/expr.go index 9e1503afcf26..4c1bdc79986c 100644 --- a/pkg/sql/opt/memo/expr.go +++ b/pkg/sql/opt/memo/expr.go @@ -381,6 +381,12 @@ func (m *MutationPrivate) NeedResults() bool { return m.ReturnCols != nil } +// IsColumnOutput returns true if the i-th ordinal column should be part of the +// mutation's output columns. +func (m *MutationPrivate) IsColumnOutput(i int) bool { + return i < len(m.ReturnCols) && m.ReturnCols[i] != 0 +} + // MapToInputID maps from the ID of a returned column to the ID of the // corresponding input column that provides the value for it. If there is no // matching input column ID, MapToInputID returns 0. diff --git a/pkg/sql/opt/memo/logical_props_builder.go b/pkg/sql/opt/memo/logical_props_builder.go index 19326edff425..4b4d99fc52e6 100644 --- a/pkg/sql/opt/memo/logical_props_builder.go +++ b/pkg/sql/opt/memo/logical_props_builder.go @@ -1088,8 +1088,10 @@ func (b *logicalPropsBuilder) buildMutationProps(mutation RelExpr, rel *props.Re // -------------- // Only non-mutation columns are output columns. for i, n := 0, tab.ColumnCount(); i < n; i++ { - colID := private.Table.ColumnID(i) - rel.OutputCols.Add(colID) + if private.IsColumnOutput(i) { + colID := private.Table.ColumnID(i) + rel.OutputCols.Add(colID) + } } // Not Null Columns diff --git a/pkg/sql/opt/memo/testdata/logprops/delete b/pkg/sql/opt/memo/testdata/logprops/delete index e09e4a16736a..c6adbfdc53f7 100644 --- a/pkg/sql/opt/memo/testdata/logprops/delete +++ b/pkg/sql/opt/memo/testdata/logprops/delete @@ -57,6 +57,7 @@ project ├── side-effects, mutations ├── key: (5) ├── fd: ()-->(1), (5)-->(2-4) + ├── prune: (1-4) └── select ├── columns: a:7(int!null) b:8(int) c:9(int!null) d:10(int) rowid:11(int!null) e:12(int) ├── key: (11) @@ -92,6 +93,7 @@ project ├── side-effects, mutations ├── key: () ├── fd: ()-->(1-5) + ├── prune: (1-4) └── select ├── columns: a:7(int!null) b:8(int) c:9(int!null) d:10(int) rowid:11(int!null) e:12(int) ├── cardinality: [0 - 1] @@ -125,6 +127,7 @@ project ├── side-effects, mutations ├── key: (5) ├── fd: (2)==(3), (3)==(2), (5)-->(1-4) + ├── prune: (1-4) └── select ├── columns: a:7(int!null) b:8(int!null) c:9(int!null) d:10(int) rowid:11(int!null) e:12(int) ├── key: (11) diff --git a/pkg/sql/opt/norm/custom_funcs.go b/pkg/sql/opt/norm/custom_funcs.go index 9e21703db674..d3ffa6c678d8 100644 --- a/pkg/sql/opt/norm/custom_funcs.go +++ b/pkg/sql/opt/norm/custom_funcs.go @@ -16,6 +16,7 @@ import ( "github.com/cockroachdb/apd" "github.com/cockroachdb/cockroach/pkg/sql/opt" + "github.com/cockroachdb/cockroach/pkg/sql/opt/cat" "github.com/cockroachdb/cockroach/pkg/sql/opt/constraint" "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" "github.com/cockroachdb/cockroach/pkg/sql/opt/props" @@ -350,6 +351,17 @@ func (c *CustomFuncs) sharedProps(e opt.Expr) *props.Shared { panic(errors.AssertionFailedf("no logical properties available for node: %v", e)) } +// MutationTable returns the table upon which the mutation is applied. +func (c *CustomFuncs) MutationTable(private *memo.MutationPrivate) opt.TableID { + return private.Table +} + +// PrimaryKeyCols returns the key columns of the primary key of the table. +func (c *CustomFuncs) PrimaryKeyCols(table opt.TableID) opt.ColSet { + tabMeta := c.mem.Metadata().TableMeta(table) + return tabMeta.IndexKeyColumns(cat.PrimaryIndex) +} + // ---------------------------------------------------------------------- // // Ordering functions diff --git a/pkg/sql/opt/norm/prune_cols.go b/pkg/sql/opt/norm/prune_cols.go index 0f0818e01bf3..0b6c61c9e173 100644 --- a/pkg/sql/opt/norm/prune_cols.go +++ b/pkg/sql/opt/norm/prune_cols.go @@ -73,9 +73,17 @@ func (c *CustomFuncs) NeededMutationCols(private *memo.MutationPrivate) opt.ColS func (c *CustomFuncs) NeededMutationFetchCols( op opt.Operator, private *memo.MutationPrivate, ) opt.ColSet { + return neededMutationFetchCols(c.mem, op, private) +} + +// neededMutationFetchCols returns the set of columns needed by the given +// mutation operator. +func neededMutationFetchCols( + mem *memo.Memo, op opt.Operator, private *memo.MutationPrivate, +) opt.ColSet { var cols opt.ColSet - tabMeta := c.mem.Metadata().TableMeta(private.Table) + tabMeta := mem.Metadata().TableMeta(private.Table) // familyCols returns the columns in the given family. familyCols := func(fam cat.Family) opt.ColSet { @@ -501,6 +509,17 @@ func DerivePruneCols(e memo.RelExpr) opt.ColSet { relProps.Rule.PruneCols.DifferenceWith(w.ScalarProps(e.Memo()).OuterCols) } + case opt.UpdateOp, opt.UpsertOp, opt.DeleteOp: + // Find the columns that would need to be fetched, if no returning + // clause were present. + withoutReturningPrivate := *e.Private().(*memo.MutationPrivate) + withoutReturningPrivate.ReturnCols = opt.ColList{} + neededCols := neededMutationFetchCols(e.Memo(), e.Op(), &withoutReturningPrivate) + + // Only the "free" RETURNING columns can be pruned away (i.e. the columns + // required by the mutation only because they're being returned). + relProps.Rule.PruneCols = relProps.OutputCols.Difference(neededCols) + default: // Don't allow any columns to be pruned, since that would trigger the // creation of a wrapper Project around an operator that does not have @@ -509,3 +528,43 @@ func DerivePruneCols(e memo.RelExpr) opt.ColSet { return relProps.Rule.PruneCols } + +// CanPruneMutationReturnCols checks whether the mutation's return columns can +// be pruned. This is the pre-condition for the PruneMutationReturnCols rule. +func (c *CustomFuncs) CanPruneMutationReturnCols( + private *memo.MutationPrivate, needed opt.ColSet, +) bool { + if private.ReturnCols == nil { + return false + } + + tabID := c.mem.Metadata().TableMeta(private.Table).MetaID + for i := range private.ReturnCols { + if private.ReturnCols[i] != 0 && !needed.Contains(tabID.ColumnID(i)) { + return true + } + } + + return false +} + +// PruneMutationReturnCols rewrites the given mutation private to no longer +// keep ReturnCols that are not referenced by the RETURNING clause or are not +// part of the primary key. The caller must have already done the analysis to +// prove that such columns exist, by calling CanPruneMutationReturnCols. +func (c *CustomFuncs) PruneMutationReturnCols( + private *memo.MutationPrivate, needed opt.ColSet, +) *memo.MutationPrivate { + newPrivate := *private + newReturnCols := make(opt.ColList, len(private.ReturnCols)) + tabID := c.mem.Metadata().TableMeta(private.Table).MetaID + + for i := range private.ReturnCols { + if needed.Contains(tabID.ColumnID(i)) { + newReturnCols[i] = private.ReturnCols[i] + } + } + + newPrivate.ReturnCols = newReturnCols + return &newPrivate +} diff --git a/pkg/sql/opt/norm/rules/prune_cols.opt b/pkg/sql/opt/norm/rules/prune_cols.opt index 9e3bd969f087..d010462d003c 100644 --- a/pkg/sql/opt/norm/rules/prune_cols.opt +++ b/pkg/sql/opt/norm/rules/prune_cols.opt @@ -462,3 +462,39 @@ $checks $mutationPrivate ) + +# PruneReturningCols removes columns from the mutation operator's ReturnCols +# set if they are not used in the RETURNING clause of the mutation. +# Removing ReturnCols will then allow the PruneMutationFetchCols to be more +# conservative with the fetch columns. +# TODO(ridwanmsharif): Mutations shouldn't need to return the primary key +# columns. Make appropriate changes to SQL execution to accommodate this. +[PruneMutationReturnCols, Normalize] +(Project + $input:(Insert | Update | Upsert | Delete + $innerInput:* + $checks:* + $mutationPrivate:* + ) + $projections:* + $passthrough:* & + (CanPruneMutationReturnCols + $mutationPrivate + $needed:(UnionCols3 + (PrimaryKeyCols (MutationTable $mutationPrivate)) + (ProjectionOuterCols $projections) + $passthrough + ) + ) +) +=> +(Project + ((OpName $input) + $innerInput + $checks + (PruneMutationReturnCols $mutationPrivate $needed) + ) + $projections + $passthrough +) + diff --git a/pkg/sql/opt/norm/testdata/rules/prune_cols b/pkg/sql/opt/norm/testdata/rules/prune_cols index 9c585e0c0052..d2c938f82151 100644 --- a/pkg/sql/opt/norm/testdata/rules/prune_cols +++ b/pkg/sql/opt/norm/testdata/rules/prune_cols @@ -1883,26 +1883,19 @@ delete mutation ├── key: (6) └── fd: (6)-->(7,9,10) -# No pruning when RETURNING clause is present. -# TODO(andyk): Need to prune output columns. -opt expect-not=(PruneMutationFetchCols,PruneMutationInputCols) +opt expect=(PruneMutationFetchCols,PruneMutationInputCols) DELETE FROM a RETURNING k, s ---- -project +delete a ├── columns: k:1(int!null) s:4(string) + ├── fetch columns: k:5(int) s:8(string) ├── side-effects, mutations ├── key: (1) ├── fd: (1)-->(4) - └── delete a - ├── columns: k:1(int!null) i:2(int) f:3(float) s:4(string) - ├── fetch columns: k:5(int) i:6(int) f:7(float) s:8(string) - ├── side-effects, mutations - ├── key: (1) - ├── fd: (1)-->(2-4) - └── scan a - ├── columns: k:5(int!null) i:6(int) f:7(float) s:8(string) - ├── key: (5) - └── fd: (5)-->(6-8) + └── scan a + ├── columns: k:5(int!null) s:8(string) + ├── key: (5) + └── fd: (5)-->(8) # Prune secondary family column not needed for the update. opt expect=(PruneMutationFetchCols,PruneMutationInputCols) @@ -1945,29 +1938,28 @@ update "family" └── a + 1 [type=int, outer=(6)] # Do not prune columns that must be returned. -# TODO(justin): in order to prune e here we need a PruneMutationReturnCols rule. -opt expect-not=PruneMutationFetchCols +opt expect=(PruneMutationFetchCols, PruneMutationReturnCols) UPDATE family SET c=c+1 RETURNING b ---- project ├── columns: b:2(int) ├── side-effects, mutations └── update "family" - ├── columns: a:1(int!null) b:2(int) c:3(int) d:4(int) e:5(int) - ├── fetch columns: a:6(int) b:7(int) c:8(int) d:9(int) e:10(int) + ├── columns: a:1(int!null) b:2(int) + ├── fetch columns: a:6(int) b:7(int) c:8(int) d:9(int) ├── update-mapping: │ └── column11:11 => c:3 ├── side-effects, mutations ├── key: (1) - ├── fd: (1)-->(2-5) + ├── fd: (1)-->(2) └── project - ├── columns: column11:11(int) a:6(int!null) b:7(int) c:8(int) d:9(int) e:10(int) + ├── columns: column11:11(int) a:6(int!null) b:7(int) c:8(int) d:9(int) ├── key: (6) - ├── fd: (6)-->(7-10), (8)-->(11) + ├── fd: (6)-->(7-9), (8)-->(11) ├── scan "family" - │ ├── columns: a:6(int!null) b:7(int) c:8(int) d:9(int) e:10(int) + │ ├── columns: a:6(int!null) b:7(int) c:8(int) d:9(int) │ ├── key: (6) - │ └── fd: (6)-->(7-10) + │ └── fd: (6)-->(7-9) └── projections └── c + 1 [type=int, outer=(8)] @@ -2115,9 +2107,9 @@ project ├── key: () ├── fd: ()-->(5) └── upsert "family" - ├── columns: a:1(int!null) b:2(int) c:3(int) d:4(int) e:5(int) + ├── columns: a:1(int!null) e:5(int) ├── canary column: 11 - ├── fetch columns: a:11(int) b:12(int) c:13(int) d:14(int) e:15(int) + ├── fetch columns: a:11(int) c:13(int) d:14(int) e:15(int) ├── insert-mapping: │ ├── column1:6 => a:1 │ ├── column2:7 => b:2 @@ -2128,24 +2120,21 @@ project │ └── upsert_c:19 => c:3 ├── return-mapping: │ ├── upsert_a:17 => a:1 - │ ├── upsert_b:18 => b:2 - │ ├── upsert_c:19 => c:3 - │ ├── upsert_d:20 => d:4 │ └── upsert_e:21 => e:5 ├── cardinality: [1 - 1] ├── side-effects, mutations ├── key: () - ├── fd: ()-->(1-5) + ├── fd: ()-->(1,5) └── project - ├── columns: upsert_a:17(int) upsert_b:18(int) upsert_c:19(int) upsert_d:20(int) upsert_e:21(int) column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) a:11(int) b:12(int) c:13(int) d:14(int) e:15(int) + ├── columns: upsert_a:17(int) upsert_c:19(int) upsert_e:21(int) column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) a:11(int) c:13(int) d:14(int) e:15(int) ├── cardinality: [1 - 1] ├── key: () - ├── fd: ()-->(6-15,17-21) + ├── fd: ()-->(6-11,13-15,17,19,21) ├── left-join (hash) - │ ├── columns: column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) a:11(int) b:12(int) c:13(int) d:14(int) e:15(int) + │ ├── columns: column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) a:11(int) c:13(int) d:14(int) e:15(int) │ ├── cardinality: [1 - 1] │ ├── key: () - │ ├── fd: ()-->(6-15) + │ ├── fd: ()-->(6-11,13-15) │ ├── values │ │ ├── columns: column1:6(int!null) column2:7(int!null) column3:8(int!null) column4:9(int!null) column5:10(int!null) │ │ ├── cardinality: [1 - 1] @@ -2153,20 +2142,17 @@ project │ │ ├── fd: ()-->(6-10) │ │ └── (1, 2, 3, 4, 5) [type=tuple{int, int, int, int, int}] │ ├── scan "family" - │ │ ├── columns: a:11(int!null) b:12(int) c:13(int) d:14(int) e:15(int) + │ │ ├── columns: a:11(int!null) c:13(int) d:14(int) e:15(int) │ │ ├── constraint: /11: [/1 - /1] │ │ ├── cardinality: [0 - 1] │ │ ├── key: () - │ │ └── fd: ()-->(11-15) + │ │ └── fd: ()-->(11,13-15) │ └── filters (true) └── projections ├── CASE WHEN a IS NULL THEN column1 ELSE a END [type=int, outer=(6,11)] - ├── CASE WHEN a IS NULL THEN column2 ELSE b END [type=int, outer=(7,11,12)] ├── CASE WHEN a IS NULL THEN column3 ELSE 10 END [type=int, outer=(8,11)] - ├── CASE WHEN a IS NULL THEN column4 ELSE d END [type=int, outer=(9,11,14)] └── CASE WHEN a IS NULL THEN column5 ELSE e END [type=int, outer=(10,11,15)] - # Do not prune column in same secondary family as updated column. But prune # non-key column in primary family. opt expect=(PruneMutationFetchCols,PruneMutationInputCols) @@ -2254,3 +2240,403 @@ upsert mutation │ └── filters (true) └── projections └── CASE WHEN a IS NULL THEN column2 ELSE 10 END [type=int, outer=(7,10)] + +# ------------------------------------------------------------------------------ +# PruneMutationReturnCols +# ------------------------------------------------------------------------------ + +# Create a table with multiple column families the mutations can take advantage of. +exec-ddl +CREATE TABLE returning_test ( + a INT, + b INT, + c STRING, + d INT, + e INT, + f INT, + g INT, + FAMILY (a), + FAMILY (b), + FAMILY (c), + FAMILY (d, e, f, g), + UNIQUE (a) +) +---- + +# Fetch all the columns for the RETURN expression. +opt +UPDATE returning_test SET a = a + 1 RETURNING * +---- +project + ├── columns: a:1(int) b:2(int) c:3(string) d:4(int) e:5(int) f:6(int) g:7(int) + ├── side-effects, mutations + └── update returning_test + ├── columns: a:1(int) b:2(int) c:3(string) d:4(int) e:5(int) f:6(int) g:7(int) rowid:8(int!null) + ├── fetch columns: a:9(int) b:10(int) c:11(string) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int) + ├── update-mapping: + │ └── column17:17 => a:1 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1-7) + └── project + ├── columns: column17:17(int) a:9(int) b:10(int) c:11(string) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9-15), (9)~~>(10-16), (9)-->(17) + ├── scan returning_test + │ ├── columns: a:9(int) b:10(int) c:11(string) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9-15), (9)~~>(10-16) + └── projections + └── a + 1 [type=int, outer=(9)] + + +# Fetch all the columns in the (d, e, f, g) family as d is being set. +opt +UPDATE returning_test SET d = a + d RETURNING a, d +---- +project + ├── columns: a:1(int) d:4(int) + ├── side-effects, mutations + ├── lax-key: (1,4) + ├── fd: (1)~~>(4) + └── update returning_test + ├── columns: a:1(int) d:4(int) rowid:8(int!null) + ├── fetch columns: a:9(int) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int) + ├── update-mapping: + │ └── column17:17 => d:4 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1,4), (1)~~>(4,8) + └── project + ├── columns: column17:17(int) a:9(int) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9,12-15), (9)~~>(12-16), (9,12)-->(17) + ├── scan returning_test + │ ├── columns: a:9(int) d:12(int) e:13(int) f:14(int) g:15(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9,12-15), (9)~~>(12-16) + └── projections + └── a + d [type=int, outer=(9,12)] + +# Fetch only whats being updated (not the (d, e, f, g) family). +opt +UPDATE returning_test SET a = a + d RETURNING a +---- +project + ├── columns: a:1(int) + ├── side-effects, mutations + └── update returning_test + ├── columns: a:1(int) rowid:8(int!null) + ├── fetch columns: a:9(int) rowid:16(int) + ├── update-mapping: + │ └── column17:17 => a:1 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1) + └── project + ├── columns: column17:17(int) a:9(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9,17), (9)~~>(16,17) + ├── scan returning_test + │ ├── columns: a:9(int) d:12(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9,12), (9)~~>(12,16) + └── projections + └── a + d [type=int, outer=(9,12)] + +# We only fetch the minimal set of columns which is (a, b, c, rowid). +opt +UPDATE returning_test SET (b, a) = (a, a + b) RETURNING a, b, c +---- +project + ├── columns: a:1(int) b:2(int) c:3(string) + ├── side-effects, mutations + ├── lax-key: (1-3) + ├── fd: (2)~~>(1,3) + └── update returning_test + ├── columns: a:1(int) b:2(int) c:3(string) rowid:8(int!null) + ├── fetch columns: a:9(int) b:10(int) c:11(string) rowid:16(int) + ├── update-mapping: + │ ├── column17:17 => a:1 + │ └── a:9 => b:2 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1-3), (2)~~>(1,3,8) + └── project + ├── columns: column17:17(int) a:9(int) b:10(int) c:11(string) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9-11), (9)~~>(10,11,16), (9,10)-->(17) + ├── scan returning_test + │ ├── columns: a:9(int) b:10(int) c:11(string) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9-11), (9)~~>(10,11,16) + └── projections + └── a + b [type=int, outer=(9,10)] + + +# We apply the PruneMutationReturnCols rule multiple times, to get +# the minimal set of columns which is (a, rowid). Notice how c and b +# are pruned away. +opt +SELECT a FROM [SELECT a, b FROM [UPDATE returning_test SET a = a + 1 RETURNING a, b, c]] +---- +project + ├── columns: a:1(int) + ├── side-effects, mutations + └── update returning_test + ├── columns: a:1(int) rowid:8(int!null) + ├── fetch columns: a:9(int) rowid:16(int) + ├── update-mapping: + │ └── column17:17 => a:1 + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1) + └── project + ├── columns: column17:17(int) a:9(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9), (9)~~>(16), (9)-->(17) + ├── scan returning_test@secondary + │ ├── columns: a:9(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9), (9)~~>(16) + └── projections + └── a + 1 [type=int, outer=(9)] + +# We derive the prune cols for the mutation appropriately so we +# can prune away columns even when the mutation is not under a +# projection. Another rule will fire to add the appropriate +# projection when this happens. +opt +SELECT a FROM [SELECT a, b FROM [UPDATE returning_test SET a = a + 1 RETURNING a, b, c] WHERE a > 1] +---- +project + ├── columns: a:1(int!null) + ├── side-effects, mutations + └── select + ├── columns: a:1(int!null) rowid:8(int!null) + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1) + ├── update returning_test + │ ├── columns: a:1(int) rowid:8(int!null) + │ ├── fetch columns: a:9(int) rowid:16(int) + │ ├── update-mapping: + │ │ └── column17:17 => a:1 + │ ├── side-effects, mutations + │ ├── key: (8) + │ ├── fd: (8)-->(1) + │ └── project + │ ├── columns: column17:17(int) a:9(int) rowid:16(int!null) + │ ├── key: (16) + │ ├── fd: (16)-->(9), (9)~~>(16), (9)-->(17) + │ ├── scan returning_test@secondary + │ │ ├── columns: a:9(int) rowid:16(int!null) + │ │ ├── key: (16) + │ │ └── fd: (16)-->(9), (9)~~>(16) + │ └── projections + │ └── a + 1 [type=int, outer=(9)] + └── filters + └── a > 1 [type=bool, outer=(1), constraints=(/1: [/2 - ]; tight)] + +opt +SELECT + * +FROM + [SELECT a, b FROM returning_test] AS x + JOIN [SELECT a, b FROM [UPDATE returning_test SET a = a + 1 RETURNING a, b, c] WHERE a > 1] + AS y ON true +---- +project + ├── columns: a:1(int) b:2(int) a:9(int!null) b:10(int) + ├── side-effects, mutations + ├── fd: (1)~~>(2) + └── inner-join (hash) + ├── columns: x.a:1(int) x.b:2(int) returning_test.a:9(int!null) returning_test.b:10(int) returning_test.rowid:16(int!null) + ├── side-effects, mutations + ├── lax-key: (1,2,16) + ├── fd: (1)~~>(2), (16)-->(9,10) + ├── scan x + │ ├── columns: x.a:1(int) x.b:2(int) + │ ├── lax-key: (1,2) + │ └── fd: (1)~~>(2) + ├── select + │ ├── columns: returning_test.a:9(int!null) returning_test.b:10(int) returning_test.rowid:16(int!null) + │ ├── side-effects, mutations + │ ├── key: (16) + │ ├── fd: (16)-->(9,10) + │ ├── update returning_test + │ │ ├── columns: returning_test.a:9(int) returning_test.b:10(int) returning_test.rowid:16(int!null) + │ │ ├── fetch columns: returning_test.a:17(int) returning_test.b:18(int) returning_test.rowid:24(int) + │ │ ├── update-mapping: + │ │ │ └── column25:25 => returning_test.a:9 + │ │ ├── side-effects, mutations + │ │ ├── key: (16) + │ │ ├── fd: (16)-->(9,10) + │ │ └── project + │ │ ├── columns: column25:25(int) returning_test.a:17(int) returning_test.b:18(int) returning_test.rowid:24(int!null) + │ │ ├── key: (24) + │ │ ├── fd: (24)-->(17,18), (17)~~>(18,24), (17)-->(25) + │ │ ├── scan returning_test + │ │ │ ├── columns: returning_test.a:17(int) returning_test.b:18(int) returning_test.rowid:24(int!null) + │ │ │ ├── key: (24) + │ │ │ └── fd: (24)-->(17,18), (17)~~>(18,24) + │ │ └── projections + │ │ └── returning_test.a + 1 [type=int, outer=(17)] + │ └── filters + │ └── returning_test.a > 1 [type=bool, outer=(9), constraints=(/9: [/2 - ]; tight)] + └── filters (true) + +# Check if the rule works as desired for other mutations. +opt +INSERT INTO returning_test VALUES (1, 2, 'c') ON CONFLICT (a) DO UPDATE SET a = excluded.a + returning_test.a RETURNING a, b, c +---- +project + ├── columns: a:1(int) b:2(int) c:3(string) + ├── cardinality: [1 - 1] + ├── side-effects, mutations + ├── key: () + ├── fd: ()-->(1-3) + └── upsert returning_test + ├── columns: a:1(int) b:2(int) c:3(string) rowid:8(int!null) + ├── canary column: 21 + ├── fetch columns: a:14(int) b:15(int) c:16(string) rowid:21(int) + ├── insert-mapping: + │ ├── column1:9 => a:1 + │ ├── column2:10 => b:2 + │ ├── column3:11 => c:3 + │ ├── column12:12 => d:4 + │ ├── column12:12 => e:5 + │ ├── column12:12 => f:6 + │ ├── column12:12 => g:7 + │ └── column13:13 => rowid:8 + ├── update-mapping: + │ └── upsert_a:23 => a:1 + ├── return-mapping: + │ ├── upsert_a:23 => a:1 + │ ├── upsert_b:24 => b:2 + │ ├── upsert_c:25 => c:3 + │ └── upsert_rowid:30 => rowid:8 + ├── cardinality: [1 - 1] + ├── side-effects, mutations + ├── key: () + ├── fd: ()-->(1-3,8) + └── project + ├── columns: upsert_a:23(int) upsert_b:24(int) upsert_c:25(string) upsert_rowid:30(int) column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) a:14(int) b:15(int) c:16(string) rowid:21(int) + ├── cardinality: [1 - 1] + ├── side-effects + ├── key: () + ├── fd: ()-->(9-16,21,23-25,30) + ├── left-join (hash) + │ ├── columns: column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) a:14(int) b:15(int) c:16(string) rowid:21(int) + │ ├── cardinality: [1 - 1] + │ ├── side-effects + │ ├── key: () + │ ├── fd: ()-->(9-16,21) + │ ├── values + │ │ ├── columns: column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) + │ │ ├── cardinality: [1 - 1] + │ │ ├── side-effects + │ │ ├── key: () + │ │ ├── fd: ()-->(9-13) + │ │ └── (1, 2, 'c', CAST(NULL AS INT8), unique_rowid()) [type=tuple{int, int, string, int, int}] + │ ├── index-join returning_test + │ │ ├── columns: a:14(int!null) b:15(int) c:16(string) rowid:21(int!null) + │ │ ├── cardinality: [0 - 1] + │ │ ├── key: () + │ │ ├── fd: ()-->(14-16,21) + │ │ └── scan returning_test@secondary + │ │ ├── columns: a:14(int!null) rowid:21(int!null) + │ │ ├── constraint: /14: [/1 - /1] + │ │ ├── cardinality: [0 - 1] + │ │ ├── key: () + │ │ └── fd: ()-->(14,21) + │ └── filters (true) + └── projections + ├── CASE WHEN rowid IS NULL THEN column1 ELSE column1 + a END [type=int, outer=(9,14,21)] + ├── CASE WHEN rowid IS NULL THEN column2 ELSE b END [type=int, outer=(10,15,21)] + ├── CASE WHEN rowid IS NULL THEN column3 ELSE c END [type=string, outer=(11,16,21)] + └── CASE WHEN rowid IS NULL THEN column13 ELSE rowid END [type=int, outer=(13,21)] + +opt +DELETE FROM returning_test WHERE a < b + d RETURNING a, b, d +---- +project + ├── columns: a:1(int!null) b:2(int) d:4(int) + ├── side-effects, mutations + ├── key: (1) + ├── fd: (1)-->(2,4) + └── delete returning_test + ├── columns: a:1(int!null) b:2(int) d:4(int) rowid:8(int!null) + ├── fetch columns: a:9(int) b:10(int) d:12(int) rowid:16(int) + ├── side-effects, mutations + ├── key: (8) + ├── fd: (8)-->(1,2,4), (1)-->(2,4,8) + └── select + ├── columns: a:9(int!null) b:10(int) d:12(int) rowid:16(int!null) + ├── key: (16) + ├── fd: (16)-->(9,10,12), (9)-->(10,12,16) + ├── scan returning_test + │ ├── columns: a:9(int) b:10(int) d:12(int) rowid:16(int!null) + │ ├── key: (16) + │ └── fd: (16)-->(9,10,12), (9)~~>(10,12,16) + └── filters + └── a < (b + d) [type=bool, outer=(9,10,12), constraints=(/9: (/NULL - ])] + +opt +UPSERT INTO returning_test (a, b, c) VALUES (1, 2, 'c') RETURNING a, b, c, d +---- +project + ├── columns: a:1(int!null) b:2(int!null) c:3(string!null) d:4(int) + ├── cardinality: [1 - ] + ├── side-effects, mutations + ├── fd: ()-->(1-3) + └── upsert returning_test + ├── columns: a:1(int!null) b:2(int!null) c:3(string!null) d:4(int) rowid:8(int!null) + ├── canary column: 21 + ├── fetch columns: a:14(int) b:15(int) c:16(string) d:17(int) rowid:21(int) + ├── insert-mapping: + │ ├── column1:9 => a:1 + │ ├── column2:10 => b:2 + │ ├── column3:11 => c:3 + │ ├── column12:12 => d:4 + │ ├── column12:12 => e:5 + │ ├── column12:12 => f:6 + │ ├── column12:12 => g:7 + │ └── column13:13 => rowid:8 + ├── update-mapping: + │ ├── column1:9 => a:1 + │ ├── column2:10 => b:2 + │ └── column3:11 => c:3 + ├── return-mapping: + │ ├── column1:9 => a:1 + │ ├── column2:10 => b:2 + │ ├── column3:11 => c:3 + │ ├── upsert_d:22 => d:4 + │ └── upsert_rowid:26 => rowid:8 + ├── cardinality: [1 - ] + ├── side-effects, mutations + ├── fd: ()-->(1-3) + └── project + ├── columns: upsert_d:22(int) upsert_rowid:26(int) column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) a:14(int) b:15(int) c:16(string) d:17(int) rowid:21(int) + ├── cardinality: [1 - ] + ├── side-effects + ├── key: (21) + ├── fd: ()-->(9-13), (21)-->(14-17), (14)~~>(15-17,21), (17,21)-->(22), (21)-->(26) + ├── left-join (lookup returning_test) + │ ├── columns: column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) a:14(int) b:15(int) c:16(string) d:17(int) rowid:21(int) + │ ├── key columns: [13] = [21] + │ ├── cardinality: [1 - ] + │ ├── side-effects + │ ├── key: (21) + │ ├── fd: ()-->(9-13), (21)-->(14-17), (14)~~>(15-17,21) + │ ├── values + │ │ ├── columns: column1:9(int!null) column2:10(int!null) column3:11(string!null) column12:12(int) column13:13(int) + │ │ ├── cardinality: [1 - 1] + │ │ ├── side-effects + │ │ ├── key: () + │ │ ├── fd: ()-->(9-13) + │ │ └── (1, 2, 'c', CAST(NULL AS INT8), unique_rowid()) [type=tuple{int, int, string, int, int}] + │ └── filters (true) + └── projections + ├── CASE WHEN rowid IS NULL THEN column12 ELSE d END [type=int, outer=(12,17,21)] + └── CASE WHEN rowid IS NULL THEN column13 ELSE rowid END [type=int, outer=(13,21)] diff --git a/pkg/sql/opt/xform/testdata/rules/join b/pkg/sql/opt/xform/testdata/rules/join index 57a27fa7e5a5..72bf8c44af5a 100644 --- a/pkg/sql/opt/xform/testdata/rules/join +++ b/pkg/sql/opt/xform/testdata/rules/join @@ -2116,24 +2116,21 @@ project ├── side-effects, mutations ├── fd: ()-->(21) ├── inner-join (hash) - │ ├── columns: abc.a:5(int!null) abc.b:6(int) abc.c:7(int) abc.rowid:8(int!null) + │ ├── columns: abc.rowid:8(int!null) │ ├── cardinality: [0 - 0] │ ├── side-effects, mutations - │ ├── fd: ()-->(5-7) │ ├── select - │ │ ├── columns: abc.a:5(int!null) abc.b:6(int) abc.c:7(int) abc.rowid:8(int!null) + │ │ ├── columns: abc.rowid:8(int!null) │ │ ├── cardinality: [0 - 0] │ │ ├── side-effects, mutations - │ │ ├── fd: ()-->(5-7) │ │ ├── insert abc - │ │ │ ├── columns: abc.a:5(int!null) abc.b:6(int) abc.c:7(int) abc.rowid:8(int!null) + │ │ │ ├── columns: abc.rowid:8(int!null) │ │ │ ├── insert-mapping: │ │ │ │ ├── "?column?":13 => abc.a:5 │ │ │ │ ├── column14:14 => abc.b:6 │ │ │ │ ├── column14:14 => abc.c:7 │ │ │ │ └── column15:15 => abc.rowid:8 │ │ │ ├── side-effects, mutations - │ │ │ ├── fd: ()-->(5-7) │ │ │ └── project │ │ │ ├── columns: column14:14(int) column15:15(int) "?column?":13(int!null) │ │ │ ├── side-effects diff --git a/pkg/sql/opt_exec_factory.go b/pkg/sql/opt_exec_factory.go index f7ace19fdd9f..94315ead1b31 100644 --- a/pkg/sql/opt_exec_factory.go +++ b/pkg/sql/opt_exec_factory.go @@ -1154,20 +1154,46 @@ func (ef *execFactory) ConstructShowTrace(typ tree.ShowTraceType, compact bool) return node, nil } +// mutationRowIdxToReturnIdx returns the mapping from the origColDescs to the +// returnColDescs (where returnColDescs is a subset of the origColDescs). +// -1 is used for columns not part of the returnColDescs. +// It is the responsibility of the caller to ensure a mapping is possible. +func mutationRowIdxToReturnIdx(origColDescs, returnColDescs []sqlbase.ColumnDescriptor) []int { + // Create a ColumnID to index map. + colIDToRetIndex := row.ColIDtoRowIndexFromCols(origColDescs) + + // Initialize the rowIdxToTabColIdx array. + rowIdxToRetIdx := make([]int, len(origColDescs)) + for i := range rowIdxToRetIdx { + // -1 value indicates that this column is not being returned. + rowIdxToRetIdx[i] = -1 + } + + // Set the appropriate index values for the returning columns. + for i := range returnColDescs { + if idx, ok := colIDToRetIndex[returnColDescs[i].ID]; ok { + rowIdxToRetIdx[idx] = i + } + } + + return rowIdxToRetIdx +} + func (ef *execFactory) ConstructInsert( input exec.Node, table cat.Table, - insertCols exec.ColumnOrdinalSet, - checks exec.CheckOrdinalSet, - rowsNeeded bool, + insertColOrdSet exec.ColumnOrdinalSet, + returnColOrdSet exec.ColumnOrdinalSet, + checkOrdSet exec.CheckOrdinalSet, skipFKChecks bool, ) (exec.Node, error) { // Derive insert table and column descriptors. + rowsNeeded := !returnColOrdSet.Empty() tabDesc := table.(*optTable).desc - colDescs := makeColDescList(table, insertCols) + colDescs := makeColDescList(table, insertColOrdSet) // Construct the check helper if there are any check constraints. - checkHelper := sqlbase.NewInputCheckHelper(checks, tabDesc) + checkHelper := sqlbase.NewInputCheckHelper(checkOrdSet, tabDesc) // Determine the foreign key tables involved in the update. fkTables, err := ef.makeFkMetadata(tabDesc, row.CheckInserts, checkHelper) @@ -1189,10 +1215,25 @@ func (ef *execFactory) ConstructInsert( // Determine the relational type of the generated insert node. // If rows are not needed, no columns are returned. var returnCols sqlbase.ResultColumns + var tabColIdxToRetIdx []int if rowsNeeded { - // Insert always returns all non-mutation columns, in the same order they - // are defined in the table. - returnCols = sqlbase.ResultColumnsFromColDescs(tabDesc.Columns) + returnColDescs := makeColDescList(table, returnColOrdSet) + + // Only return the columns that are part of the table descriptor. + // This is important when columns are added and being back-filled + // as part of the same transaction when the delete runs. + // In such cases, the newly added columns shouldn't be returned. + // See regression logic tests for #29494. + if len(tabDesc.Columns) < len(returnColDescs) { + returnColDescs = returnColDescs[:len(tabDesc.Columns)] + } + + returnCols = sqlbase.ResultColumnsFromColDescs(returnColDescs) + + // Update the tabColIdxToRetIdx for the mutation. Insert always + // returns non-mutation columns in the same order they are defined in + // the table. + tabColIdxToRetIdx = mutationRowIdxToReturnIdx(tabDesc.Columns, returnColDescs) } // Regular path for INSERT. @@ -1208,7 +1249,8 @@ func (ef *execFactory) ConstructInsert( Cols: tabDesc.Columns, Mapping: ri.InsertColIDtoRowIndex, }, - insertCols: ri.InsertCols, + insertCols: ri.InsertCols, + tabColIdxToRetIdx: tabColIdxToRetIdx, }, } @@ -1227,19 +1269,20 @@ func (ef *execFactory) ConstructInsert( func (ef *execFactory) ConstructUpdate( input exec.Node, table cat.Table, - fetchCols exec.ColumnOrdinalSet, - updateCols exec.ColumnOrdinalSet, + fetchColOrdSet exec.ColumnOrdinalSet, + updateColOrdSet exec.ColumnOrdinalSet, + returnColOrdSet exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, ) (exec.Node, error) { // Derive table and column descriptors. + rowsNeeded := !returnColOrdSet.Empty() tabDesc := table.(*optTable).desc - fetchColDescs := makeColDescList(table, fetchCols) + fetchColDescs := makeColDescList(table, fetchColOrdSet) // Add each column to update as a sourceSlot. The CBO only uses scalarSlot, // since it compiles tuples and subqueries into a simple sequence of target // columns. - updateColDescs := makeColDescList(table, updateCols) + updateColDescs := makeColDescList(table, updateColOrdSet) sourceSlots := make([]sourceSlot, len(updateColDescs)) for i := range sourceSlots { sourceSlots[i] = scalarSlot{column: updateColDescs[i], sourceIndex: len(fetchColDescs) + i} @@ -1287,10 +1330,30 @@ func (ef *execFactory) ConstructUpdate( // Determine the relational type of the generated update node. // If rows are not needed, no columns are returned. var returnCols sqlbase.ResultColumns + var rowIdxToRetIdx []int if rowsNeeded { - // Update always returns all non-mutation columns, in the same order they - // are defined in the table. - returnCols = sqlbase.ResultColumnsFromColDescs(tabDesc.Columns) + returnColDescs := makeColDescList(table, returnColOrdSet) + + // Only return the columns that are part of the table descriptor. + // This is important when columns are added and being back-filled + // as part of the same transaction when the update runs. + // In such cases, the newly added columns shouldn't be returned. + // See regression logic tests for #29494. + if len(tabDesc.Columns) < len(returnColDescs) { + returnColDescs = returnColDescs[:len(tabDesc.Columns)] + } + + returnCols = sqlbase.ResultColumnsFromColDescs(returnColDescs) + + // Update the rowIdxToRetIdx for the mutation. Update returns + // the non-mutation columns specified, in the same order they are + // defined in the table. + // + // The Updater derives/stores the fetch columns of the mutation and + // since the return columns are always a subset of the fetch columns, + // we can use use the fetch columns to generate the mapping for the + // returned rows. + rowIdxToRetIdx = mutationRowIdxToReturnIdx(ru.FetchCols, returnColDescs) } // updateColsIdx inverts the mapping of UpdateCols to FetchCols. See @@ -1314,9 +1377,10 @@ func (ef *execFactory) ConstructUpdate( Cols: ru.FetchCols, Mapping: ru.FetchColIDtoRowIndex, }, - sourceSlots: sourceSlots, - updateValues: make(tree.Datums, len(ru.UpdateCols)), - updateColsIdx: updateColsIdx, + sourceSlots: sourceSlots, + updateValues: make(tree.Datums, len(ru.UpdateCols)), + updateColsIdx: updateColsIdx, + rowIdxToRetIdx: rowIdxToRetIdx, }, } @@ -1353,17 +1417,18 @@ func (ef *execFactory) ConstructUpsert( input exec.Node, table cat.Table, canaryCol exec.ColumnOrdinal, - insertCols exec.ColumnOrdinalSet, - fetchCols exec.ColumnOrdinalSet, - updateCols exec.ColumnOrdinalSet, + insertColOrdSet exec.ColumnOrdinalSet, + fetchColOrdSet exec.ColumnOrdinalSet, + updateColOrdSet exec.ColumnOrdinalSet, + returnColOrdSet exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, - rowsNeeded bool, ) (exec.Node, error) { // Derive table and column descriptors. + rowsNeeded := !returnColOrdSet.Empty() tabDesc := table.(*optTable).desc - insertColDescs := makeColDescList(table, insertCols) - fetchColDescs := makeColDescList(table, fetchCols) - updateColDescs := makeColDescList(table, updateCols) + insertColDescs := makeColDescList(table, insertColOrdSet) + fetchColDescs := makeColDescList(table, fetchColOrdSet) + updateColDescs := makeColDescList(table, updateColOrdSet) // Construct the check helper if there are any check constraints. checkHelper := sqlbase.NewInputCheckHelper(checks, tabDesc) @@ -1407,10 +1472,26 @@ func (ef *execFactory) ConstructUpsert( // Determine the relational type of the generated upsert node. // If rows are not needed, no columns are returned. var returnCols sqlbase.ResultColumns + var returnColDescs []sqlbase.ColumnDescriptor + var tabColIdxToRetIdx []int if rowsNeeded { - // Upsert always returns all non-mutation columns, in the same order they - // are defined in the table. - returnCols = sqlbase.ResultColumnsFromColDescs(tabDesc.Columns) + returnColDescs = makeColDescList(table, returnColOrdSet) + + // Only return the columns that are part of the table descriptor. + // This is important when columns are added and being back-filled + // as part of the same transaction when the delete runs. + // In such cases, the newly added columns shouldn't be returned. + // See regression logic tests for #29494. + if len(tabDesc.Columns) < len(returnColDescs) { + returnColDescs = returnColDescs[:len(tabDesc.Columns)] + } + + returnCols = sqlbase.ResultColumnsFromColDescs(returnColDescs) + + // Update the tabColIdxToRetIdx for the mutation. Upsert returns + // non-mutation columns specified, in the same order they are defined + // in the table. + tabColIdxToRetIdx = mutationRowIdxToReturnIdx(tabDesc.Columns, returnColDescs) } // updateColsIdx inverts the mapping of UpdateCols to FetchCols. See @@ -1439,11 +1520,13 @@ func (ef *execFactory) ConstructUpsert( alloc: &ef.planner.alloc, collectRows: rowsNeeded, }, - canaryOrdinal: int(canaryCol), - fkTables: fkTables, - fetchCols: fetchColDescs, - updateCols: updateColDescs, - ru: ru, + canaryOrdinal: int(canaryCol), + fkTables: fkTables, + fetchCols: fetchColDescs, + updateCols: updateColDescs, + returnCols: returnColDescs, + ru: ru, + tabColIdxToRetIdx: tabColIdxToRetIdx, }, }, } @@ -1460,12 +1543,65 @@ func (ef *execFactory) ConstructUpsert( return &rowCountNode{source: ups}, nil } +// colsRequiredForDelete returns all the columns required to perform a delete +// of a row on the table. This will include the returnColDescs columns that +// are referenced in the RETURNING clause of the delete mutation. This +// is different from the fetch columns of the delete mutation as the +// fetch columns includes more columns. Specifically, the fetch columns also +// include columns that are not part of index keys or the RETURNING columns +// (columns, for example, referenced in the WHERE clause). +func colsRequiredForDelete( + table cat.Table, tableColDescs, returnColDescs []sqlbase.ColumnDescriptor, +) []sqlbase.ColumnDescriptor { + // Find all the columns that are part of the rows returned by the delete. + deleteDescs := make([]sqlbase.ColumnDescriptor, 0, len(tableColDescs)) + var deleteCols util.FastIntSet + for i := 0; i < table.IndexCount(); i++ { + index := table.Index(i) + for j := 0; j < index.KeyColumnCount(); j++ { + col := *index.Column(j).Column.(*sqlbase.ColumnDescriptor) + if deleteCols.Contains(int(col.ID)) { + continue + } + + deleteDescs = append(deleteDescs, col) + deleteCols.Add(int(col.ID)) + } + } + + // Add columns specified in the RETURNING clause. + for _, col := range returnColDescs { + if deleteCols.Contains(int(col.ID)) { + continue + } + + deleteDescs = append(deleteDescs, col) + deleteCols.Add(int(col.ID)) + } + + // The order of the columns processed by the delete must be in the order they + // are present in the table. + tabDescs := make([]sqlbase.ColumnDescriptor, 0, len(deleteDescs)) + for i := 0; i < len(tableColDescs); i++ { + col := tableColDescs[i] + if deleteCols.Contains(int(col.ID)) { + tabDescs = append(tabDescs, col) + } + } + + return tabDescs +} + func (ef *execFactory) ConstructDelete( - input exec.Node, table cat.Table, fetchCols exec.ColumnOrdinalSet, rowsNeeded bool, + input exec.Node, + table cat.Table, + fetchColOrdSet exec.ColumnOrdinalSet, + returnColOrdSet exec.ColumnOrdinalSet, ) (exec.Node, error) { // Derive table and column descriptors. + rowsNeeded := !returnColOrdSet.Empty() tabDesc := table.(*optTable).desc - fetchColDescs := makeColDescList(table, fetchCols) + fetchColDescs := makeColDescList(table, fetchColOrdSet) // Determine the foreign key tables involved in the update. fkTables, err := ef.makeFkMetadata(tabDesc, row.CheckDeletes, nil /* checkHelper */) @@ -1503,10 +1639,29 @@ func (ef *execFactory) ConstructDelete( // Determine the relational type of the generated delete node. // If rows are not needed, no columns are returned. var returnCols sqlbase.ResultColumns + var rowIdxToRetIdx []int if rowsNeeded { - // Delete always returns all non-mutation columns, in the same order they - // are defined in the table. - returnCols = sqlbase.ResultColumnsFromColDescs(tabDesc.Columns) + returnColDescs := makeColDescList(table, returnColOrdSet) + + // Only return the columns that are part of the table descriptor. + // This is important when columns are added and being back-filled + // as part of the same transaction when the delete runs. + // In such cases, the newly added columns shouldn't be returned. + // See regression logic tests for #29494. + if len(tabDesc.Columns) < len(returnColDescs) { + returnColDescs = returnColDescs[:len(tabDesc.Columns)] + } + + // Delete returns the non-mutation columns specified, in the same + // order they are defined in the table. + returnCols = sqlbase.ResultColumnsFromColDescs(returnColDescs) + + // Find all the columns that the deleteNode receives. The returning + // columns of the mutation are a subset of this column set. + requiredDeleteColumns := colsRequiredForDelete(table, tabDesc.Columns, returnColDescs) + + // Update the rowIdxToReturnIdx for the mutation. + rowIdxToRetIdx = mutationRowIdxToReturnIdx(requiredDeleteColumns, returnColDescs) } // Now make a delete node. We use a pool. @@ -1515,8 +1670,9 @@ func (ef *execFactory) ConstructDelete( source: input.(planNode), columns: returnCols, run: deleteRun{ - td: tableDeleter{rd: rd, alloc: &ef.planner.alloc}, - rowsNeeded: rowsNeeded, + td: tableDeleter{rd: rd, alloc: &ef.planner.alloc}, + rowsNeeded: rowsNeeded, + rowIdxToRetIdx: rowIdxToRetIdx, }, } diff --git a/pkg/sql/rowcontainer/datum_row_container.go b/pkg/sql/rowcontainer/datum_row_container.go index 6460e58975c9..9944c36f42da 100644 --- a/pkg/sql/rowcontainer/datum_row_container.go +++ b/pkg/sql/rowcontainer/datum_row_container.go @@ -244,6 +244,11 @@ func (c *RowContainer) Len() int { return c.numRows } +// NumCols reports the number of columns for each row in the container. +func (c *RowContainer) NumCols() int { + return c.numCols +} + // At accesses a row at a specific index. func (c *RowContainer) At(i int) tree.Datums { // This is a hot-path: do not add additional checks here. diff --git a/pkg/sql/tablewriter_upsert_opt.go b/pkg/sql/tablewriter_upsert_opt.go index 64121acbfd3c..bc04821bc1f4 100644 --- a/pkg/sql/tablewriter_upsert_opt.go +++ b/pkg/sql/tablewriter_upsert_opt.go @@ -15,9 +15,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/internal/client" "github.com/cockroachdb/cockroach/pkg/sql/row" + "github.com/cockroachdb/cockroach/pkg/sql/rowcontainer" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" - "github.com/cockroachdb/errors" ) // optTableUpserter implements the upsert operation when it is planned by the @@ -53,6 +53,9 @@ type optTableUpserter struct { // updateCols indicate which columns need an update during a conflict. updateCols []sqlbase.ColumnDescriptor + // returnCols indicate which columns need to be returned by the Upsert. + returnCols []sqlbase.ColumnDescriptor + // canaryOrdinal is the ordinal position of the column within the input row // that is used to decide whether to execute an insert or update operation. // If the canary column is null, then an insert will be performed; otherwise, @@ -67,6 +70,14 @@ type optTableUpserter struct { // ru is used when updating rows. ru row.Updater + + // tabColIdxToRetIdx is the mapping from the columns in the table to the + // columns in the resultRowBuffer. A value of -1 is used to indicate + // that the table column at that index is not part of the resultRowBuffer + // of the mutation. Otherwise, the value at the i-th index refers to the + // index of the resultRowBuffer where the i-th column of the table is + // to be returned. + tabColIdxToRetIdx []int } // init is part of the tableWriter interface. @@ -77,7 +88,12 @@ func (tu *optTableUpserter) init(txn *client.Txn, evalCtx *tree.EvalContext) err } if tu.collectRows { - tu.resultRow = make(tree.Datums, len(tu.colIDToReturnIndex)) + tu.resultRow = make(tree.Datums, len(tu.returnCols)) + tu.rowsUpserted = rowcontainer.NewRowContainer( + evalCtx.Mon.MakeBoundAccount(), + sqlbase.ColTypeInfoFromColDescs(tu.returnCols), + tu.insertRows.Len(), + ) } tu.ru, err = row.MakeUpdater( @@ -161,12 +177,27 @@ func (tu *optTableUpserter) insertNonConflictingRow( // Reshape the row if needed. if tu.insertReorderingRequired { - resultRow := tu.makeResultFromRow(insertRow, tu.ri.InsertColIDtoRowIndex) - _, err := tu.rowsUpserted.AddRow(ctx, resultRow) + tableRow := tu.makeResultFromRow(insertRow, tu.ri.InsertColIDtoRowIndex) + + // TODO(ridwanmsharif): Why didn't they update the value of tu.resultRow + // before? Is it safe to be doing it now? + // Map the upserted columns into the result row before adding it. + for tabIdx := range tableRow { + if retIdx := tu.tabColIdxToRetIdx[tabIdx]; retIdx >= 0 { + tu.resultRow[retIdx] = tableRow[tabIdx] + } + } + _, err := tu.rowsUpserted.AddRow(ctx, tu.resultRow) return err } - _, err := tu.rowsUpserted.AddRow(ctx, insertRow) + // Map the upserted columns into the result row before adding it. + for tabIdx := range insertRow { + if retIdx := tu.tabColIdxToRetIdx[tabIdx]; retIdx >= 0 { + tu.resultRow[retIdx] = insertRow[tabIdx] + } + } + _, err := tu.rowsUpserted.AddRow(ctx, tu.resultRow) return err } @@ -208,22 +239,30 @@ func (tu *optTableUpserter) updateConflictingRow( return nil } - // We now need a row that has the shape of the result row. + // We now need a row that has the shape of the result row with + // the appropriate return columns. Make sure all the fetch columns + // are present. + tableRow := tu.makeResultFromRow(fetchRow, tu.ru.FetchColIDtoRowIndex) + + // Make sure all the updated columns are present. for colID, returnIndex := range tu.colIDToReturnIndex { // If an update value for a given column exists, use that; else use the - // existing value of that column. + // existing value of that column if it has been fetched. rowIndex, ok := tu.ru.UpdateColIDtoRowIndex[colID] if ok { - tu.resultRow[returnIndex] = updateValues[rowIndex] - } else { - rowIndex, ok = tu.ru.FetchColIDtoRowIndex[colID] - if !ok { - return errors.AssertionFailedf("no existing value is available for column") - } - tu.resultRow[returnIndex] = fetchRow[rowIndex] + tableRow[returnIndex] = updateValues[rowIndex] + } + } + + // Map the upserted columns into the result row before adding it. + for tabIdx := range tableRow { + if retIdx := tu.tabColIdxToRetIdx[tabIdx]; retIdx >= 0 { + tu.resultRow[retIdx] = tableRow[tabIdx] } } + // The resulting row may have nil values for columns that aren't + // being upserted, updated or fetched. _, err = tu.rowsUpserted.AddRow(ctx, tu.resultRow) return err } diff --git a/pkg/sql/update.go b/pkg/sql/update.go index 778476c9c0f8..1f9a67705565 100644 --- a/pkg/sql/update.go +++ b/pkg/sql/update.go @@ -382,6 +382,12 @@ func (p *planner) Update( updateColsIdx[id] = i } + // Since all columns are being returned, use the 1:1 mapping. + rowIdxToRetIdx := make([]int, len(desc.Columns)) + for i := range rowIdxToRetIdx { + rowIdxToRetIdx[i] = i + } + un := updateNodePool.Get().(*updateNode) *un = updateNode{ source: rows, @@ -397,9 +403,10 @@ func (p *planner) Update( Cols: desc.Columns, Mapping: ru.FetchColIDtoRowIndex, }, - sourceSlots: sourceSlots, - updateValues: make(tree.Datums, len(ru.UpdateCols)), - updateColsIdx: updateColsIdx, + sourceSlots: sourceSlots, + updateValues: make(tree.Datums, len(ru.UpdateCols)), + updateColsIdx: updateColsIdx, + rowIdxToRetIdx: rowIdxToRetIdx, }, } @@ -480,6 +487,13 @@ type updateRun struct { // This provides the inverse mapping of sourceSlots. // updateColsIdx map[sqlbase.ColumnID]int + + // rowIdxToRetIdx is the mapping from the columns in ru.FetchCols to the + // columns in the resultRowBuffer. A value of -1 is used to indicate + // that the column at that index is not part of the resultRowBuffer + // of the mutation. Otherwise, the value at the i-th index refers to the + // index of the resultRowBuffer where the i-th column is to be returned. + rowIdxToRetIdx []int } // maxUpdateBatchSize is the max number of entries in the KV batch for @@ -701,7 +715,14 @@ func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums) // // MakeUpdater guarantees that the first columns of the new values // are those specified u.columns. - resultValues := newValues[:len(u.columns)] + resultValues := make([]tree.Datum, len(u.columns)) + for i := range u.run.rowIdxToRetIdx { + retIdx := u.run.rowIdxToRetIdx[i] + if retIdx >= 0 { + resultValues[retIdx] = newValues[i] + } + } + if _, err := u.run.rows.AddRow(params.ctx, resultValues); err != nil { return err }