Skip to content

Commit e87e37b

Browse files
authored
Backport: fix UPDATE failures on GENERATED ALWAYS AS IDENTITY columns (#820) (#822)
1 parent 6aba9d7 commit e87e37b

8 files changed

Lines changed: 415 additions & 34 deletions

File tree

pkg/schemalog/schema.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,14 @@ func (c *Column) IsGenerated() bool {
276276
return c.Generated || c.Identity != ""
277277
}
278278

279+
// IsAlwaysIdentity reports whether the column is defined as
280+
// GENERATED ALWAYS AS IDENTITY. Such columns reject explicit values in UPDATE
281+
// SET clauses — only DEFAULT is accepted. The DDL injector encodes this as
282+
// "ALWAYS" in the schema log (see migrations/postgres/core/2_create_emit_ddl_function_and_triggers.up.sql).
283+
func (c *Column) IsAlwaysIdentity() bool {
284+
return c.Identity == "ALWAYS"
285+
}
286+
279287
func (c *Column) IsSerial() bool {
280288
return c.HasSequence() &&
281289
(strings.ToUpper(c.DataType) == "INTEGER" ||

pkg/stream/integration/pg_pg_integration_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,93 @@ func Test_PostgresToPostgres_SchemaObjects(t *testing.T) {
416416
}, 20*time.Second, 200*time.Millisecond)
417417
}
418418

419+
// Test_PostgresToPostgres_AlwaysIdentityUpdate verifies that UPDATE events on a
420+
// table with a GENERATED ALWAYS AS IDENTITY column are replicated successfully.
421+
// Postgres rejects UPDATE ... SET <always-identity-col> = <value> with
422+
// "column can only be updated to DEFAULT", so the column must be filtered out
423+
// of the SET clause on the target.
424+
func Test_PostgresToPostgres_AlwaysIdentityUpdate(t *testing.T) {
425+
if os.Getenv("PGSTREAM_INTEGRATION_TESTS") == "" {
426+
t.Skip("skipping integration test...")
427+
}
428+
429+
cfg := &stream.Config{
430+
Listener: testPostgresListenerCfg(),
431+
Processor: testPostgresProcessorCfg(pgurl, withoutBulkIngestion),
432+
}
433+
434+
ctx, cancel := context.WithCancel(context.Background())
435+
defer cancel()
436+
437+
targetConn, err := pglib.NewConn(ctx, targetPGURL)
438+
require.NoError(t, err)
439+
defer targetConn.Close(ctx)
440+
441+
runStream(t, ctx, cfg)
442+
443+
testTable := "pg2pg_always_identity_update_test"
444+
defer execQuery(t, ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", testTable))
445+
446+
// Table has a separate primary key plus a GENERATED ALWAYS AS IDENTITY
447+
// column. UPDATEs on `name` must not include `request_id` in the SET
448+
// clause, or Postgres rejects them.
449+
execQuery(t, ctx, fmt.Sprintf(`
450+
CREATE TABLE %s (
451+
id bigint PRIMARY KEY,
452+
request_id bigint GENERATED ALWAYS AS IDENTITY,
453+
name text
454+
)
455+
`, testTable))
456+
457+
// Wait for the table schema to land on the target before querying it.
458+
require.Eventually(t, func() bool {
459+
columns := getInformationSchemaColumns(t, ctx, targetConn, testTable)
460+
return len(columns) == 3
461+
}, 20*time.Second, 200*time.Millisecond, "table schema not replicated")
462+
463+
execQuery(t, ctx, fmt.Sprintf("INSERT INTO %s(id, name) VALUES(1, 'Alice')", testTable))
464+
execQuery(t, ctx, fmt.Sprintf("INSERT INTO %s(id, name) VALUES(2, 'Bob')", testTable))
465+
466+
require.Eventually(t, func() bool {
467+
rows := getIDNameRows(t, ctx, targetConn,
468+
fmt.Sprintf("SELECT id, name FROM %s ORDER BY id", testTable))
469+
return len(rows) == 2 && rows[0].name == "Alice" && rows[1].name == "Bob"
470+
}, 20*time.Second, 200*time.Millisecond, "initial inserts not replicated")
471+
472+
// Capture the original request_id values on the target so we can verify
473+
// they are preserved across the UPDATE (since the SET clause must not
474+
// touch the always-identity column).
475+
var aliceReqIDBefore, bobReqIDBefore int64
476+
err = targetConn.QueryRow(ctx, []any{&aliceReqIDBefore},
477+
fmt.Sprintf("SELECT request_id FROM %s WHERE id = 1", testTable))
478+
require.NoError(t, err)
479+
err = targetConn.QueryRow(ctx, []any{&bobReqIDBefore},
480+
fmt.Sprintf("SELECT request_id FROM %s WHERE id = 2", testTable))
481+
require.NoError(t, err)
482+
483+
// Perform UPDATEs that, pre-fix, produced
484+
// "column \"request_id\" can only be updated to DEFAULT" on the target.
485+
execQuery(t, ctx, fmt.Sprintf("UPDATE %s SET name = 'Alice2' WHERE id = 1", testTable))
486+
execQuery(t, ctx, fmt.Sprintf("UPDATE %s SET name = 'Bob2' WHERE id = 2", testTable))
487+
488+
require.Eventually(t, func() bool {
489+
rows := getIDNameRows(t, ctx, targetConn,
490+
fmt.Sprintf("SELECT id, name FROM %s ORDER BY id", testTable))
491+
return len(rows) == 2 && rows[0].name == "Alice2" && rows[1].name == "Bob2"
492+
}, 20*time.Second, 200*time.Millisecond, "UPDATE not replicated — likely rejected by always-identity rule")
493+
494+
var aliceReqIDAfter, bobReqIDAfter int64
495+
err = targetConn.QueryRow(ctx, []any{&aliceReqIDAfter},
496+
fmt.Sprintf("SELECT request_id FROM %s WHERE id = 1", testTable))
497+
require.NoError(t, err)
498+
err = targetConn.QueryRow(ctx, []any{&bobReqIDAfter},
499+
fmt.Sprintf("SELECT request_id FROM %s WHERE id = 2", testTable))
500+
require.NoError(t, err)
501+
502+
require.Equal(t, aliceReqIDBefore, aliceReqIDAfter, "request_id should be unchanged by UPDATE")
503+
require.Equal(t, bobReqIDBefore, bobReqIDAfter, "request_id should be unchanged by UPDATE")
504+
}
505+
419506
func Test_PostgresToPostgres_Sequences(t *testing.T) {
420507
if os.Getenv("PGSTREAM_INTEGRATION_TESTS") == "" {
421508
t.Skip("skipping integration test...")

pkg/wal/processor/postgres/helper_test.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,25 @@ func (m *mockAdapter) close() error {
2222
}
2323

2424
type mockSchemaObserver struct {
25-
getGeneratedColumnNamesFn func(ctx context.Context, schema, table string) (map[string]struct{}, error)
26-
getSequenceColumnsFn func(ctx context.Context, schema, table string) (map[string]string, error)
27-
isMaterializedViewFn func(schema, table string) bool
28-
updateFn func(logEntry *schemalog.LogEntry)
29-
closeFn func() error
25+
getGeneratedColumnNamesFn func(ctx context.Context, schema, table string) (map[string]struct{}, error)
26+
getAlwaysIdentityColumnNamesFn func(ctx context.Context, schema, table string) (map[string]struct{}, error)
27+
getSequenceColumnsFn func(ctx context.Context, schema, table string) (map[string]string, error)
28+
isMaterializedViewFn func(schema, table string) bool
29+
updateFn func(logEntry *schemalog.LogEntry)
30+
closeFn func() error
3031
}
3132

3233
func (m *mockSchemaObserver) getGeneratedColumnNames(ctx context.Context, schema, table string) (map[string]struct{}, error) {
3334
return m.getGeneratedColumnNamesFn(ctx, schema, table)
3435
}
3536

37+
func (m *mockSchemaObserver) getAlwaysIdentityColumnNames(ctx context.Context, schema, table string) (map[string]struct{}, error) {
38+
if m.getAlwaysIdentityColumnNamesFn == nil {
39+
return nil, nil
40+
}
41+
return m.getAlwaysIdentityColumnNamesFn(ctx, schema, table)
42+
}
43+
3644
func (m *mockSchemaObserver) getSequenceColumns(ctx context.Context, schema, table string) (map[string]string, error) {
3745
return m.getSequenceColumnsFn(ctx, schema, table)
3846
}

pkg/wal/processor/postgres/postgres_schema_observer.go

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ type pgSchemaObserver struct {
2121
pgConn pglib.Querier
2222
// generatedTableColumns is a map of schema.table to a list of generated column names.
2323
generatedTableColumns *synclib.Map[string, map[string]struct{}]
24+
// alwaysIdentityTableColumns is a map of schema.table to a set of column names
25+
// defined as GENERATED ALWAYS AS IDENTITY. These must be filtered from UPDATE
26+
// SET clauses since Postgres rejects explicit values for them.
27+
alwaysIdentityTableColumns *synclib.Map[string, map[string]struct{}]
2428
// materializedViews is a map of schema name to a set of materialized view names.
2529
materializedViews *synclib.Map[string, map[string]struct{}]
2630
// columnTableSequences is a map of schema.table to a map of sequence column names.
@@ -37,11 +41,12 @@ func newPGSchemaObserver(ctx context.Context, pgURL string, logger loglib.Logger
3741
return nil, err
3842
}
3943
return &pgSchemaObserver{
40-
pgConn: pgConn,
41-
generatedTableColumns: synclib.NewMap[string, map[string]struct{}](),
42-
materializedViews: synclib.NewMap[string, map[string]struct{}](),
43-
columnTableSequences: synclib.NewMap[string, map[string]string](),
44-
logger: logger,
44+
pgConn: pgConn,
45+
generatedTableColumns: synclib.NewMap[string, map[string]struct{}](),
46+
alwaysIdentityTableColumns: synclib.NewMap[string, map[string]struct{}](),
47+
materializedViews: synclib.NewMap[string, map[string]struct{}](),
48+
columnTableSequences: synclib.NewMap[string, map[string]string](),
49+
logger: logger,
4550
}, nil
4651
}
4752

@@ -66,6 +71,25 @@ func (o *pgSchemaObserver) getGeneratedColumnNames(ctx context.Context, schema,
6671
return colNames, nil
6772
}
6873

74+
// getAlwaysIdentityColumnNames returns the set of GENERATED ALWAYS AS IDENTITY
75+
// column names for the given schema.table. If not cached, it queries postgres.
76+
func (o *pgSchemaObserver) getAlwaysIdentityColumnNames(ctx context.Context, schema, table string) (map[string]struct{}, error) {
77+
key := pglib.QuoteQualifiedIdentifier(schema, table)
78+
79+
columns, found := o.alwaysIdentityTableColumns.Get(key)
80+
if found {
81+
return columns, nil
82+
}
83+
84+
colNames, err := o.queryAlwaysIdentityColumnNames(ctx, schema, table)
85+
if err != nil {
86+
return nil, err
87+
}
88+
89+
o.alwaysIdentityTableColumns.Set(key, colNames)
90+
return colNames, nil
91+
}
92+
6993
// isMaterializedView will return true if the input schema.table is a
7094
// materialized view. It uses an internal cache to reduce the number of calls to
7195
// postgres. If the value is not in the cache, it will query postgres.
@@ -114,18 +138,28 @@ func (o *pgSchemaObserver) update(logEntry *schemalog.LogEntry) {
114138
}
115139

116140
// updateGeneratedColumnNames will update the internal cache with the table
117-
// columns for the schema log on input.
141+
// columns for the schema log on input. Identity columns are added to
142+
// generatedColumns via IsGenerated() (preserved historical behavior so live
143+
// INSERTs let the target auto-generate ids and the sequence increments
144+
// naturally). GENERATED ALWAYS AS IDENTITY columns are additionally tracked in
145+
// alwaysIdentityTableColumns so UPDATE SET clauses can drop them even on
146+
// cache paths where generatedColumns is empty (e.g. populated via SQL query).
118147
func (o *pgSchemaObserver) updateGeneratedColumnNames(logEntry *schemalog.LogEntry) {
119148
for _, table := range logEntry.Schema.Tables {
120149
key := pglib.QuoteQualifiedIdentifier(logEntry.SchemaName, table.Name)
121150
generatedColumns := make(map[string]struct{}, len(table.Columns))
151+
alwaysIdentityColumns := make(map[string]struct{}, len(table.Columns))
122152
for _, c := range table.Columns {
153+
if c.IsAlwaysIdentity() {
154+
alwaysIdentityColumns[pglib.QuoteIdentifier(c.Name)] = struct{}{}
155+
}
123156
if c.IsGenerated() {
124157
generatedColumns[pglib.QuoteIdentifier(c.Name)] = struct{}{}
125158
}
126159
}
127160

128161
o.generatedTableColumns.Set(key, generatedColumns)
162+
o.alwaysIdentityTableColumns.Set(key, alwaysIdentityColumns)
129163
}
130164
}
131165

@@ -183,6 +217,34 @@ func (o *pgSchemaObserver) queryGeneratedColumnNames(ctx context.Context, schema
183217
return columnNames, nil
184218
}
185219

220+
const alwaysIdentityTableColumnsQuery = `SELECT attname FROM pg_attribute
221+
WHERE attnum > 0
222+
AND attrelid = (SELECT c.oid FROM pg_class c JOIN pg_namespace n ON c.relnamespace=n.oid WHERE c.relname=$1 and n.nspname=$2)
223+
AND attidentity = 'a'`
224+
225+
func (o *pgSchemaObserver) queryAlwaysIdentityColumnNames(ctx context.Context, schemaName, tableName string) (map[string]struct{}, error) {
226+
columnNames := map[string]struct{}{}
227+
rows, err := o.pgConn.Query(ctx, alwaysIdentityTableColumnsQuery, tableName, schemaName)
228+
if err != nil {
229+
return nil, fmt.Errorf("getting table always-identity column names for table %s.%s: %w", schemaName, tableName, err)
230+
}
231+
defer rows.Close()
232+
233+
for rows.Next() {
234+
var columnName string
235+
if err := rows.Scan(&columnName); err != nil {
236+
return nil, fmt.Errorf("scanning table always-identity column name: %w", err)
237+
}
238+
columnNames[pglib.QuoteIdentifier(columnName)] = struct{}{}
239+
}
240+
241+
if err := rows.Err(); err != nil {
242+
return nil, err
243+
}
244+
245+
return columnNames, nil
246+
}
247+
186248
const materializedViewsQuery = `SELECT matviewname FROM pg_matviews WHERE schemaname = $1`
187249

188250
func (o *pgSchemaObserver) queryMaterializedViews(ctx context.Context, schemaName string) (map[string]struct{}, error) {

pkg/wal/processor/postgres/postgres_schema_observer_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,112 @@ func TestPGSchemaObserver_getGeneratedColumnNames(t *testing.T) {
153153
}
154154
}
155155

156+
func TestPGSchemaObserver_updateGeneratedColumnNames(t *testing.T) {
157+
t.Parallel()
158+
159+
// The DDL injector encodes attidentity as the human-readable strings
160+
// "ALWAYS" / "BY DEFAULT" in the schema log JSON, not the raw 'a'/'d'.
161+
const always = "ALWAYS"
162+
const byDefault = "BY DEFAULT"
163+
quotedTable := `"test_schema"."test_table"`
164+
165+
tests := []struct {
166+
name string
167+
cols []schemalog.Column
168+
169+
wantGenerated map[string]struct{}
170+
wantAlwaysIdentity map[string]struct{}
171+
}{
172+
{
173+
name: "plain columns",
174+
cols: []schemalog.Column{
175+
{Name: "id", DataType: "integer"},
176+
{Name: "name", DataType: "text"},
177+
},
178+
179+
wantGenerated: map[string]struct{}{},
180+
wantAlwaysIdentity: map[string]struct{}{},
181+
},
182+
{
183+
name: "truly generated column",
184+
cols: []schemalog.Column{
185+
{Name: "id", DataType: "integer"},
186+
{Name: "full_name", DataType: "text", Generated: true},
187+
},
188+
189+
wantGenerated: map[string]struct{}{`"full_name"`: {}},
190+
wantAlwaysIdentity: map[string]struct{}{},
191+
},
192+
{
193+
// GENERATED ALWAYS AS IDENTITY lands in BOTH caches: filtered
194+
// from all DML via generatedColumns (historical behavior — lets
195+
// target sequence auto-increment on INSERT) and additionally
196+
// tracked in alwaysIdentityColumns so UPDATE SET drops it even
197+
// when generatedColumns is empty (e.g. SQL-query cache path).
198+
name: "always identity column",
199+
cols: []schemalog.Column{
200+
{Name: "id", DataType: "bigint", Identity: always},
201+
{Name: "name", DataType: "text"},
202+
},
203+
204+
wantGenerated: map[string]struct{}{`"id"`: {}},
205+
wantAlwaysIdentity: map[string]struct{}{`"id"`: {}},
206+
},
207+
{
208+
// GENERATED BY DEFAULT AS IDENTITY also lands in generatedColumns
209+
// via IsGenerated() (Identity != ""). It is NOT in the
210+
// alwaysIdentity map since PG accepts explicit values for it in
211+
// UPDATE SET.
212+
name: "by-default identity column",
213+
cols: []schemalog.Column{
214+
{Name: "id", DataType: "bigint", Identity: byDefault},
215+
{Name: "name", DataType: "text"},
216+
},
217+
218+
wantGenerated: map[string]struct{}{`"id"`: {}},
219+
wantAlwaysIdentity: map[string]struct{}{},
220+
},
221+
{
222+
name: "mix of generated and always identity",
223+
cols: []schemalog.Column{
224+
{Name: "id", DataType: "bigint", Identity: always},
225+
{Name: "request_id", DataType: "bigint", Identity: always},
226+
{Name: "full_name", DataType: "text", Generated: true},
227+
{Name: "name", DataType: "text"},
228+
},
229+
230+
wantGenerated: map[string]struct{}{`"id"`: {}, `"request_id"`: {}, `"full_name"`: {}},
231+
wantAlwaysIdentity: map[string]struct{}{`"id"`: {}, `"request_id"`: {}},
232+
},
233+
}
234+
235+
for _, tc := range tests {
236+
t.Run(tc.name, func(t *testing.T) {
237+
t.Parallel()
238+
239+
obs := &pgSchemaObserver{
240+
generatedTableColumns: synclib.NewMap[string, map[string]struct{}](),
241+
alwaysIdentityTableColumns: synclib.NewMap[string, map[string]struct{}](),
242+
logger: loglib.NewNoopLogger(),
243+
}
244+
245+
obs.updateGeneratedColumnNames(&schemalog.LogEntry{
246+
SchemaName: testSchema,
247+
Schema: schemalog.Schema{
248+
Tables: []schemalog.Table{
249+
{Name: "test_table", Columns: tc.cols},
250+
},
251+
},
252+
})
253+
254+
require.Equal(t, map[string]map[string]struct{}{quotedTable: tc.wantGenerated},
255+
obs.generatedTableColumns.GetMap())
256+
require.Equal(t, map[string]map[string]struct{}{quotedTable: tc.wantAlwaysIdentity},
257+
obs.alwaysIdentityTableColumns.GetMap())
258+
})
259+
}
260+
}
261+
156262
func TestPGSchemaObserver_isMaterializedView(t *testing.T) {
157263
t.Parallel()
158264

0 commit comments

Comments
 (0)