Skip to content

Commit

Permalink
fix insert on through tables
Browse files Browse the repository at this point in the history
  • Loading branch information
toluaina committed Jan 29, 2023
1 parent ea018fe commit ec966b9
Show file tree
Hide file tree
Showing 15 changed files with 495 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: black

- repo: https://github.com/pycqa/isort
rev: 5.11.4
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
1 change: 1 addition & 0 deletions examples/bootstrap.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
set -u
# create database prior to running this bootstrap
source .pythonpath
source .env

if [ $# -eq 0 ]; then
echo "No arguments supplied"
Expand Down
5 changes: 5 additions & 0 deletions examples/through/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DELETE FROM customer_group;

INSERT INTO customer_group (customer_id, group_id) VALUES ( 1, 2);

INSERT INTO customer_group (customer_id, group_id) VALUES ( 1, 3);
67 changes: 67 additions & 0 deletions examples/through/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import datetime
import random
from typing import Dict, List

import click
from faker import Faker
from schema import Customer, CustomerGroup, Group
from sqlalchemy.orm import sessionmaker

from pgsync.base import pg_engine, subtransactions
from pgsync.constants import DEFAULT_SCHEMA
from pgsync.helper import teardown
from pgsync.utils import config_loader, get_config


@click.command()
@click.option(
"--config",
"-c",
help="Schema config",
type=click.Path(exists=True),
)
def main(config):

config: str = get_config(config)
teardown(drop_db=False, config=config)

for document in config_loader(config):

database: str = document.get("database", document["index"])
with pg_engine(database) as engine:
schema: str = document.get("schema", DEFAULT_SCHEMA)
connection = engine.connect().execution_options(
schema_translate_map={None: schema}
)
Session = sessionmaker(bind=connection, autoflush=True)
session = Session()

customers = [
Customer(name="CustomerA"),
Customer(name="CustomerB"),
Customer(name="CustomerC"),
]
with subtransactions(session):
session.add_all(customers)

groups = [
Group(group_name="GroupA"),
Group(group_name="GroupB"),
Group(group_name="GroupC"),
]
with subtransactions(session):
session.add_all(groups)

customers_groups = [
CustomerGroup(customer=customers[0], group=groups[0]),
CustomerGroup(customer=customers[1], group=groups[1]),
CustomerGroup(customer=customers[2], group=groups[2]),
]
with subtransactions(session):
session.add_all(customers_groups)

session.commit()


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions examples/through/schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
[
{
"database": "through",
"index": "through",
"nodes":
{
"table": "customer",
"columns":
[
"id",
"name"
],
"children":
[
{
"table": "group",
"columns":
[
"id",
"group_name"
],
"relationship":
{
"variant": "object",
"type": "one_to_many",
"through_tables":
[
"customer_group"
]
}
}
]
}
}
]
88 changes: 88 additions & 0 deletions examples/through/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import click
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.schema import UniqueConstraint

from pgsync.base import create_database, create_schema, pg_engine
from pgsync.constants import DEFAULT_SCHEMA
from pgsync.helper import teardown
from pgsync.utils import config_loader, get_config

Base = declarative_base()


class Customer(Base):
__tablename__ = "customer"
__table_args__ = (UniqueConstraint("name"),)
id = sa.Column(sa.Integer, primary_key=True, autoincrement=True)
name = sa.Column(sa.String, nullable=False)


class Group(Base):
__tablename__ = "group"
__table_args__ = (
UniqueConstraint(
"group_name",
),
)
id = sa.Column(sa.Integer, primary_key=True, autoincrement=True)
group_name = sa.Column(sa.String, nullable=False)


class CustomerGroup(Base):
__tablename__ = "customer_group"
__table_args__ = (
UniqueConstraint(
"customer_id",
"group_id",
),
)
id = sa.Column(sa.Integer, primary_key=True, autoincrement=True)
customer_id = sa.Column(
sa.Integer,
sa.ForeignKey(Customer.id, ondelete="CASCADE"),
)
customer = sa.orm.relationship(
Customer,
backref=sa.orm.backref("customers"),
)
group_id = sa.Column(
sa.Integer,
sa.ForeignKey(Group.id, ondelete="CASCADE"),
)
group = sa.orm.relationship(
Group,
backref=sa.orm.backref("groups"),
)


def setup(config: str) -> None:
for document in config_loader(config):
database: str = document.get("database", document["index"])
schema: str = document.get("schema", DEFAULT_SCHEMA)
create_database(database)
create_schema(database, schema)
with pg_engine(database) as engine:
engine = engine.connect().execution_options(
schema_translate_map={None: schema}
)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)


