Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 76 additions & 10 deletions src/polymorphic/query_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ def _translate_polymorphic_filter_definition(
# handle instance_of expressions or alternatively,
# if this is a normal Django filter expression, return None
if field_path == "instance_of":
return create_instanceof_q(field_val, using=using)
return create_instanceof_q(field_val, using=using, queryset_model=queryset_model)
elif field_path == "not_instance_of":
return create_instanceof_q(field_val, not_instance_of=True, using=using)
return create_instanceof_q(
field_val, not_instance_of=True, using=using, queryset_model=queryset_model
)
elif "___" not in field_path:
return None # no change

Expand Down Expand Up @@ -253,13 +255,62 @@ def _get_query_related_name(myclass):
return myclass.__name__.lower()


def create_instanceof_q(modellist, not_instance_of=False, using=DEFAULT_DB_ALIAS):
def _resolve_model_string(model_string, queryset_model=None):
"""
Resolve a string model reference to an actual model class.

Supports formats:
- 'app_label.ModelName' - full app.model reference
- 'ModelName' - looks up in subclasses of queryset_model if provided

Args:
model_string: String reference to model
queryset_model: Optional base model for context when using short names

Returns:
Model class

Raises:
ValueError: If model cannot be resolved
"""
if "." in model_string:
# Format: 'app_label.ModelName'
app_label, model_name = model_string.rsplit(".", 1)
try:
model = apps.get_model(app_label, model_name)
return model
except LookupError:
raise ValueError(f"PolymorphicModel: model '{model_string}' not found")
else:
# Format: 'ModelName' - try to find in subclasses if queryset_model provided
if queryset_model:
submodels = _get_all_sub_models(queryset_model)
model = submodels.get(model_string)
if model:
return model

# Model name without app label - cannot resolve without context
raise ValueError(
f"PolymorphicModel: model '{model_string}' not found. "
f"Use 'app_label.ModelName' format for clarity."
)


def create_instanceof_q(
modellist, not_instance_of=False, using=DEFAULT_DB_ALIAS, queryset_model=None
):
"""
Helper function for instance_of / not_instance_of
Creates and returns a Q object that filters for the models in modellist,
including all subclasses of these models (as we want to do the same
as pythons isinstance() ).
.

Args:
modellist: Model class, list of models, or string references ('app.Model')
not_instance_of: If True, negate the query
using: Database alias
queryset_model: Base model for resolving short model name strings

We recursively collect all __subclasses__(), create a Q filter for each,
and or-combine these Q objects. This could be done much more
efficiently however (regarding the resulting sql), should an optimization
Expand All @@ -268,18 +319,33 @@ def create_instanceof_q(modellist, not_instance_of=False, using=DEFAULT_DB_ALIAS
if not modellist:
return None

# Normalize to list
if not isinstance(modellist, (list, tuple)):
from .models import PolymorphicModel

if issubclass(modellist, PolymorphicModel):
modellist = [modellist]
modellist = [modellist]

# Resolve any string references and validate models
resolved_models = []
from .models import PolymorphicModel

for item in modellist:
if isinstance(item, str):
# String reference - resolve it
model = _resolve_model_string(item, queryset_model)
if not issubclass(model, PolymorphicModel):
raise TypeError(
f"PolymorphicModel: '{item}' resolves to {model.__name__} "
f"which is not a PolymorphicModel"
)
resolved_models.append(model)
elif isinstance(item, type) and issubclass(item, PolymorphicModel):
resolved_models.append(item)
else:
raise TypeError(
"PolymorphicModel: instance_of expects a list of (polymorphic) "
"models or a single (polymorphic) model"
"models, model strings ('app.Model'), or a single (polymorphic) model/string"
)

contenttype_ids = _get_mro_content_type_ids(modellist, using)
contenttype_ids = _get_mro_content_type_ids(resolved_models, using)
q = Q(polymorphic_ctype__in=sorted(contenttype_ids))
if not_instance_of:
q = ~q
Expand Down
54 changes: 54 additions & 0 deletions src/polymorphic/tests/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,60 @@ def test_instance_of_filter(self):
objects, [Model2A], transform=lambda o: o.__class__, ordered=False
)

def test_instance_of_with_string_reference(self):
"""Test instance_of with string model reference (issue #505)"""
self.create_model2abcd()

# Test with full app.Model format
objects = Model2A.objects.instance_of("tests.Model2B")
self.assertQuerySetEqual(
objects,
[Model2B, Model2C, Model2D],
transform=lambda o: o.__class__,
ordered=False,
)

# Test with short name (requires queryset context)
objects = Model2A.objects.instance_of("Model2B")
self.assertQuerySetEqual(
objects,
[Model2B, Model2C, Model2D],
transform=lambda o: o.__class__,
ordered=False,
)

def test_instance_of_Q_object_with_string_reference(self):
"""Test instance_of in Q-object with string model reference (issue #505)"""
self.create_model2abcd()

# Test with full app.Model format in filter
objects = Model2A.objects.filter(instance_of="tests.Model2B")
self.assertQuerySetEqual(
objects,
[Model2B, Model2C, Model2D],
transform=lambda o: o.__class__,
ordered=False,
)

# Test with full app.Model format in explicit Q object
objects = Model2A.objects.filter(Q(instance_of="tests.Model2B"))
self.assertQuerySetEqual(
objects,
[Model2B, Model2C, Model2D],
transform=lambda o: o.__class__,
ordered=False,
)

def test_not_instance_of_with_string_reference(self):
"""Test not_instance_of with string model reference (issue #505)"""
self.create_model2abcd()

# Test with full app.Model format
objects = Model2A.objects.not_instance_of("tests.Model2B")
self.assertQuerySetEqual(
objects, [Model2A], transform=lambda o: o.__class__, ordered=False
)

def test_polymorphic___filter(self):
self.create_model2abcd()

Expand Down
Loading