diff --git a/join.go b/join.go index 26ed6e2..6991007 100644 --- a/join.go +++ b/join.go @@ -22,14 +22,14 @@ func (j JoinClause) String() string { // Compile compiles a JoinClause func (j JoinClause) Compile(d dialect.Dialect, ps *Parameters) (string, error) { // Ignore clauses if CROSS - if j.method == "CROSS JOIN" { - return fmt.Sprintf(` CROSS JOIN "%s"`, j.table.Name()), nil + if j.method == CROSSJOIN { + return fmt.Sprintf(`%s "%s"`, CROSSJOIN, j.table.Name()), nil } // If no clauses were given, assume the join is NATURAL if len(j.ArrayClause.clauses) == 0 { return fmt.Sprintf( - ` NATURAL %s "%s"`, j.method, j.table.Name(), + `NATURAL %s "%s"`, j.method, j.table.Name(), ), nil } @@ -41,6 +41,6 @@ func (j JoinClause) Compile(d dialect.Dialect, ps *Parameters) (string, error) { } return fmt.Sprintf( - ` %s "%s" ON %s`, j.method, j.table.Name(), clauses, + `%s "%s" ON %s`, j.method, j.table.Name(), clauses, ), nil } diff --git a/join_test.go b/join_test.go index 143565f..30f90f8 100644 --- a/join_test.go +++ b/join_test.go @@ -36,11 +36,14 @@ func TestJoinClause(t *testing.T) { ) expect.SQL( - `SELECT "a"."id", "a"."value" FROM "a" LEFT OUTER JOIN "relations" ON "a"."id" = "relations"."a_id" AND "a"."id" = $1`, + `SELECT "a"."id", "a"."value" FROM "a" LEFT OUTER JOIN "relations" ON "a"."id" = "relations"."a_id" AND "a"."id" = $1 LEFT OUTER JOIN "b" ON "b"."id" = "relations"."b_id"`, Select(tableA).LeftOuterJoin( relations, tableA.C("id").Equals(relations.C("a_id")), tableA.C("id").Equals(2), + ).LeftOuterJoin( + tableB, + tableB.C("id").Equals(relations.C("b_id")), ), 2, ) diff --git a/select.go b/select.go index 0f66ed3..9a07a7e 100644 --- a/select.go +++ b/select.go @@ -48,14 +48,15 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error return "", err } - compiled := "SELECT" + // The finished statement will be joined with spaces + var compiled = []string{SELECT} if stmt.isDistinct { - compiled += " DISTINCT" + compiled = append(compiled, DISTINCT) if stmt.distincts.Exists() { - compiled += fmt.Sprintf( - " ON (%s)", strings.Join(stmt.distincts.Names(), ", "), - ) + compiled = append(compiled, fmt.Sprintf( + "ON (%s)", strings.Join(stmt.distincts.Names(), ", "), + )) } } @@ -64,9 +65,8 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error return "", nil } - compiled += fmt.Sprintf( - " %s FROM %s", selections, strings.Join(stmt.compileTables(), ", "), - ) + tables := strings.Join(stmt.compileTables(), ", ") // TODO use compilation? + compiled = append(compiled, selections, FROM, tables) if len(stmt.joins) > 0 { for _, j := range stmt.joins { @@ -74,7 +74,7 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error if err != nil { return "", err } - compiled += jc + compiled = append(compiled, jc) } } @@ -83,12 +83,12 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error if err != nil { return "", err } - compiled += fmt.Sprintf(" WHERE %s", conditional) + compiled = append(compiled, WHERE, conditional) } if stmt.groupBy.Exists() { - compiled += fmt.Sprintf( - " GROUP BY %s", strings.Join(stmt.groupBy.Names(), ", "), + compiled = append( + compiled, GROUPBY, strings.Join(stmt.groupBy.Names(), ", "), ) } @@ -97,7 +97,7 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error if err != nil { return "", err } - compiled += fmt.Sprintf(" HAVING %s", conditional) + compiled = append(compiled, HAVING, conditional) } if len(stmt.orderBy) > 0 { @@ -105,17 +105,17 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error for i, ord := range stmt.orderBy { order[i], _ = ord.Compile(d, ps) } - compiled += fmt.Sprintf(" ORDER BY %s", strings.Join(order, ", ")) + compiled = append(compiled, ORDERBY, strings.Join(order, ", ")) } if stmt.limit != 0 { - compiled += fmt.Sprintf(" LIMIT %d", stmt.limit) + compiled = append(compiled, LIMIT, fmt.Sprintf("%d", stmt.limit)) } if stmt.offset != 0 { - compiled += fmt.Sprintf(" OFFSET %d", stmt.offset) + compiled = append(compiled, OFFSET, fmt.Sprintf("%d", stmt.offset)) } - return compiled, nil + return strings.Join(compiled, " "), nil } func (stmt SelectStmt) hasTable(name string) bool { @@ -163,31 +163,31 @@ func (stmt SelectStmt) join(table Tabular, method string, clauses ...Clause) Sel // CrossJoin adds a CROSS JOIN ... clause to the SELECT statement. func (stmt SelectStmt) CrossJoin(table Tabular) SelectStmt { - return stmt.join(table, "CROSS JOIN") + return stmt.join(table, CROSSJOIN) } // InnerJoin adds an INNER JOIN ... ON ... clause to the SELECT statement. // If no clauses are given, it will assume the clause is NATURAL. func (stmt SelectStmt) InnerJoin(table Tabular, clauses ...Clause) SelectStmt { - return stmt.join(table, "INNER JOIN", clauses...) + return stmt.join(table, INNERJOIN, clauses...) } // LeftOuterJoin adds a LEFT OUTER JOIN ... ON ... clause to the SELECT // statement. If no clauses are given, it will assume the clause is NATURAL. func (stmt SelectStmt) LeftOuterJoin(table Tabular, clauses ...Clause) SelectStmt { - return stmt.join(table, "LEFT OUTER JOIN", clauses...) + return stmt.join(table, LEFTOUTERJOIN, clauses...) } // RightOuterJoin adds a RIGHT OUTER JOIN ... ON ... clause to the SELECT // statement. If no clauses are given, it will assume the clause is NATURAL. func (stmt SelectStmt) RightOuterJoin(table Tabular, clauses ...Clause) SelectStmt { - return stmt.join(table, "RIGHT OUTER JOIN", clauses...) + return stmt.join(table, RIGHTOUTERJOIN, clauses...) } // FullOuterJoin adds a FULL OUTER JOIN ... ON ... clause to the SELECT // statement. If no clauses are given, it will assume the clause is NATURAL. func (stmt SelectStmt) FullOuterJoin(table Tabular, clauses ...Clause) SelectStmt { - return stmt.join(table, "FULL OUTER JOIN", clauses...) + return stmt.join(table, FULLOUTERJOIN, clauses...) } // Where adds a conditional clause to the SELECT statement. Only one WHERE diff --git a/tokens.go b/tokens.go new file mode 100644 index 0000000..9f3a042 --- /dev/null +++ b/tokens.go @@ -0,0 +1,22 @@ +package sol + +// Tokens holds SQL tokens that are used for compilation / AST parsing +const ( + CROSSJOIN = "CROSS JOIN" + DELETE = "DELETE" + DISTINCT = "DISTINCT" + FROM = "FROM" + FULLOUTERJOIN = "FULL OUTER JOIN" + GROUPBY = "GROUP BY" + HAVING = "HAVING" + INNERJOIN = "INNER JOIN" + INSERT = "INSERT" + LEFTOUTERJOIN = "LEFT OUTER JOIN" + LIMIT = "LIMIT" + OFFSET = "OFFSET" + ORDERBY = "ORDER BY" + RIGHTOUTERJOIN = "RIGHT OUTER JOIN" + SELECT = "SELECT" + UPDATE = "UPDATE" + WHERE = "WHERE" +)