Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
de06188
first fix versoin, working only if the items has the same id
Ckk3 Nov 19, 2024
0cd732d
bring back the first version, still missin the different ids logic!
Ckk3 Nov 19, 2024
401cd65
fix: now query can pickup related_model and self_model id
Ckk3 Nov 22, 2024
6f644e3
fix: not working with different ids
Ckk3 Nov 22, 2024
ca9bc1c
add nes tests
Ckk3 Nov 23, 2024
eb852ce
add tests
Ckk3 Nov 23, 2024
5770379
Fix mypy erros, still missing some tests
Ckk3 Nov 24, 2024
be77996
update code to work with sqlalchemy 1.4
Ckk3 Nov 24, 2024
fb6a580
remove old code that only works with sqlalchemy 2
Ckk3 Nov 24, 2024
0fb61bb
add seconday tables tests in test_loader
Ckk3 Nov 24, 2024
03a5438
add new tests to loadar and start mapper tests
Ckk3 Nov 26, 2024
a575650
add mapper tests
Ckk3 Nov 28, 2024
beaa3f9
refactor conftest
Ckk3 Nov 30, 2024
8a65328
refactor test_loader
Ckk3 Nov 30, 2024
9d76061
refactor test_mapper
Ckk3 Nov 30, 2024
91c24c5
run autopep
Ckk3 Nov 30, 2024
1cd8df4
run autopep
Ckk3 Nov 30, 2024
e96f179
separate test
Ckk3 Nov 30, 2024
4b6516b
fix lint
Ckk3 Nov 30, 2024
9b079d4
add release file
Ckk3 Nov 30, 2024
4baa7ae
refactor tests
Ckk3 Nov 30, 2024
33d7758
refactor loader
Ckk3 Nov 30, 2024
2a53474
fix release
Ckk3 Nov 30, 2024
d04af46
update pre-commit to work with python 3.8
Ckk3 Jan 26, 2025
3f7f13d
update loader.py
Ckk3 Jan 26, 2025
ff3e419
updated mapper
Ckk3 Jan 26, 2025
6752231
fix lint
Ckk3 Jan 26, 2025
0cd68d2
remote autopep8 from dev container because it give problems when work…
Ckk3 Jan 26, 2025
0745c64
fix lint
Ckk3 Jan 26, 2025
a8c03a5
Merge remote-tracking branch 'origin/main' into issue-19
Ckk3 May 17, 2025
276ffd5
Formatter and ruff
Ckk3 May 17, 2025
f3e2149
change release type to minor, add new todo
Ckk3 Jun 11, 2025
94ae7b1
Merge remote-tracking branch 'origin/main' into issue-19
Ckk3 Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

Add support for secondary table relationships in SQLAlchemy mapper, addressing a bug and enhancing the loader to handle these relationships efficiently.
8 changes: 8 additions & 0 deletions src/strawberry_sqlalchemy_mapper/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ def __init__(self, model):
f"Model `{model}` is not polymorphic or is not the base model of its "
+ "inheritance chain, and thus cannot be used as an interface."
)


class InvalidLocalRemotePairs(Exception):
def __init__(self, relationship_name):
super().__init__(
f"The `local_remote_pairs` for the relationship `{relationship_name}` is invalid or missing. "
+ "This is likely an issue with the library. Please report this error to the maintainers."
)
84 changes: 75 additions & 9 deletions src/strawberry_sqlalchemy_mapper/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Tuple,
Union,
)
from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs

from sqlalchemy import select, tuple_
from sqlalchemy.engine.base import Connection
Expand Down Expand Up @@ -45,12 +46,16 @@ def __init__(
"One of bind or async_bind_factory must be set for loader to function properly."
)

