Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unique together validator doesn't respect condition's fields #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,15 +1430,18 @@ def get_unique_together_constraints(self, model):
"""
for parent_class in [model] + list(model._meta.parents):
for unique_together in parent_class._meta.unique_together:
yield unique_together, model._default_manager
yield unique_together, model._default_manager, []
for constraint in parent_class._meta.constraints:
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
yield (
constraint.fields,
model._default_manager
if constraint.condition is None
else model._default_manager.filter(constraint.condition)
)
if constraint.condition is None:
queryset = model._default_manager
condition_fields = []
else:
queryset = model._default_manager.filter(constraint.condition)
condition_fields = [
f[0].split("__")[0] for f in constraint.condition.children
]
yield (constraint.fields, queryset, condition_fields)

def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
"""
Expand Down Expand Up @@ -1470,9 +1473,9 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs

# Include each of the `unique_together` and `UniqueConstraint` field names,
# so long as all the field names are included on the serializer.
for unique_together_list, queryset in self.get_unique_together_constraints(model):
if set(field_names).issuperset(unique_together_list):
unique_constraint_names |= set(unique_together_list)
for unique_together_list, queryset, condition_fields in self.get_unique_together_constraints(model):
if set(field_names).issuperset((*unique_together_list, *condition_fields)):
unique_constraint_names |= set((*unique_together_list, *condition_fields))

# Now we have all the field names that have uniqueness constraints
# applied, we can add the extra 'required=...' or 'default=...'
Expand Down Expand Up @@ -1592,12 +1595,12 @@ def get_unique_together_validators(self):
# Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes.
validators = []
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
for unique_together, queryset, condition_fields in self.get_unique_together_constraints(self.Meta.model):
# Skip if serializer does not map to all unique together sources
if not set(source_map).issuperset(unique_together):
if not set(source_map).issuperset((*unique_together, *condition_fields)):
continue

for source in unique_together:
for source in (*unique_together, *condition_fields):
assert len(source_map[source]) == 1, (
"Unable to create `UniqueTogetherValidator` for "
"`{model}.{field}` as `{serializer}` has multiple "
Expand All @@ -1614,9 +1617,9 @@ def get_unique_together_validators(self):
)

field_names = tuple(source_map[f][0] for f in unique_together)
condition_fields = tuple(source_map[f][0] for f in condition_fields)
validator = UniqueTogetherValidator(
queryset=queryset,
fields=field_names
queryset=queryset, fields=field_names, condition_fields=condition_fields
)
validators.append(validator)
return validators
Expand Down
7 changes: 4 additions & 3 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ class UniqueTogetherValidator:
missing_message = _('This field is required.')
requires_context = True

def __init__(self, queryset, fields, message=None):
def __init__(self, queryset, fields, message=None, condition_fields=None):
self.queryset = queryset
self.fields = fields
self.message = message or self.message
self.condition_fields = [] if condition_fields is None else condition_fields

def enforce_required_fields(self, attrs, serializer):
"""
Expand All @@ -114,7 +115,7 @@ def enforce_required_fields(self, attrs, serializer):

missing_items = {
field_name: self.missing_message
for field_name in self.fields
for field_name in (*self.fields, *self.condition_fields)
if serializer.fields[field_name].source not in attrs
}
if missing_items:
Expand All @@ -127,7 +128,7 @@ def filter_queryset(self, attrs, queryset, serializer):
# field names => field sources
sources = [
serializer.fields[field_name].source
for field_name in self.fields
for field_name in (*self.fields, *self.condition_fields)
]

# If this is an update, then any unprovided field should
Expand Down
48 changes: 46 additions & 2 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ class Meta:
name="unique_constraint_model_together_uniq",
fields=('race_name', 'position'),
condition=models.Q(race_name='example'),
),
models.UniqueConstraint(
name="unique_constraint_model_together_uniq2",
fields=('race_name', 'position'),
condition=models.Q(fancy_conditions__gte=10),
)
]

Expand Down Expand Up @@ -563,13 +568,52 @@ def test_unique_together_field(self):
to UniqueTogetherValidator as fields and queryset
"""
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
assert len(serializer.validators) == 2
validator = serializer.validators[0]
assert validator.fields == ('race_name', 'position')
assert set(validator.queryset.values_list(flat=True)) == set(
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
)

def test_unique_together_condition(self):
"""
Fields used in UniqueConstraint's condition must be included
into queryset existence check
"""
UniqueConstraintModel.objects.create(
race_name='condition',
position=1,
global_id=10,
fancy_conditions=10
)
serializer = UniqueConstraintSerializer(data={
'race_name': 'condition',
'position': 1,
'global_id': 11,
'fancy_conditions': 11,
})
assert serializer.is_valid()

def test_unique_together_condition_fields_required(self):
"""
Fields used in UniqueConstraint's condition must be present in serializer
"""
serializer = UniqueConstraintSerializer(data={
'race_name': 'condition',
'position': 1,
'global_id': 11,
})
assert not serializer.is_valid()
assert serializer.errors == {'fancy_conditions': ['This field is required.']}

class NoFieldsSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueConstraintModel
fields = ('race_name', 'position', 'global_id')

serializer = NoFieldsSerializer()
assert len(serializer.validators) == 1

def test_single_field_uniq_validators(self):
"""
UniqueConstraint with single field must be transformed into
Expand All @@ -579,7 +623,7 @@ def test_single_field_uniq_validators(self):
extra_validators_qty = 2 if django_version[0] >= 5 else 0
#
serializer = UniqueConstraintSerializer()
assert len(serializer.validators) == 1
assert len(serializer.validators) == 2
validators = serializer.fields['global_id'].validators
assert len(validators) == 1 + extra_validators_qty
assert validators[0].queryset == UniqueConstraintModel.objects
Expand Down