@click.command()
@click.option(
"--config",
"-c",
help="Schema config",
type=click.Path(exists=True),
)
def main(config):

config: str = get_config(config)
teardown(config=config)
setup(config)


if __name__ == "__main__":
main()
32 changes: 31 additions & 1 deletion pgsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ def data(self) -> dict:
return self.old
return self.new

def foreign_key_constraint(self, model) -> dict:
"""
{
'public.customer': { referred table with a fully qualified name
'local': 'customer_id',
'remote': 'id',
'value': 1
},
'public.group': { referred table with a fully qualified name
'local': 'group_id',
'remote': 'id',
'value': 1
}
}
"""
constraints: dict = {}
for foreign_key in model.foreign_keys:
referred_table: str = str(foreign_key.constraint.referred_table)
constraints.setdefault(referred_table, {})
if foreign_key.constraint.column_keys:
if foreign_key.constraint.column_keys[0] in self.data:
constraints[referred_table] = {
"local": foreign_key.constraint.column_keys[0],
"remote": foreign_key.column.name,
"value": self.data[
foreign_key.constraint.column_keys[0]
],
}
return constraints


class TupleIdentifierType(sa.types.UserDefinedType):
cache_ok: bool = True
Expand Down Expand Up @@ -453,7 +483,7 @@ def logical_slot_peek_changes(
upto_nchanges: Optional[int] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[sa.engine.row.LegacyRow]:
) -> List[sa.engine.row.Row]:
"""Peek a logical replication slot without consuming changes.
SELECT * FROM PG_LOGICAL_SLOT_PEEK_CHANGES('testdb', NULL, 1)
Expand Down
53 changes: 39 additions & 14 deletions pgsync/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,23 @@ def _root_foreign_key_resolver(

return filters

def _through_node_resolver(
self, node: Node, payload: Payload, filters: list
) -> list:
"""Handle where node is a through table with a direct references to
root
"""
foreign_key_constraint = payload.foreign_key_constraint(node.model)
if self.tree.root.name in foreign_key_constraint:
filters.append(
{
foreign_key_constraint[self.tree.root.name][
"remote"
]: foreign_key_constraint[self.tree.root.name]["value"]
}
)
return filters

def _insert_op(
self, node: Node, filters: dict, payloads: List[Payload]
) -> dict:
Expand Down Expand Up @@ -540,28 +557,36 @@ def _insert_op(
)
raise

# set the parent as the new entity that has changed
foreign_keys = self.query_builder._get_foreign_keys(
node.parent,
node,
)
try:
foreign_keys = self.query_builder.get_foreign_keys(
node.parent,
node,
)
except ForeignKeyError:
foreign_keys = self.query_builder._get_foreign_keys(
node.parent,
node,
)

_filters: list = []
for payload in payloads:
for i, key in enumerate(foreign_keys[node.name]):
if key == foreign_keys[node.parent.name][i]:
filters[node.parent.table].append(
{
foreign_keys[node.parent.name][
i
]: payload.data[key]
}
)
for node_key in foreign_keys[node.name]:
for parent_key in foreign_keys[node.parent.name]:
if node_key == parent_key:
filters[node.parent.table].append(
{parent_key: payload.data[node_key]}
)

_filters = self._root_foreign_key_resolver(
node, payload, foreign_keys, _filters
)

# through table with a direct references to root
if not _filters:
_filters = self._through_node_resolver(
node, payload, _filters
)

if _filters:
filters[self.tree.root.table].extend(_filters)

Expand Down
4 changes: 3 additions & 1 deletion requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ faker
psycopg2-binary
redis
requests-aws4auth
sqlalchemy

# pin sqlalchemy to latest 1.* until 2.0 support
sqlalchemy==1.4.*
sqlparse

# pin these libs because latest flake8 does not allow newer versions of importlib-metadata https://github.com/PyCQA/flake8/issues/1522
Expand Down
Loading

0 comments on commit ec966b9

Please sign in to comment.