diff --git a/canned/canned.go b/canned/canned.go new file mode 100644 index 0000000000..5af0a15375 --- /dev/null +++ b/canned/canned.go @@ -0,0 +1,3 @@ +// Package canned offers some common containers +// and straight forward integrations for your tests. +package canned diff --git a/canned/postgresql.go b/canned/postgresql.go new file mode 100644 index 0000000000..77c4828ce3 --- /dev/null +++ b/canned/postgresql.go @@ -0,0 +1,158 @@ +package canned + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + + "github.com/docker/go-connections/nat" + "github.com/lib/pq" + "github.com/pkg/errors" + testcontainers "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + postgresUser = "user" + postgresPassword = "password" + postgresDatabase = "database" + postgresImage = "postgres" + postgresDefaultTag = "11.5" + postgresPort = "5432/tcp" +) + +// PostgreSQLContainerRequest completes GenericContainerRequest +// with PostgreSQL specific parameters +type PostgreSQLContainerRequest struct { + testcontainers.GenericContainerRequest + User string + Password string + Database string +} + +// PostgreSQLContainer should always be created via NewPostgreSQLContainer +type PostgreSQLContainer struct { + Container testcontainers.Container + db *sql.DB + req PostgreSQLContainerRequest +} + +// GetDriver returns a sql.DB connecting to the previously started Postgres DB. +// All the parameters are taken from the previous PostgreSQLContainerRequest +func (c *PostgreSQLContainer) GetDriver(ctx context.Context) (*sql.DB, error) { + + host, err := c.Container.Host(ctx) + if err != nil { + return nil, err + } + + mappedPort, err := c.Container.MappedPort(ctx, postgresPort) + if err != nil { + return nil, err + } + + db, err := sql.Open("postgres", fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + host, + mappedPort.Int(), + c.req.User, + c.req.Password, + c.req.Database, + )) + if err != nil { + return nil, err + } + + return db, nil +} + +// NewPostgreSQLContainer creates and (optionally) starts a Postgres database. +// If autostarted, the function returns only after a successful execution of a query +// (confirming that the database is ready) +func NewPostgreSQLContainer(ctx context.Context, req PostgreSQLContainerRequest) (*PostgreSQLContainer, error) { + + provider, err := req.ProviderType.GetProvider() + if err != nil { + return nil, err + } + + // With the current logic it's not really possible to allow other ports... + req.ExposedPorts = []string{postgresPort} + + if req.Env == nil { + req.Env = map[string]string{} + } + + // Set the default values if none were provided in the request + if req.Image == "" && req.FromDockerfile.Context == "" { + req.Image = fmt.Sprintf("%s:%s", postgresImage, postgresDefaultTag) + } + + if req.User == "" { + req.User = postgresUser + } + + if req.Password == "" { + req.Password = postgresPassword + } + + if req.Database == "" { + req.Database = postgresDatabase + } + + req.Env["POSTGRES_USER"] = req.User + req.Env["POSTGRES_PASSWORD"] = req.Password + req.Env["POSTGRES_DB"] = req.Database + + connectorVars := map[string]interface{}{ + "port": postgresPort, + "user": req.User, + "password": req.Password, + "database": req.Database, + } + + req.WaitingFor = wait.ForSQL(postgresConnectorFromTarget, connectorVars) + + c, err := provider.CreateContainer(ctx, req.ContainerRequest) + if err != nil { + return nil, errors.Wrap(err, "failed to create container") + } + + postgresC := &PostgreSQLContainer{ + Container: c, + req: req, + } + + if req.Started { + if err := c.Start(ctx); err != nil { + return postgresC, errors.Wrap(err, "failed to start container") + } + } + + return postgresC, nil +} + +func postgresConnectorFromTarget(ctx context.Context, target wait.StrategyTarget, variables wait.SQLVariables) (driver.Connector, error) { + + host, err := target.Host(ctx) + if err != nil { + return nil, err + } + + mappedPort, err := target.MappedPort(ctx, nat.Port(variables["port"].(string))) + if err != nil { + return nil, err + } + + connString := fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + host, + mappedPort.Int(), + variables["user"], + variables["password"], + variables["database"], + ) + + return pq.NewConnector(connString) +} diff --git a/canned/postgresql_test.go b/canned/postgresql_test.go new file mode 100644 index 0000000000..d6423142ff --- /dev/null +++ b/canned/postgresql_test.go @@ -0,0 +1,90 @@ +package canned + +import ( + "context" + "testing" + + testcontainers "github.com/testcontainers/testcontainers-go" +) + +func TestWriteIntoAPostgreSQLContainerViaDriver(t *testing.T) { + + ctx := context.Background() + + c, err := NewPostgreSQLContainer(ctx, PostgreSQLContainerRequest{ + Database: "hello", + GenericContainerRequest: testcontainers.GenericContainerRequest{ + Started: true, + }, + }) + if err != nil { + t.Fatal(err.Error()) + } + defer c.Container.Terminate(ctx) + + sqlC, err := c.GetDriver(ctx) + if err != nil { + t.Fatal(err.Error()) + } + + _, err = sqlC.ExecContext(ctx, "CREATE TABLE example ( id integer, data varchar(32) )") + if err != nil { + t.Fatal(err.Error()) + } +} + +func ExamplePostgreSQLContainerRequest() { + + // Optional + containerRequest := testcontainers.ContainerRequest{ + Image: "docker.io/library/postgres:11.5", + } + + genericContainerRequest := testcontainers.GenericContainerRequest{ + Started: true, + ContainerRequest: containerRequest, + } + + // Database, User, and Password are optional, + // the driver will use default ones if not provided + postgreSQLContainerRequest := PostgreSQLContainerRequest{ + Database: "mycustomdatabase", + User: "anyuser", + Password: "yoursecurepassword", + GenericContainerRequest: genericContainerRequest, + } + + postgreSQLContainerRequest.Validate() +} + +func ExampleNewPostgreSQLContainer() { + ctx := context.Background() + + // Create your PostgreSQL database, + // by setting Started this function will not return + // until a test connection has been established + c, _ := NewPostgreSQLContainer(ctx, PostgreSQLContainerRequest{ + Database: "hello", + GenericContainerRequest: testcontainers.GenericContainerRequest{ + Started: true, + }, + }) + defer c.Container.Terminate(ctx) +} + +func ExamplePostgreSQLContainer_GetDriver() { + ctx := context.Background() + + c, _ := NewPostgreSQLContainer(ctx, PostgreSQLContainerRequest{ + Database: "hello", + GenericContainerRequest: testcontainers.GenericContainerRequest{ + Started: true, + }, + }) + defer c.Container.Terminate(ctx) + + // Now you can simply interact with your DB + db, _ := c.GetDriver(ctx) + + db.Ping() +} diff --git a/go.mod b/go.mod index a6474453d7..49f313fe03 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/mux v1.6.2 // indirect github.com/kr/pretty v0.1.0 // indirect + github.com/lib/pq v1.2.0 github.com/morikuni/aec v0.0.0-20170113033406-39771216ff4c // indirect github.com/onsi/ginkgo v1.8.0 // indirect github.com/onsi/gomega v1.5.0 // indirect diff --git a/go.sum b/go.sum index 39255a906d..b742aebdaf 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,7 @@ github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QH github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc h1:TP+534wVlf61smEIq1nwLLAjQVEK2EADoW3CX9AuT+8= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE= github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= @@ -38,6 +39,8 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/morikuni/aec v0.0.0-20170113033406-39771216ff4c/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0 h1:VkHVNpR4iVnU8XQR6DBm8BqYjN7CRzw+xKUbVVbbW9w= @@ -62,6 +65,7 @@ github.com/sirupsen/logrus v1.2.0 h1:juTguoYk5qI21pwyTXY3B3Y5cOTH3ZUyZCg1v/mihuo github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793 h1:u+LnwYTOOW7Ukr/fppxEb1Nwz0AtPflrblfvUudpo+I= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= diff --git a/wait/sql.go b/wait/sql.go new file mode 100644 index 0000000000..541f7cfeff --- /dev/null +++ b/wait/sql.go @@ -0,0 +1,80 @@ +package wait + +import ( + "context" + "time" + + "database/sql" + "database/sql/driver" + + "github.com/pkg/errors" +) + +type SQLVariables map[string]interface{} + +type SQLConnectorFromTarget func(ctx context.Context, target StrategyTarget, variables SQLVariables) (driver.Connector, error) + +var _ Strategy = (*SQLStrategy)(nil) + +type SQLStrategy struct { + startupTimeout time.Duration + PollInterval time.Duration + ConnectorSource SQLConnectorFromTarget + SQLVariables SQLVariables +} + +func NewSQLStrategy(ds SQLConnectorFromTarget, sv SQLVariables) *SQLStrategy { + return &SQLStrategy{ + startupTimeout: defaultStartupTimeout(), + PollInterval: 500 * time.Millisecond, + ConnectorSource: ds, + SQLVariables: sv, + } +} + +func ForSQL(ds SQLConnectorFromTarget, sv SQLVariables) *SQLStrategy { + return NewSQLStrategy(ds, sv) +} + +// WithStartupTimeout can be used to change the default startup timeout +func (ws *SQLStrategy) WithStartupTimeout(startupTimeout time.Duration) *SQLStrategy { + ws.startupTimeout = startupTimeout + return ws +} + +// WithPollInterval can be used to override the default polling interval of 100 milliseconds +func (ws *SQLStrategy) WithPollInterval(pollInterval time.Duration) *SQLStrategy { + ws.PollInterval = pollInterval + return ws +} + +func (ws *SQLStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) error { + + ctx, cancelContext := context.WithTimeout(ctx, ws.startupTimeout) + defer cancelContext() + + conn, err := ws.ConnectorSource(ctx, target, ws.SQLVariables) + if err != nil { + return errors.Wrap(err, "could not retrieve the SQL connector from the provided function") + } + + db := sql.OpenDB(conn) + +LOOP: + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + _, err := db.ExecContext(ctx, "SELECT 1") + if err != nil { + time.Sleep(ws.PollInterval) + continue + } + break LOOP + } + + } + + return nil +}