Skip to content

Commit

Permalink
NEOS-1478 Separate sql statements from args (#2718)
Browse files Browse the repository at this point in the history
  • Loading branch information
alishakawaguchi authored Sep 19, 2024
1 parent 11091bf commit 763d6b3
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 30 deletions.
2 changes: 1 addition & 1 deletion internal/postgres/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func SqlRowToPgTypesMap(rows *sql.Rows) (map[string]any, error) {
continue
}
}
jObj[col] = string(t)
jObj[col] = t
case *PgxArray[any]:
jObj[col] = pgArrayToGoSlice(t)
default:
Expand Down
20 changes: 15 additions & 5 deletions worker/pkg/benthos/sql/output_sql_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa

processedCols, processedRows := s.processRows(s.columns, rows)

insertQuery, err := querybuilder.BuildInsertQuery(s.driver, s.schema, s.table, processedCols, s.columnDataTypes, processedRows, &s.onConflictDoNothing)
insertQuery, args, err := querybuilder.BuildInsertQuery(s.driver, s.schema, s.table, processedCols, s.columnDataTypes, processedRows, &s.onConflictDoNothing)
if err != nil {
return err
}
Expand All @@ -304,8 +304,11 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa
insertQuery = sqlmanager_postgres.BuildPgInsertIdentityAlwaysSql(insertQuery)
}

query := s.buildQuery(insertQuery)
if _, err := s.db.ExecContext(ctx, query); err != nil {
if s.driver != sqlmanager_shared.PostgresDriver {
insertQuery = s.buildQuery(insertQuery)
}

if _, err := s.db.ExecContext(ctx, insertQuery, args...); err != nil {
if !s.skipForeignKeyViolations || !neosync_benthos.IsForeignKeyViolationError(err.Error()) {
return err
}
Expand All @@ -314,6 +317,13 @@ func (s *pooledInsertOutput) WriteBatch(ctx context.Context, batch service.Messa
return err
}
}
if s.driver == sqlmanager_shared.PostgresDriver && s.suffix != nil && *s.suffix != "" {
// to prevent postgres cannot insert multiple commands into a prepared statement error
// must run table identity count reset separately
if _, err := s.db.ExecContext(ctx, *s.suffix); err != nil {
return err
}
}
return nil
}

