Skip to content

Commit

Permalink
Experimental tokens in preparation for AST parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
aodin committed May 24, 2016
1 parent 479d0f1 commit dd562c0
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 27 deletions.
8 changes: 4 additions & 4 deletions join.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ func (j JoinClause) String() string {
// Compile compiles a JoinClause
func (j JoinClause) Compile(d dialect.Dialect, ps *Parameters) (string, error) {
// Ignore clauses if CROSS
if j.method == "CROSS JOIN" {
return fmt.Sprintf(` CROSS JOIN "%s"`, j.table.Name()), nil
if j.method == CROSSJOIN {
return fmt.Sprintf(`%s "%s"`, CROSSJOIN, j.table.Name()), nil
}

// If no clauses were given, assume the join is NATURAL
if len(j.ArrayClause.clauses) == 0 {
return fmt.Sprintf(
` NATURAL %s "%s"`, j.method, j.table.Name(),
`NATURAL %s "%s"`, j.method, j.table.Name(),
), nil
}

Expand All @@ -41,6 +41,6 @@ func (j JoinClause) Compile(d dialect.Dialect, ps *Parameters) (string, error) {
}

return fmt.Sprintf(
` %s "%s" ON %s`, j.method, j.table.Name(), clauses,
`%s "%s" ON %s`, j.method, j.table.Name(), clauses,
), nil
}
5 changes: 4 additions & 1 deletion join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ func TestJoinClause(t *testing.T) {
)

expect.SQL(
`SELECT "a"."id", "a"."value" FROM "a" LEFT OUTER JOIN "relations" ON "a"."id" = "relations"."a_id" AND "a"."id" = $1`,
`SELECT "a"."id", "a"."value" FROM "a" LEFT OUTER JOIN "relations" ON "a"."id" = "relations"."a_id" AND "a"."id" = $1 LEFT OUTER JOIN "b" ON "b"."id" = "relations"."b_id"`,
Select(tableA).LeftOuterJoin(
relations,
tableA.C("id").Equals(relations.C("a_id")),
tableA.C("id").Equals(2),
).LeftOuterJoin(
tableB,
tableB.C("id").Equals(relations.C("b_id")),
),
2,
)
Expand Down
44 changes: 22 additions & 22 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error
return "", err
}

compiled := "SELECT"
// The finished statement will be joined with spaces
var compiled = []string{SELECT}

if stmt.isDistinct {
compiled += " DISTINCT"
compiled = append(compiled, DISTINCT)
if stmt.distincts.Exists() {
compiled += fmt.Sprintf(
" ON (%s)", strings.Join(stmt.distincts.Names(), ", "),
)
compiled = append(compiled, fmt.Sprintf(
"ON (%s)", strings.Join(stmt.distincts.Names(), ", "),
))
}
}

Expand All @@ -64,17 +65,16 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error
return "", nil
}

compiled += fmt.Sprintf(
" %s FROM %s", selections, strings.Join(stmt.compileTables(), ", "),
)
tables := strings.Join(stmt.compileTables(), ", ") // TODO use compilation?
compiled = append(compiled, selections, FROM, tables)

if len(stmt.joins) > 0 {
for _, j := range stmt.joins {
jc, err := j.Compile(d, ps)
if err != nil {
return "", err
}
compiled += jc
compiled = append(compiled, jc)
}
}

Expand All @@ -83,12 +83,12 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error
if err != nil {
return "", err
}
compiled += fmt.Sprintf(" WHERE %s", conditional)
compiled = append(compiled, WHERE, conditional)
}

if stmt.groupBy.Exists() {
compiled += fmt.Sprintf(
" GROUP BY %s", strings.Join(stmt.groupBy.Names(), ", "),
compiled = append(
compiled, GROUPBY, strings.Join(stmt.groupBy.Names(), ", "),
)
}

Expand All @@ -97,25 +97,25 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error
if err != nil {
return "", err
}
compiled += fmt.Sprintf(" HAVING %s", conditional)
compiled = append(compiled, HAVING, conditional)
}

