Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove strict mode #676

Merged
merged 2 commits into from
Oct 3, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,20 +294,6 @@ supposed to happen, setting this on some MySQL providers (such as AWS Aurora)
is safer for failovers.


##### `strict`

```
Type: bool
Valid Values: true, false
Default: false
```

`strict=true` enables a driver-side strict mode in which MySQL warnings are treated as errors. This mode should not be used in production as it may lead to data corruption in certain situations.

A server-side strict mode, which is safe for production use, can be set via the [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html) system variable.

By default MySQL also treats notes as warnings. Use [`sql_notes=false`](http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_sql_notes) to ignore notes.

##### `timeout`

```
Expand All @@ -317,6 +303,7 @@ Default: OS default

Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.


##### `tls`

```
Expand All @@ -327,6 +314,7 @@ Default: false

`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).


##### `writeTimeout`

```
Expand Down
6 changes: 1 addition & 5 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ func initDB(b *testing.B, queries ...string) *sql.DB {
db := tb.checkDB(sql.Open("mysql", dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
if w, ok := err.(MySQLWarnings); ok {
b.Logf("warning on %q: %v", query, w)
} else {
b.Fatalf("error on %q: %v", query, err)
}
b.Fatalf("error on %q: %v", query, err)
}
}
return db
Expand Down
1 change: 0 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ type mysqlConn struct {
status statusFlag
sequence uint8
parseTime bool
strict bool

// for context support (Go 1.8+)
watching bool
Expand Down
1 change: 0 additions & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
return nil, err
}
mc.parseTime = mc.cfg.ParseTime
mc.strict = mc.cfg.Strict

// Connect to Server
if dial, ok := dials[mc.cfg.Net]; ok {
Expand Down
82 changes: 3 additions & 79 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func init() {
addr = env("MYSQL_TEST_ADDR", "localhost:3306")
dbname = env("MYSQL_TEST_DBNAME", "gotest")
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname)
c, err := net.Dial(prot, addr)
if err == nil {
available = true
Expand Down Expand Up @@ -1170,82 +1170,6 @@ func TestFoundRows(t *testing.T) {
})
}

func TestStrict(t *testing.T) {
// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'"
// make sure the MySQL version is recent enough with a separate connection
// before running the test
conn, err := MySQLDriver{}.Open(relaxedDsn)
if conn != nil {
conn.Close()
}
// Error 1231: Variable 'sql_mode' can't be set to the value of
// 'ALLOW_INVALID_DATES' => skip test, MySQL server version is too old
maybeSkip(t, err, 1231)
runTests(t, relaxedDsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")

var queries = [...]struct {
in string
codes []string
}{
{"DROP TABLE IF EXISTS no_such_table", []string{"1051"}},
{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}},
}
var err error

var checkWarnings = func(err error, mode string, idx int) {
if err == nil {
dbt.Errorf("expected STRICT error on query [%s] %s", mode, queries[idx].in)
}

if warnings, ok := err.(MySQLWarnings); ok {
var codes = make([]string, len(warnings))
for i := range warnings {
codes[i] = warnings[i].Code
}
if len(codes) != len(queries[idx].codes) {
dbt.Errorf("unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
}

for i := range warnings {
if codes[i] != queries[idx].codes[i] {
dbt.Errorf("unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
return
}
}

} else {
dbt.Errorf("unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
}
}

// text protocol
for i := range queries {
_, err = dbt.db.Exec(queries[i].in)
checkWarnings(err, "text", i)
}

var stmt *sql.Stmt

// binary protocol
for i := range queries {
stmt, err = dbt.db.Prepare(queries[i].in)
if err != nil {
dbt.Errorf("error on preparing query %s: %s", queries[i].in, err.Error())
}

_, err = stmt.Exec()
checkWarnings(err, "binary", i)

err = stmt.Close()
if err != nil {
dbt.Errorf("error on closing stmt for query %s: %s", queries[i].in, err.Error())
}
}
})
}

func TestTLS(t *testing.T) {
tlsTest := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
Expand Down Expand Up @@ -1762,7 +1686,7 @@ func TestCustomDial(t *testing.T) {
return net.Dial(prot, addr)
})

db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
Expand Down Expand Up @@ -1859,7 +1783,7 @@ func TestUnixSocketAuthFail(t *testing.T) {
}
}
t.Logf("socket: %s", socket)
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname)
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname)
db, err := sql.Open("mysql", badDSN)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
Expand Down
16 changes: 1 addition & 15 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
Strict bool // Return warnings as errors
}

// FormatDSN formats the given Config into a DSN string which can be passed to
Expand Down Expand Up @@ -206,15 +205,6 @@ func (cfg *Config) FormatDSN() string {
}
}

if cfg.Strict {
if hasParam {
buf.WriteString("&strict=true")
} else {
hasParam = true
buf.WriteString("?strict=true")
}
}

if cfg.Timeout > 0 {
if hasParam {
buf.WriteString("&timeout=")
Expand Down Expand Up @@ -502,11 +492,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {

// Strict mode
case "strict":
var isBool bool
cfg.Strict, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
return errors.New("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this panic?


// Dial Timeout
case "timeout":
Expand Down
73 changes: 0 additions & 73 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
package mysql

import (
"database/sql/driver"
"errors"
"fmt"
"io"
"log"
"os"
)
Expand Down Expand Up @@ -65,74 +63,3 @@ type MySQLError struct {
func (me *MySQLError) Error() string {
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}

// MySQLWarnings is an error type which represents a group of one or more MySQL
// warnings
type MySQLWarnings []MySQLWarning

func (mws MySQLWarnings) Error() string {
var msg string
for i, warning := range mws {
if i > 0 {
msg += "\r\n"
}
msg += fmt.Sprintf(
"%s %s: %s",
warning.Level,
warning.Code,
warning.Message,
)
}
return msg
}

// MySQLWarning is an error type which represents a single MySQL warning.
// Warnings are returned in groups only. See MySQLWarnings
type MySQLWarning struct {
Level string
Code string
Message string
}

func (mc *mysqlConn) getWarnings() (err error) {
rows, err := mc.Query("SHOW WARNINGS", nil)
if err != nil {
return
}

var warnings = MySQLWarnings{}
var values = make([]driver.Value, 3)

for {
err = rows.Next(values)
switch err {
case nil:
warning := MySQLWarning{}

if raw, ok := values[0].([]byte); ok {
warning.Level = string(raw)
} else {
warning.Level = fmt.Sprintf("%s", values[0])
}
if raw, ok := values[1].([]byte); ok {
warning.Code = string(raw)
} else {
warning.Code = fmt.Sprintf("%s", values[1])
}
if raw, ok := values[2].([]byte); ok {
warning.Message = string(raw)
} else {
warning.Message = fmt.Sprintf("%s", values[0])
}

warnings = append(warnings, warning)

case io.EOF:
return warnings

default:
rows.Close()
return
}
}
}
14 changes: 0 additions & 14 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,14 +624,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
}

// warning count [2 bytes]
if !mc.strict {
return nil
}

pos := 1 + n + m + 2
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
return mc.getWarnings()
}
return nil
}

Expand Down Expand Up @@ -843,14 +836,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
// Reserved [8 bit]

// Warning count [16 bit uint]
if !stmt.mc.strict {
return columnCount, nil
}

// Check for warnings count > 0, only available in MySQL > 4.1
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
return columnCount, stmt.mc.getWarnings()
}
return columnCount, nil
}
return 0, err
Expand Down