diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 3d757df1dcc7..fe853d3fb4c5 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -25,6 +25,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { pathListRoles(&b), pathRoles(&b), pathRoleCreate(&b), + pathResetConnection(&b), }, Secrets: []*framework.Secret{ diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index b53bb0c75732..1e66d27f6edc 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -15,47 +15,40 @@ import ( "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" - "github.com/mitchellh/mapstructure" ) type ConnectionProducer interface { Connection() (interface{}, error) Close() - // TODO: Should we make this immutable instead? - Reset(*DatabaseConfig) error } // sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases -type sqlConnectionDetails struct { +type sqlConnectionProducer struct { ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` -} -type sqlConnectionProducer struct { config *DatabaseConfig - // TODO: Should we merge these two structures make it immutable? - connDetails *sqlConnectionDetails db *sql.DB sync.Mutex } -func (cp *sqlConnectionProducer) Connection() (interface{}, error) { +func (c *sqlConnectionProducer) Connection() (interface{}, error) { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() // If we already have a DB, we got it! - if cp.db != nil { - if err := cp.db.Ping(); err == nil { - return cp.db, nil + if c.db != nil { + if err := c.db.Ping(); err == nil { + return c.db, nil } // If the ping was unsuccessful, close it and ignore errors as we'll be // reestablishing anyways - cp.db.Close() + c.db.Close() } // Otherwise, attempt to make connection - conn := cp.connDetails.ConnectionURL + conn := c.ConnectionURL // Ensure timezone is set to UTC for all the conenctions if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { @@ -67,54 +60,33 @@ func (cp *sqlConnectionProducer) Connection() (interface{}, error) { } var err error - cp.db, err = sql.Open(cp.config.DatabaseType, conn) + c.db, err = sql.Open(c.config.DatabaseType, conn) if err != nil { return nil, err } // Set some connection pool settings. We don't need much of this, // since the request rate shouldn't be high. - cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections) - cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections) - cp.db.SetConnMaxLifetime(cp.config.MaxConnectionLifetime) - - return cp.db, nil -} + c.db.SetMaxOpenConns(c.config.MaxOpenConnections) + c.db.SetMaxIdleConns(c.config.MaxIdleConnections) + c.db.SetConnMaxLifetime(c.config.MaxConnectionLifetime) -func (cp *sqlConnectionProducer) Close() { - // Grab the write lock - cp.Lock() - defer cp.Unlock() - - if cp.db != nil { - cp.db.Close() - } - - cp.db = nil + return c.db, nil } -func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error { +func (c *sqlConnectionProducer) Close() { // Grab the write lock - cp.Lock() + c.Lock() + defer c.Unlock() - var details *sqlConnectionDetails - err := mapstructure.Decode(config.ConnectionDetails, &details) - if err != nil { - return err + if c.db != nil { + c.db.Close() } - cp.connDetails = details - cp.config = config - - cp.Unlock() - - cp.Close() - _, err = cp.Connection() - return err + c.db = nil } -// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra -type cassandraConnectionDetails struct { +type cassandraConnectionProducer struct { Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` Username string `json:"username" structs:"username" mapstructure:"username"` Password string `json:"password" structs:"password" mapstructure:"password"` @@ -127,90 +99,74 @@ type cassandraConnectionDetails struct { ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` -} -type cassandraConnectionProducer struct { config *DatabaseConfig - // TODO: Should we merge these two structures make it immutable? - connDetails *cassandraConnectionDetails session *gocql.Session sync.Mutex } -func (cp *cassandraConnectionProducer) Connection() (interface{}, error) { +func (c *cassandraConnectionProducer) Connection() (interface{}, error) { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() // If we already have a DB, we got it! - if cp.session != nil { - return cp.session, nil + if c.session != nil { + return c.session, nil } - session, err := cp.createSession(cp.connDetails) + session, err := c.createSession() if err != nil { return nil, err } // Store the session in backend for reuse - cp.session = session + c.session = session return session, nil } -func (cp *cassandraConnectionProducer) Close() { +func (c *cassandraConnectionProducer) Close() { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() - if cp.session != nil { - cp.session.Close() + if c.session != nil { + c.session.Close() } - cp.session = nil -} - -func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error { - // Grab the write lock - cp.Lock() - cp.config = config - cp.Unlock() - - cp.Close() - _, err := cp.Connection() - - return err + c.session = nil } -func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) { - clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) +func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...) clusterConfig.Authenticator = gocql.PasswordAuthenticator{ - Username: cfg.Username, - Password: cfg.Password, + Username: c.Username, + Password: c.Password, } - clusterConfig.ProtoVersion = cfg.ProtocolVersion + clusterConfig.ProtoVersion = c.ProtocolVersion if clusterConfig.ProtoVersion == 0 { clusterConfig.ProtoVersion = 2 } - clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second + clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second - if cfg.TLS { + if c.TLS { var tlsConfig *tls.Config - if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { - if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { + if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { + if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") } certBundle := &certutil.CertBundle{} - if len(cfg.Certificate) > 0 { - certBundle.Certificate = cfg.Certificate - certBundle.PrivateKey = cfg.PrivateKey + if len(c.Certificate) > 0 { + certBundle.Certificate = c.Certificate + certBundle.PrivateKey = c.PrivateKey } - if len(cfg.IssuingCA) > 0 { - certBundle.IssuingCA = cfg.IssuingCA + if len(c.IssuingCA) > 0 { + certBundle.IssuingCA = c.IssuingCA } parsedCertBundle, err := certBundle.ToParsedCertBundle() @@ -222,11 +178,11 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet if err != nil || tlsConfig == nil { return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) } - tlsConfig.InsecureSkipVerify = cfg.InsecureTLS + tlsConfig.InsecureSkipVerify = c.InsecureTLS - if cfg.TLSMinVersion != "" { + if c.TLSMinVersion != "" { var ok bool - tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] if !ok { return nil, fmt.Errorf("invalid 'tls_min_version' in config") } @@ -248,8 +204,8 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet } // Set consistency - if cfg.Consistency != "" { - consistencyValue, err := gocql.ParseConsistencyWrapper(cfg.Consistency) + if c.Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index d648b776fa3f..4c04c0fd4f9f 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -22,16 +22,12 @@ var ( func Factory(conf *DatabaseConfig) (DatabaseType, error) { switch conf.DatabaseType { case postgreSQLTypeName: - var details *sqlConnectionDetails - err := mapstructure.Decode(conf.ConnectionDetails, &details) + var connProducer *sqlConnectionProducer + err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) if err != nil { return nil, err } - - connProducer := &sqlConnectionProducer{ - config: conf, - connDetails: details, - } + connProducer.config = conf credsProducer := &sqlCredentialsProducer{ displayNameLen: 23, @@ -44,16 +40,12 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case mySQLTypeName: - var details *sqlConnectionDetails - err := mapstructure.Decode(conf.ConnectionDetails, &details) + var connProducer *sqlConnectionProducer + err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) if err != nil { return nil, err } - - connProducer := &sqlConnectionProducer{ - config: conf, - connDetails: details, - } + connProducer.config = conf credsProducer := &sqlCredentialsProducer{ displayNameLen: 4, @@ -66,16 +58,12 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case cassandraTypeName: - var details *cassandraConnectionDetails - err := mapstructure.Decode(conf.ConnectionDetails, &details) + var connProducer *cassandraConnectionProducer + err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) if err != nil { return nil, err } - - connProducer := &cassandraConnectionProducer{ - config: conf, - connDetails: details, - } + connProducer.config = conf credsProducer := &cassandraCredentialsProducer{} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 9fe9260508d0..085113fe98e7 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,6 +1,7 @@ package database import ( + "errors" "fmt" "time" @@ -10,6 +11,64 @@ import ( "github.com/hashicorp/vault/logical/framework" ) +func pathResetConnection(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: fmt.Sprintf("reset/%s", framework.GenericNameRegex("name")), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of this DB type", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.pathConnectionReset, + }, + + HelpSynopsis: pathConfigConnectionHelpSyn, + HelpDescription: pathConfigConnectionHelpDesc, + } +} + +func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return nil, errors.New("No database name set") + } + + // Grab the mutex lock + b.Lock() + defer b.Unlock() + + entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } + + var config dbs.DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + db, ok := b.connections[name] + if !ok { + return logical.ErrorResponse("Can not change type of existing connection."), nil + } + + db.Close() + db, err = dbs.Factory(&config) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + + b.connections[name] = db + + return nil, nil +} + func pathConfigConnection(b *databaseBackend) *framework.Path { return &framework.Path{ Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), @@ -129,13 +188,13 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew if b.connections[name].Type() != connType { return logical.ErrorResponse("Can not change type of existing connection."), nil } - - db = b.connections[name] } else { db, err = dbs.Factory(config) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + + b.connections[name] = db } /* @@ -166,9 +225,6 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew } // Reset the DB connection - db.Reset(config) - b.connections[name] = db - resp := &logical.Response{} resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")