Skip to content

Commit

Permalink
CompileColumns is now a method of ColumnSet
Browse files Browse the repository at this point in the history
  • Loading branch information
aodin committed May 19, 2016
1 parent 628ee04 commit 479d0f1
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 72 deletions.
55 changes: 48 additions & 7 deletions columnset.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package sol

import "fmt"
import (
"fmt"
"strings"

"github.com/aodin/sol/dialect"
)

// ColumnSet maintains a []ColumnElem. It includes a variety of
// getters and setter. Optionally, it can force unique
Expand All @@ -10,10 +15,25 @@ type ColumnSet struct {
order []ColumnElem
}

// Add adds any number of ColumnElem types to the set and returns the new set.
func (set ColumnSet) Compile(d dialect.Dialect, ps *Parameters) (string, error) {
names := make([]string, len(set.order))
for i, col := range set.order {
compiled, err := col.Compile(d, ps)
if err != nil {
return "", err
}
if col.Alias() != "" {
compiled += fmt.Sprintf(` AS "%s"`, col.Alias())
}
names[i] = compiled
}
return strings.Join(names, ", "), nil
}

// Add adds any number of Columnar types to the set and returns the new set.
// If the set is marked unique, adding a column with the same name
// as an existing column in the set will return an error.
func (set ColumnSet) Add(columns ...ColumnElem) (ColumnSet, error) {
func (set ColumnSet) Add(columns ...Columnar) (ColumnSet, error) {
if set.unique {
for _, column := range columns {
for _, existing := range set.order {
Expand All @@ -32,10 +52,12 @@ func (set ColumnSet) Add(columns ...ColumnElem) (ColumnSet, error) {
)
}
}
set.order = append(set.order, column)
set.order = append(set.order, column.Column())
}
} else {
set.order = append(set.order, columns...)
for _, column := range columns {
set.order = append(set.order, column.Column())
}
}
return set, nil
}
Expand All @@ -45,6 +67,11 @@ func (set ColumnSet) All() []ColumnElem {
return set.order
}

// Exists returns true if there is at least one column in the set
func (set ColumnSet) Exists() bool {
return len(set.order) > 0
}

