diff --git a/src/polymorphic/query_translate.py b/src/polymorphic/query_translate.py index 59960815..2c48085e 100644 --- a/src/polymorphic/query_translate.py +++ b/src/polymorphic/query_translate.py @@ -114,9 +114,11 @@ def _translate_polymorphic_filter_definition( # handle instance_of expressions or alternatively, # if this is a normal Django filter expression, return None if field_path == "instance_of": - return create_instanceof_q(field_val, using=using) + return create_instanceof_q(field_val, using=using, queryset_model=queryset_model) elif field_path == "not_instance_of": - return create_instanceof_q(field_val, not_instance_of=True, using=using) + return create_instanceof_q( + field_val, not_instance_of=True, using=using, queryset_model=queryset_model + ) elif "___" not in field_path: return None # no change @@ -253,13 +255,62 @@ def _get_query_related_name(myclass): return myclass.__name__.lower() -def create_instanceof_q(modellist, not_instance_of=False, using=DEFAULT_DB_ALIAS): +def _resolve_model_string(model_string, queryset_model=None): + """ + Resolve a string model reference to an actual model class. + + Supports formats: + - 'app_label.ModelName' - full app.model reference + - 'ModelName' - looks up in subclasses of queryset_model if provided + + Args: + model_string: String reference to model + queryset_model: Optional base model for context when using short names + + Returns: + Model class + + Raises: + ValueError: If model cannot be resolved + """ + if "." in model_string: + # Format: 'app_label.ModelName' + app_label, model_name = model_string.rsplit(".", 1) + try: + model = apps.get_model(app_label, model_name) + return model + except LookupError: + raise ValueError(f"PolymorphicModel: model '{model_string}' not found") + else: + # Format: 'ModelName' - try to find in subclasses if queryset_model provided + if queryset_model: + submodels = _get_all_sub_models(queryset_model) + model = submodels.get(model_string) + if model: + return model + + # Model name without app label - cannot resolve without context + raise ValueError( + f"PolymorphicModel: model '{model_string}' not found. " + f"Use 'app_label.ModelName' format for clarity." + ) + + +def create_instanceof_q( + modellist, not_instance_of=False, using=DEFAULT_DB_ALIAS, queryset_model=None +): """ Helper function for instance_of / not_instance_of Creates and returns a Q object that filters for the models in modellist, including all subclasses of these models (as we want to do the same as pythons isinstance() ). - . + + Args: + modellist: Model class, list of models, or string references ('app.Model') + not_instance_of: If True, negate the query + using: Database alias + queryset_model: Base model for resolving short model name strings + We recursively collect all __subclasses__(), create a Q filter for each, and or-combine these Q objects. This could be done much more efficiently however (regarding the resulting sql), should an optimization @@ -268,18 +319,33 @@ def create_instanceof_q(modellist, not_instance_of=False, using=DEFAULT_DB_ALIAS if not modellist: return None + # Normalize to list if not isinstance(modellist, (list, tuple)): - from .models import PolymorphicModel - - if issubclass(modellist, PolymorphicModel): - modellist = [modellist] + modellist = [modellist] + + # Resolve any string references and validate models + resolved_models = [] + from .models import PolymorphicModel + + for item in modellist: + if isinstance(item, str): + # String reference - resolve it + model = _resolve_model_string(item, queryset_model) + if not issubclass(model, PolymorphicModel): + raise TypeError( + f"PolymorphicModel: '{item}' resolves to {model.__name__} " + f"which is not a PolymorphicModel" + ) + resolved_models.append(model) + elif isinstance(item, type) and issubclass(item, PolymorphicModel): + resolved_models.append(item) else: raise TypeError( "PolymorphicModel: instance_of expects a list of (polymorphic) " - "models or a single (polymorphic) model" + "models, model strings ('app.Model'), or a single (polymorphic) model/string" ) - contenttype_ids = _get_mro_content_type_ids(modellist, using) + contenttype_ids = _get_mro_content_type_ids(resolved_models, using) q = Q(polymorphic_ctype__in=sorted(contenttype_ids)) if not_instance_of: q = ~q diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 2f579b3e..2e264d8f 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -650,6 +650,60 @@ def test_instance_of_filter(self): objects, [Model2A], transform=lambda o: o.__class__, ordered=False ) + def test_instance_of_with_string_reference(self): + """Test instance_of with string model reference (issue #505)""" + self.create_model2abcd() + + # Test with full app.Model format + objects = Model2A.objects.instance_of("tests.Model2B") + self.assertQuerySetEqual( + objects, + [Model2B, Model2C, Model2D], + transform=lambda o: o.__class__, + ordered=False, + ) + + # Test with short name (requires queryset context) + objects = Model2A.objects.instance_of("Model2B") + self.assertQuerySetEqual( + objects, + [Model2B, Model2C, Model2D], + transform=lambda o: o.__class__, + ordered=False, + ) + + def test_instance_of_Q_object_with_string_reference(self): + """Test instance_of in Q-object with string model reference (issue #505)""" + self.create_model2abcd() + + # Test with full app.Model format in filter + objects = Model2A.objects.filter(instance_of="tests.Model2B") + self.assertQuerySetEqual( + objects, + [Model2B, Model2C, Model2D], + transform=lambda o: o.__class__, + ordered=False, + ) + + # Test with full app.Model format in explicit Q object + objects = Model2A.objects.filter(Q(instance_of="tests.Model2B")) + self.assertQuerySetEqual( + objects, + [Model2B, Model2C, Model2D], + transform=lambda o: o.__class__, + ordered=False, + ) + + def test_not_instance_of_with_string_reference(self): + """Test not_instance_of with string model reference (issue #505)""" + self.create_model2abcd() + + # Test with full app.Model format + objects = Model2A.objects.not_instance_of("tests.Model2B") + self.assertQuerySetEqual( + objects, [Model2A], transform=lambda o: o.__class__, ordered=False + ) + def test_polymorphic___filter(self): self.create_model2abcd()