Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cmd/atlas: add support for simple mysql url format #608

Merged
merged 2 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/action/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func CmdApplyRun(cmd *cobra.Command, args []string) {
schemaCmd.PrintErrln("The Atlas UI is not available in this release.")
return
}
d, err := defaultMux.OpenAtlas(ApplyFlags.URL)
d, err := DefaultMux.OpenAtlas(ApplyFlags.URL)
cobra.CheckErr(err)
applyRun(d, ApplyFlags.URL, ApplyFlags.File, ApplyFlags.DryRun, ApplyFlags.AutoApprove)
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/action/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ func init() {
// cmdDiffRun connects to the given databases, and prints an SQL plan to get from
// the "from" schema to the "to" schema.
func cmdDiffRun(cmd *cobra.Command, flags *diffCmdOpts) {
fromDriver, err := defaultMux.OpenAtlas(flags.fromURL)
fromDriver, err := DefaultMux.OpenAtlas(flags.fromURL)
cobra.CheckErr(err)
toDriver, err := defaultMux.OpenAtlas(flags.toURL)
toDriver, err := DefaultMux.OpenAtlas(flags.toURL)
cobra.CheckErr(err)
fromName, err := SchemaNameFromURL(flags.fromURL)
cobra.CheckErr(err)
Expand Down
2 changes: 1 addition & 1 deletion cmd/action/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func CmdInspectRun(_ *cobra.Command, _ []string) {
schemaCmd.PrintErrln("The Atlas UI is not available in this release.")
return
}
d, err := defaultMux.OpenAtlas(InspectFlags.URL)
d, err := DefaultMux.OpenAtlas(InspectFlags.URL)
cobra.CheckErr(err)
inspectRun(d, InspectFlags.URL)
}
Expand Down
36 changes: 33 additions & 3 deletions cmd/action/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ func NewMux() *Mux {
}

var (
defaultMux = NewMux()
inMemory = regexp.MustCompile("^file:.*:memory:$|:memory:|^file:.*mode=memory.*")
// DefaultMux is the default Mux that is used by the different commands.
DefaultMux = NewMux()
a8m marked this conversation as resolved.
Show resolved Hide resolved
reMemMode = regexp.MustCompile(":memory:|^file:.*mode=memory.*")
)

// RegisterProvider is used to register a Driver provider by key.
Expand Down Expand Up @@ -142,7 +143,7 @@ func schemaName(dsn string) (string, error) {
}

func sqliteFileExists(dsn string) error {
if !inMemory.MatchString(dsn) {
if !reMemMode.MatchString(dsn) {
return fileExists(dsn)
}
return nil
Expand All @@ -163,3 +164,32 @@ func fileExists(dsn string) error {
}
return nil
}

func mysqlDSN(d string) (string, error) {
cfg, err := mysql.ParseDSN(d)
// A standard MySQL DSN.
if err == nil {
return d, nil
}
u, err := url.Parse("mysql://" + d)
if err != nil {
return "", nil
}
schema := strings.TrimPrefix(u.Path, "/")
// In case of a URL (non-standard DSN),
// parse the options from query string.
if u.RawQuery != "" {
cfg, err = mysql.ParseDSN(fmt.Sprintf("/%s?%s", schema, u.RawQuery))
if err != nil {
return "", err
}
} else {
cfg = mysql.NewConfig()
}
cfg.Net = "tcp"
cfg.Addr = u.Host
cfg.User = u.User.Username()
cfg.Passwd, _ = u.User.Password()
cfg.DBName = schema
return cfg.FormatDSN(), nil
}
42 changes: 42 additions & 0 deletions cmd/action/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
package action_test

import (
"fmt"
"io/ioutil"
"log"
"net"
"os"
"testing"

"ariga.io/atlas/cmd/action"

mysqld "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -153,3 +157,41 @@ func Test_PostgresSchemaDSN(t *testing.T) {
})
}
}

func TestMux_OpenAtlas(t *testing.T) {
t.Run("MySQL", func(t *testing.T) {
for _, u := range []string{
"mysql://root:pass@tcp(%s)/",
"mysql://root:pass@tcp(%s)/test",
"mysql://root:pass@%s",
"mysql://root:pass@%s/",
"mysql://root:pass@%s/test",
"mysql://%s/test",
} {
calls, l := mockServer(t)
require.NoError(t, mysqld.SetLogger(log.New(ioutil.Discard, "", 1)))
_, err := action.DefaultMux.OpenAtlas(fmt.Sprintf(u, l.Addr()))
require.Error(t, err, "mock server rejects all incoming connections")
require.NotZero(t, *calls)
}
})
}

func mockServer(t *testing.T) (*int, net.Listener) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't really like it, but wanted to add tests to the OpenAtlas part.

There's no way to register a "mock driver" with the same name, because the mysql package already registers it, and database/sql.Register will panic in case of multiple calls with the same name.

Another alternative was to call sql.unregisterAllDrivers, but this requires using the unsafe package which I prefer to avoid.

var (
calls int
l, err = net.Listen("tcp", "localhost:")
)
require.NoError(t, err)
go func() {
for {
conn, err := l.Accept()
if err != nil {
return
}
calls++
require.NoError(t, conn.Close())
}
}()
return &calls, l
}
19 changes: 12 additions & 7 deletions cmd/action/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ import (
)

func init() {
defaultMux.RegisterProvider("mysql", mysqlProvider)
defaultMux.RegisterProvider("mariadb", mysqlProvider)
defaultMux.RegisterProvider("postgres", postgresProvider)
defaultMux.RegisterProvider("sqlite", sqliteProvider)
DefaultMux.RegisterProvider("mysql", mysqlProvider)
DefaultMux.RegisterProvider("mariadb", mysqlProvider)
DefaultMux.RegisterProvider("postgres", postgresProvider)
DefaultMux.RegisterProvider("sqlite", sqliteProvider)
}

func mysqlProvider(dsn string) (*Driver, error) {
db, err := sql.Open("mysql", dsn)
d, err := mysqlDSN(dsn)
if err != nil {
return nil, err
}
db, err := sql.Open("mysql", d)
if err != nil {
return nil, err
}
Expand All @@ -34,9 +38,10 @@ func mysqlProvider(dsn string) (*Driver, error) {
Unmarshaler: mysql.UnmarshalHCL,
}, nil
}

func postgresProvider(dsn string) (*Driver, error) {
url := "postgres://" + dsn
db, err := sql.Open("postgres", url)
u := "postgres://" + dsn
db, err := sql.Open("postgres", u)
if err != nil {
return nil, err
}
Expand Down
8 changes: 1 addition & 7 deletions cmd/atlas/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/spf13/cobra"
)

func main() {
cobra.OnInitialize(initConfig)
action.RootCmd.SetOut(os.Stdout)
err := action.RootCmd.Execute()
// Print error from command
if err != nil {
Expand All @@ -26,8 +25,3 @@ func main() {
}

}

// initConfig reads in config file and ENV variables if set.
func initConfig() {
action.RootCmd.SetOut(os.Stdout)
}