Expand All @@ -325,12 +335,12 @@ func (s *pooledInsertOutput) RetryInsertRowByRow(
errorCount := 0
insertCount := 0
for _, row := range rows {
insertQuery, err := querybuilder.BuildInsertQuery(s.driver, s.schema, s.table, columns, s.columnDataTypes, [][]any{row}, &s.onConflictDoNothing)
insertQuery, args, err := querybuilder.BuildInsertQuery(s.driver, s.schema, s.table, columns, s.columnDataTypes, [][]any{row}, &s.onConflictDoNothing)
if err != nil {
return err
}
query := s.buildQuery(insertQuery)
_, err = s.db.ExecContext(ctx, query)
_, err = s.db.ExecContext(ctx, query, args...)
if err != nil && neosync_benthos.IsForeignKeyViolationError(err.Error()) {
errorCount++
} else if err != nil && !neosync_benthos.IsForeignKeyViolationError(err.Error()) {
Expand Down
28 changes: 18 additions & 10 deletions worker/pkg/query-builder/query-builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
// import the dialect
_ "github.com/doug-martin/goqu/v9/dialect/mysql"
_ "github.com/doug-martin/goqu/v9/dialect/postgres"
_ "github.com/doug-martin/goqu/v9/dialect/sqlserver"
"github.com/doug-martin/goqu/v9/exp"
gotypeutil "github.com/nucleuscloud/neosync/internal/gotypeutil"
pgutil "github.com/nucleuscloud/neosync/internal/postgres"
Expand All @@ -29,12 +30,19 @@ type SubsetColumnConstraint struct {
ForeignKey *SubsetReferenceKey
}

func getGoquDialect(driver string) goqu.DialectWrapper {
if driver == sqlmanager_shared.PostgresDriver {
return goqu.Dialect("postgres")
}
return goqu.Dialect(driver)
}

func BuildSelectQuery(
driver, table string,
columns []string,
whereClause *string,
) (string, error) {
builder := goqu.Dialect(driver)
builder := getGoquDialect(driver)
sqltable := goqu.I(table)

selectColumns := make([]any, len(columns))
Expand Down Expand Up @@ -62,7 +70,7 @@ func BuildSelectLimitQuery(
driver, table string,
limit uint,
) (string, error) {
builder := goqu.Dialect(driver)
builder := getGoquDialect(driver)
sqltable := goqu.I(table)
sql, _, err := builder.From((sqltable)).Limit(limit).ToSQL()
if err != nil {
Expand Down Expand Up @@ -132,14 +140,14 @@ func BuildInsertQuery(
columnDataTypes []string,
values [][]any,
onConflictDoNothing *bool,
) (string, error) {
builder := goqu.Dialect(driver)
) (sql string, args []any, err error) {
builder := getGoquDialect(driver)
sqltable := goqu.S(schema).Table(table)
insertCols := make([]any, len(columns))
for i, col := range columns {
insertCols[i] = col
}
insert := builder.Insert(sqltable).Cols(insertCols...)
insert := builder.Insert(sqltable).Prepared(true).Cols(insertCols...)
for _, row := range values {
gval := getGoquVals(driver, row, columnDataTypes)
insert = insert.Vals(gval)
Expand All @@ -149,11 +157,11 @@ func BuildInsertQuery(
insert = insert.OnConflict(goqu.DoNothing())
}

query, _, err := insert.ToSQL()
query, args, err := insert.ToSQL()
if err != nil {
return "", err
return "", nil, err
}
return query, nil
return query, args, nil
}

func BuildUpdateQuery(
Expand All @@ -162,7 +170,7 @@ func BuildUpdateQuery(
whereColumns []string,
columnValueMap map[string]any,
) (string, error) {
builder := goqu.Dialect(driver)
builder := getGoquDialect(driver)
sqltable := goqu.S(schema).Table(table)

updateRecord := goqu.Record{}
Expand Down Expand Up @@ -195,7 +203,7 @@ func BuildUpdateQuery(
func BuildTruncateQuery(
driver, table string,
) (string, error) {
builder := goqu.Dialect(driver)
builder := getGoquDialect(driver)
sqltable := goqu.I(table)
truncate := builder.Truncate(sqltable)
query, _, err := truncate.ToSQL()
Expand Down
23 changes: 13 additions & 10 deletions worker/pkg/query-builder/query-builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,21 @@ func Test_BuildInsertQuery(t *testing.T) {
values [][]any
onConflictDoNothing bool
expected string
expectedArgs []any
}{
{"Single Column mysql", "mysql", "public", "users", []string{"name"}, []string{}, [][]any{{"Alice"}, {"Bob"}}, false, "INSERT INTO `public`.`users` (`name`) VALUES ('Alice'), ('Bob')"},
{"Special characters mysql", "mysql", "public", "users.stage$dev", []string{"name"}, []string{}, [][]any{{"Alice"}, {"Bob"}}, false, "INSERT INTO `public`.`users.stage$dev` (`name`) VALUES ('Alice'), ('Bob')"},
{"Multiple Columns mysql", "mysql", "public", "users", []string{"name", "email"}, []string{}, [][]any{{"Alice", "[email protected]"}, {"Bob", "[email protected]"}}, true, "INSERT IGNORE INTO `public`.`users` (`name`, `email`) VALUES ('Alice', '[email protected]'), ('Bob', '[email protected]')"},
{"Single Column postgres", "postgres", "public", "users", []string{"name"}, []string{}, [][]any{{"Alice"}, {"Bob"}}, false, `INSERT INTO "public"."users" ("name") VALUES ('Alice'), ('Bob')`},
{"Multiple Columns postgres", "postgres", "public", "users", []string{"name", "email"}, []string{}, [][]any{{"Alice", "[email protected]"}, {"Bob", "[email protected]"}}, true, `INSERT INTO "public"."users" ("name", "email") VALUES ('Alice', '[email protected]'), ('Bob', '[email protected]') ON CONFLICT DO NOTHING`},
{"Single Column mysql", "mysql", "public", "users", []string{"name"}, []string{}, [][]any{{"Alice"}, {"Bob"}}, false, "INSERT INTO `public`.`users` (`name`) VALUES (?), (?)", []any{"Alice", "Bob"}},
{"Special characters mysql", "mysql", "public", "users.stage$dev", []string{"name"}, []string{}, [][]any{{"Alice"}, {"Bob"}}, false, "INSERT INTO `public`.`users.stage$dev` (`name`) VALUES (?), (?)", []any{"Alice", "Bob"}},
{"Multiple Columns mysql", "mysql", "public", "users", []string{"name", "email"}, []string{}, [][]any{{"Alice", "[email protected]"}, {"Bob", "[email protected]"}}, true, "INSERT IGNORE INTO `public`.`users` (`name`, `email`) VALUES (?, ?), (?, ?)", []any{"Alice", "[email protected]", "Bob", "[email protected]"}},
{"Single Column postgres", "postgres", "public", "users", []string{"name"}, []string{}, [][]any{{"Alice"}, {"Bob"}}, false, `INSERT INTO "public"."users" ("name") VALUES ($1), ($2)`, []any{"Alice", "Bob"}},
{"Multiple Columns postgres", "postgres", "public", "users", []string{"name", "email"}, []string{}, [][]any{{"Alice", "[email protected]"}, {"Bob", "[email protected]"}}, true, `INSERT INTO "public"."users" ("name", "email") VALUES ($1, $2), ($3, $4) ON CONFLICT DO NOTHING`, []any{"Alice", "[email protected]", "Bob", "[email protected]"}},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual, err := BuildInsertQuery(tt.driver, tt.schema, tt.table, tt.columns, tt.columnDataTypes, tt.values, &tt.onConflictDoNothing)
actual, args, err := BuildInsertQuery(tt.driver, tt.schema, tt.table, tt.columns, tt.columnDataTypes, tt.values, &tt.onConflictDoNothing)
require.NoError(t, err)
require.Equal(t, tt.expected, actual)
require.Equal(t, tt.expectedArgs, args)
})
}
}
Expand All @@ -138,9 +140,9 @@ func Test_BuildInsertQuery_JsonArray(t *testing.T) {
}
onConflictDoNothing := false

query, err := BuildInsertQuery(driver, schema, table, columns, columnDataTypes, values, &onConflictDoNothing)
query, _, err := BuildInsertQuery(driver, schema, table, columns, columnDataTypes, values, &onConflictDoNothing)
require.NoError(t, err)
expectedQuery := `INSERT INTO "public"."test_table" ("id", "name", "tags") VALUES (1, 'John', ARRAY['{"tag":"cool"}','{"tag":"awesome"}']::jsonb[]), (2, 'Jane', ARRAY['{"tag":"smart"}','{"tag":"clever"}']::jsonb[])`
expectedQuery := `INSERT INTO "public"."test_table" ("id", "name", "tags") VALUES ($1, $2, ARRAY['{"tag":"cool"}','{"tag":"awesome"}']::jsonb[]), ($3, $4, ARRAY['{"tag":"smart"}','{"tag":"clever"}']::jsonb[])`
require.Equal(t, expectedQuery, query)
}

Expand All @@ -156,10 +158,11 @@ func Test_BuildInsertQuery_Json(t *testing.T) {
}
onConflictDoNothing := false

query, err := BuildInsertQuery(driver, schema, table, columns, columnDataTypes, values, &onConflictDoNothing)
query, args, err := BuildInsertQuery(driver, schema, table, columns, columnDataTypes, values, &onConflictDoNothing)
require.NoError(t, err)
expectedQuery := `INSERT INTO "public"."test_table" ("id", "name", "tags") VALUES (1, 'John', '{"tag":"cool"}'), (2, 'Jane', '{"tag":"smart"}')`
expectedQuery := `INSERT INTO "public"."test_table" ("id", "name", "tags") VALUES ($1, $2, $3), ($4, $5, $6)`
require.Equal(t, expectedQuery, query)
require.Equal(t, []any{int64(1), "John", []byte{123, 34, 116, 97, 103, 34, 58, 34, 99, 111, 111, 108, 34, 125}, int64(2), "Jane", []byte{123, 34, 116, 97, 103, 34, 58, 34, 115, 109, 97, 114, 116, 34, 125}}, args)
}

func TestGetGoquVals(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ CREATE TABLE transformers (
str VARCHAR (255),
character_scramble VARCHAR (255),
bool BOOLEAN,
card_number VARCHAR(255),
card_number BIGINT,
categorical VARCHAR(255),
city VARCHAR(255),
full_address VARCHAR(255),
Expand All @@ -27,7 +27,7 @@ CREATE TABLE transformers (
street_address VARCHAR(255),
unix_time BIGINT,
username VARCHAR(255),
utc_timestamp VARCHAR(255),
utc_timestamp TIMESTAMPTZ,
uuid VARCHAR(255),
zipcode BIGINT
);

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,42 @@ INSERT INTO alltypes.all_postgres_types (
123456 -- oid_col
);

INSERT INTO alltypes.all_postgres_types (
Id
) VALUES (
DEFAULT
);


CREATE TABLE IF NOT EXISTS alltypes.time_time (
id SERIAL PRIMARY KEY,
timestamp_col TIMESTAMP,
timestamptz_col TIMESTAMPTZ,
date_col DATE
);

INSERT INTO alltypes.time_time (
timestamp_col,
timestamptz_col,
date_col
)
VALUES (
'2024-03-18 10:30:00',
'2024-03-18 10:30:00+00',
'2024-03-18'
);

INSERT INTO alltypes.time_time (
timestamp_col,
timestamptz_col,
date_col
)
VALUES (
'0001-01-01 00:00:00 BC',
'0001-01-01 00:00:00+00 BC',
'0001-01-01 BC'
);


CREATE TABLE IF NOT EXISTS alltypes.array_types (
"id" BIGINT NOT NULL PRIMARY KEY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
TruncateCascade: true,
},
Expected: map[string]*workflow_testdata.ExpectedOutput{
"alltypes.all_postgres_types": &workflow_testdata.ExpectedOutput{RowCount: 1},
"alltypes.all_postgres_types": &workflow_testdata.ExpectedOutput{RowCount: 2},
"alltypes.array_types": &workflow_testdata.ExpectedOutput{RowCount: 1},
"alltypes.time_time": &workflow_testdata.ExpectedOutput{RowCount: 2},
},
},
{
Expand All @@ -29,8 +30,9 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
InitSchema: true,
},
Expected: map[string]*workflow_testdata.ExpectedOutput{
"alltypes.all_postgres_types": &workflow_testdata.ExpectedOutput{RowCount: 1},
"alltypes.all_postgres_types": &workflow_testdata.ExpectedOutput{RowCount: 2},
"alltypes.array_types": &workflow_testdata.ExpectedOutput{RowCount: 1},
"alltypes.time_time": &workflow_testdata.ExpectedOutput{RowCount: 2},
},
},
}
Expand Down

0 comments on commit 763d6b3

Please sign in to comment.