diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 2874fa2217608..1757d6066ef32 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -7890,7 +7890,7 @@ func validateCommentLength(vars *variable.SessionVars, name string, comment *str } if len(*comment) > maxLen { err := errTooLongComment.GenWithStackByArgs(name, maxLen) - if vars.StrictSQLMode { + if vars.SQLMode.HasStrictMode() { // may be treated like an error. return "", err } diff --git a/pkg/ddl/index.go b/pkg/ddl/index.go index 22178087a6344..d36fbf29686ef 100644 --- a/pkg/ddl/index.go +++ b/pkg/ddl/index.go @@ -62,6 +62,7 @@ import ( "github.com/pingcap/tidb/pkg/util/dbterror" "github.com/pingcap/tidb/pkg/util/logutil" decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" + "github.com/pingcap/tidb/pkg/util/stringutil" "github.com/prometheus/client_golang/prometheus" "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" @@ -77,8 +78,17 @@ const ( var ( telemetryAddIndexIngestUsage = metrics.TelemetryAddIndexIngestCnt + // SuppressErrorTooLongKeyKey is used by SchemaTracker to suppress err too long key error + SuppressErrorTooLongKeyKey stringutil.StringerStr = "suppressErrorTooLongKeyKey" ) +func suppressErrorTooLongKeyKey(sctx sessionctx.Context) bool { + if suppress, ok := sctx.Value(SuppressErrorTooLongKeyKey).(bool); ok && suppress { + return true + } + return false +} + func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) { // Build offsets. idxParts := make([]*model.IndexColumn, 0, len(indexPartSpecifications)) @@ -113,7 +123,8 @@ func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, inde if sumLength > maxIndexLength { // The multiple column index and the unique index in which the length sum exceeds the maximum size // will return an error instead produce a warning. - if ctx == nil || ctx.GetSessionVars().StrictSQLMode || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 { + suppress := suppressErrorTooLongKeyKey(ctx) + if ctx == nil || (ctx.GetSessionVars().SQLMode.HasStrictMode() && !suppress) || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 { return nil, false, dbterror.ErrTooLongKey.GenWithStackByArgs(sumLength, maxIndexLength) } // truncate index length and produce warning message in non-restrict sql mode. @@ -222,9 +233,12 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn } // Specified length must be shorter than the max length for prefix. maxIndexLength := config.GetGlobalConfig().MaxIndexLength - if indexColumnLen > maxIndexLength && (ctx == nil || ctx.GetSessionVars().StrictSQLMode) { - // return error in strict sql mode - return dbterror.ErrTooLongKey.GenWithStackByArgs(indexColumnLen, maxIndexLength) + if indexColumnLen > maxIndexLength { + suppress := suppressErrorTooLongKeyKey(ctx) + if ctx == nil || (ctx.GetSessionVars().SQLMode.HasStrictMode() && !suppress) { + // return error in strict sql mode + return dbterror.ErrTooLongKey.GenWithStackByArgs(indexColumnLen, maxIndexLength) + } } return nil } diff --git a/pkg/ddl/schematracker/checker.go b/pkg/ddl/schematracker/checker.go index bc63e50639b93..84fbfcc59909f 100644 --- a/pkg/ddl/schematracker/checker.go +++ b/pkg/ddl/schematracker/checker.go @@ -226,7 +226,6 @@ func (d *Checker) CreateTable(ctx sessionctx.Context, stmt *ast.CreateTableStmt) // some unit test will also check warnings, we reset the warnings after SchemaTracker use session context again. count := ctx.GetSessionVars().StmtCtx.WarningCount() // backup old session variables because CreateTable will change them. - strictSQLMode := ctx.GetSessionVars().StrictSQLMode enableClusteredIndex := ctx.GetSessionVars().EnableClusteredIndex err = d.tracker.CreateTable(ctx, stmt) @@ -234,7 +233,6 @@ func (d *Checker) CreateTable(ctx sessionctx.Context, stmt *ast.CreateTableStmt) panic(err) } - ctx.GetSessionVars().StrictSQLMode = strictSQLMode ctx.GetSessionVars().EnableClusteredIndex = enableClusteredIndex ctx.GetSessionVars().StmtCtx.TruncateWarnings(int(count)) diff --git a/pkg/ddl/schematracker/dm_tracker.go b/pkg/ddl/schematracker/dm_tracker.go index eb8076c4be0dd..07bef2c2efab5 100644 --- a/pkg/ddl/schematracker/dm_tracker.go +++ b/pkg/ddl/schematracker/dm_tracker.go @@ -186,13 +186,12 @@ func (d SchemaTracker) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStm return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) } // suppress ErrTooLongKey - strictSQLModeBackup := ctx.GetSessionVars().StrictSQLMode - ctx.GetSessionVars().StrictSQLMode = false + ctx.SetValue(ddl.SuppressErrorTooLongKeyKey, true) // support drop PK enableClusteredIndexBackup := ctx.GetSessionVars().EnableClusteredIndex ctx.GetSessionVars().EnableClusteredIndex = variable.ClusteredIndexDefModeOff defer func() { - ctx.GetSessionVars().StrictSQLMode = strictSQLModeBackup + ctx.ClearValue(ddl.SuppressErrorTooLongKeyKey) ctx.GetSessionVars().EnableClusteredIndex = enableClusteredIndexBackup }() diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index 573e2cc56acf6..5941665aefe55 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -2105,6 +2105,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { // pushing them down to TiKV as flags. sc.InRestrictedSQL = vars.InRestrictedSQL + strictSQLMode := vars.SQLMode.HasStrictMode() errLevels := sc.ErrLevels() errLevels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn @@ -2126,26 +2127,26 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { errLevels[errctx.ErrGroupAutoIncReadFailed] = errctx.LevelWarn errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.LevelWarn } - errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !vars.StrictSQLMode || stmt.IgnoreErr) + errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !strictSQLMode || stmt.IgnoreErr) errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( !vars.SQLMode.HasErrorForDivisionByZeroMode(), - !vars.StrictSQLMode || stmt.IgnoreErr, + !strictSQLMode || stmt.IgnoreErr, ) sc.Priority = stmt.Priority sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr). + WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || - !vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode || stmt.IgnoreErr || + !vars.SQLMode.HasNoZeroDateMode() || !strictSQLMode || stmt.IgnoreErr || vars.SQLMode.HasAllowInvalidDatesMode())) case *ast.CreateTableStmt, *ast.AlterTableStmt: sc.InCreateOrAlterStmt = true sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(!vars.StrictSQLMode). + WithTruncateAsWarning(!strictSQLMode). WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). - WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !vars.StrictSQLMode || + WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !strictSQLMode || vars.SQLMode.HasAllowInvalidDatesMode()). - WithIgnoreZeroDateErr(!vars.SQLMode.HasNoZeroDateMode() || !vars.StrictSQLMode)) + WithIgnoreZeroDateErr(!vars.SQLMode.HasNoZeroDateMode() || !strictSQLMode)) case *ast.LoadDataStmt: sc.InLoadDataStmt = true @@ -2257,41 +2258,43 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { // ResetUpdateStmtCtx resets statement context for UpdateStmt. func ResetUpdateStmtCtx(sc *stmtctx.StatementContext, stmt *ast.UpdateStmt, vars *variable.SessionVars) { + strictSQLMode := vars.SQLMode.HasStrictMode() sc.InUpdateStmt = true errLevels := sc.ErrLevels() errLevels[errctx.ErrGroupDupKey] = errctx.ResolveErrLevel(false, stmt.IgnoreErr) - errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !vars.StrictSQLMode || stmt.IgnoreErr) + errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !strictSQLMode || stmt.IgnoreErr) errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( !vars.SQLMode.HasErrorForDivisionByZeroMode(), - !vars.StrictSQLMode || stmt.IgnoreErr, + !strictSQLMode || stmt.IgnoreErr, ) errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.ResolveErrLevel(false, stmt.IgnoreErr) sc.SetErrLevels(errLevels) sc.Priority = stmt.Priority sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr). + WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || - !vars.StrictSQLMode || stmt.IgnoreErr || vars.SQLMode.HasAllowInvalidDatesMode())) + !strictSQLMode || stmt.IgnoreErr || vars.SQLMode.HasAllowInvalidDatesMode())) } // ResetDeleteStmtCtx resets statement context for DeleteStmt. func ResetDeleteStmtCtx(sc *stmtctx.StatementContext, stmt *ast.DeleteStmt, vars *variable.SessionVars) { + strictSQLMode := vars.SQLMode.HasStrictMode() sc.InDeleteStmt = true errLevels := sc.ErrLevels() errLevels[errctx.ErrGroupDupKey] = errctx.ResolveErrLevel(false, stmt.IgnoreErr) - errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !vars.StrictSQLMode || stmt.IgnoreErr) + errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !strictSQLMode || stmt.IgnoreErr) errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( !vars.SQLMode.HasErrorForDivisionByZeroMode(), - !vars.StrictSQLMode || stmt.IgnoreErr, + !strictSQLMode || stmt.IgnoreErr, ) sc.SetErrLevels(errLevels) sc.Priority = stmt.Priority sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(!vars.StrictSQLMode || stmt.IgnoreErr). + WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || - !vars.StrictSQLMode || stmt.IgnoreErr || vars.SQLMode.HasAllowInvalidDatesMode())) + !strictSQLMode || stmt.IgnoreErr || vars.SQLMode.HasAllowInvalidDatesMode())) } func setOptionForTopSQL(sc *stmtctx.StatementContext, snapshot kv.Snapshot) { diff --git a/pkg/executor/executor_pkg_test.go b/pkg/executor/executor_pkg_test.go index 62400d70fd313..faf77a0a44cd9 100644 --- a/pkg/executor/executor_pkg_test.go +++ b/pkg/executor/executor_pkg_test.go @@ -484,7 +484,6 @@ func TestErrLevelsForResetStmtContext(t *testing.T) { for _, stmt := range c.stmt { msg := fmt.Sprintf("%d: %s, stmt: %T", i, c.name, stmt) ctx.GetSessionVars().SQLMode = c.sqlMode - ctx.GetSessionVars().StrictSQLMode = ctx.GetSessionVars().SQLMode.HasStrictMode() require.NoError(t, ResetContextOfStmt(ctx, stmt), msg) ec := ctx.GetSessionVars().StmtCtx.ErrCtx() require.Equal(t, c.levels, ec.LevelMap(), msg) diff --git a/pkg/executor/insert_common.go b/pkg/executor/insert_common.go index 034515684a516..ff6cb313afed5 100644 --- a/pkg/executor/insert_common.go +++ b/pkg/executor/insert_common.go @@ -607,7 +607,7 @@ func (e *InsertValues) fillColValue( if !hasValue && mysql.HasNoDefaultValueFlag(column.ToInfo().GetFlag()) { vars := e.Ctx().GetSessionVars() sc := vars.StmtCtx - if vars.StrictSQLMode { + if vars.SQLMode.HasStrictMode() { return datum, table.ErrNoDefaultValue.FastGenByArgs(column.ToInfo().Name) } sc.AppendWarning(table.ErrNoDefaultValue.FastGenByArgs(column.ToInfo().Name)) diff --git a/pkg/executor/mem_reader.go b/pkg/executor/mem_reader.go index 38fab12f728a2..11bb6bf079876 100644 --- a/pkg/executor/mem_reader.go +++ b/pkg/executor/mem_reader.go @@ -1158,11 +1158,7 @@ func getColIDAndPkColIDs(ctx sessionctx.Context, tbl table.Table, columns []*mod pkColIDs = []int64{-1} } defVal := func(i int) ([]byte, error) { - sessVars := ctx.GetSessionVars() - originStrict := sessVars.StrictSQLMode - sessVars.StrictSQLMode = false - d, err := table.GetColOriginDefaultValue(ctx, columns[i]) - sessVars.StrictSQLMode = originStrict + d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(ctx, columns[i]) if err != nil { return nil, err } diff --git a/pkg/executor/set_test.go b/pkg/executor/set_test.go index 8da41ee078d90..ee72d9b4070a5 100644 --- a/pkg/executor/set_test.go +++ b/pkg/executor/set_test.go @@ -101,9 +101,9 @@ func TestSetVar(t *testing.T) { require.False(t, vars.IsAutocommit()) tk.MustExec("set @@sql_mode = 'strict_trans_tables'") - require.True(t, vars.StrictSQLMode) + require.True(t, vars.SQLMode.HasStrictMode()) tk.MustExec("set @@sql_mode = ''") - require.False(t, vars.StrictSQLMode) + require.False(t, vars.SQLMode.HasStrictMode()) tk.MustExec("set names utf8") charset, collation := vars.GetCharsetInfo() diff --git a/pkg/expression/builtin_string.go b/pkg/expression/builtin_string.go index b53fb7fbd16d5..95d3fe303be24 100644 --- a/pkg/expression/builtin_string.go +++ b/pkg/expression/builtin_string.go @@ -2413,7 +2413,7 @@ func (b *builtinCharSig) evalString(ctx EvalContext, row chunk.Row) (string, boo if err != nil { tc := typeCtx(ctx) tc.AppendWarning(err) - if strictMode(ctx) { + if sqlMode(ctx).HasStrictMode() { return "", true, nil } } diff --git a/pkg/expression/builtin_string_test.go b/pkg/expression/builtin_string_test.go index 78baf4a75feb9..6d2eac0b3b663 100644 --- a/pkg/expression/builtin_string_test.go +++ b/pkg/expression/builtin_string_test.go @@ -1424,9 +1424,11 @@ func TestChar(t *testing.T) { run(i, v.result, v.warnings, v.str, v.iNum, v.fNum, v.charset) } // char() returns null only when the sql_mode is strict. - ctx.GetSessionVars().StrictSQLMode = true + require.True(t, ctx.GetSessionVars().SQLMode.HasStrictMode()) run(-1, nil, 1, 123456, "utf8") - ctx.GetSessionVars().StrictSQLMode = false + + ctx.GetSessionVars().SQLMode = ctx.GetSessionVars().SQLMode &^ (mysql.ModeStrictTransTables | mysql.ModeStrictAllTables) + require.False(t, ctx.GetSessionVars().SQLMode.HasStrictMode()) run(-2, string([]byte{1}), 1, 123456, "utf8") } diff --git a/pkg/expression/builtin_string_vec.go b/pkg/expression/builtin_string_vec.go index 4b46a78cb4506..85e4ee9a54765 100644 --- a/pkg/expression/builtin_string_vec.go +++ b/pkg/expression/builtin_string_vec.go @@ -2330,7 +2330,7 @@ func (b *builtinCharSig) vecEvalString(ctx EvalContext, input *chunk.Chunk, resu } encBuf := &bytes.Buffer{} enc := charset.FindEncoding(b.tp.GetCharset()) - hasStrictMode := strictMode(ctx) + hasStrictMode := sqlMode(ctx).HasStrictMode() for i := 0; i < n; i++ { bigints = bigints[0:0] for j := 0; j < l-1; j++ { diff --git a/pkg/expression/builtin_time_test.go b/pkg/expression/builtin_time_test.go index 69397b54774ff..d50435e6ba80c 100644 --- a/pkg/expression/builtin_time_test.go +++ b/pkg/expression/builtin_time_test.go @@ -212,6 +212,7 @@ func TestDate(t *testing.T) { } // test nil + ctx.GetSessionVars().SQLMode = mysql.DelSQLMode(ctx.GetSessionVars().SQLMode, mysql.ModeNoZeroDate) tblNil := []struct { Input interface{} Year interface{} diff --git a/pkg/expression/builtin_time_vec_test.go b/pkg/expression/builtin_time_vec_test.go index 2b435c80be452..d6d56465b8e9d 100644 --- a/pkg/expression/builtin_time_vec_test.go +++ b/pkg/expression/builtin_time_vec_test.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/mock" "github.com/stretchr/testify/require" ) @@ -579,8 +578,7 @@ func BenchmarkVectorizedBuiltinTimeFunc(b *testing.B) { } func TestVecMonth(t *testing.T) { - ctx := mock.NewContext() - ctx.GetSessionVars().SQLMode |= mysql.ModeNoZeroDate + ctx := createContext(t) typeFlags := ctx.GetSessionVars().StmtCtx.TypeFlags() ctx.GetSessionVars().StmtCtx.SetTypeFlags(typeFlags.WithTruncateAsWarning(true)) input := chunk.New([]*types.FieldType{types.NewFieldType(mysql.TypeDatetime)}, 3, 3) @@ -590,7 +588,7 @@ func TestVecMonth(t *testing.T) { input.AppendTime(0, types.ZeroDate) f, _, _, result := genVecBuiltinFuncBenchCase(ctx, ast.Month, vecExprBenchCase{retEvalType: types.ETInt, childrenTypes: []types.EvalType{types.ETDatetime}}) - require.True(t, ctx.GetSessionVars().StrictSQLMode) + require.True(t, ctx.GetSessionVars().SQLMode.HasStrictMode()) require.NoError(t, f.vecEvalInt(ctx, input, result)) require.Equal(t, 0, len(ctx.GetSessionVars().StmtCtx.GetWarnings())) diff --git a/pkg/expression/context.go b/pkg/expression/context.go index 98cac331f9ff3..300714986bb57 100644 --- a/pkg/expression/context.go +++ b/pkg/expression/context.go @@ -55,10 +55,6 @@ func sqlMode(ctx EvalContext) mysql.SQLMode { return ctx.GetSessionVars().SQLMode } -func strictMode(ctx EvalContext) bool { - return ctx.GetSessionVars().StrictSQLMode -} - func typeCtx(ctx EvalContext) types.Context { return ctx.GetSessionVars().StmtCtx.TypeCtx() } diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index 7d0b42396847c..7f1fb811025e4 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -1413,11 +1413,6 @@ func TestTimeBuiltin(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) - originSQLMode := tk.Session().GetSessionVars().StrictSQLMode - tk.Session().GetSessionVars().StrictSQLMode = true - defer func() { - tk.Session().GetSessionVars().StrictSQLMode = originSQLMode - }() tk.MustExec("use test") // for makeDate diff --git a/pkg/expression/main_test.go b/pkg/expression/main_test.go index bff6688b806d6..75240c3c11335 100644 --- a/pkg/expression/main_test.go +++ b/pkg/expression/main_test.go @@ -19,6 +19,7 @@ import ( "time" "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/testkit/testmain" "github.com/pingcap/tidb/pkg/testkit/testsetup" "github.com/pingcap/tidb/pkg/util/mock" @@ -58,6 +59,10 @@ func TestMain(m *testing.M) { func createContext(t *testing.T) *mock.Context { ctx := mock.NewContext() + sqlMode, err := mysql.GetSQLMode(mysql.DefaultSQLMode) + require.NoError(t, err) + require.True(t, sqlMode.HasStrictMode()) + ctx.GetSessionVars().SQLMode = sqlMode ctx.GetSessionVars().StmtCtx.SetTimeZone(time.Local) sc := ctx.GetSessionVars().StmtCtx sc.SetTypeFlags(sc.TypeFlags().WithTruncateAsWarning(true)) diff --git a/pkg/planner/optimize.go b/pkg/planner/optimize.go index e3db7c030c564..c840900bee7d5 100644 --- a/pkg/planner/optimize.go +++ b/pkg/planner/optimize.go @@ -144,7 +144,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in } } - if sctx.GetSessionVars().StrictSQLMode && !IsReadOnly(node, sessVars) { + if sctx.GetSessionVars().SQLMode.HasStrictMode() && !IsReadOnly(node, sessVars) { sessVars.StmtCtx.TiFlashEngineRemovedDueToStrictSQLMode = true _, hasTiFlashAccess := sessVars.IsolationReadEngines[kv.TiFlash] if hasTiFlashAccess { diff --git a/pkg/session/session.go b/pkg/session/session.go index 6da9d88672676..da0a87b8c24a3 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -1049,7 +1049,7 @@ func (s *session) String() string { "user": sessVars.User, "currDBName": sessVars.CurrentDB, "status": sessVars.Status, - "strictMode": sessVars.StrictSQLMode, + "strictMode": sessVars.SQLMode.HasStrictMode(), } if s.txn.Valid() { // if txn is committed or rolled back, txn is nil. diff --git a/pkg/sessionctx/variable/session.go b/pkg/sessionctx/variable/session.go index 0fe46fe82941b..eace57a438a80 100644 --- a/pkg/sessionctx/variable/session.go +++ b/pkg/sessionctx/variable/session.go @@ -759,9 +759,6 @@ type SessionVars struct { // the slow log to make it be compatible with MySQL, https://github.com/pingcap/tidb/issues/17846. CurrentDBChanged bool - // StrictSQLMode indicates if the session is in strict mode. - StrictSQLMode bool - // CommonGlobalLoaded indicates if common global variable has been loaded for this session. CommonGlobalLoaded bool @@ -1944,7 +1941,6 @@ func NewSessionVars(hctx HookContext) *SessionVars { TxnCtx: &TransactionContext{}, RetryInfo: &RetryInfo{}, ActiveRoles: make([]*auth.RoleIdentity, 0, 10), - StrictSQLMode: true, AutoIncrementIncrement: DefAutoIncrementIncrement, AutoIncrementOffset: DefAutoIncrementOffset, Status: mysql.ServerStatusAutocommit, diff --git a/pkg/sessionctx/variable/sysvar.go b/pkg/sessionctx/variable/sysvar.go index 6441a30ff66d2..fd9b055f66c28 100644 --- a/pkg/sessionctx/variable/sysvar.go +++ b/pkg/sessionctx/variable/sysvar.go @@ -1446,7 +1446,6 @@ var defaultSysVars = []*SysVar{ if err != nil { return errors.Trace(err) } - s.StrictSQLMode = sqlMode.HasStrictMode() s.SQLMode = sqlMode s.SetStatusFlag(mysql.ServerStatusNoBackslashEscaped, sqlMode.HasNoBackslashEscapesMode()) return nil diff --git a/pkg/sessionctx/variable/sysvar_test.go b/pkg/sessionctx/variable/sysvar_test.go index 0408cff31d293..cca8b42ff451b 100644 --- a/pkg/sessionctx/variable/sysvar_test.go +++ b/pkg/sessionctx/variable/sysvar_test.go @@ -67,7 +67,7 @@ func TestSQLModeVar(t *testing.T) { require.Equal(t, "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION", val) require.Nil(t, sv.SetSessionFromHook(vars, val)) // sets to strict from above - require.True(t, vars.StrictSQLMode) + require.True(t, vars.SQLMode.HasStrictMode()) sqlMode, err := mysql.GetSQLMode(val) require.NoError(t, err) @@ -79,7 +79,7 @@ func TestSQLModeVar(t *testing.T) { require.Equal(t, "ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION", val) require.Nil(t, sv.SetSessionFromHook(vars, val)) // sets to non-strict from above - require.False(t, vars.StrictSQLMode) + require.False(t, vars.SQLMode.HasStrictMode()) sqlMode, err = mysql.GetSQLMode(val) require.NoError(t, err) require.Equal(t, sqlMode, vars.SQLMode) diff --git a/pkg/sessionctx/variable/varsutil_test.go b/pkg/sessionctx/variable/varsutil_test.go index 975cea729ac48..13e1f15cd9642 100644 --- a/pkg/sessionctx/variable/varsutil_test.go +++ b/pkg/sessionctx/variable/varsutil_test.go @@ -133,10 +133,10 @@ func TestVarsutil(t *testing.T) { val, err = v.GetSessionOrGlobalSystemVar(context.Background(), "sql_mode") require.NoError(t, err) require.Equal(t, "STRICT_TRANS_TABLES", val) - require.True(t, v.StrictSQLMode) + require.True(t, v.SQLMode.HasStrictMode()) err = v.SetSystemVar("sql_mode", "") require.NoError(t, err) - require.False(t, v.StrictSQLMode) + require.False(t, v.SQLMode.HasStrictMode()) err = v.SetSystemVar("character_set_connection", "utf8") require.NoError(t, err) diff --git a/pkg/table/column.go b/pkg/table/column.go index d14bce08b1236..9ffd7c3476747 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -643,7 +643,7 @@ func getColDefaultValueFromNil(ctx sessionctx.Context, col *model.ColumnInfo, ar if args != nil { strictSQLMode = args.StrictSQLMode } else { - strictSQLMode = vars.StrictSQLMode + strictSQLMode = vars.SQLMode.HasStrictMode() } if !strictSQLMode { sc.AppendWarning(ErrNoDefaultValue.FastGenByArgs(col.Name)) diff --git a/pkg/table/column_test.go b/pkg/table/column_test.go index 140d04de41701..0724ced076717 100644 --- a/pkg/table/column_test.go +++ b/pkg/table/column_test.go @@ -466,8 +466,16 @@ func TestGetDefaultValue(t *testing.T) { expression.EvalAstExpr = exp }() + defaultMode, err := mysql.GetSQLMode(mysql.DefaultSQLMode) + require.NoError(t, err) + require.True(t, defaultMode.HasStrictMode()) for _, tt := range tests { sc := ctx.GetSessionVars().StmtCtx + if tt.strict { + ctx.GetSessionVars().SQLMode = defaultMode + } else { + ctx.GetSessionVars().SQLMode = mysql.DelSQLMode(defaultMode, mysql.ModeStrictAllTables|mysql.ModeStrictTransTables) + } levels := sc.ErrLevels() levels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !tt.strict) sc.SetErrLevels(levels) @@ -485,6 +493,11 @@ func TestGetDefaultValue(t *testing.T) { for _, tt := range tests { sc := ctx.GetSessionVars().StmtCtx + if tt.strict { + ctx.GetSessionVars().SQLMode = defaultMode + } else { + ctx.GetSessionVars().SQLMode = mysql.DelSQLMode(defaultMode, mysql.ModeStrictAllTables|mysql.ModeStrictTransTables) + } levels := sc.ErrLevels() levels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !tt.strict) sc.SetErrLevels(levels) diff --git a/pkg/table/tables/tables.go b/pkg/table/tables/tables.go index 6204a34fe5ec3..f25e3cfb729a5 100644 --- a/pkg/table/tables/tables.go +++ b/pkg/table/tables/tables.go @@ -2326,10 +2326,7 @@ func SetPBColumnsDefaultValue(ctx sessionctx.Context, pbColumns []*tipb.ColumnIn } sessVars := ctx.GetSessionVars() - originStrict := sessVars.StrictSQLMode - sessVars.StrictSQLMode = false - d, err := table.GetColOriginDefaultValue(ctx, c) - sessVars.StrictSQLMode = originStrict + d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(ctx, c) if err != nil { return err }