Skip to content

Commit

Permalink
cmd/action: add schema flag (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
yonidavidson authored Jan 26, 2022
1 parent a732feb commit 4210db9
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 28 deletions.
24 changes: 16 additions & 8 deletions cmd/action/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ package action
import (
"context"

"ariga.io/atlas/sql/schema"
"github.com/spf13/cobra"
)

var (
// InspectFlags are the flags used in Inspect command.
InspectFlags struct {
DSN string
Web bool
Addr string
DSN string
Web bool
Addr string
Schema []string
}
// InspectCmd represents the inspect command.
InspectCmd = &cobra.Command{
Expand All @@ -29,7 +31,7 @@ and execute schema migrations against the given database.
Run: CmdInspectRun,
Example: `
atlas schema inspect -d "mysql://user:pass@tcp(localhost:3306)/dbname"
atlas schema inspect -d "mariadb://user:pass@tcp(localhost:3306)/dbname"
atlas schema inspect -d "mariadb://user:pass@tcp(localhost:3306)/" --schema=schemaA,schemaB -s schemaC
atlas schema inspect --dsn "postgres://user:pass@host:port/dbname"
atlas schema inspect -d "sqlite://file:ex1.db?_fk=1"`,
}
Expand All @@ -45,7 +47,8 @@ func init() {
"[driver://username:password@protocol(address)/dbname?param=value] Select data source using the dsn format",
)
InspectCmd.Flags().BoolVarP(&InspectFlags.Web, "web", "w", false, "Open in a local Atlas UI")
InspectCmd.Flags().StringVarP(&InspectFlags.Addr, "addr", "", "127.0.0.1:5800", "used with -w, local address to bind the server to")
InspectCmd.Flags().StringVarP(&InspectFlags.Addr, "addr", "", "127.0.0.1:5800", "Used with -w, local address to bind the server to")
InspectCmd.Flags().StringSliceVarP(&InspectFlags.Schema, "schema", "s", nil, "Set schema name")
cobra.CheckErr(InspectCmd.MarkFlagRequired("dsn"))
}

Expand All @@ -62,9 +65,14 @@ func CmdInspectRun(_ *cobra.Command, _ []string) {

func inspectRun(d *Driver, dsn string) {
ctx := context.Background()
name, err := SchemaNameFromDSN(dsn)
cobra.CheckErr(err)
s, err := d.InspectSchema(ctx, name, nil)
schemas := InspectFlags.Schema
if n, err := SchemaNameFromDSN(dsn); n != "" {
cobra.CheckErr(err)
schemas = append(schemas, n)
}
s, err := d.InspectRealm(ctx, &schema.InspectRealmOption{
Schemas: schemas,
})
cobra.CheckErr(err)
ddl, err := d.MarshalSpec(s)
cobra.CheckErr(err)
Expand Down
7 changes: 2 additions & 5 deletions cmd/action/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ func SchemaNameFromDSN(url string) (string, error) {
func postgresSchema(dsn string) (string, error) {
url, err := url.Parse(dsn)
if err != nil {
// For backwards compatibility, we default to "public" when failing to
// parse.
return "public", nil
return "", nil
}
// lib/pq supports setting default schemas via the `search_path` parameter
// in a dsn.
Expand All @@ -113,8 +111,7 @@ func postgresSchema(dsn string) (string, error) {
if schema := url.Query().Get("search_path"); schema != "" {
return schema, nil
}

return "public", nil
return "", nil
}

func schemaName(dsn string) (string, error) {
Expand Down
9 changes: 5 additions & 4 deletions doc/md/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,16 @@ and execute schema migrations against the given database.
```
atlas schema inspect -d "mysql://user:pass@tcp(localhost:3306)/dbname"
atlas schema inspect -d "mariadb://user:pass@tcp(localhost:3306)/dbname"
atlas schema inspect -d "mariadb://user:pass@tcp(localhost:3306)/" --schema=schemaA,schemaB -s schemaC
atlas schema inspect --dsn "postgres://user:pass@host:port/dbname"
atlas schema inspect -d "sqlite://file:ex1.db?_fk=1"
```
#### Flags
```
--addr string used with -w, local address to bind the server to (default "127.0.0.1:5800")
-d, --dsn string [driver://username:password@protocol(address)/dbname?param=value] Select data source using the dsn format
-w, --web Open in a local Atlas UI
--addr string Used with -w, local address to bind the server to (default "127.0.0.1:5800")
-d, --dsn string [driver://username:password@protocol(address)/dbname?param=value] Select data source using the dsn format
-s, --schema strings Set schema name
-w, --web Open in a local Atlas UI
```

Expand Down
29 changes: 29 additions & 0 deletions internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io/ioutil"
"os"
"os/exec"
"strings"
"testing"
"time"

Expand All @@ -33,6 +34,7 @@ type T interface {
migrate(...schema.Change)
diff(*schema.Table, *schema.Table) []schema.Change
applyHcl(spec string)
applyRealmHcl(spec string)
}

func testAddDrop(t T) {
Expand Down Expand Up @@ -169,6 +171,33 @@ func testCLISchemaInspect(t T, h string, dsn string, unmarshaler schemaspec.Unma
require.Equal(t, expected, actual)
}

func testCLIMultiSchemaInspect(t T, h string, dsn string, schemas []string, unmarshaler schemaspec.Unmarshaler) {
// Required to have a clean "stderr" while running first time.
err := exec.Command("go", "run", "-mod=mod", "ariga.io/atlas/cmd/atlas").Run()
require.NoError(t, err)
var expected schema.Realm
err = unmarshaler.UnmarshalSpec([]byte(h), &expected)
require.NoError(t, err)
t.applyRealmHcl(h)
cmd := exec.Command("go", "run", "ariga.io/atlas/cmd/atlas",
"schema",
"inspect",
"-d",
dsn,
"-s",
strings.Join(schemas, ","),
)
stdout, stderr := bytes.NewBuffer(nil), bytes.NewBuffer(nil)
cmd.Stderr = stderr
cmd.Stdout = stdout
require.NoError(t, cmd.Run(), stderr.String())
var actual schema.Realm
err = unmarshaler.UnmarshalSpec(stdout.Bytes(), &actual)
require.NoError(t, err)
require.Empty(t, stderr.String())
require.Equal(t, expected, actual)
}

func testCLISchemaApply(t T, h string, dsn string) {
// Required to have a clean "stderr" while running first time.
err := exec.Command("go", "run", "-mod=mod", "ariga.io/atlas/cmd/atlas").Run()
Expand Down
57 changes: 49 additions & 8 deletions internal/integration/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,26 +527,65 @@ func TestMySQL_CLI(t *testing.T) {
myRun(t, func(t *myTest) {
attrs := t.defaultAttrs()
charset, collate := attrs[0].(*schema.Charset), attrs[1].(*schema.Collation)
testCLISchemaInspect(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn(), mysql.UnmarshalHCL)
testCLISchemaInspect(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn("test"), mysql.UnmarshalHCL)
})
})
t.Run("SchemaApply", func(t *testing.T) {
myRun(t, func(t *myTest) {
attrs := t.defaultAttrs()
charset, collate := attrs[0].(*schema.Charset), attrs[1].(*schema.Collation)
testCLISchemaApply(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn())
testCLISchemaApply(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn("test"))
})
})
t.Run("SchemaApplyDryRun", func(t *testing.T) {
myRun(t, func(t *myTest) {
attrs := t.defaultAttrs()
charset, collate := attrs[0].(*schema.Charset), attrs[1].(*schema.Collation)
testCLISchemaApplyDry(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn())
testCLISchemaApplyDry(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn("test"))
})
})
t.Run("SchemaDiffRun", func(t *testing.T) {
myRun(t, func(t *myTest) {
testCLISchemaDiff(t, t.dsn())
testCLISchemaDiff(t, t.dsn("test"))
})
})
}

func TestMySQL_CLI_MultiSchema(t *testing.T) {
t.Run("SchemaInspect", func(t *testing.T) {
h := `
schema "test" {
charset = "%s"
collation = "%s"
}
table "users" {
schema = schema.test
column "id" {
type = int
}
primary_key {
columns = [table.users.column.id]
}
}
schema "test2" {
charset = "%s"
collation = "%s"
}
table "users" {
schema = schema.test2
column "id" {
type = int
}
primary_key {
columns = [table.users.column.id]
}
}`
myRun(t, func(t *myTest) {
t.dropDB("test2")
t.dropTables("users")
attrs := t.defaultAttrs()
charset, collate := attrs[0].(*schema.Charset), attrs[1].(*schema.Collation)
testCLIMultiSchemaInspect(t, fmt.Sprintf(h, charset.V, collate.V, charset.V, collate.V), t.dsn(""), []string{"test", "test2"}, mysql.UnmarshalHCL)
})
})
}
Expand Down Expand Up @@ -1040,12 +1079,12 @@ create table atlas_types_sanity
})
}

func (t *myTest) dsn() string {
func (t *myTest) dsn(dbname string) string {
d := "mysql"
if t.mariadb() {
d = "mariadb"
}
return fmt.Sprintf("%s://root:pass@tcp(localhost:%d)/test", d, t.port)
return fmt.Sprintf("%s://root:pass@tcp(localhost:%d)/%s", d, t.port, dbname)
}

func (t *myTest) applyHcl(spec string) {
Expand Down Expand Up @@ -1092,8 +1131,10 @@ func (t *myTest) dropTables(names ...string) {

func (t *myTest) dropDB(names ...string) {
t.Cleanup(func() {
_, err := t.db.Exec("DROP DATABASE IF EXISTS " + strings.Join(names, ", "))
require.NoError(t.T, err, "drop db %q", names)
for _, n := range names {
_, err := t.db.Exec("DROP DATABASE IF EXISTS " + n)
require.NoError(t.T, err, "drop db %q", names)
}
})
}

Expand Down
37 changes: 35 additions & 2 deletions internal/integration/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ schema "second" {
func (t *pgTest) applyRealmHcl(spec string) {
realm := t.loadRealm()
var desired schema.Realm
err := mysql.UnmarshalHCL([]byte(spec), &desired)
err := postgres.UnmarshalHCL([]byte(spec), &desired)
require.NoError(t, err)
diff, err := t.drv.RealmDiff(realm, &desired)
require.NoError(t, err)
Expand Down Expand Up @@ -480,6 +480,39 @@ func TestPostgres_CLI(t *testing.T) {
})
}

func TestPostgres_CLI_MultiSchema(t *testing.T) {
t.Run("SchemaInspect", func(t *testing.T) {
h := `
schema "public" {
}
table "users" {
schema = schema.public
column "id" {
type = integer
}
primary_key {
columns = [table.users.column.id]
}
}
schema "test2" {
}
table "users" {
schema = schema.test2
column "id" {
type = integer
}
primary_key {
columns = [table.users.column.id]
}
}`
pgRun(t, func(t *pgTest) {
t.dropSchemas("test2")
t.dropTables("users")
testCLIMultiSchemaInspect(t, h, t.dsn(), []string{"public", "test2"}, postgres.UnmarshalHCL)
})
})
}

func TestPostgres_DefaultsHCL(t *testing.T) {
n := "atlas_defaults"
pgRun(t, func(t *pgTest) {
Expand Down Expand Up @@ -1042,7 +1075,7 @@ func (t *pgTest) dropTables(names ...string) {

func (t *pgTest) dropSchemas(names ...string) {
t.Cleanup(func() {
_, err := t.db.Exec("DROP SCHEMA IF EXISTS " + strings.Join(names, ", "))
_, err := t.db.Exec("DROP SCHEMA IF EXISTS " + strings.Join(names, ", ") + " CASCADE")
require.NoError(t.T, err, "drop schema %q", names)
})
}
4 changes: 4 additions & 0 deletions internal/integration/sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,7 @@ func (t *liteTest) dropTables(names ...string) {
func (t *liteTest) dsn() string {
return fmt.Sprintf("sqlite://file:%s?cache=shared&_fk=1", t.file)
}

func (t *liteTest) applyRealmHcl(spec string) {
t.applyHcl(spec)
}
2 changes: 1 addition & 1 deletion sql/postgres/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ const (
paramsQuery = `SELECT setting FROM pg_settings WHERE name IN ('lc_collate', 'lc_ctype', 'server_version_num') ORDER BY name`

// Query to list database schemas.
schemasQuery = "SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'pg_toast') ORDER BY schema_name"
schemasQuery = "SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'pg_toast') AND schema_name NOT LIKE 'pg_%temp_%' ORDER BY schema_name"

// Query to list specific database schemas.
schemasQueryArgs = "SELECT schema_name FROM information_schema.schemata WHERE schema_name %s ORDER BY schema_name"
Expand Down

0 comments on commit 4210db9

Please sign in to comment.