async def _scalars_all(self, *args, **kwargs):
async def _scalars_all(self, *args, disabled_optimization_to_secondary_tables=False, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: maybe call this enable_ and have True as the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont want to do this because it removes optimizations that we only need to remove when we need to pick up secondary tables values, so if the default is True we will lose peformance in queries that dont need it.
But I agree that this var name aren't good enought, so I will change the name to query_secondary_tables and refactor the function.

if self._async_bind_factory:
async with self._async_bind_factory() as bind:
if disabled_optimization_to_secondary_tables is True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick:

Suggested change
if disabled_optimization_to_secondary_tables is True:
if disabled_optimization_to_secondary_tables:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated! Thank you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

return (await bind.execute(*args, **kwargs)).all()
return (await bind.scalars(*args, **kwargs)).all()
else:
assert self._bind is not None
if disabled_optimization_to_secondary_tables is True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick:

Suggested change
if disabled_optimization_to_secondary_tables is True:
if disabled_optimization_to_secondary_tables:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

return self._bind.execute(*args, **kwargs).all()
return self._bind.scalars(*args, **kwargs).all()

def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
Expand All @@ -63,14 +68,69 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
related_model = relationship.entity.entity

async def load_fn(keys: List[Tuple]) -> List[Any]:
query = select(related_model).filter(
tuple_(
*[remote for _, remote in relationship.local_remote_pairs or []]
).in_(keys)
)
def _build_normal_relationship_query(related_model, relationship, keys):
return select(related_model).filter(
tuple_(
*[remote for _, remote in relationship.local_remote_pairs or []]
).in_(keys)
)

def _build_relationship_with_secondary_table_query(related_model, relationship, keys):
# Use another query when relationship uses a secondary table
self_model = relationship.parent.entity

if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{related_model.__name__} -- {self_model.__name__}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish: for ruff/black this parenthesis should be closed in the next line. I think you forgot to pre-commit install =P (ditto for the lines below)

Also, we are probably missing a lint check in here which runs ruff/black/etc (and maybe migrate to ruff formatter instead of black soon)

Copy link
Contributor Author

@Ckk3 Ckk3 Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, I see now that pre-commit dont run due to some updates that dont work with dev container python version (3.8).
I updated that imports and now i'm fixing all the erros ;)


self_model_key_label = str(
relationship.local_remote_pairs[0][1].key)
related_model_key_label = str(
relationship.local_remote_pairs[1][1].key)

self_model_key = str(
relationship.local_remote_pairs[0][0].key)
related_model_key = str(
relationship.local_remote_pairs[1][0].key)

remote_to_use = relationship.local_remote_pairs[0][1]
query_keys = tuple([item[0] for item in keys])

# This query returns rows in this format -> (self_model.key, related_model)
return (
select(
getattr(self_model, self_model_key).label(
self_model_key_label),
related_model
)
.join(
relationship.secondary,
getattr(relationship.secondary.c,
related_model_key_label) == getattr(related_model, related_model_key)
)
.join(
self_model,
getattr(relationship.secondary.c,
self_model_key_label) == getattr(self_model, self_model_key)
)
.filter(
remote_to_use.in_(query_keys)
)
)

def _build_query(*args):
return _build_normal_relationship_query(*args) if relationship.secondary is None else _build_relationship_with_secondary_table_query(*args)

query = _build_query(related_model, relationship, keys)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: total personal preference, so feel free to ignore. I feel like those 2 _buil functions could be defined in the module and query could be defined with the if directly, like query = _build_normal_relationship_query(related_model, relationship, keys) if relationship.secondary is None else _build_relationship_with_secondary_table_query(related_model, relationship, keys)

Even better to avoid *args in there as mypy/pyright can validate them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!


if relationship.order_by:
query = query.order_by(*relationship.order_by)
rows = await self._scalars_all(query)

if relationship.secondary is not None:
# We need to retrieve values from both the self_model and related_model. To achieve this, we must disable the default SQLAlchemy optimization that returns only related_model values. This is necessary because we use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=True)
else:
rows = await self._scalars_all(query)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion:

Suggested change
if relationship.secondary is not None:
# We need to retrieve values from both the self_model and related_model. To achieve this, we must disable the default SQLAlchemy optimization that returns only related_model values. This is necessary because we use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=True)
else:
rows = await self._scalars_all(query)
# We need to retrieve values from both the self_model and related_model.
# To achieve this, we must disable the default SQLAlchemy optimization
# that returns only related_model values. This is necessary because we
# use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=relationship.secondary is not None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated!


def group_by_remote_key(row: Any) -> Tuple:
return tuple(
Expand All @@ -82,8 +142,13 @@ def group_by_remote_key(row: Any) -> Tuple:
)

grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list)
for row in rows:
grouped_keys[group_by_remote_key(row)].append(row)
if relationship.secondary is None:
for row in rows:
grouped_keys[group_by_remote_key(row)].append(row)
else:
for row in rows:
grouped_keys[(row[0],)].append(row[1])

