Skip to content

Commit

Permalink
fix(chsql): properly handle column aliasing
Browse files Browse the repository at this point in the history
  • Loading branch information
tdakkota committed Jun 7, 2024
1 parent e924a78 commit 5eb101c
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 24 deletions.
1 change: 1 addition & 0 deletions internal/chstorage/chsql/_golden/Test5.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT (toString(name) AS name) FROM spans
1 change: 1 addition & 0 deletions internal/chstorage/chsql/_golden/Test6.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT (column AS name) FROM spans
1 change: 1 addition & 0 deletions internal/chstorage/chsql/_golden/Test7.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT (column) FROM spans
31 changes: 19 additions & 12 deletions internal/chstorage/chsql/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,15 @@ func (q *SelectQuery) WriteSQL(p *Printer) error {
if i != 0 {
p.Comma()
}

cexpr := c.Expr
if cexpr.IsZero() {
// If expression is not defined, assume that column
// name is expected.
cexpr = Ident(c.Name)
}
if needColumnAlias(c.Name, cexpr) {
// Do not alias the name if name is explicitly aliased by user.
cexpr = binaryOp(cexpr, "AS", Ident(c.Name))
}
cexpr = aliasColumn(c.Name, cexpr)

p.OpenParen()
if err := p.WriteExpr(cexpr); err != nil {
return errors.Wrapf(err, "column %q", c.Name)
Expand Down Expand Up @@ -188,15 +189,21 @@ func (q *SelectQuery) WriteSQL(p *Printer) error {
return nil
}

func needColumnAlias(name string, cexpr Expr) bool {
switch cexpr.typ {
case exprBinaryOp:
return !strings.EqualFold(cexpr.tok, "AS")
case exprIdent:
return cexpr.tok != name
default:
return true
func aliasColumn(name string, cexpr Expr) Expr {
if cexpr.typ == exprBinaryOp && strings.EqualFold(cexpr.tok, "AS") {
// If expression already aliased, rename the alias.
if len(cexpr.args) < 2 {
// Return invalid expression as-is.
return cexpr
}
cexpr = cexpr.args[0]
}
if cexpr.typ == exprIdent && cexpr.tok == name {
// Do not alias expression if it is an identifier (column name)
// with same name.
return cexpr
}
return binaryOp(cexpr, "AS", Ident(name))
}

// ResultColumn defines a column result.
Expand Down
42 changes: 42 additions & 0 deletions internal/chstorage/chsql/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func TestSelect(t *testing.T) {
},
false,
},
// Test subqueries.
{
func() *SelectQuery {
return SelectFrom(
Expand All @@ -81,6 +82,7 @@ func TestSelect(t *testing.T) {
},
false,
},
// Test ORDER By.
{
func() *SelectQuery {
return Select("spans",
Expand All @@ -106,6 +108,46 @@ func TestSelect(t *testing.T) {
},
false,
},
// Ensure aliasing is properly handled.
//
// Alias expression.
{
func() *SelectQuery {
return Select("spans", ResultColumn{
Name: "name",
Expr: ToString(Ident("name")),
})
},
false,
},
// User-defined alias.
{
func() *SelectQuery {
return Select("spans", ResultColumn{
Name: "name",
Expr: binaryOp(
Ident("column"),
"AS",
Ident("spanName"),
),
})
},
false,
},
// User-defined alias with the same name as column.
{
func() *SelectQuery {
return Select("spans", ResultColumn{
Name: "column",
Expr: binaryOp(
Ident("column"),
"AS",
Ident("spanName"),
),
})
},
false,
},

// No columns.
{
Expand Down
16 changes: 4 additions & 12 deletions internal/chstorage/chsql/sugar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,17 @@ func TestInTimeRange(t *testing.T) {

func TestJoinAnd(t *testing.T) {
tests := []struct {
args []Expr
want string
wantErr bool
args []Expr
want string
}{
{nil, "true", false},
{nil, "true"},
{[]Expr{Ident("foo")}, "foo"},
{
[]Expr{
Ident("foo"),
Ident("bar"),
},
"foo AND bar",
false,
},
{
[]Expr{
Expand All @@ -54,10 +53,7 @@ func TestJoinAnd(t *testing.T) {
Ident("baz"),
},
"foo AND bar AND baz",
false,
},

{[]Expr{Ident("foo")}, "", true},
}
for i, tt := range tests {
tt := tt
Expand All @@ -66,10 +62,6 @@ func TestJoinAnd(t *testing.T) {

p := GetPrinter()
err := p.WriteExpr(got)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.want, p.String())
})
Expand Down

0 comments on commit 5eb101c

Please sign in to comment.