Skip to content

Commit

Permalink
cmd/atlas: add support for simple mysql url format
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Mar 1, 2022
1 parent 97437b9 commit 810964b
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 21 deletions.
2 changes: 1 addition & 1 deletion cmd/action/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func CmdApplyRun(cmd *cobra.Command, args []string) {
schemaCmd.PrintErrln("The Atlas UI is not available in this release.")
return
}
d, err := defaultMux.OpenAtlas(ApplyFlags.URL)
d, err := DefaultMux.OpenAtlas(ApplyFlags.URL)
cobra.CheckErr(err)
applyRun(d, ApplyFlags.URL, ApplyFlags.File, ApplyFlags.DryRun, ApplyFlags.AutoApprove)
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/action/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ func init() {
// cmdDiffRun connects to the given databases, and prints an SQL plan to get from
// the "from" schema to the "to" schema.
func cmdDiffRun(cmd *cobra.Command, flags *diffCmdOpts) {
fromDriver, err := defaultMux.OpenAtlas(flags.fromURL)
fromDriver, err := DefaultMux.OpenAtlas(flags.fromURL)
cobra.CheckErr(err)
toDriver, err := defaultMux.OpenAtlas(flags.toURL)
toDriver, err := DefaultMux.OpenAtlas(flags.toURL)
cobra.CheckErr(err)
fromName, err := SchemaNameFromURL(flags.fromURL)
cobra.CheckErr(err)
Expand Down
2 changes: 1 addition & 1 deletion cmd/action/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func CmdInspectRun(_ *cobra.Command, _ []string) {
schemaCmd.PrintErrln("The Atlas UI is not available in this release.")
return
}
d, err := defaultMux.OpenAtlas(InspectFlags.URL)
d, err := DefaultMux.OpenAtlas(InspectFlags.URL)
cobra.CheckErr(err)
inspectRun(d, InspectFlags.URL)
}
Expand Down
36 changes: 33 additions & 3 deletions cmd/action/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ func NewMux() *Mux {
}

var (
defaultMux = NewMux()
inMemory = regexp.MustCompile("^file:.*:memory:$|:memory:|^file:.*mode=memory.*")
// DefaultMux is the default Mux that is used by the different commands.
DefaultMux = NewMux()
reMemMode = regexp.MustCompile("^file:.*:memory:$|:memory:|^file:.*mode=memory.*")
)

// RegisterProvider is used to register a Driver provider by key.
Expand Down Expand Up @@ -142,7 +143,7 @@ func schemaName(dsn string) (string, error) {
}

func sqliteFileExists(dsn string) error {
if !inMemory.MatchString(dsn) {
if !reMemMode.MatchString(dsn) {
return fileExists(dsn)
}
return nil
Expand All @@ -163,3 +164,32 @@ func fileExists(dsn string) error {
}
return nil
}

func mysqlDSN(d string) (string, error) {
cfg, err := mysql.ParseDSN(d)
// A standard MySQL DSN.
if err == nil {
return d, nil
}
u, err := url.Parse("mysql://" + d)
if err != nil {
return "", nil
}
schema := strings.TrimPrefix(u.Path, "/")
// In case of a URL (non-standard DSN),
// parse the options from query string.
if u.RawQuery != "" {
cfg, err = mysql.ParseDSN(fmt.Sprintf("/%s?%s", schema, u.RawQuery))
if err != nil {
return "", err
}
} else {
cfg = mysql.NewConfig()
}
cfg.Net = "tcp"
cfg.Addr = u.Host
cfg.User = u.User.Username()
cfg.Passwd, _ = u.User.Password()
cfg.DBName = schema
return cfg.FormatDSN(), nil
}
42 changes: 42 additions & 0 deletions cmd/action/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
package action_test

import (
"fmt"
"io/ioutil"
"log"
"net"
"os"
"testing"

"ariga.io/atlas/cmd/action"

mysqld "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -153,3 +157,41 @@ func Test_PostgresSchemaDSN(t *testing.T) {
})
}
}

func TestMux_OpenAtlas(t *testing.T) {
t.Run("MySQL", func(t *testing.T) {
for _, u := range []string{
"mysql://root:pass@tcp(%s)/",
"mysql://root:pass@tcp(%s)/test",
"mysql://root:pass@%s",
"mysql://root:pass@%s/",
"mysql://root:pass@%s/test",
"mysql://%s/test",
} {
calls, l := mockServer(t)
require.NoError(t, mysqld.SetLogger(log.New(ioutil.Discard, "", 1)))
_, err := action.DefaultMux.OpenAtlas(fmt.Sprintf(u, l.Addr()))
require.Error(t, err, "mock server rejects all incoming connections")
require.NotZero(t, *calls)
}
})
}

func mockServer(t *testing.T) (*int, net.Listener) {
var (
calls int
l, err = net.Listen("tcp", "localhost:")
)
require.NoError(t, err)
go func() {
for {
conn, err := l.Accept()
if err != nil {
return
}
calls++
require.NoError(t, conn.Close())
}
}()
return &calls, l
}
19 changes: 12 additions & 7 deletions cmd/action/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ import (
)

func init() {
defaultMux.RegisterProvider("mysql", mysqlProvider)
defaultMux.RegisterProvider("mariadb", mysqlProvider)
defaultMux.RegisterProvider("postgres", postgresProvider)
defaultMux.RegisterProvider("sqlite", sqliteProvider)
DefaultMux.RegisterProvider("mysql", mysqlProvider)
DefaultMux.RegisterProvider("mariadb", mysqlProvider)
DefaultMux.RegisterProvider("postgres", postgresProvider)
DefaultMux.RegisterProvider("sqlite", sqliteProvider)
}

func mysqlProvider(dsn string) (*Driver, error) {
db, err := sql.Open("mysql", dsn)
d, err := mysqlDSN(dsn)
if err != nil {
return nil, err
}
db, err := sql.Open("mysql", d)
if err != nil {
return nil, err
}
Expand All @@ -34,9 +38,10 @@ func mysqlProvider(dsn string) (*Driver, error) {
Unmarshaler: mysql.UnmarshalHCL,
}, nil
}

func postgresProvider(dsn string) (*Driver, error) {
url := "postgres://" + dsn
db, err := sql.Open("postgres", url)
u := "postgres://" + dsn
db, err := sql.Open("postgres", u)
if err != nil {
return nil, err
}
Expand Down
8 changes: 1 addition & 7 deletions cmd/atlas/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/spf13/cobra"
)

func main() {
cobra.OnInitialize(initConfig)
action.RootCmd.SetOut(os.Stdout)
err := action.RootCmd.Execute()
// Print error from command
if err != nil {
Expand All @@ -26,8 +25,3 @@ func main() {
}

}

// initConfig reads in config file and ENV variables if set.
func initConfig() {
action.RootCmd.SetOut(os.Stdout)
}

0 comments on commit 810964b

Please sign in to comment.