if relationship.uselist:
return [grouped_keys[key] for key in keys]
else:
Expand All @@ -94,3 +159,4 @@ def group_by_remote_key(row: Any) -> Tuple:

self._loaders[relationship] = DataLoader(load_fn=load_fn)
return self._loaders[relationship]

93 changes: 63 additions & 30 deletions src/strawberry_sqlalchemy_mapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from strawberry_sqlalchemy_mapper.exc import (
HybridPropertyNotAnnotated,
InterfaceModelNotPolymorphic,
InvalidLocalRemotePairs,
UnsupportedAssociationProxyTarget,
UnsupportedColumnType,
UnsupportedDescriptorType,
Expand Down Expand Up @@ -154,7 +155,8 @@ def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...

@overload
@classmethod
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...
def from_type(cls, type_: type, *,
strict: bool = False) -> Optional[Self]: ...

@classmethod
def from_type(
Expand All @@ -165,7 +167,8 @@ def from_type(
) -> Optional[Self]:
definition = getattr(type_, cls.TYPE_KEY_NAME, None)
if strict and definition is None:
raise TypeError(f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
raise TypeError(
f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
return definition


Expand Down Expand Up @@ -228,11 +231,12 @@ class StrawberrySQLAlchemyMapper(Generic[BaseModelType]):

def __init__(
self,
model_to_type_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
model_to_interface_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
extra_sqlalchemy_type_to_strawberry_type_map: Optional[
Mapping[Type[TypeEngine], Type[Any]]
] = None,
model_to_type_name: Optional[Callable[[
Type[BaseModelType]], str]] = None,
model_to_interface_name: Optional[Callable[[
Type[BaseModelType]], str]] = None,
extra_sqlalchemy_type_to_strawberry_type_map: Optional[Mapping[Type[TypeEngine], Type[Any]]
] = None,
) -> None:
if model_to_type_name is None:
model_to_type_name = self._default_model_to_type_name
Expand Down Expand Up @@ -295,7 +299,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
"""
edge_name = f"{type_name}Edge"
if edge_name not in self.edge_types:
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
lazy_type = StrawberrySQLAlchemyLazy(
type_name=type_name, mapper=self)
self.edge_types[edge_name] = edge_type = strawberry.type(
dataclasses.make_dataclass(
edge_name,
Expand All @@ -314,14 +319,15 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
connection_name = f"{type_name}Connection"
if connection_name not in self.connection_types:
edge_type = self._edge_type_for(type_name)
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
lazy_type = StrawberrySQLAlchemyLazy(
type_name=type_name, mapper=self)
self.connection_types[connection_name] = connection_type = strawberry.type(
dataclasses.make_dataclass(
connection_name,
[
("edges", List[edge_type]), # type: ignore[valid-type]
],
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
)
)
setattr(connection_type, _GENERATED_FIELD_KEYS_KEY, ["edges"])
Expand Down Expand Up @@ -387,7 +393,7 @@ def _convert_relationship_to_strawberry_type(
if relationship.uselist:
# Use list if excluding relay pagination
if use_list:
return List[ForwardRef(type_name)] # type: ignore
return List[ForwardRef(type_name)] # type: ignore

return self._connection_type_for(type_name)
else:
Expand Down Expand Up @@ -451,7 +457,8 @@ def _get_association_proxy_annotation(
strawberry_type.__forward_arg__
)
else:
strawberry_type = self._connection_type_for(strawberry_type.__name__)
strawberry_type = self._connection_type_for(
strawberry_type.__name__)
return strawberry_type

def make_connection_wrapper_resolver(
Expand Down Expand Up @@ -500,13 +507,29 @@ async def resolve(self, info: Info):
if relationship.key not in instance_state.unloaded:
related_objects = getattr(self, relationship.key)
else:
relationship_key = tuple(
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
)
if relationship.secondary is None:
relationship_key = tuple(
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: you can pass the iterator to the tuple directly, no need to create a list for that

Suggested change
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

)
else:
# If has a secondary table, gets only the first ID as additional IDs require a separate query
if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{relationship.entity.entity.__name__} -- {relationship.parent.entity.__name__}")

local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[
0][0]
relationship_key = tuple(
[
getattr(
self, str(local_remote_pairs_secondary_table_local.key)),
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

)

if any(item is None for item in relationship_key):
if relationship.uselist:
return []
Expand Down Expand Up @@ -536,7 +559,8 @@ def connection_resolver_for(
if relationship.uselist and not use_list:
return self.make_connection_wrapper_resolver(
relationship_resolver,
self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type]
self.model_to_type_or_interface_name(
relationship.entity.entity), # type: ignore[arg-type]
)
else:
return relationship_resolver
Expand All @@ -554,13 +578,15 @@ def association_proxy_resolver_for(
Return an async field resolver for the given association proxy.
"""
in_between_relationship = mapper.relationships[descriptor.target_collection]
in_between_resolver = self.relationship_resolver_for(in_between_relationship)
in_between_resolver = self.relationship_resolver_for(
in_between_relationship)
in_between_mapper: Mapper = mapper.relationships[ # type: ignore[assignment]
descriptor.target_collection
].entity
assert descriptor.value_attr in in_between_mapper.relationships
end_relationship = in_between_mapper.relationships[descriptor.value_attr]
end_relationship_resolver = self.relationship_resolver_for(end_relationship)
end_relationship_resolver = self.relationship_resolver_for(
end_relationship)
end_type_name = self.model_to_type_or_interface_name(
end_relationship.entity.entity # type: ignore[arg-type]
)
Expand All @@ -587,7 +613,8 @@ async def resolve(self, info: Info):
if outputs and isinstance(outputs[0], list):
outputs = list(chain.from_iterable(outputs))
else:
outputs = [output for output in outputs if output is not None]
outputs = [
output for output in outputs if output is not None]
else:
outputs = await end_relationship_resolver(in_between_objects, info)
if not isinstance(outputs, collections.abc.Iterable):
Expand Down Expand Up @@ -683,7 +710,8 @@ def convert(type_: Any) -> Any:
setattr(type_, key, field(resolver=val))
generated_field_keys.append(key)

self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
self._handle_columns(
mapper, type_, excluded_keys, generated_field_keys)
relationship: RelationshipProperty
for key, relationship in mapper.relationships.items():
if (
Expand Down Expand Up @@ -805,7 +833,7 @@ def convert(type_: Any) -> Any:
setattr(
type_,
attr,
types.MethodType(func, type_), # type: ignore[arg-type]
types.MethodType(func, type_), # type: ignore[arg-type]
)

# Adjust types that inherit from other types/interfaces that implement Node
Expand All @@ -818,7 +846,8 @@ def convert(type_: Any) -> Any:
setattr(
type_,
attr,
types.MethodType(cast(classmethod, meth).__func__, type_),
types.MethodType(
cast(classmethod, meth).__func__, type_),
)

# need to make fields that are already in the type
Expand Down Expand Up @@ -846,7 +875,8 @@ def convert(type_: Any) -> Any:
model=model,
),
)
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY,
generated_field_keys)
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
return mapped_type

Expand Down Expand Up @@ -886,14 +916,16 @@ def _fix_annotation_namespaces(self) -> None:
self.edge_types.values(),
self.connection_types.values(),
):
strawberry_definition = get_object_definition(mapped_type, strict=True)
strawberry_definition = get_object_definition(
mapped_type, strict=True)
for f in strawberry_definition.fields:
if f.name in getattr(mapped_type, _GENERATED_FIELD_KEYS_KEY):
namespace = {}
if hasattr(mapped_type, _ORIGINAL_TYPE_KEY):
namespace.update(
sys.modules[
getattr(mapped_type, _ORIGINAL_TYPE_KEY).__module__
getattr(mapped_type,
_ORIGINAL_TYPE_KEY).__module__
].__dict__
)
namespace.update(self.mapped_types)
Expand Down Expand Up @@ -924,7 +956,8 @@ def _map_unmapped_relationships(self) -> None:
if type_name not in self.mapped_interfaces:
unmapped_interface_models.add(model)
for model in unmapped_models:
self.type(model)(type(self.model_to_type_name(model), (object,), {}))
self.type(model)(
type(self.model_to_type_name(model), (object,), {}))
for model in unmapped_interface_models:
self.interface(model)(
type(self.model_to_interface_name(model), (object,), {})
Expand Down
Loading
Loading