// Get returns a ColumnElem - or an invalid ColumnElem if a column
// with the given name does not exist in the set
func (set ColumnSet) Get(name string) ColumnElem {
Expand All @@ -62,13 +89,27 @@ func (set ColumnSet) Has(name string) bool {
return set.Get(name).IsValid()
}

// IsEmpty returns true if there are no columns in this set
func (set ColumnSet) IsEmpty() bool {
return len(set.order) == 0
}

// Names returns the full names of the set's columns without alias
func (set ColumnSet) Names() []string {
names := make([]string, len(set.order))
for i, col := range set.order {
names[i] = fmt.Sprintf(`%s`, col.FullName())
}
return names
}

// UniqueColumns creates a new ColumnSet that can only hold columns
// with unique names
func UniqueColumns() ColumnSet {
return ColumnSet{unique: true}
}

// Columns creates a new ColumnSet
func Columns() ColumnSet {
return ColumnSet{}
func Columns(columns ...ColumnElem) ColumnSet {
return ColumnSet{order: columns}
}
20 changes: 10 additions & 10 deletions postgres/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package postgres

import (
"fmt"
"strings"

"github.com/aodin/sol"
"github.com/aodin/sol/dialect"
Expand All @@ -16,7 +15,7 @@ type InsertStmt struct {
conflictTargets []string
values sol.Values
where sol.Clause
returning []sol.ColumnElem // TODO ColumnMap
returning sol.ColumnSet
}

// String outputs the parameter-less INSERT ... RETURNING statement in the
Expand Down Expand Up @@ -58,11 +57,12 @@ func (stmt InsertStmt) Compile(d dialect.Dialect, ps *sol.Parameters) (string, e
}
}

if len(stmt.returning) > 0 {
compiled += fmt.Sprintf(
" RETURNING %s",
strings.Join(sol.CompileColumns(stmt.returning), ", "),
)
if stmt.returning.Exists() {
selections, err := stmt.returning.Compile(d, ps)
if err != nil {
return "", err
}
compiled += fmt.Sprintf(" RETURNING %s", selections)
}
return compiled, nil
}
Expand Down Expand Up @@ -123,9 +123,9 @@ func (stmt InsertStmt) Returning(selections ...sol.Selectable) InsertStmt {
// http://www.postgresql.org/docs/devel/static/sql-insert.html

// If no selections were provided, default to the table
if len(selections) == 0 {
if len(selections) == 0 && stmt.Table() != nil {
for _, column := range stmt.Table().Columns() {
stmt.returning = append(stmt.returning, column)
stmt.returning, _ = stmt.returning.Add(column)
}
return stmt
}
Expand All @@ -148,7 +148,7 @@ func (stmt InsertStmt) Returning(selections ...sol.Selectable) InsertStmt {
)
break
}
stmt.returning = append(stmt.returning, column)
stmt.returning, _ = stmt.returning.Add(column)
}
}
return stmt
Expand Down
91 changes: 36 additions & 55 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ type Selectable interface {
type SelectStmt struct {
ConditionalStmt
tables []Tabular
columns []ColumnElem // TODO Columns
columns ColumnSet
joins []JoinClause
groupBy []ColumnElem // TODO Columns
groupBy ColumnSet
having Clause
orderBy []OrderedColumn
isDistinct bool
distincts []ColumnElem // TODO Columns
distincts ColumnSet
limit int
offset int
}
Expand All @@ -34,20 +34,7 @@ func (stmt SelectStmt) String() string {
return compiled
}

// TODO where should this function live? Also used in postgres.InsertStmt
func CompileColumns(columns []ColumnElem) []string {
names := make([]string, len(columns))
for i, col := range columns {
// Ignore dialect, parameters and error?
compiled, _ := col.Compile(nil, nil)
if col.Alias() != "" {
compiled += fmt.Sprintf(` AS "%s"`, col.Alias())
}
names[i] = compiled
}
return names
}

// TODO create a TableSet type?
func (stmt SelectStmt) compileTables() []string {
names := make([]string, len(stmt.tables))
for i, table := range stmt.tables {
Expand All @@ -63,24 +50,22 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error

compiled := "SELECT"

// DISTINCT
if stmt.isDistinct {
compiled += " DISTINCT"
if len(stmt.distincts) > 0 {
distincts := make([]string, len(stmt.distincts))
for i, col := range stmt.distincts {
distincts[i] = fmt.Sprintf(`%s`, col.FullName())
}
if stmt.distincts.Exists() {
compiled += fmt.Sprintf(
" ON (%s)", strings.Join(distincts, ", "),
" ON (%s)", strings.Join(stmt.distincts.Names(), ", "),
)
}
}

selections, err := stmt.columns.Compile(d, ps)
if err != nil {
return "", nil
}

compiled += fmt.Sprintf(
" %s FROM %s",
strings.Join(CompileColumns(stmt.columns), ", "),
strings.Join(stmt.compileTables(), ", "),
" %s FROM %s", selections, strings.Join(stmt.compileTables(), ", "),
)

if len(stmt.joins) > 0 {
Expand All @@ -101,16 +86,12 @@ func (stmt SelectStmt) Compile(d dialect.Dialect, ps *Parameters) (string, error
compiled += fmt.Sprintf(" WHERE %s", conditional)
}

// GROUP BY ...
if len(stmt.groupBy) > 0 {
groupBy := make([]string, len(stmt.groupBy))
for i, col := range stmt.groupBy {
groupBy[i] = fmt.Sprintf(`%s`, col.FullName())
}
compiled += fmt.Sprintf(" GROUP BY %s", strings.Join(groupBy, ", "))
if stmt.groupBy.Exists() {
compiled += fmt.Sprintf(
" GROUP BY %s", strings.Join(stmt.groupBy.Names(), ", "),
)
}

// HAVING
if stmt.having != nil {
conditional, err := stmt.having.Compile(d, ps)
if err != nil {
Expand Down Expand Up @@ -155,18 +136,16 @@ func (stmt SelectStmt) From(tables ...Tabular) SelectStmt {
// All removes the DISTINCT clause from the SELECT statement.
func (stmt SelectStmt) All() SelectStmt {
stmt.isDistinct = false
stmt.distincts = nil
stmt.distincts = Columns() // reset
return stmt
}

// Distinct adds a DISTINCT clause to the SELECT statement. If any
// column are provided, the clause will be compiled as a DISTINCT ON.
func (stmt SelectStmt) Distinct(columns ...Columnar) SelectStmt {
stmt.isDistinct = true
// TODO ColumnMap method
for _, column := range columns {
stmt.distincts = append(stmt.distincts, column.Column())
}
// Since the ColumnSet is not unique, any errors can be ignored
stmt.distincts, _ = stmt.distincts.Add(columns...)
return stmt
}

Expand Down Expand Up @@ -231,9 +210,8 @@ func (stmt SelectStmt) Where(conditions ...Clause) SelectStmt {
// is allowed per statement. Additional calls to GroupBy will overwrite the
// existing GROUP BY clause.
func (stmt SelectStmt) GroupBy(columns ...Columnar) SelectStmt {
for _, column := range columns {
stmt.groupBy = append(stmt.groupBy, column.Column())
}
// Since the ColumnSet is not unique, any errors can be ignored
stmt.groupBy, _ = stmt.groupBy.Add(columns...)
return stmt
}

Expand Down Expand Up @@ -287,24 +265,25 @@ func (stmt SelectStmt) Offset(offset int) SelectStmt {
func SelectTable(table Tabular, selects ...Selectable) (stmt SelectStmt) {
stmt.tables = []Tabular{table}

// Add the columns from the alias
stmt.columns = table.Columns()
// Add the columns from the initial table
stmt.columns = Columns(table.Columns()...)

// Add any additional selections
for _, selection := range selects {
if selection == nil {
stmt.AddMeta("sol: received a nil selectable in SelectTable()")
return
}
stmt.columns = append(stmt.columns, selection.Columns()...)
}

for _, column := range stmt.columns {
if column.IsInvalid() {
stmt.AddMeta(
"sol: the column %s does not exist", column.FullName(),
)
return
for _, column := range selection.Columns() {
if column.IsInvalid() {
stmt.AddMeta(
"sol: cannot select invalid column %s", column.FullName(),
)
return
}
// Since selections do not need to be unique, any errors
// from the ColumnSet can be ignored
stmt.columns, _ = stmt.columns.Add(column)
}
}
return
Expand Down Expand Up @@ -333,7 +312,9 @@ func Select(selections ...Selectable) (stmt SelectStmt) {
)
return
}
stmt.columns = append(stmt.columns, column)
// Since selections do not need to be unique, any errors
// from the ColumnSet can be ignored
stmt.columns, _ = stmt.columns.Add(column)

// Add the table to the stmt tables if it does not already exist
if !stmt.hasTable(column.Table().Name()) {
Expand Down

0 comments on commit 479d0f1

Please sign in to comment.