diff --git a/pkg/dbutil/common.go b/pkg/dbutil/common.go index 306ed30e0..ff54da953 100644 --- a/pkg/dbutil/common.go +++ b/pkg/dbutil/common.go @@ -110,6 +110,52 @@ func GetDBConfigFromEnv(schema string) DBConfig { } } +type DSNType interface { + DSNString() string +} + +type DSNStringType struct { + Key string + Value string +} + +func (s DSNStringType) DSNString() string { + // key='val'. add single quote for better compatibility. + return fmt.Sprintf("&%s=%%27%s%%27", s.Key, url.QueryEscape(s.Value)) +} + +type DSNBoolType struct { + Key string + Value bool +} + +func (b DSNBoolType) DSNString() string { + return fmt.Sprintf("&%s=%t", b.Key, b.Value) +} + +// OpenDB opens a mysql connection FD +func OpenDBWithDSN(cfg DBConfig, vars []DSNType) (*sql.DB, error) { + var dbDSN string + if len(cfg.Snapshot) != 0 { + log.Info("create connection with snapshot", zap.String("snapshot", cfg.Snapshot)) + dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&tidb_snapshot=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Snapshot) + } else { + dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4", cfg.User, cfg.Password, cfg.Host, cfg.Port) + } + + for _, dsnType := range vars { + dbDSN += dsnType.DSNString() + } + + dbConn, err := sql.Open("mysql", dbDSN) + if err != nil { + return nil, errors.Trace(err) + } + + err = dbConn.Ping() + return dbConn, errors.Trace(err) +} + // OpenDB opens a mysql connection FD func OpenDB(cfg DBConfig, vars map[string]string) (*sql.DB, error) { var dbDSN string diff --git a/sync_diff_inspector/source/common/conn.go b/sync_diff_inspector/source/common/conn.go index ad2feff9c..05da79550 100644 --- a/sync_diff_inspector/source/common/conn.go +++ b/sync_diff_inspector/source/common/conn.go @@ -22,8 +22,8 @@ import ( ) // CreateDB creates sql.DB used for select data -func CreateDB(ctx context.Context, dbConfig *dbutil.DBConfig, vars map[string]string, num int) (db *sql.DB, err error) { - db, err = dbutil.OpenDB(*dbConfig, vars) +func CreateDB(ctx context.Context, dbConfig *dbutil.DBConfig, vars []dbutil.DSNType, num int) (db *sql.DB, err error) { + db, err = dbutil.OpenDBWithDSN(*dbConfig, vars) if err != nil { return nil, errors.Errorf("create db connections %s error %v", dbConfig.String(), err) } diff --git a/sync_diff_inspector/source/source.go b/sync_diff_inspector/source/source.go index 3bc1db0da..0a6145e25 100644 --- a/sync_diff_inspector/source/source.go +++ b/sync_diff_inspector/source/source.go @@ -235,8 +235,8 @@ func buildSourceFromCfg(ctx context.Context, tableDiffs []*common.TableDiff, con return NewMySQLSources(ctx, tableDiffs, dbs, connCount, f) } -func getAutoSnapshotPosition(dbConfig *dbutil.DBConfig, vars map[string]string) (string, string, error) { - tmpConn, err := dbutil.OpenDB(*dbConfig, vars) +func getAutoSnapshotPosition(dbConfig *dbutil.DBConfig, vars []dbutil.DSNType) (string, string, error) { + tmpConn, err := dbutil.OpenDBWithDSN(*dbConfig, vars) if err != nil { return "", "", errors.Annotatef(err, "connecting to auto-position tidb_snapshot failed") } @@ -250,9 +250,16 @@ func getAutoSnapshotPosition(dbConfig *dbutil.DBConfig, vars map[string]string) } func initDBConn(ctx context.Context, cfg *config.Config) error { - // Unified time zone - vars := map[string]string{ - "time_zone": UnifiedTimeZone, + vars := []dbutil.DSNType{ + // Unified time zone + dbutil.DSNStringType{ + Key: "time_zone", + Value: UnifiedTimeZone, + }, + dbutil.DSNBoolType{ + Key: "interpolateParams", + Value: true, + }, } // Fill in tidb_snapshot if it is set to AUTO