diff --git a/pkg/datasource/sql/datasource/mysql/trigger.go b/pkg/datasource/sql/datasource/mysql/trigger.go index 60157a474..34622ecbf 100644 --- a/pkg/datasource/sql/datasource/mysql/trigger.go +++ b/pkg/datasource/sql/datasource/mysql/trigger.go @@ -27,11 +27,6 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/undo/executor" ) -const ( - columnMetaSql = "SELECT `TABLE_NAME`, `TABLE_SCHEMA`, `COLUMN_NAME`, `DATA_TYPE`, `COLUMN_TYPE`, `COLUMN_KEY`, `IS_NULLABLE`, `EXTRA` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - indexMetaSql = "SELECT `INDEX_NAME`, `COLUMN_NAME`, `NON_UNIQUE` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" -) - type mysqlTrigger struct { } @@ -91,6 +86,7 @@ func (m *mysqlTrigger) getColumnMetas(ctx context.Context, dbName string, table table = executor.DelEscape(table, types.DBTypeMySQL) var columnMetas []types.ColumnMeta + columnMetaSql := "SELECT `TABLE_NAME`, `TABLE_SCHEMA`, `COLUMN_NAME`, `DATA_TYPE`, `COLUMN_TYPE`, `COLUMN_KEY`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `EXTRA` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" stmt, err := conn.PrepareContext(ctx, columnMetaSql) if err != nil { return nil, err @@ -104,14 +100,15 @@ func (m *mysqlTrigger) getColumnMetas(ctx context.Context, dbName string, table for rows.Next() { var ( - tableName string - tableSchema string - columnName string - dataType string - columnType string - columnKey string - isNullable string - extra string + tableName string + tableSchema string + columnName string + dataType string + columnType string + columnKey string + isNullable string + columnDefault []byte + extra string ) columnMeta := types.ColumnMeta{} @@ -123,6 +120,7 @@ func (m *mysqlTrigger) getColumnMetas(ctx context.Context, dbName string, table &columnType, &columnKey, &isNullable, + &columnDefault, &extra); err != nil { return nil, err } @@ -139,9 +137,9 @@ func (m *mysqlTrigger) getColumnMetas(ctx context.Context, dbName string, table } else { columnMeta.IsNullable = 0 } + columnMeta.ColumnDef = columnDefault columnMeta.Extra = extra columnMeta.Autoincrement = strings.Contains(strings.ToLower(extra), "auto_increment") - columnMetas = append(columnMetas, columnMeta) } @@ -157,6 +155,7 @@ func (m *mysqlTrigger) getIndexes(ctx context.Context, dbName string, tableName tableName = executor.DelEscape(tableName, types.DBTypeMySQL) result := make([]types.IndexMeta, 0) + indexMetaSql := "SELECT `INDEX_NAME`, `COLUMN_NAME`, `NON_UNIQUE` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" stmt, err := conn.PrepareContext(ctx, indexMetaSql) if err != nil { return nil, err diff --git a/pkg/datasource/sql/types/meta.go b/pkg/datasource/sql/types/meta.go index fae03d209..90722cae2 100644 --- a/pkg/datasource/sql/types/meta.go +++ b/pkg/datasource/sql/types/meta.go @@ -26,8 +26,11 @@ import ( // ColumnMeta type ColumnMeta struct { // Schema - Schema string - Table string + Schema string + Table string + // ColumnDef the column default + ColumnDef []byte + // Autoincrement Autoincrement bool // todo get columnType //ColumnTypeInfo *sql.ColumnType diff --git a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go new file mode 100644 index 000000000..c42d16c63 --- /dev/null +++ b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package builder + +import ( + "context" + "database/sql/driver" + "fmt" + "strings" +) + +import ( + "github.com/arana-db/parser/ast" +) + +import ( + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/undo" + "github.com/seata/seata-go/pkg/datasource/sql/undo/executor" + "github.com/seata/seata-go/pkg/util/log" +) + +func init() { + undo.RegisterUndoLogBuilder(types.InsertOnDuplicateExecutor, GetMySQLInsertOnDuplicateUndoLogBuilder) +} + +type MySQLInsertOnDuplicateUndoLogBuilder struct { + MySQLInsertUndoLogBuilder + BeforeSelectSql string + Args []driver.Value + BeforeImageSqlPrimaryKeys map[string]bool +} + +func GetMySQLInsertOnDuplicateUndoLogBuilder() undo.UndoLogBuilder { + return &MySQLInsertOnDuplicateUndoLogBuilder{ + MySQLInsertUndoLogBuilder: MySQLInsertUndoLogBuilder{}, + Args: make([]driver.Value, 0), + BeforeImageSqlPrimaryKeys: make(map[string]bool), + } +} + +func (u *MySQLInsertOnDuplicateUndoLogBuilder) GetExecutorType() types.ExecutorType { + return types.InsertOnDuplicateExecutor +} + +func (u *MySQLInsertOnDuplicateUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { + if execCtx.ParseContext.InsertStmt == nil { + log.Errorf("invalid insert stmt") + return nil, fmt.Errorf("invalid insert stmt") + } + vals := execCtx.Values + if vals == nil { + vals = make([]driver.Value, len(execCtx.NamedValues)) + for n, param := range execCtx.NamedValues { + vals[n] = param.Value + } + } + tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O + metaData := execCtx.MetaDataMap[tableName] + selectSQL, selectArgs, err := u.buildBeforeImageSQL(execCtx.ParseContext.InsertStmt, metaData, vals) + if err != nil { + return nil, err + } + if len(selectArgs) == 0 { + log.Errorf("the SQL statement has no primary key or unique index value, it will not hit any row data.recommend to convert to a normal insert statement") + return nil, fmt.Errorf("the SQL statement has no primary key or unique index value, it will not hit any row data.recommend to convert to a normal insert statement") + } + u.BeforeSelectSql = selectSQL + u.Args = selectArgs + stmt, err := execCtx.Conn.Prepare(selectSQL) + if err != nil { + log.Errorf("build prepare stmt: %+v", err) + return nil, err + } + + rows, err := stmt.Query(selectArgs) + if err != nil { + log.Errorf("stmt query: %+v", err) + return nil, err + } + image, err := u.buildRecordImages(rows, &metaData) + if err != nil { + return nil, err + } + return []*types.RecordImage{image}, nil +} + +// buildBeforeImageSQL build select sql from insert on duplicate update sql +func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *ast.InsertStmt, metaData types.TableMeta, args []driver.Value) (string, []driver.Value, error) { + if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil { + return "", nil, err + } + var selectArgs []driver.Value + pkIndexMap := u.getPkIndex(insertStmt, metaData) + var pkIndexArray []int + for _, val := range pkIndexMap { + tmpVal := val + pkIndexArray = append(pkIndexArray, tmpVal) + } + insertRows, err := getInsertRows(insertStmt, pkIndexArray) + if err != nil { + return "", nil, err + } + insertNum := len(insertRows) + paramMap, err := u.buildImageParameters(insertStmt, args, insertRows) + if err != nil { + return "", nil, err + } + + sql := strings.Builder{} + sql.WriteString("SELECT * FROM " + metaData.TableName + " ") + isContainWhere := false + for i := 0; i < insertNum; i++ { + finalI := i + paramAppenderTempList := make([]driver.Value, 0) + for _, index := range metaData.Indexs { + //unique index + if index.NonUnique || isIndexValueNotNull(index, paramMap, finalI) == false { + continue + } + columnIsNull := true + uniqueList := make([]string, 0) + for _, columnMeta := range index.Columns { + columnName := columnMeta.ColumnName + imageParameters, ok := paramMap[columnName] + if !ok && columnMeta.ColumnDef != nil { + if strings.EqualFold("PRIMARY", index.Name) { + u.BeforeImageSqlPrimaryKeys[columnName] = true + } + uniqueList = append(uniqueList, columnName+" = DEFAULT("+columnName+") ") + columnIsNull = false + continue + } + if strings.EqualFold("PRIMARY", index.Name) { + u.BeforeImageSqlPrimaryKeys[columnName] = true + } + columnIsNull = false + uniqueList = append(uniqueList, columnName+" = ? ") + paramAppenderTempList = append(paramAppenderTempList, imageParameters[finalI]) + } + + if !columnIsNull { + if isContainWhere { + sql.WriteString(" OR (" + strings.Join(uniqueList, " and ") + ") ") + } else { + sql.WriteString(" WHERE (" + strings.Join(uniqueList, " and ") + ") ") + isContainWhere = true + } + } + } + selectArgs = append(selectArgs, paramAppenderTempList...) + } + log.Infof("build select sql by insert on duplicate sourceQuery, sql {}", sql.String()) + return sql.String(), selectArgs, nil +} + +func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { + afterSelectSql, selectArgs := u.buildAfterImageSQL(ctx, beforeImages) + stmt, err := execCtx.Conn.Prepare(afterSelectSql) + if err != nil { + log.Errorf("build prepare stmt: %+v", err) + return nil, err + } + + rows, err := stmt.Query(selectArgs) + if err != nil { + log.Errorf("stmt query: %+v", err) + return nil, err + } + tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O + metaData := execCtx.MetaDataMap[tableName] + image, err := u.buildRecordImages(rows, &metaData) + if err != nil { + return nil, err + } + return []*types.RecordImage{image}, nil +} + +func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Context, beforeImages []*types.RecordImage) (string, []driver.Value) { + selectSQL, selectArgs := u.BeforeSelectSql, u.Args + + var beforeImage *types.RecordImage + if len(beforeImages) > 0 { + beforeImage = beforeImages[0] + } + primaryValueMap := make(map[string][]interface{}) + for _, row := range beforeImage.Rows { + for _, col := range row.Columns { + if col.KeyType == types.IndexTypePrimaryKey { + primaryValueMap[col.ColumnName] = append(primaryValueMap[col.ColumnName], col.Value) + } + } + } + + var afterImageSql strings.Builder + var primaryValues []driver.Value + afterImageSql.WriteString(selectSQL) + for i := 0; i < len(beforeImage.Rows); i++ { + wherePrimaryList := make([]string, 0) + for name, value := range primaryValueMap { + if !u.BeforeImageSqlPrimaryKeys[name] { + wherePrimaryList = append(wherePrimaryList, name+" = ? ") + primaryValues = append(primaryValues, value[i]) + } + } + if len(wherePrimaryList) != 0 { + afterImageSql.WriteString(" OR (" + strings.Join(wherePrimaryList, " and ") + ") ") + } + } + selectArgs = append(selectArgs, primaryValues...) + log.Infof("build after select sql by insert on duplicate sourceQuery, sql {}", afterImageSql.String()) + return afterImageSql.String(), selectArgs +} + +func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error { + duplicateColsMap := make(map[string]bool) + for _, v := range insert.OnDuplicate { + duplicateColsMap[v.Column.Name.L] = true + } + if len(duplicateColsMap) == 0 { + return nil + } + for _, index := range metaData.Indexs { + if types.IndexTypePrimaryKey != index.IType { + continue + } + for name, col := range index.Columns { + if duplicateColsMap[strings.ToLower(col.ColumnName)] { + log.Errorf("update pk value is not supported! index name:%s update column name: %s", name, col.ColumnName) + return fmt.Errorf("update pk value is not supported! ") + } + } + } + return nil +} + +// build sql params +func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) (map[string][]driver.Value, error) { + var ( + parameterMap = make(map[string][]driver.Value) + ) + insertColumns := getInsertColumns(insert) + var placeHolderIndex = 0 + for _, row := range insertRows { + if len(row) != len(insertColumns) { + log.Errorf("insert row's column size not equal to insert column size") + return nil, fmt.Errorf("insert row's column size not equal to insert column size") + } + for i, col := range insertColumns { + columnName := executor.DelEscape(col, types.DBTypeMySQL) + val := row[i] + rStr, ok := val.(string) + if ok && strings.EqualFold(rStr, SqlPlaceholder) { + objects := args[placeHolderIndex] + parameterMap[columnName] = append(parameterMap[col], objects) + placeHolderIndex++ + } else { + parameterMap[columnName] = append(parameterMap[col], val) + } + } + } + return parameterMap, nil +} + +func getInsertColumns(insertStmt *ast.InsertStmt) []string { + if insertStmt == nil { + return nil + } + colList := insertStmt.Columns + if len(colList) == 0 { + return nil + } + var list []string + for _, col := range colList { + list = append(list, col.Name.L) + } + return list +} + +func isIndexValueNotNull(indexMeta types.IndexMeta, imageParameterMap map[string][]driver.Value, rowIndex int) bool { + for _, colMeta := range indexMeta.Columns { + columnName := colMeta.ColumnName + imageParameters := imageParameterMap[columnName] + if imageParameters == nil && colMeta.ColumnDef == nil { + return false + } else if imageParameters != nil && (rowIndex >= len(imageParameters) || imageParameters[rowIndex] == nil) { + return false + } + } + return true +} diff --git a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go new file mode 100644 index 000000000..b3f41538e --- /dev/null +++ b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package builder + +import ( + "context" + "testing" + + "database/sql/driver" + + "github.com/seata/seata-go/pkg/datasource/sql/parser" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/stretchr/testify/assert" +) + +func TestInsertOnDuplicateBuildBeforeImageSQL(t *testing.T) { + var ( + builder = MySQLInsertOnDuplicateUndoLogBuilder{ + BeforeImageSqlPrimaryKeys: make(map[string]bool), + } + tableMeta1 types.TableMeta + //one index table + tableMeta2 types.TableMeta + columns = make(map[string]types.ColumnMeta) + index = make(map[string]types.IndexMeta) + index2 = make(map[string]types.IndexMeta) + columnMeta1 []types.ColumnMeta + columnMeta2 []types.ColumnMeta + ColumnNames []string + ) + columnId := types.ColumnMeta{ + ColumnDef: nil, + ColumnName: "id", + } + columnName := types.ColumnMeta{ + ColumnDef: nil, + ColumnName: "name", + } + columnAge := types.ColumnMeta{ + ColumnDef: nil, + ColumnName: "age", + } + columns["id"] = columnId + columns["name"] = columnName + columns["age"] = columnAge + columnMeta1 = append(columnMeta1, columnId) + columnMeta2 = append(columnMeta2, columnName, columnAge) + index["id"] = types.IndexMeta{ + Name: "PRIMARY", + IType: types.IndexTypePrimaryKey, + Columns: columnMeta1, + } + index["id_name_age"] = types.IndexMeta{ + Name: "name_age_idx", + IType: types.IndexUnique, + Columns: columnMeta2, + } + + ColumnNames = []string{"id", "name", "age"} + tableMeta1 = types.TableMeta{ + TableName: "t_user", + Columns: columns, + Indexs: index, + ColumnNames: ColumnNames, + } + + index2["id_name_age"] = types.IndexMeta{ + Name: "name_age_idx", + IType: types.IndexUnique, + Columns: columnMeta2, + } + + tableMeta2 = types.TableMeta{ + TableName: "t_user", + Columns: columns, + Indexs: index2, + ColumnNames: ColumnNames, + } + + tests := []struct { + name string + execCtx *types.ExecContext + sourceQueryArgs []driver.Value + expectQuery1 string + expectQueryArgs1 []driver.Value + expectQuery2 string + expectQueryArgs2 []driver.Value + }{ + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,?,?) on duplicate key update name = ?,age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{1, "Jack1", 81, "Link", 18}, + expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs1: []driver.Value{1, "Jack1", 81}, + expectQuery2: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (id = ? ) ", + expectQueryArgs2: []driver.Value{"Jack1", 81, 1}, + }, + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(1,'Jack1',?) on duplicate key update name = 'Michael',age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{81, "Link", 18}, + expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs1: []driver.Value{int64(1), "Jack1", 81}, + expectQuery2: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (id = ? ) ", + expectQueryArgs2: []driver.Value{"Jack1", 81, int64(1)}, + }, + // multi insert one index + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,?,?),(?,?,?) on duplicate key update name = ?,age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta2}, + }, + sourceQueryArgs: []driver.Value{1, "Jack1", 81, 2, "Michal", 35, "Link", 18}, + expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs1: []driver.Value{"Jack1", 81, "Michal", 35}, + }, + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,'Jack1',?),(?,?,35) on duplicate key update name = 'Faker',age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta2}, + }, + sourceQueryArgs: []driver.Value{1, 81, 2, "Michal", 26}, + expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs1: []driver.Value{"Jack1", 81, "Michal", int64(35)}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := parser.DoParser(tt.execCtx.Query) + assert.Nil(t, err) + tt.execCtx.ParseContext = c + query, args, err := builder.buildBeforeImageSQL(tt.execCtx.ParseContext.InsertStmt, tt.execCtx.MetaDataMap["t_user"], tt.sourceQueryArgs) + assert.Nil(t, err) + if query == tt.expectQuery1 { + assert.Equal(t, tt.expectQuery1, query) + assert.Equal(t, tt.expectQueryArgs1, args) + } else { + assert.Equal(t, tt.expectQuery2, query) + assert.Equal(t, tt.expectQueryArgs2, args) + } + }) + } +} + +func TestInsertOnDuplicateBuildAfterImageSQL(t *testing.T) { + var ( + builder = MySQLInsertOnDuplicateUndoLogBuilder{} + ) + tests := []struct { + name string + beforeSelectSql string + BeforeImageSqlPrimaryKeys map[string]bool + beforeSelectArgs []driver.Value + beforeImages []*types.RecordImage + expectQuery string + expectQueryArgs []driver.Value + }{ + { + beforeSelectSql: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) ", + BeforeImageSqlPrimaryKeys: map[string]bool{"id": true}, + beforeSelectArgs: []driver.Value{1, "Jack1", 81}, + beforeImages: []*types.RecordImage{ + { + TableName: "t_user", + Rows: []types.RowImage{ + { + Columns: []types.ColumnImage{ + { + KeyType: types.IndexTypePrimaryKey, + ColumnName: "id", + Value: 2, + }, + { + KeyType: types.IndexUnique, + ColumnName: "name", + Value: "Jack", + }, + { + KeyType: types.IndexUnique, + ColumnName: "age", + Value: 18, + }, + }, + }, + }, + }, + }, + expectQuery: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs: []driver.Value{1, "Jack1", 81}, + }, + { + beforeSelectSql: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) OR (id = ? ) OR (name = ? and age = ? ) ", + BeforeImageSqlPrimaryKeys: map[string]bool{"id": true}, + beforeSelectArgs: []driver.Value{1, "Jack1", 30, 2, "Michael", 18}, + beforeImages: []*types.RecordImage{ + { + TableName: "t_user", + Rows: []types.RowImage{ + { + Columns: []types.ColumnImage{ + { + KeyType: types.IndexTypePrimaryKey, + ColumnName: "id", + Value: 1, + }, + { + KeyType: types.IndexUnique, + ColumnName: "name", + Value: "Jack", + }, + { + KeyType: types.IndexUnique, + ColumnName: "age", + Value: 18, + }, + }, + }, + { + Columns: []types.ColumnImage{ + { + KeyType: types.IndexTypePrimaryKey, + ColumnName: "id", + Value: 2, + }, + { + KeyType: types.IndexUnique, + ColumnName: "name", + Value: "Michael", + }, + { + KeyType: types.IndexUnique, + ColumnName: "age", + Value: 30, + }, + }, + }, + }, + }, + }, + expectQuery: "SELECT * FROM t_user WHERE (id = ? ) OR (name = ? and age = ? ) OR (id = ? ) OR (name = ? and age = ? ) ", + expectQueryArgs: []driver.Value{1, "Jack1", 30, 2, "Michael", 18}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + builder.BeforeSelectSql = tt.beforeSelectSql + builder.BeforeImageSqlPrimaryKeys = tt.BeforeImageSqlPrimaryKeys + builder.Args = tt.beforeSelectArgs + query, args := builder.buildAfterImageSQL(context.TODO(), tt.beforeImages) + assert.Equal(t, tt.expectQuery, query) + assert.Equal(t, tt.expectQueryArgs, args) + }) + } +}