diff --git a/cmd/action/apply.go b/cmd/action/apply.go index e51ecb82bde..de001a962dc 100644 --- a/cmd/action/apply.go +++ b/cmd/action/apply.go @@ -71,7 +71,7 @@ func CmdApplyRun(cmd *cobra.Command, args []string) { schemaCmd.PrintErrln("The Atlas UI is not available in this release.") return } - d, err := defaultMux.OpenAtlas(ApplyFlags.URL) + d, err := DefaultMux.OpenAtlas(ApplyFlags.URL) cobra.CheckErr(err) applyRun(d, ApplyFlags.URL, ApplyFlags.File, ApplyFlags.DryRun, ApplyFlags.AutoApprove) } diff --git a/cmd/action/diff.go b/cmd/action/diff.go index df5dbd41120..1eeae816614 100644 --- a/cmd/action/diff.go +++ b/cmd/action/diff.go @@ -44,9 +44,9 @@ func init() { // cmdDiffRun connects to the given databases, and prints an SQL plan to get from // the "from" schema to the "to" schema. func cmdDiffRun(cmd *cobra.Command, flags *diffCmdOpts) { - fromDriver, err := defaultMux.OpenAtlas(flags.fromURL) + fromDriver, err := DefaultMux.OpenAtlas(flags.fromURL) cobra.CheckErr(err) - toDriver, err := defaultMux.OpenAtlas(flags.toURL) + toDriver, err := DefaultMux.OpenAtlas(flags.toURL) cobra.CheckErr(err) fromName, err := SchemaNameFromURL(flags.fromURL) cobra.CheckErr(err) diff --git a/cmd/action/inspect.go b/cmd/action/inspect.go index 3ce496443e3..b17058fcd34 100644 --- a/cmd/action/inspect.go +++ b/cmd/action/inspect.go @@ -60,7 +60,7 @@ func CmdInspectRun(_ *cobra.Command, _ []string) { schemaCmd.PrintErrln("The Atlas UI is not available in this release.") return } - d, err := defaultMux.OpenAtlas(InspectFlags.URL) + d, err := DefaultMux.OpenAtlas(InspectFlags.URL) cobra.CheckErr(err) inspectRun(d, InspectFlags.URL) } diff --git a/cmd/action/mux.go b/cmd/action/mux.go index 8d49c9ebb76..0cbaeea9f49 100644 --- a/cmd/action/mux.go +++ b/cmd/action/mux.go @@ -43,8 +43,9 @@ func NewMux() *Mux { } var ( - defaultMux = NewMux() - inMemory = regexp.MustCompile("^file:.*:memory:$|:memory:|^file:.*mode=memory.*") + // DefaultMux is the default Mux that is used by the different commands. + DefaultMux = NewMux() + reMemMode = regexp.MustCompile("^file:.*:memory:$|:memory:|^file:.*mode=memory.*") ) // RegisterProvider is used to register a Driver provider by key. @@ -142,7 +143,7 @@ func schemaName(dsn string) (string, error) { } func sqliteFileExists(dsn string) error { - if !inMemory.MatchString(dsn) { + if !reMemMode.MatchString(dsn) { return fileExists(dsn) } return nil @@ -163,3 +164,32 @@ func fileExists(dsn string) error { } return nil } + +func mysqlDSN(d string) (string, error) { + cfg, err := mysql.ParseDSN(d) + // A standard MySQL DSN. + if err == nil { + return d, nil + } + u, err := url.Parse("mysql://" + d) + if err != nil { + return "", nil + } + schema := strings.TrimPrefix(u.Path, "/") + // In case of a URL (non-standard DSN), + // parse the options from query string. + if u.RawQuery != "" { + cfg, err = mysql.ParseDSN(fmt.Sprintf("/%s?%s", schema, u.RawQuery)) + if err != nil { + return "", err + } + } else { + cfg = mysql.NewConfig() + } + cfg.Net = "tcp" + cfg.Addr = u.Host + cfg.User = u.User.Username() + cfg.Passwd, _ = u.User.Password() + cfg.DBName = schema + return cfg.FormatDSN(), nil +} diff --git a/cmd/action/mux_test.go b/cmd/action/mux_test.go index 3fe55252e6f..56ac5e82192 100644 --- a/cmd/action/mux_test.go +++ b/cmd/action/mux_test.go @@ -5,12 +5,16 @@ package action_test import ( + "fmt" "io/ioutil" + "log" + "net" "os" "testing" "ariga.io/atlas/cmd/action" + mysqld "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/require" ) @@ -153,3 +157,41 @@ func Test_PostgresSchemaDSN(t *testing.T) { }) } } + +func TestMux_OpenAtlas(t *testing.T) { + t.Run("MySQL", func(t *testing.T) { + for _, u := range []string{ + "mysql://root:pass@tcp(%s)/", + "mysql://root:pass@tcp(%s)/test", + "mysql://root:pass@%s", + "mysql://root:pass@%s/", + "mysql://root:pass@%s/test", + "mysql://%s/test", + } { + calls, l := mockServer(t) + require.NoError(t, mysqld.SetLogger(log.New(ioutil.Discard, "", 1))) + _, err := action.DefaultMux.OpenAtlas(fmt.Sprintf(u, l.Addr())) + require.Error(t, err, "mock server rejects all incoming connections") + require.NotZero(t, *calls) + } + }) +} + +func mockServer(t *testing.T) (*int, net.Listener) { + var ( + calls int + l, err = net.Listen("tcp", "localhost:") + ) + require.NoError(t, err) + go func() { + for { + conn, err := l.Accept() + if err != nil { + return + } + calls++ + require.NoError(t, conn.Close()) + } + }() + return &calls, l +} diff --git a/cmd/action/provider.go b/cmd/action/provider.go index f3290bbe255..33e380e4930 100644 --- a/cmd/action/provider.go +++ b/cmd/action/provider.go @@ -13,14 +13,18 @@ import ( ) func init() { - defaultMux.RegisterProvider("mysql", mysqlProvider) - defaultMux.RegisterProvider("mariadb", mysqlProvider) - defaultMux.RegisterProvider("postgres", postgresProvider) - defaultMux.RegisterProvider("sqlite", sqliteProvider) + DefaultMux.RegisterProvider("mysql", mysqlProvider) + DefaultMux.RegisterProvider("mariadb", mysqlProvider) + DefaultMux.RegisterProvider("postgres", postgresProvider) + DefaultMux.RegisterProvider("sqlite", sqliteProvider) } func mysqlProvider(dsn string) (*Driver, error) { - db, err := sql.Open("mysql", dsn) + d, err := mysqlDSN(dsn) + if err != nil { + return nil, err + } + db, err := sql.Open("mysql", d) if err != nil { return nil, err } @@ -34,9 +38,10 @@ func mysqlProvider(dsn string) (*Driver, error) { Unmarshaler: mysql.UnmarshalHCL, }, nil } + func postgresProvider(dsn string) (*Driver, error) { - url := "postgres://" + dsn - db, err := sql.Open("postgres", url) + u := "postgres://" + dsn + db, err := sql.Open("postgres", u) if err != nil { return nil, err } diff --git a/cmd/atlas/main.go b/cmd/atlas/main.go index d856e3a001a..cdb277ce2e2 100644 --- a/cmd/atlas/main.go +++ b/cmd/atlas/main.go @@ -8,11 +8,10 @@ import ( _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - "github.com/spf13/cobra" ) func main() { - cobra.OnInitialize(initConfig) + action.RootCmd.SetOut(os.Stdout) err := action.RootCmd.Execute() // Print error from command if err != nil { @@ -26,8 +25,3 @@ func main() { } } - -// initConfig reads in config file and ENV variables if set. -func initConfig() { - action.RootCmd.SetOut(os.Stdout) -}