Skip to content

Commit

Permalink
s/Statement/Statements/
Browse files Browse the repository at this point in the history
  • Loading branch information
Brian Kassouf committed Jan 4, 2017
1 parent 1ee5087 commit 95e5091
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 43 deletions.
13 changes: 9 additions & 4 deletions builtin/logical/database/dbs/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
"github.com/hashicorp/vault/helper/strutil"
)

const (
defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
defaultRollbackCQL = `DROP USER '{{username}}';`
)

type Cassandra struct {
// Session is goroutine safe, however, since we reinitialize
// it when connection info changes, we want to make sure we
Expand All @@ -31,7 +36,7 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) {
return session.(*gocql.Session), nil
}

func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
// Get the connection
session, err := c.getConnection()
if err != nil {
Expand All @@ -54,7 +59,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp
}*/

// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
Expand All @@ -65,7 +70,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp
"password": password,
})).Exec()
if err != nil {
for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmts, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
Expand All @@ -88,7 +93,7 @@ func (c *Cassandra) RenewUser(username, expiration string) error {
return nil
}

func (c *Cassandra) RevokeUser(username, revocationSQL string) error {
func (c *Cassandra) RevokeUser(username, revocationStmts string) error {
session, err := c.getConnection()
if err != nil {
return err
Expand Down
3 changes: 3 additions & 0 deletions builtin/logical/database/dbs/connectionproducer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"sync"
"time"

// Import sql drivers
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"

"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/tlsutil"
Expand Down
14 changes: 7 additions & 7 deletions builtin/logical/database/dbs/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/hashicorp/vault/helper/strutil"
)

const defaultRevocationSQL = `
const defaultRevocationStmts = `
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
DROP USER '{{name}}'@'%'
`
Expand All @@ -34,7 +34,7 @@ func (p *MySQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil
}

func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
func (p *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
// Get the connection
db, err := p.getConnection()
if err != nil {
Expand All @@ -54,7 +54,7 @@ func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expirat
defer tx.Rollback()

// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
Expand Down Expand Up @@ -86,7 +86,7 @@ func (p *MySQL) RenewUser(username, expiration string) error {
return nil
}

func (p *MySQL) RevokeUser(username, revocationStmt string) error {
func (p *MySQL) RevokeUser(username, revocationStmts string) error {
// Get the connection
db, err := p.getConnection()
if err != nil {
Expand All @@ -99,8 +99,8 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error {

// Use a default SQL statement for revocation if one cannot be fetched from the role

if revocationStmt == "" {
revocationStmt = defaultRevocationSQL
if revocationStmts == "" {
revocationStmts = defaultRevocationStmts
}

// Start a transaction
Expand All @@ -110,7 +110,7 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error {
}
defer tx.Rollback()

for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
Expand Down
14 changes: 7 additions & 7 deletions builtin/logical/database/dbs/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
return db.(*sql.DB), nil
}

func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error {
func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error {
// Get the connection
db, err := p.getConnection()
if err != nil {
Expand All @@ -56,7 +56,7 @@ func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, ex
// Return the secret

// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
Expand Down Expand Up @@ -115,19 +115,19 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error {
return nil
}

func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error {
func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error {
// Grab the read lock
p.RLock()
defer p.RUnlock()

if revocationStmt == "" {
if revocationStmts == "" {
return p.defaultRevokeUser(username)
}

return p.customRevokeUser(username, revocationStmt)
return p.customRevokeUser(username, revocationStmts)
}

func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error {
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
db, err := p.getConnection()
if err != nil {
return err
Expand All @@ -141,7 +141,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error {
tx.Rollback()
}()

for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") {
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
Expand Down
3 changes: 1 addition & 2 deletions builtin/logical/database/path_role_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
_ "github.com/lib/pq"
)

func pathRoleCreate(b *databaseBackend) *framework.Path {
Expand Down Expand Up @@ -68,7 +67,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo

expiration := db.GenerateExpiration(role.DefaultTTL)

err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration)
err = db.CreateUser(role.CreationStatements, role.RollbackStatements, username, password, expiration)
if err != nil {
return nil, err
}
Expand Down
46 changes: 23 additions & 23 deletions builtin/logical/database/path_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ func pathRoles(b *databaseBackend) *framework.Path {
Description: "Name of the database this role acts on.",
},

"creation_statement": {
"creation_statements": {
Type: framework.TypeString,
Description: "SQL string to create a user. See help for more info.",
},

"revocation_statement": {
"revocation_statements": {
Type: framework.TypeString,
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
string, a base64-encoded semicolon-separated string, a serialized JSON string
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
will be substituted.`,
},

"rollback_statement": {
"rollback_statements": {
Type: framework.TypeString,
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
string, a base64-encoded semicolon-separated string, a serialized JSON string
Expand Down Expand Up @@ -98,11 +98,11 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie

return &logical.Response{
Data: map[string]interface{}{
"creation_statment": role.CreationStatement,
"revocation_statement": role.RevocationStatement,
"rollback_statement": role.RollbackStatement,
"default_ttl": role.DefaultTTL.String(),
"max_ttl": role.MaxTTL.String(),
"creation_statments": role.CreationStatements,
"revocation_statements": role.RevocationStatements,
"rollback_statements": role.RollbackStatements,
"default_ttl": role.DefaultTTL.String(),
"max_ttl": role.MaxTTL.String(),
},
}, nil
}
Expand All @@ -119,9 +119,9 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
dbName := data.Get("db_name").(string)
creationStmt := data.Get("creation_statement").(string)
revocationStmt := data.Get("revocation_statement").(string)
rollbackStmt := data.Get("rollback_statement").(string)
creationStmts := data.Get("creation_statements").(string)
revocationStmts := data.Get("revocation_statements").(string)
rollbackStmts := data.Get("rollback_statements").(string)
defaultTTLRaw := data.Get("default_ttl").(string)
maxTTLRaw := data.Get("max_ttl").(string)

Expand All @@ -140,12 +140,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F

// Store it
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
DBName: dbName,
CreationStatement: creationStmt,
RevocationStatement: revocationStmt,
RollbackStatement: rollbackStmt,
DefaultTTL: defaultTTL,
MaxTTL: maxTTL,
DBName: dbName,
CreationStatements: creationStmts,
RevocationStatements: revocationStmts,
RollbackStatements: rollbackStmts,
DefaultTTL: defaultTTL,
MaxTTL: maxTTL,
})
if err != nil {
return nil, err
Expand All @@ -158,12 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
}

type roleEntry struct {
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
CreationStatements string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
RevocationStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
RollbackStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
}

const pathRoleHelpSyn = `
Expand Down

0 comments on commit 95e5091

Please sign in to comment.