Skip to content

Commit

Permalink
refactor: add subscribed() queryset methods
Browse files Browse the repository at this point in the history
  • Loading branch information
danjac committed Aug 29, 2024
1 parent 16f6747 commit dbac4d4
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 30 deletions.
10 changes: 10 additions & 0 deletions radiofeed/episodes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
class EpisodeQuerySet(FastCountQuerySetMixin, SearchQuerySetMixin, FastUpdateQuerySet):
"""QuerySet for Episode model."""

def subscribed(self, user: User) -> models.QuerySet[Podcast]:
"""Returns episodes belonging to episodes subscribed by user."""
return self.alias(
is_subscribed=models.Exists(
user.subscriptions.filter(
podcast=models.OuterRef("podcast"),
)
)
).filter(is_subscribed=True)


class Episode(models.Model):
"""Individual podcast episode."""
Expand Down
11 changes: 10 additions & 1 deletion radiofeed/episodes/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
EpisodeFactory,
)
from radiofeed.podcasts.models import Podcast
from radiofeed.podcasts.tests.factories import PodcastFactory
from radiofeed.podcasts.tests.factories import PodcastFactory, SubscriptionFactory


class TestEpisodeManager:
Expand All @@ -23,6 +23,15 @@ def test_search_empty(self):
EpisodeFactory(title="testing")
assert Episode.objects.search("").count() == 0

@pytest.mark.django_db
def test_subscribed_true(self, user, episode):
SubscriptionFactory(subscriber=user, podcast=episode.podcast)
assert Episode.objects.subscribed(user).exists() is True

@pytest.mark.django_db
def test_subscribed_false(self, user, episode):
assert Episode.objects.subscribed(user).exists() is False


class TestEpisodeModel:
link = "https://example.com"
Expand Down
8 changes: 1 addition & 7 deletions radiofeed/episodes/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from django.contrib import messages
from django.contrib.auth.decorators import login_required
from django.db import IntegrityError
from django.db.models import Exists, OuterRef
from django.http import (
Http404,
HttpRequest,
Expand Down Expand Up @@ -32,13 +31,8 @@ def index(request: HttpRequest) -> TemplateResponse:
"""List latest episodes from subscriptions."""

episodes = (
Episode.objects.alias(
is_subscribed=Exists(
request.user.subscriptions.filter(podcast=OuterRef("podcast"))
)
)
Episode.objects.subscribed(request.user)
.filter(
is_subscribed=True,
pub_date__gt=timezone.now() - timedelta(days=14),
)
.select_related("podcast")
Expand Down
10 changes: 10 additions & 0 deletions radiofeed/podcasts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def search(self, search_term: str) -> models.QuerySet[Podcast]:
)
return qs.annotate(exact_match=models.Value(0))

def subscribed(self, user: User) -> models.QuerySet[Podcast]:
"""Returns podcasts subscribed by user."""
return self.alias(
is_subscribed=models.Exists(
user.subscriptions.filter(
podcast=models.OuterRef("pk"),
)
)
).filter(is_subscribed=True)


class Podcast(models.Model):
"""Podcast channel or feed."""
Expand Down
10 changes: 10 additions & 0 deletions radiofeed/podcasts/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
CategoryFactory,
PodcastFactory,
RecommendationFactory,
SubscriptionFactory,
)


Expand Down Expand Up @@ -109,6 +110,15 @@ def test_compare_exact_and_partial_matches_in_search(self):
assert second.title == "the testing"
assert second.exact_match == 0

@pytest.mark.django_db
def test_subscribed_true(self, user):
SubscriptionFactory(subscriber=user)
assert Podcast.objects.subscribed(user).exists() is True

@pytest.mark.django_db
def test_subscribed_false(self, user, podcast):
assert Podcast.objects.subscribed(user).exists() is False


class TestPodcastModel:
def test_str(self):
Expand Down
24 changes: 2 additions & 22 deletions radiofeed/podcasts/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,7 @@ def index(request: HttpRequest) -> HttpResponseRedirect | TemplateResponse:
@login_required
def subscriptions(request: HttpRequest) -> TemplateResponse:
"""Render podcast index page."""
podcasts = (
_get_podcasts()
.alias(
is_subscribed=Exists(
request.user.subscriptions.filter(
podcast=OuterRef("pk"),
)
)
)
.filter(is_subscribed=True)
)
podcasts = _get_podcasts().subscribed(request.user)
podcasts = (
podcasts.search(request.search.value).order_by(
"-exact_match",
Expand Down Expand Up @@ -307,17 +297,7 @@ def unsubscribe(request: HttpRequest, podcast_id: int) -> TemplateResponse:
@login_required
def private_feeds(request: HttpRequest) -> TemplateResponse:
"""Lists user's private feeds."""
podcasts = (
_get_podcasts()
.alias(
is_subscribed=Exists(
request.user.subscriptions.filter(
podcast=OuterRef("pk"),
)
)
)
.filter(private=True, is_subscribed=True)
)
podcasts = _get_podcasts().subscribed(request.user).filter(private=True)

podcasts = (
podcasts.search(request.search.value).order_by(
Expand Down

0 comments on commit dbac4d4

Please sign in to comment.