Skip to content

Commit

Permalink
Fixes & features for higher taxon ranks (#496)
Browse files Browse the repository at this point in the history
* Update admin to use new parents_json

* Add tests for parents_json & recursive counts

* Default taxon parents to list, update django-pydantic-fields

* Use real TaxonRank objects everywhere except API responses

* Show all occurrences under a taxon parent

* Placeholder hack to include a recursive occurrence count for detail views

* Give up on add all occurrence counts in one query for now

* Link to occurrences list in species detail modal, even if occurrences count not available
  • Loading branch information
mihow authored Aug 13, 2024
1 parent f7f3470 commit 0a0ea5e
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 41 deletions.
6 changes: 4 additions & 2 deletions ami/main/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ class TaxonAdmin(admin.ModelAdmin[Taxon]):
)
list_filter = ("lists", "rank", TaxonParentFilter)
search_fields = ("name",)
exclude = ("parents",)

# annotate queryset with occurrence counts and allow sorting
# https://docs.djangoproject.com/en/3.2/ref/contrib/admin/#django.contrib.admin.ModelAdmin.list_display
Expand Down Expand Up @@ -278,7 +277,10 @@ def update_display_names(self, request: HttpRequest, queryset: QuerySet[Taxon])
ordering="parents",
)
def parent_names(self, obj) -> str:
return ", ".join([str(taxon) for taxon in obj.parents.values_list("name", flat=True)])
if obj.parents_json:
return ", ".join([str(taxon.name) for taxon in obj.parents_json])
else:
return ""

actions = [update_species_parents, update_display_names]

Expand Down
17 changes: 12 additions & 5 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime

from django.db.models import QuerySet
from django_pydantic_field.rest_framework import SchemaField
from rest_framework import serializers

from ami.base.serializers import DefaultSerializer, get_current_user, reverse_with_params
Expand Down Expand Up @@ -29,7 +28,6 @@
SourceImageCollection,
SourceImageUpload,
Taxon,
TaxonParent,
)


Expand Down Expand Up @@ -408,13 +406,22 @@ class Meta:
]


class TaxonParentSerializer(serializers.Serializer):
id = serializers.IntegerField()
name = serializers.CharField()
rank = serializers.SerializerMethodField()

def get_rank(self, obj):
return obj.rank.value


class TaxonNestedSerializer(TaxonNoParentNestedSerializer):
"""
Simple Taxon serializer with 1 level of nested parents.
"""

parent = TaxonNoParentNestedSerializer(read_only=True)
parents = SchemaField(list[TaxonParent], source="parents_json", read_only=True)
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")

