Skip to content

Commit

Permalink
feat(electric): Prevent updates to table PKs (#725)
Browse files Browse the repository at this point in the history
Co-authored-by: Oleksii Sholik <[email protected]>
  • Loading branch information
magnetised and alco authored Dec 12, 2023
1 parent 9676b4d commit 0dfb35d
Show file tree
Hide file tree
Showing 41 changed files with 1,275 additions and 307 deletions.
6 changes: 6 additions & 0 deletions .changeset/chilled-pillows-live.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@core/electric": patch
"electric-sql": patch
---

[VAX-1324] Prevent updates to table PKs
30 changes: 29 additions & 1 deletion clients/typescript/src/_generated/protocol/satellite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ export interface SatAuthResp {
export interface SatErrorResp {
$type: "Electric.Satellite.SatErrorResp";
errorType: SatErrorResp_ErrorCode;
/** lsn of the txn that caused the problem, if available */
lsn?:
| Uint8Array
| undefined;
/** human readable explanation of what went wrong */
message?: string | undefined;
}

export enum SatErrorResp_ErrorCode {
Expand Down Expand Up @@ -931,7 +937,7 @@ export const SatAuthResp = {
messageTypeRegistry.set(SatAuthResp.$type, SatAuthResp);

function createBaseSatErrorResp(): SatErrorResp {
return { $type: "Electric.Satellite.SatErrorResp", errorType: 0 };
return { $type: "Electric.Satellite.SatErrorResp", errorType: 0, lsn: undefined, message: undefined };
}

export const SatErrorResp = {
Expand All @@ -941,6 +947,12 @@ export const SatErrorResp = {
if (message.errorType !== 0) {
writer.uint32(8).int32(message.errorType);
}
if (message.lsn !== undefined) {
writer.uint32(18).bytes(message.lsn);
}
if (message.message !== undefined) {
writer.uint32(26).string(message.message);
}
return writer;
},

Expand All @@ -958,6 +970,20 @@ export const SatErrorResp = {

message.errorType = reader.int32() as any;
continue;
case 2:
if (tag !== 18) {
break;
}

message.lsn = reader.bytes();
continue;
case 3:
if (tag !== 26) {
break;
}

message.message = reader.string();
continue;
}
if ((tag & 7) === 4 || tag === 0) {
break;
Expand All @@ -974,6 +1000,8 @@ export const SatErrorResp = {
fromPartial<I extends Exact<DeepPartial<SatErrorResp>, I>>(object: I): SatErrorResp {
const message = createBaseSatErrorResp();
message.errorType = object.errorType ?? 0;
message.lsn = object.lsn ?? undefined;
message.message = object.message ?? undefined;
return message;
},
};
Expand Down
6 changes: 4 additions & 2 deletions components/electric/lib/electric/plug/migrations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule Electric.Plug.Migrations do
use Plug.Router
use Electric.Satellite.Protobuf

alias Electric.Postgres.Extension.SchemaCache
alias Electric.Postgres.Extension.{SchemaCache, SchemaLoader}

require Logger

Expand Down Expand Up @@ -106,8 +106,10 @@ defmodule Electric.Plug.Migrations do

defp translate_stmts(version, schema, stmts, dialect) do
Enum.flat_map(stmts, fn stmt ->
schema_version = SchemaLoader.Version.new(version, schema)

{:ok, msgs, _relations} =
Electric.Postgres.Replication.migrate(schema, version, stmt, dialect)
Electric.Postgres.Replication.migrate(schema_version, stmt, dialect)

msgs
end)
Expand Down
78 changes: 43 additions & 35 deletions components/electric/lib/electric/postgres/extension/schema_cache.ex
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,6 @@ defmodule Electric.Postgres.Extension.SchemaCache do
call(origin, {:relation_oid, type, schema, name})
end

@impl SchemaLoader
def primary_keys(origin, {schema, name}) do
call(origin, {:primary_keys, schema, name})
end

@impl SchemaLoader
def primary_keys(origin, schema, name) do
call(origin, {:primary_keys, schema, name})
end

@impl SchemaLoader
def refresh_subscription(origin, name) do
call(origin, {:refresh_subscription, name})
Expand Down Expand Up @@ -246,27 +236,39 @@ defmodule Electric.Postgres.Extension.SchemaCache do

@impl GenServer
def handle_call({:load, :current}, _from, %{current: nil} = state) do
{result, state} = load_current_schema(state)

{:reply, result, state}
with {{:ok, schema_version}, state} <- load_current_schema(state) do
{:reply, {:ok, schema_version}, state}
else
{error, state} ->
{:reply, error, state}
end
end

def handle_call({:load, :current}, _from, %{current: {version, schema}} = state) do
{:reply, {:ok, version, schema}, state}
def handle_call(
{:load, :current},
_from,
%{current: %SchemaLoader.Version{} = schema_version} = state
) do
{:reply, {:ok, schema_version}, state}
end

def handle_call({:load, {:version, version}}, _from, %{current: {version, schema}} = state) do
{:reply, {:ok, version, schema}, state}
def handle_call(
{:load, {:version, version}},
_from,
%{current: %{version: version} = schema_version} = state
) do
{:reply, {:ok, schema_version}, state}
end

def handle_call({:load, {:version, version}}, _from, state) do
{:reply, SchemaLoader.load(state.backend, version), state}
end

def handle_call({:save, version, schema, stmts}, _from, state) do
{:ok, backend} = SchemaLoader.save(state.backend, version, schema, stmts)
{:ok, backend, schema_version} = SchemaLoader.save(state.backend, version, schema, stmts)

{:reply, {:ok, state.origin}, %{state | backend: backend, current: {version, schema}}}
{:reply, {:ok, state.origin, schema_version},
%{state | backend: backend, current: schema_version}}
end

def handle_call({:relation_oid, type, schema, name}, _from, state) do
Expand All @@ -275,8 +277,14 @@ defmodule Electric.Postgres.Extension.SchemaCache do

def handle_call({:primary_keys, sname, tname}, _from, state) do
{result, state} =
with {{:ok, _version, schema}, state} <- current_schema(state) do
{Schema.primary_keys(schema, sname, tname), state}
with {{:ok, schema_version}, state} <- current_schema(state) do
case SchemaLoader.Version.primary_keys(schema_version, {sname, tname}) do
{:ok, pks} ->
{{:ok, pks}, state}

{:error, _reason} = error ->
{error, state}
end
end

{:reply, result, state}
Expand All @@ -296,8 +304,8 @@ defmodule Electric.Postgres.Extension.SchemaCache do
end

def handle_call(:electrified_tables, _from, state) do
load_and_reply(state, fn schema ->
{:ok, Schema.table_info(schema)}
load_and_reply(state, fn schema_version ->
{:ok, Schema.table_info(schema_version.schema)}
end)
end

Expand All @@ -323,21 +331,21 @@ defmodule Electric.Postgres.Extension.SchemaCache do
end

def handle_call({:relation, oid}, _from, state) when is_integer(oid) do
load_and_reply(state, fn schema ->
Schema.table_info(schema, oid)
load_and_reply(state, fn schema_version ->
Schema.table_info(schema_version.schema, oid)
end)
end

def handle_call({:relation, {_sname, _tname} = relation}, _from, state) do
load_and_reply(state, fn schema ->
Schema.table_info(schema, relation)
load_and_reply(state, fn schema_version ->
Schema.table_info(schema_version.schema, relation)
end)
end

def handle_call({:relation, relation, version}, _from, state) do
{result, state} =
with {:ok, ^version, schema} <- SchemaLoader.load(state.backend, version) do
{Schema.table_info(schema, relation), state}
with {:ok, schema_version} <- SchemaLoader.load(state.backend, version) do
{Schema.table_info(schema_version.schema, relation), state}
else
error -> {error, state}
end
Expand Down Expand Up @@ -422,14 +430,14 @@ defmodule Electric.Postgres.Extension.SchemaCache do
load_current_schema(state)
end

defp current_schema(%{current: {version, schema}} = state) do
{{:ok, version, schema}, state}
defp current_schema(%{current: schema_version} = state) do
{{:ok, schema_version}, state}
end

defp load_current_schema(state) do
case SchemaLoader.load(state.backend) do
{:ok, version, schema} ->
{{:ok, version, schema}, %{state | current: {version, schema}}}
{:ok, schema_version} ->
{{:ok, schema_version}, %{state | current: schema_version}}

error ->
{error, state}
Expand All @@ -438,8 +446,8 @@ defmodule Electric.Postgres.Extension.SchemaCache do

defp load_and_reply(state, process) when is_function(process, 1) do
{result, state} =
with {{:ok, _version, schema}, state} <- current_schema(state) do
{process.(schema), state}
with {{:ok, schema_version}, state} <- current_schema(state) do
{process.(schema_version), state}
else
error -> {error, state}
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,6 @@ defmodule Electric.Postgres.Extension.SchemaCache.Global do
fun.(pid)
end

def primary_keys({_schema, _name} = relation) do
with_instance(fn pid ->
SchemaCache.primary_keys(pid, relation)
end)
end

def primary_keys(schema, name) when is_binary(schema) and is_binary(name) do
with_instance(fn pid ->
SchemaCache.primary_keys(pid, schema, name)
end)
end

def migration_history(version) do
with_instance(fn pid ->
SchemaCache.migration_history(pid, version)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defmodule Electric.Postgres.Extension.SchemaLoader do
alias Electric.Postgres.{Schema, Extension.Migration}
alias Electric.Replication.Connectors
alias __MODULE__.Version

@type state() :: term()
@type version() :: String.t()
Expand All @@ -18,13 +19,11 @@ defmodule Electric.Postgres.Extension.SchemaLoader do
@type tx_fk_row() :: %{binary() => integer() | binary()}

@callback connect(Connectors.config(), Keyword.t()) :: {:ok, state()}
@callback load(state()) :: {:ok, version(), Schema.t()}
@callback load(state(), version()) :: {:ok, version(), Schema.t()} | {:error, binary()}
@callback load(state()) :: {:ok, Version.t()}
@callback load(state(), version()) :: {:ok, Version.t()} | {:error, binary()}
@callback save(state(), version(), Schema.t(), [String.t()]) ::
{:ok, state()} | {:error, term()}
{:ok, state(), Version.t()} | {:error, term()}
@callback relation_oid(state(), rel_type(), schema(), name()) :: oid_result()
@callback primary_keys(state(), schema(), name()) :: pk_result()
@callback primary_keys(state(), relation()) :: pk_result()
@callback refresh_subscription(state(), name()) :: :ok | {:error, term()}
@callback migration_history(state(), version() | nil) ::
{:ok, [Migration.t()]} | {:error, term()}
Expand Down Expand Up @@ -61,23 +60,15 @@ defmodule Electric.Postgres.Extension.SchemaLoader do
end

def save({module, state}, version, schema, stmts) do
with {:ok, state} <- module.save(state, version, schema, stmts) do
{:ok, {module, state}}
with {:ok, state, schema_version} <- module.save(state, version, schema, stmts) do
{:ok, {module, state}, schema_version}
end
end

def relation_oid({module, state}, rel_type, schema, table) do
module.relation_oid(state, rel_type, schema, table)
end

def primary_keys({module, state}, schema, table) do
module.primary_keys(state, schema, table)
end

def primary_keys({_module, _state} = impl, {schema, table}) do
primary_keys(impl, schema, table)
end

def refresh_subscription({module, state}, name) do
module.refresh_subscription(state, name)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,26 @@ defmodule Electric.Postgres.Extension.SchemaLoader.Epgsql do
@impl true
def load(pool) do
checkout!(pool, fn conn ->
Extension.current_schema(conn)
with {:ok, version, schema} <- Extension.current_schema(conn) do
{:ok, SchemaLoader.Version.new(version, schema)}
end
end)
end

@impl true
def load(pool, version) do
checkout!(pool, fn conn ->
Extension.schema_version(conn, version)
with {:ok, version, schema} <- Extension.schema_version(conn, version) do
{:ok, SchemaLoader.Version.new(version, schema)}
end
end)
end

@impl true
def save(pool, version, schema, stmts) do
checkout!(pool, fn conn ->
with :ok <- Extension.save_schema(conn, version, schema, stmts) do
{:ok, pool}
{:ok, pool, SchemaLoader.Version.new(version, schema)}
end
end)
end
Expand All @@ -124,35 +128,6 @@ defmodule Electric.Postgres.Extension.SchemaLoader.Epgsql do
end)
end

@primary_keys_query """
SELECT a.attname
FROM pg_class c
INNER JOIN pg_namespace n ON c.relnamespace = n.oid
INNER JOIN pg_index i ON i.indrelid = c.oid
INNER JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE
n.nspname = $1
AND c.relname = $2
AND c.relkind = 'r'
AND i.indisprimary
"""

@impl true
def primary_keys(pool, schema, name) do
checkout!(pool, fn conn ->
{:ok, _, pks_data} = :epgsql.equery(conn, @primary_keys_query, [schema, name])

{:ok, Enum.map(pks_data, &elem(&1, 0))}
end)
end

@impl true
def primary_keys(pool, {schema, name}) do
checkout!(pool, fn conn ->
primary_keys(conn, schema, name)
end)
end

@impl true
def refresh_subscription(pool, name) do
checkout!(pool, fn conn ->
Expand Down
Loading

0 comments on commit 0dfb35d

Please sign in to comment.