diff --git a/pkg/datasource/sql/types/image.go b/pkg/datasource/sql/types/image.go index 0290b366..af3fe5be 100644 --- a/pkg/datasource/sql/types/image.go +++ b/pkg/datasource/sql/types/image.go @@ -18,6 +18,7 @@ package types import ( + "database/sql/driver" "encoding/base64" "encoding/json" "reflect" @@ -117,14 +118,16 @@ type RecordImage struct { // Rows data row Rows []RowImage `json:"rows"` // TableMeta table information schema - TableMeta *TableMeta `json:"-"` + TableMeta *TableMeta `json:"-"` + PrimaryKeyMap map[string][]driver.Value `json:"primaryKeyMap,omitempty"` } func NewEmptyRecordImage(tableMeta *TableMeta, sqlType SQLType) *RecordImage { return &RecordImage{ - TableName: tableMeta.TableName, - TableMeta: tableMeta, - SQLType: sqlType, + TableName: tableMeta.TableName, + TableMeta: tableMeta, + SQLType: sqlType, + PrimaryKeyMap: make(map[string][]driver.Value), } } diff --git a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go index 6b82d537..c40d4d9c 100644 --- a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder.go @@ -97,68 +97,108 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil { return "", nil, err } - var selectArgs []driver.Value + u.BeforeImageSqlPrimaryKeys = make(map[string]bool, len(metaData.Indexs)) pkIndexMap := u.getPkIndex(insertStmt, metaData) var pkIndexArray []int for _, val := range pkIndexMap { - tmpVal := val - pkIndexArray = append(pkIndexArray, tmpVal) + pkIndexArray = append(pkIndexArray, val) } insertRows, err := getInsertRows(insertStmt, pkIndexArray) if err != nil { return "", nil, err } - insertNum := len(insertRows) paramMap, err := u.buildImageParameters(insertStmt, args, insertRows) if err != nil { return "", nil, err } - - sql := strings.Builder{} - sql.WriteString("SELECT * FROM " + metaData.TableName + " ") + if len(paramMap) == 0 || len(metaData.Indexs) == 0 { + return "", nil, nil + } + hasPK := false + for _, index := range metaData.Indexs { + if strings.EqualFold("PRIMARY", index.Name) { + allPKColumnsHaveValue := true + for _, col := range index.Columns { + if params, ok := paramMap[col.ColumnName]; !ok || len(params) == 0 || params[0] == nil { + allPKColumnsHaveValue = false + break + } + } + hasPK = allPKColumnsHaveValue + break + } + } + if !hasPK { + hasValidUniqueIndex := false + for _, index := range metaData.Indexs { + if !index.NonUnique && !strings.EqualFold("PRIMARY", index.Name) { + if _, _, valid := validateIndexPrefix(index, paramMap, 0); valid { + hasValidUniqueIndex = true + break + } + } + } + if !hasValidUniqueIndex { + return "", nil, nil + } + } + var sql strings.Builder + sql.WriteString("SELECT * FROM " + metaData.TableName + " ") + var selectArgs []driver.Value isContainWhere := false - for i := 0; i < insertNum; i++ { - finalI := i - paramAppenderTempList := make([]driver.Value, 0) + hasConditions := false + for i := 0; i < len(insertRows); i++ { + var rowConditions []string + var rowArgs []driver.Value + usedParams := make(map[string]bool) + + // First try unique indexes for _, index := range metaData.Indexs { - //unique index - if index.NonUnique || isIndexValueNotNull(index, paramMap, finalI) == false { + if index.NonUnique || strings.EqualFold("PRIMARY", index.Name) { continue } - columnIsNull := true - uniqueList := make([]string, 0) - for _, columnMeta := range index.Columns { - columnName := strings.ToLower(columnMeta.ColumnName) - imageParameters, ok := paramMap[columnName] - if !ok && columnMeta.ColumnDef != nil { - if strings.EqualFold("PRIMARY", index.Name) { - u.BeforeImageSqlPrimaryKeys[columnName] = true - } - uniqueList = append(uniqueList, columnName+" = DEFAULT("+columnName+") ") - columnIsNull = false - continue - } - if strings.EqualFold("PRIMARY", index.Name) { - u.BeforeImageSqlPrimaryKeys[columnName] = true + if conditions, args, valid := validateIndexPrefix(index, paramMap, i); valid { + rowConditions = append(rowConditions, "("+strings.Join(conditions, " and ")+")") + rowArgs = append(rowArgs, args...) + hasConditions = true + for _, colMeta := range index.Columns { + usedParams[colMeta.ColumnName] = true } - columnIsNull = false - uniqueList = append(uniqueList, columnName+" = ? ") - paramAppenderTempList = append(paramAppenderTempList, imageParameters[finalI]) } + } - if !columnIsNull { - if isContainWhere { - sql.WriteString(" OR (" + strings.Join(uniqueList, " and ") + ") ") - } else { - sql.WriteString(" WHERE (" + strings.Join(uniqueList, " and ") + ") ") - isContainWhere = true + // Then try primary key + for _, index := range metaData.Indexs { + if !strings.EqualFold("PRIMARY", index.Name) { + continue + } + if conditions, args, valid := validateIndexPrefix(index, paramMap, i); valid { + rowConditions = append(rowConditions, "("+strings.Join(conditions, " and ")+")") + rowArgs = append(rowArgs, args...) + hasConditions = true + for _, colMeta := range index.Columns { + usedParams[colMeta.ColumnName] = true } } } - selectArgs = append(selectArgs, paramAppenderTempList...) + + if len(rowConditions) > 0 { + if !isContainWhere { + sql.WriteString("WHERE ") + isContainWhere = true + } else { + sql.WriteString(" OR ") + } + sql.WriteString(strings.Join(rowConditions, " OR ") + " ") + selectArgs = append(selectArgs, rowArgs...) + } + } + if !hasConditions { + return "", nil, nil } - log.Infof("build select sql by insert on duplicate sourceQuery, sql {}", sql.String()) - return sql.String(), selectArgs, nil + sqlStr := sql.String() + log.Infof("build select sql by insert on duplicate sourceQuery, sql: %s", sqlStr) + return sqlStr, selectArgs, nil } func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { @@ -168,14 +208,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e log.Errorf("build prepare stmt: %+v", err) return nil, err } - + defer stmt.Close() + tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O + metaData := execCtx.MetaDataMap[tableName] rows, err := stmt.Query(selectArgs) if err != nil { - log.Errorf("stmt query: %+v", err) return nil, err } - tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O - metaData := execCtx.MetaDataMap[tableName] + defer rows.Close() image, err := u.buildRecordImages(rows, &metaData) if err != nil { return nil, err @@ -185,11 +225,13 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Context, beforeImages []*types.RecordImage) (string, []driver.Value) { selectSQL, selectArgs := u.BeforeSelectSql, u.Args - var beforeImage *types.RecordImage if len(beforeImages) > 0 { beforeImage = beforeImages[0] } + if beforeImage == nil || len(beforeImage.Rows) == 0 { + return selectSQL, selectArgs + } primaryValueMap := make(map[string][]interface{}) for _, row := range beforeImage.Rows { for _, col := range row.Columns { @@ -198,25 +240,46 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co } } } - var afterImageSql strings.Builder - var primaryValues []driver.Value afterImageSql.WriteString(selectSQL) - for i := 0; i < len(beforeImage.Rows); i++ { - wherePrimaryList := make([]string, 0) - for name, value := range primaryValueMap { - if !u.BeforeImageSqlPrimaryKeys[name] { - wherePrimaryList = append(wherePrimaryList, name+" = ? ") - primaryValues = append(primaryValues, value[i]) + if len(primaryValueMap) == 0 || len(selectArgs) == len(beforeImage.Rows)*len(primaryValueMap) { + return selectSQL, selectArgs + } + var primaryValues []driver.Value + usedPrimaryKeys := make(map[string]bool) + for name := range primaryValueMap { + if !u.BeforeImageSqlPrimaryKeys[name] { + usedPrimaryKeys[name] = true + for i := 0; i < len(beforeImage.Rows); i++ { + if value := primaryValueMap[name][i]; value != nil { + if dv, ok := value.(driver.Value); ok { + primaryValues = append(primaryValues, dv) + } else { + primaryValues = append(primaryValues, value) + } + } } } - if len(wherePrimaryList) != 0 { - afterImageSql.WriteString(" OR (" + strings.Join(wherePrimaryList, " and ") + ") ") + } + if len(primaryValues) > 0 { + afterImageSql.WriteString(" OR (" + strings.Join(u.buildPrimaryKeyConditions(primaryValueMap, usedPrimaryKeys), " and ") + ") ") + } + finalArgs := make([]driver.Value, len(selectArgs)+len(primaryValues)) + copy(finalArgs, selectArgs) + copy(finalArgs[len(selectArgs):], primaryValues) + sqlStr := afterImageSql.String() + log.Infof("build after select sql by insert on duplicate sourceQuery, sql %s", sqlStr) + return sqlStr, finalArgs +} + +func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildPrimaryKeyConditions(primaryValueMap map[string][]interface{}, usedPrimaryKeys map[string]bool) []string { + var conditions []string + for name := range primaryValueMap { + if !usedPrimaryKeys[name] { + conditions = append(conditions, name+" = ? ") } } - selectArgs = append(selectArgs, primaryValues...) - log.Infof("build after select sql by insert on duplicate sourceQuery, sql {}", afterImageSql.String()) - return afterImageSql.String(), selectArgs + return conditions } func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error { @@ -243,11 +306,10 @@ func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) e // build sql params func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) (map[string][]driver.Value, error) { - var ( - parameterMap = make(map[string][]driver.Value) - ) + parameterMap := make(map[string][]driver.Value) insertColumns := getInsertColumns(insert) - var placeHolderIndex = 0 + placeHolderIndex := 0 + for _, row := range insertRows { if len(row) != len(insertColumns) { log.Errorf("insert row's column size not equal to insert column size") @@ -256,13 +318,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast. for i, col := range insertColumns { columnName := strings.ToLower(executor.DelEscape(col, types.DBTypeMySQL)) val := row[i] - rStr, ok := val.(string) - if ok && strings.EqualFold(rStr, SqlPlaceholder) { - objects := args[placeHolderIndex] - parameterMap[columnName] = append(parameterMap[col], objects) + if str, ok := val.(string); ok && strings.EqualFold(str, SqlPlaceholder) { + if placeHolderIndex >= len(args) { + return nil, fmt.Errorf("not enough parameters for placeholders") + } + parameterMap[columnName] = append(parameterMap[columnName], args[placeHolderIndex]) placeHolderIndex++ } else { - parameterMap[columnName] = append(parameterMap[col], val) + parameterMap[columnName] = append(parameterMap[columnName], val) } } } @@ -296,3 +359,28 @@ func isIndexValueNotNull(indexMeta types.IndexMeta, imageParameterMap map[string } return true } + +func validateIndexPrefix(index types.IndexMeta, paramMap map[string][]driver.Value, rowIndex int) ([]string, []driver.Value, bool) { + var indexConditions []string + var indexArgs []driver.Value + if len(index.Columns) > 1 { + for _, colMeta := range index.Columns { + params, ok := paramMap[colMeta.ColumnName] + if !ok || len(params) <= rowIndex || params[rowIndex] == nil { + return nil, nil, false + } + } + } + for _, colMeta := range index.Columns { + columnName := colMeta.ColumnName + params, ok := paramMap[columnName] + if ok && len(params) > rowIndex && params[rowIndex] != nil { + indexConditions = append(indexConditions, columnName+" = ? ") + indexArgs = append(indexArgs, params[rowIndex]) + } + } + if len(indexConditions) != len(index.Columns) { + return nil, nil, false + } + return indexConditions, indexArgs, true +} diff --git a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go index 59e673f7..f6e5b7cc 100644 --- a/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go +++ b/pkg/datasource/sql/undo/builder/mysql_insertonduplicate_update_undo_log_builder_test.go @@ -143,6 +143,69 @@ func TestInsertOnDuplicateBuildBeforeImageSQL(t *testing.T) { expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (name = ? and age = ? ) ", expectQueryArgs1: []driver.Value{"Jack1", 81, "Michal", int64(35)}, }, + // Test case for null unique index + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?, ?, ?) on duplicate key update age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{1, nil, 2, 5}, + expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) ", + expectQueryArgs1: []driver.Value{1}, + }, + // Test case for null primary key + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, age) values(?, ?) on duplicate key update age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{nil, 2, 5}, + expectQuery1: "SELECT * FROM t_user WHERE (age = ? )", + expectQueryArgs1: []driver.Value{2}, + }, + // Test case for null unique index with no primary key + { + execCtx: &types.ExecContext{ + Query: "insert into t_user(name, age) values(?, ?) on duplicate key update age = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta2}, + }, + sourceQueryArgs: []driver.Value{nil, 2, 5}, + expectQuery1: "", + expectQueryArgs1: nil, + }, + // Test case for composite index with all columns + { + name: "composite_index_full", + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,?,?) on duplicate key update other = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{1, "Jack", 25, "other"}, + expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (id = ? ) ", + expectQueryArgs1: []driver.Value{"Jack", 25, 1}, + }, + // Test case for composite index with null value + { + name: "composite_index_with_null", + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name, age) values(?,?,?) on duplicate key update other = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{1, "Jack", nil, "other"}, + expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) ", + expectQueryArgs1: []driver.Value{1}, + }, + // Test case for composite index with leftmost prefix only + { + name: "composite_index_leftmost_prefix", + execCtx: &types.ExecContext{ + Query: "insert into t_user(id, name) values(?,?) on duplicate key update other = ?", + MetaDataMap: map[string]types.TableMeta{"t_user": tableMeta1}, + }, + sourceQueryArgs: []driver.Value{1, "Jack", "other"}, + expectQuery1: "SELECT * FROM t_user WHERE (id = ? ) ", + expectQueryArgs1: []driver.Value{1}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {