Skip to content

Commit

Permalink
*: support require-secure-transport startup option (#15341)
Browse files Browse the repository at this point in the history
  • Loading branch information
lysu authored Mar 17, 2020
1 parent 164c1e4 commit aec6143
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 48 deletions.
17 changes: 9 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,15 @@ func (l *Log) getDisableErrorStack() bool {

// Security is the security section of the config.
type Security struct {
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
SkipGrantTable bool `toml:"skip-grant-table" json:"skip-grant-table"`
SSLCA string `toml:"ssl-ca" json:"ssl-ca"`
SSLCert string `toml:"ssl-cert" json:"ssl-cert"`
SSLKey string `toml:"ssl-key" json:"ssl-key"`
RequireSecureTransport bool `toml:"require-secure-transport" json:"require-secure-transport"`
ClusterSSLCA string `toml:"cluster-ssl-ca" json:"cluster-ssl-ca"`
ClusterSSLCert string `toml:"cluster-ssl-cert" json:"cluster-ssl-cert"`
ClusterSSLKey string `toml:"cluster-ssl-key" json:"cluster-ssl-key"`
ClusterVerifyCN []string `toml:"cluster-verify-cn" json:"cluster-verify-cn"`
}

// The ErrConfigValidationFailed error is used so that external callers can do a type assertion
Expand Down
1 change: 1 addition & 0 deletions errno/errcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ const (
ErrInvalidJSONContainsPathType = 3150
ErrJSONUsedAsKey = 3152
ErrJSONDocumentNULLKey = 3158
ErrSecureTransportRequired = 3159
ErrBadUser = 3162
ErrUserAlreadyExists = 3163
ErrInvalidJSONPathArrayCell = 3165
Expand Down
1 change: 1 addition & 0 deletions errno/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ var MySQLErrName = map[uint16]string{
ErrInvalidJSONContainsPathType: "The second argument can only be either 'one' or 'all'.",
ErrJSONUsedAsKey: "JSON column '%-.192s' cannot be used in key specification.",
ErrJSONDocumentNULLKey: "JSON documents may not contain NULL member names.",
ErrSecureTransportRequired: "Connections using insecure transport are prohibited while --require_secure_transport=ON.",
ErrBadUser: "User %s does not exist.",
ErrUserAlreadyExists: "User %s already exists.",
ErrInvalidJSONPathArrayCell: "A path expression is not a path to a cell in an array.",
Expand Down
2 changes: 1 addition & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error {
variable.SysVars["ssl_cert"].Value,
)
if err != nil {
if !s.NoRollbackOnError {
if !s.NoRollbackOnError || config.GetGlobalConfig().Security.RequireSecureTransport {
return err
}
logutil.BgLogger().Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'")
Expand Down
3 changes: 3 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
Expand Down Expand Up @@ -521,6 +522,8 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con
return err
}
}
} else if config.GetGlobalConfig().Security.RequireSecureTransport {
return errSecureTransportRequired.FastGenByArgs()
}

// Read the remaining part of the packet.
Expand Down
16 changes: 9 additions & 7 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ func init() {
}

var (
errUnknownFieldType = terror.ClassServer.New(errno.ErrUnknownFieldType, errno.MySQLErrName[errno.ErrUnknownFieldType])
errInvalidSequence = terror.ClassServer.New(errno.ErrInvalidSequence, errno.MySQLErrName[errno.ErrInvalidSequence])
errInvalidType = terror.ClassServer.New(errno.ErrInvalidType, errno.MySQLErrName[errno.ErrInvalidType])
errNotAllowedCommand = terror.ClassServer.New(errno.ErrNotAllowedCommand, errno.MySQLErrName[errno.ErrNotAllowedCommand])
errAccessDenied = terror.ClassServer.New(errno.ErrAccessDenied, errno.MySQLErrName[errno.ErrAccessDenied])
errConCount = terror.ClassServer.New(errno.ErrConCount, errno.MySQLErrName[errno.ErrConCount])
errUnknownFieldType = terror.ClassServer.New(errno.ErrUnknownFieldType, errno.MySQLErrName[errno.ErrUnknownFieldType])
errInvalidSequence = terror.ClassServer.New(errno.ErrInvalidSequence, errno.MySQLErrName[errno.ErrInvalidSequence])
errInvalidType = terror.ClassServer.New(errno.ErrInvalidType, errno.MySQLErrName[errno.ErrInvalidType])
errNotAllowedCommand = terror.ClassServer.New(errno.ErrNotAllowedCommand, errno.MySQLErrName[errno.ErrNotAllowedCommand])
errAccessDenied = terror.ClassServer.New(errno.ErrAccessDenied, errno.MySQLErrName[errno.ErrAccessDenied])
errConCount = terror.ClassServer.New(errno.ErrConCount, errno.MySQLErrName[errno.ErrConCount])
errSecureTransportRequired = terror.ClassServer.New(errno.ErrSecureTransportRequired, errno.MySQLErrName[errno.ErrSecureTransportRequired])
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down Expand Up @@ -209,12 +210,13 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
tlsConfig, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
if err != nil {
logutil.BgLogger().Error("secure connection cert/key/ca load fail", zap.Error(err))
return nil, err
}
if tlsConfig != nil {
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
logutil.BgLogger().Info("mysql protocol server secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0))
} else if cfg.Security.RequireSecureTransport {
return nil, errSecureTransportRequired.FastGenByArgs()
}

setSystemTimeZoneVariable()
Expand Down
8 changes: 4 additions & 4 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,11 +663,11 @@ func (ts *tidbTestSerialSuite) TestErrorNoRollback(c *C) {
cfg.Port = cli.port
cfg.Status.ReportStatus = false

// test cannot startup with wrong tls config
cfg.Security = config.Security{
SSLCA: "wrong path",
SSLCert: "wrong path",
SSLKey: "wrong path",
RequireSecureTransport: true,
SSLCA: "wrong path",
SSLCert: "wrong path",
SSLKey: "wrong path",
}
_, err = NewServer(cfg, ts.tidbdrv)
c.Assert(err, NotNil)
Expand Down
59 changes: 32 additions & 27 deletions tidb-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,33 +69,34 @@ import (

// Flag Names
const (
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmRepairMode = "repair-mode"
nmRepairList = "repair-list"
nmVersion = "V"
nmConfig = "config"
nmConfigCheck = "config-check"
nmConfigStrict = "config-strict"
nmStore = "store"
nmStorePath = "path"
nmHost = "host"
nmAdvertiseAddress = "advertise-address"
nmPort = "P"
nmCors = "cors"
nmSocket = "socket"
nmEnableBinlog = "enable-binlog"
nmRunDDL = "run-ddl"
nmLogLevel = "L"
nmLogFile = "log-file"
nmLogSlowQuery = "log-slow-query"
nmReportStatus = "report-status"
nmStatusHost = "status-host"
nmStatusPort = "status"
nmMetricsAddr = "metrics-addr"
nmMetricsInterval = "metrics-interval"
nmDdlLease = "lease"
nmTokenLimit = "token-limit"
nmPluginDir = "plugin-dir"
nmPluginLoad = "plugin-load"
nmRepairMode = "repair-mode"
nmRepairList = "repair-list"
nmRequireSecureTransport = "require-secure-transport"

nmProxyProtocolNetworks = "proxy-protocol-networks"
nmProxyProtocolHeaderTimeout = "proxy-protocol-header-timeout"
Expand Down Expand Up @@ -125,6 +126,7 @@ var (
affinityCPU = flag.String(nmAffinityCPU, "", "affinity cpu (cpu-no. separated by comma, e.g. 1,2,3)")
repairMode = flagBoolean(nmRepairMode, false, "enable admin repair mode")
repairList = flag.String(nmRepairList, "", "admin repair table list")
requireTLS = flag.Bool(nmRequireSecureTransport, false, "require client use secure transport")

// Log
logLevel = flag.String(nmLogLevel, "info", "log level: info, debug, warn, error, fatal")
Expand Down Expand Up @@ -421,6 +423,9 @@ func overrideConfig(cfg *config.Config) {
if actualFlags[nmPluginDir] {
cfg.Plugin.Dir = *pluginDir
}
if actualFlags[nmRequireSecureTransport] {
cfg.Security.RequireSecureTransport = *requireTLS
}
if actualFlags[nmRepairMode] {
cfg.RepairMode = *repairMode
}
Expand Down
11 changes: 10 additions & 1 deletion util/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,13 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
return
}

requireTLS := config.GetGlobalConfig().Security.RequireSecureTransport

// Try loading CA cert.
clientAuthPolicy := tls.NoClientCert
if requireTLS {
clientAuthPolicy = tls.RequestClientCert
}
var certPool *x509.CertPool
if len(ca) > 0 {
var caCert []byte
Expand All @@ -382,7 +387,11 @@ func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error
}
certPool = x509.NewCertPool()
if certPool.AppendCertsFromPEM(caCert) {
clientAuthPolicy = tls.VerifyClientCertIfGiven
if requireTLS {
clientAuthPolicy = tls.RequireAndVerifyClientCert
} else {
clientAuthPolicy = tls.VerifyClientCertIfGiven
}
}
}
tlsConfig = &tls.Config{
Expand Down

0 comments on commit aec6143

Please sign in to comment.