Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: ON CONFLICT DO UPDATE form of upsert #6591

Merged
merged 1 commit into from
May 11, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions sql/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (p *planner) Insert(
columns: makeResultColumns(en.tableDesc.Columns),
}
for i := range checkExprs {
expr, err := resolveQNames(checkExprs[i], &table, qvals, &p.qnameVisitor)
expr, err := resolveQNames(checkExprs[i], []*tableInfo{&table}, qvals, &p.qnameVisitor)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -203,35 +203,26 @@ func (p *planner) Insert(
if n.OnConflict == nil {
tw = &tableInserter{ri: ri, autoCommit: autoCommit}
} else {
// TODO(dan): These are both implied by the short form of UPSERT. When the
// INSERT INTO ON CONFLICT form is implemented, get these values from
// n.OnConfict.
upsertConflictIndex := en.tableDesc.PrimaryIndex
insertCols := ri.insertCols

indexColSet := make(map[sqlbase.ColumnID]struct{}, len(upsertConflictIndex.ColumnIDs))
for _, colID := range upsertConflictIndex.ColumnIDs {
indexColSet[colID] = struct{}{}
}

// updateCols contains the columns that will be updated when a conflict is
// found. For the UPSERT short form, it is the set of columns in insertCols
// minus any columns in the conflict index. Example:
// `UPSERT INTO abc VALUES (1, 2, 3)` is syntactic sugar for
// `INSERT INTO abc VALUES (1, 2, 3) ON CONFLICT a DO UPDATE SET b = 2, c = 3`.
updateCols := make([]sqlbase.ColumnDescriptor, 0, len(insertCols))
for _, c := range insertCols {
if _, ok := indexColSet[c.ID]; !ok {
updateCols = append(updateCols, c)
}
updateExprs, conflictIndex, err := upsertExprsAndIndex(en.tableDesc, *n.OnConflict, ri.insertCols)
if err != nil {
return nil, err
}
ru, err := makeRowUpdater(en.tableDesc, updateCols)

names, err := p.namesForExprs(updateExprs)
if err != nil {
return nil, err
}
updateCols, err := p.processColumns(en.tableDesc, names)
if err != nil {
return nil, err
}

helper, err := p.makeUpsertHelper(en.tableDesc, ri.insertCols, updateExprs, conflictIndex)
if err != nil {
return nil, err
}
// TODO(dan): Use ru.fetchCols to compute the fetch selectors.

tw = &tableUpserter{ri: ri, ru: ru, autoCommit: autoCommit}
tw = &tableUpserter{ri: ri, updateCols: updateCols, conflictIndex: *conflictIndex, evaler: helper}
}

in := &insertNode{
Expand Down
29 changes: 23 additions & 6 deletions sql/parser/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ type Insert struct {
Table TableExpr
Columns QualifiedNames
Rows *Select
Returning ReturningExprs
OnConflict *OnConflict
Returning ReturningExprs
}

// Format implements the NodeFormatter interface.
func (node *Insert) Format(buf *bytes.Buffer, f FmtFlags) {
if node.OnConflict != nil {
if node.OnConflict.IsUpsertAlias() {
buf.WriteString("UPSERT")
} else {
buf.WriteString("INSERT")
Expand All @@ -53,6 +53,15 @@ func (node *Insert) Format(buf *bytes.Buffer, f FmtFlags) {
buf.WriteByte(' ')
FormatNode(buf, f, node.Rows)
}
if node.OnConflict != nil && !node.OnConflict.IsUpsertAlias() {
buf.WriteString(" ON CONFLICT (")
FormatNode(buf, f, node.OnConflict.Columns)
buf.WriteString(") DO UPDATE SET ")
FormatNode(buf, f, node.OnConflict.Exprs)
if node.OnConflict.Where != nil {
FormatNode(buf, f, node.OnConflict.Where)
}
}
FormatNode(buf, f, node.Returning)
}

Expand All @@ -61,11 +70,19 @@ func (node *Insert) DefaultValues() bool {
return node.Rows.Select == nil
}

// OnConflict represents an `ON CONFLICT index DO UPDATE SET` clause.
// OnConflict represents an `ON CONFLICT (columns) DO UPDATE SET exprs WHERE
// where` clause.
//
// The zero value for OnConflict is used to signal the UPSERT short form, which
// uses the primary key for Index and the values being inserted for Exprs.
// uses the primary key for as the conflict index and the values being inserted
// for Exprs.
type OnConflict struct {
Index Name
Exprs UpdateExprs
Columns NameList
Exprs UpdateExprs
Where *Where
}

// IsUpsertAlias returns true if the UPSERT syntactic sugar was used.
func (oc *OnConflict) IsUpsertAlias() bool {
return oc != nil && oc.Columns == nil && oc.Exprs == nil && oc.Where == nil
}
10 changes: 10 additions & 0 deletions sql/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ func TestParse(t *testing.T) {
{`UPSERT INTO a(a, a.b) VALUES (1, 2)`},
{`UPSERT INTO a SELECT b, c FROM d`},
{`UPSERT INTO a DEFAULT VALUES`},
{`UPSERT INTO a VALUES (1) RETURNING a, b`},
{`UPSERT INTO a VALUES (1, 2) RETURNING 1, 2`},
{`UPSERT INTO a VALUES (1, 2) RETURNING a + b, c`},

{`INSERT INTO a VALUES (1) ON CONFLICT (a) DO UPDATE SET a = 1`},
{`INSERT INTO a VALUES (1) ON CONFLICT (a, b) DO UPDATE SET a = 1`},
{`INSERT INTO a VALUES (1) ON CONFLICT (a) DO UPDATE SET a = 1, b = excluded.a`},
{`INSERT INTO a VALUES (1) ON CONFLICT (a) DO UPDATE SET a = 1 WHERE b > 2`},
{`INSERT INTO a VALUES (1) ON CONFLICT (a) DO UPDATE SET a = DEFAULT`},
{`INSERT INTO a VALUES (1) ON CONFLICT (a) DO UPDATE SET (a, b) = (SELECT 1, 2)`},

{`SELECT 1 + 1`},
{`SELECT - 1`},
Expand Down
Loading