From 4210db9ff2808247d9bdd240f188eb294b7cd376 Mon Sep 17 00:00:00 2001 From: Yoni Davidson Date: Wed, 26 Jan 2022 16:06:59 +0200 Subject: [PATCH] cmd/action: add schema flag (#490) --- cmd/action/inspect.go | 24 ++++++---- cmd/action/mux.go | 7 +-- doc/md/cli.md | 9 ++-- internal/integration/integration_test.go | 29 ++++++++++++ internal/integration/mysql_test.go | 57 ++++++++++++++++++++---- internal/integration/postgres_test.go | 37 ++++++++++++++- internal/integration/sqlite_test.go | 4 ++ sql/postgres/inspect.go | 2 +- 8 files changed, 141 insertions(+), 28 deletions(-) diff --git a/cmd/action/inspect.go b/cmd/action/inspect.go index 64a3d1e4bd2..2a133a3fc48 100644 --- a/cmd/action/inspect.go +++ b/cmd/action/inspect.go @@ -3,15 +3,17 @@ package action import ( "context" + "ariga.io/atlas/sql/schema" "github.com/spf13/cobra" ) var ( // InspectFlags are the flags used in Inspect command. InspectFlags struct { - DSN string - Web bool - Addr string + DSN string + Web bool + Addr string + Schema []string } // InspectCmd represents the inspect command. InspectCmd = &cobra.Command{ @@ -29,7 +31,7 @@ and execute schema migrations against the given database. Run: CmdInspectRun, Example: ` atlas schema inspect -d "mysql://user:pass@tcp(localhost:3306)/dbname" -atlas schema inspect -d "mariadb://user:pass@tcp(localhost:3306)/dbname" +atlas schema inspect -d "mariadb://user:pass@tcp(localhost:3306)/" --schema=schemaA,schemaB -s schemaC atlas schema inspect --dsn "postgres://user:pass@host:port/dbname" atlas schema inspect -d "sqlite://file:ex1.db?_fk=1"`, } @@ -45,7 +47,8 @@ func init() { "[driver://username:password@protocol(address)/dbname?param=value] Select data source using the dsn format", ) InspectCmd.Flags().BoolVarP(&InspectFlags.Web, "web", "w", false, "Open in a local Atlas UI") - InspectCmd.Flags().StringVarP(&InspectFlags.Addr, "addr", "", "127.0.0.1:5800", "used with -w, local address to bind the server to") + InspectCmd.Flags().StringVarP(&InspectFlags.Addr, "addr", "", "127.0.0.1:5800", "Used with -w, local address to bind the server to") + InspectCmd.Flags().StringSliceVarP(&InspectFlags.Schema, "schema", "s", nil, "Set schema name") cobra.CheckErr(InspectCmd.MarkFlagRequired("dsn")) } @@ -62,9 +65,14 @@ func CmdInspectRun(_ *cobra.Command, _ []string) { func inspectRun(d *Driver, dsn string) { ctx := context.Background() - name, err := SchemaNameFromDSN(dsn) - cobra.CheckErr(err) - s, err := d.InspectSchema(ctx, name, nil) + schemas := InspectFlags.Schema + if n, err := SchemaNameFromDSN(dsn); n != "" { + cobra.CheckErr(err) + schemas = append(schemas, n) + } + s, err := d.InspectRealm(ctx, &schema.InspectRealmOption{ + Schemas: schemas, + }) cobra.CheckErr(err) ddl, err := d.MarshalSpec(s) cobra.CheckErr(err) diff --git a/cmd/action/mux.go b/cmd/action/mux.go index 9de7d7d248e..832e01b2a62 100644 --- a/cmd/action/mux.go +++ b/cmd/action/mux.go @@ -102,9 +102,7 @@ func SchemaNameFromDSN(url string) (string, error) { func postgresSchema(dsn string) (string, error) { url, err := url.Parse(dsn) if err != nil { - // For backwards compatibility, we default to "public" when failing to - // parse. - return "public", nil + return "", nil } // lib/pq supports setting default schemas via the `search_path` parameter // in a dsn. @@ -113,8 +111,7 @@ func postgresSchema(dsn string) (string, error) { if schema := url.Query().Get("search_path"); schema != "" { return schema, nil } - - return "public", nil + return "", nil } func schemaName(dsn string) (string, error) { diff --git a/doc/md/cli.md b/doc/md/cli.md index 2abe17a610d..b0a9a032446 100644 --- a/doc/md/cli.md +++ b/doc/md/cli.md @@ -161,15 +161,16 @@ and execute schema migrations against the given database. ``` atlas schema inspect -d "mysql://user:pass@tcp(localhost:3306)/dbname" -atlas schema inspect -d "mariadb://user:pass@tcp(localhost:3306)/dbname" +atlas schema inspect -d "mariadb://user:pass@tcp(localhost:3306)/" --schema=schemaA,schemaB -s schemaC atlas schema inspect --dsn "postgres://user:pass@host:port/dbname" atlas schema inspect -d "sqlite://file:ex1.db?_fk=1" ``` #### Flags ``` - --addr string used with -w, local address to bind the server to (default "127.0.0.1:5800") - -d, --dsn string [driver://username:password@protocol(address)/dbname?param=value] Select data source using the dsn format - -w, --web Open in a local Atlas UI + --addr string Used with -w, local address to bind the server to (default "127.0.0.1:5800") + -d, --dsn string [driver://username:password@protocol(address)/dbname?param=value] Select data source using the dsn format + -s, --schema strings Set schema name + -w, --web Open in a local Atlas UI ``` diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 816e904d3ae..9e1f04342dd 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "os" "os/exec" + "strings" "testing" "time" @@ -33,6 +34,7 @@ type T interface { migrate(...schema.Change) diff(*schema.Table, *schema.Table) []schema.Change applyHcl(spec string) + applyRealmHcl(spec string) } func testAddDrop(t T) { @@ -169,6 +171,33 @@ func testCLISchemaInspect(t T, h string, dsn string, unmarshaler schemaspec.Unma require.Equal(t, expected, actual) } +func testCLIMultiSchemaInspect(t T, h string, dsn string, schemas []string, unmarshaler schemaspec.Unmarshaler) { + // Required to have a clean "stderr" while running first time. + err := exec.Command("go", "run", "-mod=mod", "ariga.io/atlas/cmd/atlas").Run() + require.NoError(t, err) + var expected schema.Realm + err = unmarshaler.UnmarshalSpec([]byte(h), &expected) + require.NoError(t, err) + t.applyRealmHcl(h) + cmd := exec.Command("go", "run", "ariga.io/atlas/cmd/atlas", + "schema", + "inspect", + "-d", + dsn, + "-s", + strings.Join(schemas, ","), + ) + stdout, stderr := bytes.NewBuffer(nil), bytes.NewBuffer(nil) + cmd.Stderr = stderr + cmd.Stdout = stdout + require.NoError(t, cmd.Run(), stderr.String()) + var actual schema.Realm + err = unmarshaler.UnmarshalSpec(stdout.Bytes(), &actual) + require.NoError(t, err) + require.Empty(t, stderr.String()) + require.Equal(t, expected, actual) +} + func testCLISchemaApply(t T, h string, dsn string) { // Required to have a clean "stderr" while running first time. err := exec.Command("go", "run", "-mod=mod", "ariga.io/atlas/cmd/atlas").Run() diff --git a/internal/integration/mysql_test.go b/internal/integration/mysql_test.go index 35539b42b44..1e1420de179 100644 --- a/internal/integration/mysql_test.go +++ b/internal/integration/mysql_test.go @@ -527,26 +527,65 @@ func TestMySQL_CLI(t *testing.T) { myRun(t, func(t *myTest) { attrs := t.defaultAttrs() charset, collate := attrs[0].(*schema.Charset), attrs[1].(*schema.Collation) - testCLISchemaInspect(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn(), mysql.UnmarshalHCL) + testCLISchemaInspect(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn("test"), mysql.UnmarshalHCL) }) }) t.Run("SchemaApply", func(t *testing.T) { myRun(t, func(t *myTest) { attrs := t.defaultAttrs() charset, collate := attrs[0].(*schema.Charset), attrs[1].(*schema.Collation) - testCLISchemaApply(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn()) + testCLISchemaApply(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn("test")) }) }) t.Run("SchemaApplyDryRun", func(t *testing.T) { myRun(t, func(t *myTest) { attrs := t.defaultAttrs() charset, collate := attrs[0].(*schema.Charset), attrs[1].(*schema.Collation) - testCLISchemaApplyDry(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn()) + testCLISchemaApplyDry(t, fmt.Sprintf(h, charset.V, collate.V), t.dsn("test")) }) }) t.Run("SchemaDiffRun", func(t *testing.T) { myRun(t, func(t *myTest) { - testCLISchemaDiff(t, t.dsn()) + testCLISchemaDiff(t, t.dsn("test")) + }) + }) +} + +func TestMySQL_CLI_MultiSchema(t *testing.T) { + t.Run("SchemaInspect", func(t *testing.T) { + h := ` + schema "test" { + charset = "%s" + collation = "%s" + } + table "users" { + schema = schema.test + column "id" { + type = int + } + primary_key { + columns = [table.users.column.id] + } + } + schema "test2" { + charset = "%s" + collation = "%s" + } + table "users" { + schema = schema.test2 + column "id" { + type = int + } + primary_key { + columns = [table.users.column.id] + } + }` + myRun(t, func(t *myTest) { + t.dropDB("test2") + t.dropTables("users") + attrs := t.defaultAttrs() + charset, collate := attrs[0].(*schema.Charset), attrs[1].(*schema.Collation) + testCLIMultiSchemaInspect(t, fmt.Sprintf(h, charset.V, collate.V, charset.V, collate.V), t.dsn(""), []string{"test", "test2"}, mysql.UnmarshalHCL) }) }) } @@ -1040,12 +1079,12 @@ create table atlas_types_sanity }) } -func (t *myTest) dsn() string { +func (t *myTest) dsn(dbname string) string { d := "mysql" if t.mariadb() { d = "mariadb" } - return fmt.Sprintf("%s://root:pass@tcp(localhost:%d)/test", d, t.port) + return fmt.Sprintf("%s://root:pass@tcp(localhost:%d)/%s", d, t.port, dbname) } func (t *myTest) applyHcl(spec string) { @@ -1092,8 +1131,10 @@ func (t *myTest) dropTables(names ...string) { func (t *myTest) dropDB(names ...string) { t.Cleanup(func() { - _, err := t.db.Exec("DROP DATABASE IF EXISTS " + strings.Join(names, ", ")) - require.NoError(t.T, err, "drop db %q", names) + for _, n := range names { + _, err := t.db.Exec("DROP DATABASE IF EXISTS " + n) + require.NoError(t.T, err, "drop db %q", names) + } }) } diff --git a/internal/integration/postgres_test.go b/internal/integration/postgres_test.go index ea776bf0de8..9227c9dc836 100644 --- a/internal/integration/postgres_test.go +++ b/internal/integration/postgres_test.go @@ -437,7 +437,7 @@ schema "second" { func (t *pgTest) applyRealmHcl(spec string) { realm := t.loadRealm() var desired schema.Realm - err := mysql.UnmarshalHCL([]byte(spec), &desired) + err := postgres.UnmarshalHCL([]byte(spec), &desired) require.NoError(t, err) diff, err := t.drv.RealmDiff(realm, &desired) require.NoError(t, err) @@ -480,6 +480,39 @@ func TestPostgres_CLI(t *testing.T) { }) } +func TestPostgres_CLI_MultiSchema(t *testing.T) { + t.Run("SchemaInspect", func(t *testing.T) { + h := ` + schema "public" { + } + table "users" { + schema = schema.public + column "id" { + type = integer + } + primary_key { + columns = [table.users.column.id] + } + } + schema "test2" { + } + table "users" { + schema = schema.test2 + column "id" { + type = integer + } + primary_key { + columns = [table.users.column.id] + } + }` + pgRun(t, func(t *pgTest) { + t.dropSchemas("test2") + t.dropTables("users") + testCLIMultiSchemaInspect(t, h, t.dsn(), []string{"public", "test2"}, postgres.UnmarshalHCL) + }) + }) +} + func TestPostgres_DefaultsHCL(t *testing.T) { n := "atlas_defaults" pgRun(t, func(t *pgTest) { @@ -1042,7 +1075,7 @@ func (t *pgTest) dropTables(names ...string) { func (t *pgTest) dropSchemas(names ...string) { t.Cleanup(func() { - _, err := t.db.Exec("DROP SCHEMA IF EXISTS " + strings.Join(names, ", ")) + _, err := t.db.Exec("DROP SCHEMA IF EXISTS " + strings.Join(names, ", ") + " CASCADE") require.NoError(t.T, err, "drop schema %q", names) }) } diff --git a/internal/integration/sqlite_test.go b/internal/integration/sqlite_test.go index fbfc48b4f5e..9668951e165 100644 --- a/internal/integration/sqlite_test.go +++ b/internal/integration/sqlite_test.go @@ -839,3 +839,7 @@ func (t *liteTest) dropTables(names ...string) { func (t *liteTest) dsn() string { return fmt.Sprintf("sqlite://file:%s?cache=shared&_fk=1", t.file) } + +func (t *liteTest) applyRealmHcl(spec string) { + t.applyHcl(spec) +} diff --git a/sql/postgres/inspect.go b/sql/postgres/inspect.go index 5bbeb3c5050..63c6742a43b 100644 --- a/sql/postgres/inspect.go +++ b/sql/postgres/inspect.go @@ -706,7 +706,7 @@ const ( paramsQuery = `SELECT setting FROM pg_settings WHERE name IN ('lc_collate', 'lc_ctype', 'server_version_num') ORDER BY name` // Query to list database schemas. - schemasQuery = "SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'pg_toast') ORDER BY schema_name" + schemasQuery = "SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'pg_toast') AND schema_name NOT LIKE 'pg_%temp_%' ORDER BY schema_name" // Query to list specific database schemas. schemasQueryArgs = "SELECT schema_name FROM information_schema.schemata WHERE schema_name %s ORDER BY schema_name"