Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql/postgres: unmarshal HCL enum #445

Merged
merged 3 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 71 additions & 13 deletions sql/postgres/sqlspec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"

"ariga.io/atlas/schema/schemaspec"
"ariga.io/atlas/schema/schemaspec/schemahcl"
Expand Down Expand Up @@ -40,7 +41,7 @@ func UnmarshalSpec(data []byte, unmarshaler schemaspec.Unmarshaler, v interface{
}
switch v := v.(type) {
case *schema.Realm:
realm, err := Realm(d.Schemas, d.Tables)
realm, err := Realm(d.Schemas, d.Tables, d.Enums)
if err != nil {
return fmt.Errorf("specutil: failed converting to *schema.Realm: %w", err)
}
Expand All @@ -49,7 +50,7 @@ func UnmarshalSpec(data []byte, unmarshaler schemaspec.Unmarshaler, v interface{
if len(d.Schemas) != 1 {
return fmt.Errorf("specutil: expecting document to contain a single schema, got %d", len(d.Schemas))
}
conv, err := Schema(d.Schemas[0], d.Tables)
conv, err := Schema(d.Schemas[0], d.Tables, d.Enums)
if err != nil {
return fmt.Errorf("specutil: failed converting to *schema.Schema: %w", err)
}
Expand All @@ -65,11 +66,14 @@ func MarshalSpec(v interface{}, marshaler schemaspec.Marshaler) ([]byte, error)
return specutil.Marshal(v, marshaler, schemaSpec)
}

// Realm converts the schemas and tables into a schema.Realm.
func Realm(schemas []*sqlspec.Schema, tables []*sqlspec.Table) (*schema.Realm, error) {
// Realm converts the schemas and tables of the doc into a schema.Realm.
func Realm(schemas []*sqlspec.Schema, tables []*sqlspec.Table, enums []*Enum) (*schema.Realm, error) {
r := &schema.Realm{}
for _, schemaSpec := range schemas {
var schemaTables []*sqlspec.Table
var (
schemaTables []*sqlspec.Table
schemaEnums []*Enum
)
for _, tableSpec := range tables {
name, err := specutil.SchemaName(tableSpec.Schema)
if err != nil {
Expand All @@ -79,7 +83,16 @@ func Realm(schemas []*sqlspec.Schema, tables []*sqlspec.Table) (*schema.Realm, e
schemaTables = append(schemaTables, tableSpec)
}
}
sch, err := Schema(schemaSpec, schemaTables)
for _, enum := range enums {
name, err := specutil.SchemaName(enum.Schema)
if err != nil {
return nil, fmt.Errorf("specutil: cannot extract schema name for table %q: %w", enum.Name, err)
}
if name == schemaSpec.Name {
schemaEnums = append(schemaEnums, enum)
}
}
sch, err := Schema(schemaSpec, schemaTables, schemaEnums)
if err != nil {
return nil, err
}
Expand All @@ -88,9 +101,8 @@ func Realm(schemas []*sqlspec.Schema, tables []*sqlspec.Table) (*schema.Realm, e
return r, nil
}

// Schema converts a sqlspec.Schema with its relevant []sqlspec.Tables
// into a schema.Schema.
func Schema(spec *sqlspec.Schema, tables []*sqlspec.Table) (*schema.Schema, error) {
// Schema converts a sqlspec.Schema with its relevant []sqlspec.Tables and []Enum into a schema.Schema.
func Schema(spec *sqlspec.Schema, tables []*sqlspec.Table, enums []*Enum) (*schema.Schema, error) {
sch := &schema.Schema{
Name: spec.Name,
}
Expand All @@ -108,8 +120,10 @@ func Schema(spec *sqlspec.Schema, tables []*sqlspec.Table) (*schema.Schema, erro
return nil, err
}
}
if err := convertEnums(tables, sch); err != nil {
return nil, err
if len(enums) > 0 {
if err := convertEnums(tables, enums, sch); err != nil {
return nil, err
}
}
return sch, nil
}
Expand Down Expand Up @@ -153,11 +167,55 @@ func convertColumnType(spec *sqlspec.Column) (schema.Type, error) {

// convertEnums converts possibly referenced column types (like enums) to
// an actual schema.Type and sets it on the correct schema.Column.
func convertEnums(tbls []*sqlspec.Table, sch *schema.Schema) error {
// TODO(masseelch): implement
func convertEnums(tbls []*sqlspec.Table, enums []*Enum, sch *schema.Schema) error {
for _, tbl := range tbls {
for _, col := range tbl.Columns {
if col.Type.IsRef {
e, err := resolveEnum(col.Type, enums)
if err != nil {
return err
}
t, ok := sch.Table(tbl.Name)
if !ok {
return fmt.Errorf("postgrs: table %q not found in schema %q", tbl.Name, sch.Name)
}
c, ok := t.Column(col.Name)
if !ok {
return fmt.Errorf("postgrs: column %q not found in table %q", col.Name, t.Name)
}
c.Type.Type = &EnumType{
T: e.Name,
Values: e.Values,
}
}
}
}
return nil
}

// resolveEnum returns the first Enum that matches the name referenced by the given column type.
func resolveEnum(ref *schemaspec.Type, enums []*Enum) (*Enum, error) {
n, err := enumName(ref)
if err != nil {
return nil, err
}
for _, e := range enums {
if e.Name == n {
return e, err
}
}
return nil, fmt.Errorf("postgres: enum %q not found", n)
}

// enumName extracts the name of the referenced Enum from the reference string.
func enumName(ref *schemaspec.Type) (string, error) {
s := strings.Split(ref.T, "$enum.")
if len(s) != 2 {
return "", fmt.Errorf("postgres: failed to extract enum name from %q", ref.T)
}
return s[1], nil
}

// schemaSpec converts from a concrete Postgres schema to Atlas specification.
func schemaSpec(schem *schema.Schema) (*sqlspec.Schema, []*sqlspec.Table, error) {
return specutil.FromSchema(schem, tableSpec)
Expand Down
16 changes: 16 additions & 0 deletions sql/postgres/sqlspec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,17 @@ table "accounts" {
column "name" {
type = varchar(32)
}
column "type" {
type = enum.account_type
}
primary_key {
columns = [table.accounts.column.name]
}
}

enum "account_type" {
values = ["private", "business"]
}
`
var s schema.Schema
err := UnmarshalHCL([]byte(f), &s)
Expand Down Expand Up @@ -143,6 +150,15 @@ table "accounts" {
},
},
},
{
Name: "type",
Type: &schema.ColumnType{
Type: &EnumType{
T: "account_type",
Values: []string{"private", "business"},
},
},
},
},
},
}
Expand Down