diff --git a/physical/mysql.go b/physical/mysql.go index ce1351465ccd..d063df4223aa 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "net/url" "sort" + "strconv" "strings" "time" @@ -15,6 +16,7 @@ import ( "github.com/armon/go-metrics" mysql "github.com/go-sql-driver/mysql" + "github.com/hashicorp/errwrap" ) // Unreserved tls key @@ -28,11 +30,14 @@ type MySQLBackend struct { client *sql.DB statements map[string]*sql.Stmt logger log.Logger + permitPool *PermitPool } // newMySQLBackend constructs a MySQL backend using the given API client and // server address and credential for accessing mysql database. func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) { + var err error + // Get the MySQL credentials to perform read/write operations. username, ok := conf["username"] if !ok || username == "" { @@ -60,6 +65,18 @@ func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) } dbTable := database + "." + table + maxParStr, ok := conf["max_parallel"] + var maxParInt int + if ok { + maxParInt, err = strconv.Atoi(maxParStr) + if err != nil { + return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) + } + if logger.IsDebug() { + logger.Debug("mysql: max_parallel set", "max_parallel", maxParInt) + } + } + dsnParams := url.Values{} tlsCaFile, ok := conf["tls_ca_file"] if ok { @@ -95,6 +112,7 @@ func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) client: db, statements: make(map[string]*sql.Stmt), logger: logger, + permitPool: NewPermitPool(maxParInt), } // Prepare all the statements required @@ -110,6 +128,7 @@ func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) return nil, err } } + return m, nil } @@ -127,6 +146,9 @@ func (m *MySQLBackend) prepare(name, query string) error { func (m *MySQLBackend) Put(entry *Entry) error { defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) + m.permitPool.Acquire() + defer m.permitPool.Release() + _, err := m.statements["put"].Exec(entry.Key, entry.Value) if err != nil { return err @@ -138,6 +160,9 @@ func (m *MySQLBackend) Put(entry *Entry) error { func (m *MySQLBackend) Get(key string) (*Entry, error) { defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) + m.permitPool.Acquire() + defer m.permitPool.Release() + var result []byte err := m.statements["get"].QueryRow(key).Scan(&result) if err == sql.ErrNoRows { @@ -158,6 +183,9 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { func (m *MySQLBackend) Delete(key string) error { defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) + m.permitPool.Acquire() + defer m.permitPool.Release() + _, err := m.statements["delete"].Exec(key) if err != nil { return err @@ -170,6 +198,9 @@ func (m *MySQLBackend) Delete(key string) error { func (m *MySQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) + m.permitPool.Acquire() + defer m.permitPool.Release() + // Add the % wildcard to the prefix to do the prefix search likePrefix := prefix + "%" rows, err := m.statements["list"].Query(likePrefix) diff --git a/website/source/docs/configuration/storage/mysql.html.md b/website/source/docs/configuration/storage/mysql.html.md index a71f8273757d..9e4ee205ab40 100644 --- a/website/source/docs/configuration/storage/mysql.html.md +++ b/website/source/docs/configuration/storage/mysql.html.md @@ -42,6 +42,9 @@ storage "mysql" { - `tls_ca_file` `(string: "")` – Specifies the path to the CA certificate to connect using TLS. +- `max_parallel` `(string: "128")` – Specifies the maximum number of concurrent + requests to MySQL. + Additionally, Vault requires the following authentication information. - `username` `(string: )` – Specifies the MySQL username to connect to