if len(stmt.orderBy) > 0 {
order := make([]string, len(stmt.orderBy))
for i, ord := range stmt.orderBy {
order[i], _ = ord.Compile(d, ps)
}
compiled += fmt.Sprintf(" ORDER BY %s", strings.Join(order, ", "))
compiled = append(compiled, ORDERBY, strings.Join(order, ", "))
}

if stmt.limit != 0 {
compiled += fmt.Sprintf(" LIMIT %d", stmt.limit)
compiled = append(compiled, LIMIT, fmt.Sprintf("%d", stmt.limit))
}

if stmt.offset != 0 {
compiled += fmt.Sprintf(" OFFSET %d", stmt.offset)
compiled = append(compiled, OFFSET, fmt.Sprintf("%d", stmt.offset))
}
return compiled, nil
return strings.Join(compiled, " "), nil
}

func (stmt SelectStmt) hasTable(name string) bool {
Expand Down Expand Up @@ -163,31 +163,31 @@ func (stmt SelectStmt) join(table Tabular, method string, clauses ...Clause) Sel

// CrossJoin adds a CROSS JOIN ... clause to the SELECT statement.
func (stmt SelectStmt) CrossJoin(table Tabular) SelectStmt {
return stmt.join(table, "CROSS JOIN")
return stmt.join(table, CROSSJOIN)
}

// InnerJoin adds an INNER JOIN ... ON ... clause to the SELECT statement.
// If no clauses are given, it will assume the clause is NATURAL.
func (stmt SelectStmt) InnerJoin(table Tabular, clauses ...Clause) SelectStmt {
return stmt.join(table, "INNER JOIN", clauses...)
return stmt.join(table, INNERJOIN, clauses...)
}

// LeftOuterJoin adds a LEFT OUTER JOIN ... ON ... clause to the SELECT
// statement. If no clauses are given, it will assume the clause is NATURAL.
func (stmt SelectStmt) LeftOuterJoin(table Tabular, clauses ...Clause) SelectStmt {
return stmt.join(table, "LEFT OUTER JOIN", clauses...)
return stmt.join(table, LEFTOUTERJOIN, clauses...)
}

// RightOuterJoin adds a RIGHT OUTER JOIN ... ON ... clause to the SELECT
// statement. If no clauses are given, it will assume the clause is NATURAL.
func (stmt SelectStmt) RightOuterJoin(table Tabular, clauses ...Clause) SelectStmt {
return stmt.join(table, "RIGHT OUTER JOIN", clauses...)
return stmt.join(table, RIGHTOUTERJOIN, clauses...)
}

// FullOuterJoin adds a FULL OUTER JOIN ... ON ... clause to the SELECT
// statement. If no clauses are given, it will assume the clause is NATURAL.
func (stmt SelectStmt) FullOuterJoin(table Tabular, clauses ...Clause) SelectStmt {
return stmt.join(table, "FULL OUTER JOIN", clauses...)
return stmt.join(table, FULLOUTERJOIN, clauses...)
}

// Where adds a conditional clause to the SELECT statement. Only one WHERE
Expand Down
22 changes: 22 additions & 0 deletions tokens.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package sol

// Tokens holds SQL tokens that are used for compilation / AST parsing
const (
CROSSJOIN = "CROSS JOIN"
DELETE = "DELETE"
DISTINCT = "DISTINCT"
FROM = "FROM"
FULLOUTERJOIN = "FULL OUTER JOIN"
GROUPBY = "GROUP BY"
HAVING = "HAVING"
INNERJOIN = "INNER JOIN"
INSERT = "INSERT"
LEFTOUTERJOIN = "LEFT OUTER JOIN"
LIMIT = "LIMIT"
OFFSET = "OFFSET"
ORDERBY = "ORDER BY"
RIGHTOUTERJOIN = "RIGHT OUTER JOIN"
SELECT = "SELECT"
UPDATE = "UPDATE"
WHERE = "WHERE"
)

0 comments on commit dd562c0

Please sign in to comment.