diff --git a/src/polymorphic/base.py b/src/polymorphic/base.py index 112b9c46..96e6bce9 100644 --- a/src/polymorphic/base.py +++ b/src/polymorphic/base.py @@ -85,11 +85,6 @@ def __new__(cls, model_name, bases, attrs, **kwargs): if not new_class._meta.abstract and not new_class._meta.swapped: cls.validate_model_manager(new_class.objects, model_name, "objects") - # determine the name of the primary key field and store it into the class variable - # polymorphic_primary_key_name (it is needed by query.py) - if new_class._meta.pk: - new_class.polymorphic_primary_key_name = new_class._meta.pk.attname - # wrap on_delete handlers of reverse relations back to this model with the # polymorphic deletion guard for fk in new_class._meta.fields: diff --git a/src/polymorphic/models.py b/src/polymorphic/models.py index 4850e023..b097e375 100644 --- a/src/polymorphic/models.py +++ b/src/polymorphic/models.py @@ -2,13 +2,17 @@ Seamless Polymorphic Inheritance for Django Models """ +import warnings + from django.contrib.contenttypes.models import ContentType from django.db import models, transaction from django.db.utils import DEFAULT_DB_ALIAS +from django.utils.functional import classproperty from .base import PolymorphicModelBase from .managers import PolymorphicManager from .query_translate import translate_polymorphic_Q_object +from .utils import get_base_polymorphic_model ################################################################################### # PolymorphicModel @@ -58,6 +62,20 @@ class Meta: abstract = True base_manager_name = "objects" + @classproperty + def polymorphic_primary_key_name(cls): + """ + The name of the root primary key field of this polymorphic inheritance chain. + """ + warnings.warn( + "polymorphic_primary_key_name is deprecated and will be removed in " + "version 5.0, use get_base_polymorphic_model(Model)._meta.pk.attname " + "instead.", + DeprecationWarning, + stacklevel=2, + ) + return get_base_polymorphic_model(cls, allow_abstract=True)._meta.pk.attname + @classmethod def translate_polymorphic_Q_object(cls, q): return translate_polymorphic_Q_object(cls, q) diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index 34e85d07..30f1dd86 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -392,14 +392,9 @@ class self.model, but as a class derived from self.model. We want to re-fetch # classes if child class retrieval fails classes_to_query = [] - # django's automatic ".pk" field does not always work correctly for - # custom fields in derived objects (unclear yet who to put the blame on). - # We get different type(o.pk) in this case. - # We work around this by using the real name of the field directly - # for accessing the primary key of the the derived objects. - # We might assume that self.model._meta.pk.name gives us the name of the primary key field, - # but it doesn't. Therefore we use polymorphic_primary_key_name, which we set up in base.py. - pk_name = self.model.polymorphic_primary_key_name + # use the pk attribute for the base model type used in the query to identify + # objects + pk_name = self.model._meta.pk.attname # - sort base_result_object ids into idlist_per_model lists, depending on their real class; # - store objects that already have the correct class into "results" @@ -444,7 +439,7 @@ class self.model, but as a class derived from self.model. We want to re-fetch classes_to_query, (class_priorities.get(real_concrete_class, 0), real_concrete_class), ) - idlist_per_model[real_concrete_class].append(base_object.pk) + idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name)) indexlist_per_model[real_concrete_class].append((i, len(resultlist))) resultlist.append(None) diff --git a/src/polymorphic/tests/test_base.py b/src/polymorphic/tests/test_base.py index e3d005f2..16837f5d 100644 --- a/src/polymorphic/tests/test_base.py +++ b/src/polymorphic/tests/test_base.py @@ -136,3 +136,54 @@ def test_default_manager_without_dumpdata_command(self): # Should be a PolymorphicManager assert isinstance(manager, PolymorphicManager) + + +class PrimaryKeyNameTest(TestCase): + def test_polymorphic_primary_key_name_correctness(self): + """ + Verify that polymorphic_primary_key_name points to the root pk in the + inheritance chain. + + Regression test for #758. Will go away in version 5.0 + """ + from polymorphic.tests.models import ( + CustomPkInherit, + CustomPkBase, + Model2A, + Model2B, + Model2C, + ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.assertEqual( + CustomPkInherit.polymorphic_primary_key_name, CustomPkBase._meta.pk.attname + ) + self.assertEqual(CustomPkInherit.polymorphic_primary_key_name, "id") + + self.assertEqual(Model2A.polymorphic_primary_key_name, Model2A._meta.pk.attname) + self.assertEqual(Model2A.polymorphic_primary_key_name, "id") + + self.assertEqual(Model2B.polymorphic_primary_key_name, Model2A._meta.pk.attname) + self.assertEqual(Model2B.polymorphic_primary_key_name, "id") + + self.assertEqual(Model2C.polymorphic_primary_key_name, Model2A._meta.pk.attname) + self.assertEqual(Model2C.polymorphic_primary_key_name, "id") + + assert w[0].category is DeprecationWarning + assert "polymorphic_primary_key_name" in str(w[0].message) + + def test_multiple_inheritance_pk_name(self): + """ + Verify multiple inheritance scenarios. + """ + from polymorphic.tests.models import Enhance_Inherit, Enhance_Base + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.assertEqual( + Enhance_Inherit.polymorphic_primary_key_name, Enhance_Base._meta.pk.attname + ) + self.assertEqual(Enhance_Inherit.polymorphic_primary_key_name, "base_id") + assert w[0].category is DeprecationWarning + assert "polymorphic_primary_key_name" in str(w[0].message) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 119109cc..41bb850c 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -1,3 +1,4 @@ +import warnings import pytest import uuid @@ -1659,8 +1660,15 @@ def test_subqueries(self): def test_one_to_one_primary_key(self): # check pk name resolution - for mdl in [Account, SpecialAccount1, SpecialAccount1_1, SpecialAccount2]: - assert mdl.polymorphic_primary_key_name == mdl._meta.pk.attname + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + for mdl in [Account, SpecialAccount1, SpecialAccount1_1, SpecialAccount2]: + assert mdl.polymorphic_primary_key_name == Account._meta.pk.attname + + assert w[0].category is DeprecationWarning + assert "polymorphic_primary_key_name" in str(w[0].message) user1 = get_user_model().objects.create( username="user1", email="user1@example.com", password="password" diff --git a/src/polymorphic/utils.py b/src/polymorphic/utils.py index d3c04f4a..49636992 100644 --- a/src/polymorphic/utils.py +++ b/src/polymorphic/utils.py @@ -72,6 +72,7 @@ def sort_by_subclass(*classes): return sorted(classes, key=cmp_to_key(_compare_mro)) +@lru_cache(maxsize=None) def get_base_polymorphic_model(ChildModel, allow_abstract=False): """ First the first concrete model in the inheritance chain that inherited from the