Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mongodb plugin #2698

Merged
merged 18 commits into from
May 11, 2017
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions plugins/database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ type Cassandra struct {

// New returns a new Cassandra instance
func New() (interface{}, error) {
connProducer := &connutil.CassandraConnectionProducer{}
connProducer := &cassandraConnectionProducer{}
connProducer.Type = cassandraTypeName

credsProducer := &credsutil.CassandraCredentialsProducer{}
credsProducer := &cassandraCredentialsProducer{}

dbType := &Cassandra{
ConnectionProducer: connProducer,
Expand Down
3 changes: 1 addition & 2 deletions plugins/database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/gocql/gocql"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
dockertest "gopkg.in/ory-am/dockertest.v3"
)

Expand Down Expand Up @@ -82,7 +81,7 @@ func TestCassandra_Initialize(t *testing.T) {

dbRaw, _ := New()
db := dbRaw.(*Cassandra)
connProducer := db.ConnectionProducer.(*connutil.CassandraConnectionProducer)
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)

err := db.Initialize(connectionDetails, true)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package connutil
package cassandra

import (
"crypto/tls"
Expand All @@ -13,11 +13,12 @@ import (
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/tlsutil"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
)

// CassandraConnectionProducer implements ConnectionProducer and provides an
// cassandraConnectionProducer implements ConnectionProducer and provides an
// interface for cassandra databases to make connections.
type CassandraConnectionProducer struct {
type cassandraConnectionProducer struct {
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`
Expand All @@ -41,15 +42,14 @@ type CassandraConnectionProducer struct {
sync.Mutex
}

func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()

err := mapstructure.Decode(conf, c)
if err != nil {
return err
}
c.Initialized = true

if c.ConnectTimeoutRaw == nil {
c.ConnectTimeoutRaw = "0s"
Expand Down Expand Up @@ -100,17 +100,22 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
c.TLS = true
}

// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true

if verifyConnection {
if _, err := c.Connection(); err != nil {
return fmt.Errorf("error Initalizing Connection: %s", err)
return fmt.Errorf("error verifying connection: %s", err)
}
}

return nil
}

func (c *CassandraConnectionProducer) Connection() (interface{}, error) {
func (c *cassandraConnectionProducer) Connection() (interface{}, error) {
if !c.Initialized {
return nil, errNotInitialized
return nil, connutil.ErrNotInitialized
}

// If we already have a DB, return it
Expand All @@ -129,7 +134,7 @@ func (c *CassandraConnectionProducer) Connection() (interface{}, error) {
return session, nil
}

func (c *CassandraConnectionProducer) Close() error {
func (c *cassandraConnectionProducer) Close() error {
// Grab the write lock
c.Lock()
defer c.Unlock()
Expand All @@ -143,7 +148,7 @@ func (c *CassandraConnectionProducer) Close() error {
return nil
}

func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) {
func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Username: c.Username,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package credsutil
package cassandra

import (
"fmt"
Expand All @@ -8,11 +8,11 @@ import (
uuid "github.com/hashicorp/go-uuid"
)

// CassandraCredentialsProducer implements CredentialsProducer and provides an
// cassandraCredentialsProducer implements CredentialsProducer and provides an
// interface for cassandra databases to generate user information.
type CassandraCredentialsProducer struct{}
type cassandraCredentialsProducer struct{}

func (ccp *CassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
userUUID, err := uuid.GenerateUUID()
if err != nil {
return "", err
Expand All @@ -23,7 +23,7 @@ func (ccp *CassandraCredentialsProducer) GenerateUsername(displayName string) (s
return username, nil
}

func (ccp *CassandraCredentialsProducer) GeneratePassword() (string, error) {
func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) {
password, err := uuid.GenerateUUID()
if err != nil {
return "", err
Expand All @@ -32,6 +32,6 @@ func (ccp *CassandraCredentialsProducer) GeneratePassword() (string, error) {
return password, nil
}

func (ccp *CassandraCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
return "", nil
}
8 changes: 4 additions & 4 deletions plugins/database/cassandra/test-fixtures/cassandra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ seed_provider:
parameters:
# seeds is actually a comma-delimited list of addresses.
# Ex: "<ip1>,<ip2>,<ip3>"
- seeds: "172.17.0.2"
- seeds: "172.17.0.4"

# For workloads with more data than can fit in memory, Cassandra's
# bottleneck will be reads that need to fetch data from
Expand Down Expand Up @@ -572,7 +572,7 @@ ssl_storage_port: 7001
#
# Setting listen_address to 0.0.0.0 is always wrong.
#
listen_address: 172.17.0.2
listen_address: 172.17.0.4

# Set listen_address OR listen_interface, not both. Interfaces must correspond
# to a single address, IP aliasing is not supported.
Expand All @@ -586,7 +586,7 @@ listen_address: 172.17.0.2

# Address to broadcast to other Cassandra nodes
# Leaving this blank will set it to the same value as listen_address
broadcast_address: 172.17.0.2
broadcast_address: 172.17.0.4

# When using multiple physical network interfaces, set this
# to true to listen on broadcast_address in addition to
Expand Down Expand Up @@ -668,7 +668,7 @@ rpc_port: 9160
# be set to 0.0.0.0. If left blank, this will be set to the value of
# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must
# be set.
broadcast_rpc_address: 172.17.0.2
broadcast_rpc_address: 172.17.0.4

# enable or disable keepalive on rpc/native connections
rpc_keepalive: true
Expand Down
167 changes: 167 additions & 0 deletions plugins/database/mongodb/connection_producer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package mongodb

import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"sync"
"time"

"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/mitchellh/mapstructure"

"gopkg.in/mgo.v2"
)

// mongoDBConnectionProducer implements ConnectionProducer and provides an
// interface for databases to make connections.
type mongoDBConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`

Initialized bool
Type string
session *mgo.Session
sync.Mutex
}

// Initialize parses connection configuration.
func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()

err := mapstructure.Decode(conf, c)
if err != nil {
return err
}

if len(c.ConnectionURL) == 0 {
return fmt.Errorf("connection_url cannot be empty")
}

// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true

if verifyConnection {
if _, err := c.Connection(); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}

if err := c.session.Ping(); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
}

return nil
}

// Connection creates a database connection.
func (c *mongoDBConnectionProducer) Connection() (interface{}, error) {
if !c.Initialized {
return nil, connutil.ErrNotInitialized
}

if c.session != nil {
return c.session, nil
}

dialInfo, err := parseMongoURL(c.ConnectionURL)
if err != nil {
return nil, err
}

c.session, err = mgo.DialWithInfo(dialInfo)
if err != nil {
return nil, err
}
c.session.SetSyncTimeout(1 * time.Minute)
c.session.SetSocketTimeout(1 * time.Minute)

return nil, nil
}

// Close terminates the database connection.
func (c *mongoDBConnectionProducer) Close() error {
c.Lock()
defer c.Unlock()

if c.session != nil {
c.session.Close()
}

c.session = nil

return nil
}

func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {
url, err := url.Parse(rawURL)
if err != nil {
return nil, err
}

info := mgo.DialInfo{
Addrs: strings.Split(url.Host, ","),
Database: strings.TrimPrefix(url.Path, "/"),
Timeout: 10 * time.Second,
}

if url.User != nil {
info.Username = url.User.Username()
info.Password, _ = url.User.Password()
}

query := url.Query()
for key, values := range query {
var value string
if len(values) > 0 {
value = values[0]
}

switch key {
case "authSource":
info.Source = value
case "authMechanism":
info.Mechanism = value
case "gssapiServiceName":
info.Service = value
case "replicaSet":
info.ReplicaSetName = value
case "maxPoolSize":
poolLimit, err := strconv.Atoi(value)
if err != nil {
return nil, errors.New("bad value for maxPoolSize: " + value)
}
info.PoolLimit = poolLimit
case "ssl":
// Unfortunately, mgo doesn't support the ssl parameter in its MongoDB URI parsing logic, so we have to handle that
// ourselves. See https://github.com/go-mgo/mgo/issues/84
ssl, err := strconv.ParseBool(value)
if err != nil {
return nil, errors.New("bad value for ssl: " + value)
}
if ssl {
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
return tls.Dial("tcp", addr.String(), &tls.Config{})
}
}
case "connect":
if value == "direct" {
info.Direct = true
break
}
if value == "replicaSet" {
break
}
fallthrough
default:
return nil, errors.New("unsupported connection URL option: " + key + "=" + value)
}
}

return &info, nil
}
36 changes: 36 additions & 0 deletions plugins/database/mongodb/credentials_producer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package mongodb

import (
"fmt"
"time"

uuid "github.com/hashicorp/go-uuid"
)

// mongoDBCredentialsProducer implements CredentialsProducer and provides an
// interface for databases to generate user information.
type mongoDBCredentialsProducer struct{}

func (cp *mongoDBCredentialsProducer) GenerateUsername(displayName string) (string, error) {
userUUID, err := uuid.GenerateUUID()
if err != nil {
return "", err
}

username := fmt.Sprintf("vault-%s-%s", displayName, userUUID)

return username, nil
}

func (cp *mongoDBCredentialsProducer) GeneratePassword() (string, error) {
password, err := uuid.GenerateUUID()
if err != nil {
return "", err
}

return password, nil
}

func (cp *mongoDBCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
return "", nil
}
Loading