class Meta(TaxonNoParentNestedSerializer.Meta):
fields = TaxonNoParentNestedSerializer.Meta.fields + [
Expand Down Expand Up @@ -492,7 +499,7 @@ def get_occurrence_images(self, obj):

class CaptureTaxonSerializer(DefaultSerializer):
parent = TaxonNoParentNestedSerializer(read_only=True)
parents = SchemaField(list[TaxonParent], source="parents_json", read_only=True)
parents = TaxonParentSerializer(many=True, read_only=True)

class Meta:
model = Taxon
Expand Down Expand Up @@ -649,7 +656,7 @@ class TaxonSerializer(DefaultSerializer):
parent = TaxonNoParentNestedSerializer(read_only=True)
parent_id = serializers.PrimaryKeyRelatedField(queryset=Taxon.objects.all(), source="parent", write_only=True)
# parents = TaxonParentNestedSerializer(many=True, read_only=True, source="parents_json")
parents = SchemaField(list[TaxonParent], source="parents_json", read_only=True)
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")

class Meta:
model = Taxon
Expand Down
30 changes: 28 additions & 2 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from django.utils import timezone
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import exceptions as api_exceptions
from rest_framework import serializers, status, viewsets
from rest_framework import filters, serializers, status, viewsets
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound
from rest_framework.filters import SearchFilter
Expand Down Expand Up @@ -626,6 +626,20 @@ def get_serializer_class(self):
# "detection_algorithm").all()


class CustomDeterminationFilter(filters.BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
determination_id = request.query_params.get("determination")
if determination_id:
try:
taxon = Taxon.objects.get(id=determination_id)
return queryset.filter(
models.Q(determination=taxon) | models.Q(determination__parents_json__contains=[{"id": taxon.id}])
)
except Taxon.DoesNotExist:
return queryset.none() # or just return queryset if you prefer
return queryset


class OccurrenceViewSet(DefaultViewSet):
"""
API endpoint that allows occurrences to be viewed or edited.
Expand All @@ -634,7 +648,9 @@ class OccurrenceViewSet(DefaultViewSet):
queryset = Occurrence.objects.all()

serializer_class = OccurrenceSerializer
filterset_fields = ["event", "deployment", "determination", "project", "determination__rank"]
# filter_backends = [CustomDeterminationFilter, DjangoFilterBackend, NullsLastOrderingFilter, SearchFilter]
filter_backends = DefaultViewSetMixin.filter_backends + [CustomDeterminationFilter]
filterset_fields = ["event", "deployment", "project", "determination__rank"]
ordering_fields = [
"created_at",
"updated_at",
Expand Down Expand Up @@ -681,6 +697,7 @@ def get_queryset(self) -> QuerySet:
.exclude(first_appearance_timestamp=None) # This must come after annotations
.order_by("-determination_score")
)

else:
qs = qs.prefetch_related(
Prefetch(
Expand Down Expand Up @@ -895,6 +912,15 @@ def get_queryset(self) -> QuerySet:

return qs

# def retrieve(self, request: Request, *args, **kwargs) -> Response:
# """
# Override the serializer to include the recursive occurrences count
# """
# taxon: Taxon = self.get_object()
# taxon.occurrences_count = taxon.occurrences_count_recursive() # type: ignore
# response = Response(TaxonSerializer(taxon, context={"request": request}).data)
# return response


class ClassificationViewSet(DefaultViewSet):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Generated by Django 4.2.10 on 2024-08-06 17:47

import ami.main.models
from django.db import migrations, models
import django_pydantic_field._migration_serializers
import django_pydantic_field.fields


class Migration(migrations.Migration):
dependencies = [
("main", "0034_remove_taxon_parents_taxon_parents_json"),
]

operations = [
migrations.AlterField(
model_name="taxon",
name="parents_json",
field=django_pydantic_field.fields.PydanticSchemaField(
blank=True,
config=None,
default=list,
schema=django_pydantic_field._migration_serializers.GenericContainer(
list, (ami.main.models.TaxonParent,)
),
),
),
migrations.AlterField(
model_name="taxon",
name="rank",
field=models.CharField(
choices=[
("ORDER", "ORDER"),
("SUPERFAMILY", "SUPERFAMILY"),
("FAMILY", "FAMILY"),
("SUBFAMILY", "SUBFAMILY"),
("TRIBE", "TRIBE"),
("SUBTRIBE", "SUBTRIBE"),
("GENUS", "GENUS"),
("SPECIES", "SPECIES"),
("UNKNOWN", "UNKNOWN"),
],
default="SPECIES",
max_length=255,
),
),
]
125 changes: 95 additions & 30 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,45 +2106,93 @@ def root(self):

def update_all_parents(self):
"""Efficiently update all parents for all taxa."""

taxa = self.get_queryset().select_related("parent")

logging.info(f"Updating the cached parent tree for {taxa.count()} taxa")

# Build a dictionary of taxon parents
parents = {taxon: taxon.parent for taxon in taxa}
parents = {taxon.id: taxon.parent_id for taxon in taxa}

# Precompute all parents in a single pass
all_parents = {}
for taxon_id in parents:
if taxon_id not in all_parents:
taxon_parents = []
current_id = taxon_id
while current_id in parents:
current_id = parents[current_id]
taxon_parents.append(current_id)
all_parents[taxon_id] = taxon_parents

# Prepare bulk update data
bulk_update_data = []
for taxon in taxa:
taxon_parents = all_parents[taxon.id]
parent_taxa = list(taxa.filter(id__in=taxon_parents))
taxon_parents = [
TaxonParent(
id=taxon.id,
name=taxon.name,
rank=taxon.rank,
)
for taxon in parent_taxa
]
taxon_parents.sort(key=lambda t: t.rank)

# Update all parents
for taxon, parent in parents.items():
logging.info(f"Updating parents for {taxon}")
bulk_update_data.append(taxon)

taxon_parents = []
while parent:
taxon_parents.append(parent)
# If this is None, the parent is the root taxon, so we stop here.
parent = parents.get(parent)
# Perform bulk update
# with transaction.atomic():
# self.bulk_update(bulk_update_data, ["parents_json"], batch_size=1000)
# There is a bug that causes the bulk update to fail with a custom JSONField
# https://code.djangoproject.com/ticket/35167
# So we have to update each taxon individually
for taxon in bulk_update_data:
taxon.save(update_fields=["parents_json"])

# Convert the taxa to the JSON TaxonParent type
taxon_parents = [TaxonParent(id=t.pk, name=t.name, rank=TaxonRank(t.rank)) for t in taxon_parents]
logging.info(f"Updated parents for {len(bulk_update_data)} taxa")

# Sort the parents by rank (achievable because TaxonRank is an ordered enum)
taxon_parents.sort(key=lambda t: t.rank)
def with_children(self):
qs = self.get_queryset()
# Add Taxon that are children of this Taxon using parents_json field (not direct_children)

taxon.parents_json = taxon_parents
taxon.save()
# example for single taxon:
taxon = Taxon.objects.get(pk=1)
taxa = Taxon.objects.filter(parents_json__contains=[{"id": taxon.id}])
# add them to the queryset
qs = qs.annotate(children=models.Subquery(taxa.values("id")))
return qs

def with_occurrence_counts(self) -> models.QuerySet:
"""
Count the number of occurrences for a taxon and all occurrences of the taxon's children.
@TODO Try a recursive CTE in a raw SQL query,
or count the occurrences in a separate query and attach them to the Taxon objects.
"""

raise NotImplementedError(
"Occurrence counts can not be calculated in a subquery with the current JSONField schema. "
"Fetch them per taxon."
)


class TaxonParent(pydantic.BaseModel):
"""
Should contain all data needed for TaxonParentSerializer
Needs a custom encoder and decoder for for the TaxonRank enum
because it is an OrderedEnum and not a standard str Enum.
"""

id: int
name: str
rank: TaxonRank

class Config:
use_enum_values = True
# Make sure the TaxonRank is retrieved as an object and not a string
# so we can sort by rank. The DRF serializer will convert it to a string.
# just for the API responses.
use_enum_values = False


@final
Expand All @@ -2161,7 +2209,7 @@ class Taxon(BaseModel):
# Examples how to query this JSON array field
# Taxon.objects.filter(parents_json__contains=[{"id": 1}])
# https://stackoverflow.com/a/53942463/966058
parents_json = SchemaField(list[TaxonParent], null=True, blank=True)
parents_json = SchemaField(list[TaxonParent], null=False, blank=True, default=list)

active = models.BooleanField(default=True)
synonym_of = models.ForeignKey("self", on_delete=models.SET_NULL, null=True, blank=True, related_name="synonyms")
Expand All @@ -2174,7 +2222,6 @@ class Taxon(BaseModel):

projects = models.ManyToManyField("Project", related_name="taxa")
direct_children: models.QuerySet["Taxon"]
children: models.QuerySet["Taxon"]
occurrences: models.QuerySet[Occurrence]
classifications: models.QuerySet["Classification"]
lists: models.QuerySet["TaxaList"]
Expand All @@ -2186,6 +2233,9 @@ class Taxon(BaseModel):

objects: TaxaManager = TaxaManager()

# Type hints for auto-generated fields
parent_id: int | None

def __str__(self) -> str:
name_with_rank = f"{self.name} ({self.rank})"
return name_with_rank
Expand Down Expand Up @@ -2213,13 +2263,24 @@ def num_direct_children(self) -> int:
return self.direct_children.count()

def num_children_recursive(self) -> int:
# @TODO how to do this with a single query?
return self.children.count() + sum(child.num_children_recursive() for child in self.children.all())
# Use the parents_json field to get all children
return Taxon.objects.filter(parents_json__contains=[{"id": self.pk}]).count()

def occurrences_count(self) -> int:
# return self.occurrences.count()
return 0

def occurrences_count_recursive(self) -> int:
"""
Use the parents_json field to get all children, count their occurrences and sum them.
"""
return (
Taxon.objects.filter(models.Q(models.Q(parents_json__contains=[{"id": self.pk}]) | models.Q(id=self.pk)))
.annotate(occurrences_count=models.Count("occurrences"))
.aggregate(models.Sum("occurrences_count"))["occurrences_count__sum"]
or 0
)

def detections_count(self) -> int:
# return Detection.objects.filter(occurrence__determination=self).count()
return 0
Expand Down Expand Up @@ -2288,21 +2349,25 @@ def list_names(self) -> str:

def update_parents(self, save=True):
"""
Populate the cached "parents" list by recursively following the "parent" field.
Populate the cached `parents_json` list by recursively following the `parent` field.
@TODO this requires the parents' parents already being up-to-date, which may not always be the case.
@TODO this requires all of the taxon's parent taxa to have the `parent` attribute set correctly.
"""

taxon = self
current_taxon = self
parents = []
while taxon.parent is not None:
parents.append(TaxonParent(id=taxon.parent.id, name=taxon.parent.name, rank=taxon.parent.rank))
taxon = taxon.parent
while current_taxon.parent is not None:
parents.append(
TaxonParent(id=current_taxon.parent.id, name=current_taxon.parent.name, rank=current_taxon.parent.rank)
)
current_taxon = current_taxon.parent
# Sort parents by rank using ordered enum
parents = sorted(parents, key=lambda t: t.rank)
taxon.parents_json = parents
self.parents_json = parents
if save:
taxon.save()
self.save()

return parents

class Meta:
ordering = [
Expand Down
Loading

0 comments on commit 0a0ea5e

Please sign in to comment.