Skip to content

Commit

Permalink
Small codebase refactor, keep track of changes to models in state
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp committed Aug 5, 2024
1 parent 46886aa commit 90036ea
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 66 deletions.
10 changes: 3 additions & 7 deletions misago/posting/forms/start.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from django import forms

from ..states.start import StartState
from ..states.start import StartThreadState


class StartForm(forms.Form):
class ThreadStartForm(forms.Form):
title = forms.CharField(max_length=200)
post = forms.CharField(max_length=2000, widget=forms.Textarea)

def update_state(self, state: StartState):
def update_state(self, state: StartThreadState):
state.set_thread_title(self.cleaned_data["title"])
state.set_post_message(self.cleaned_data["post"])


class ThreadStartForm(StartForm):
pass
104 changes: 104 additions & 0 deletions misago/posting/states/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from copy import deepcopy
from datetime import datetime
from typing import Any, TYPE_CHECKING

from django.db import models
from django.http import HttpRequest
from django.utils import timezone

from ...categories.models import Category
from ...core.utils import slugify
from ...parser.context import ParserContext, create_parser_context
from ...parser.enums import ContentType, PlainTextFormat
from ...parser.factory import create_parser
from ...parser.html import render_ast_to_html
from ...parser.metadata import create_ast_metadata
from ...parser.plaintext import render_ast_to_plaintext
from ...threads.models import Post, Thread

if TYPE_CHECKING:
from ...users.models import User


class State:
request: HttpRequest
timestamp: datetime
user: "User"

category: Category
thread: Thread
post: Post

parser_context: ParserContext
message_ast: list[dict] | None
message_metadata: dict | None

models_states: dict

def __init__(self, request: HttpRequest):
self.request = request
self.timestamp = timezone.now()
self.user = request.user

self.parser_context = self.initialize_parser_context()
self.message_ast = None
self.message_metadata = None

self.models_states = {}
self.store_model_state(self.user)

def store_model_state(self, model: models.Model):
state_key = self.get_model_state_key(model)
self.models_states[state_key] = self.get_model_state(model)

def get_model_state_key(self, model: models.Model) -> str:
return f"{model.__class__.__name__}:{model.pk}"

def get_model_state(self, model: models.Model) -> dict[str, Any]:
state = {}

for field in model._meta.get_fields():
if not isinstance(
field,
(models.ManyToManyRel, models.ManyToOneRel, models.ManyToManyField),
):
state[field.name] = deepcopy(getattr(model, field.attname))

return state

def get_model_changed_fields(self, model: models.Model) -> set[str]:
state_key = self.get_model_state_key(model)
old_state = self.models_states[state_key]

changed_fields: set[str] = set()
for field, value in self.get_model_state(model).items():
if old_state[field] != value:
changed_fields.add(field)

return changed_fields

def save_model_changes(self, model: models.Model) -> set[str]:
update_fields = self.get_model_changed_fields(model)
if update_fields:
model.save(update_fields=update_fields)
return update_fields

def initialize_parser_context(self) -> ParserContext:
return create_parser_context(self.request, content_type=ContentType.POST)

def set_post_message(self, message: str):
parser = create_parser(self.parser_context)
ast = parser(message)
metadata = create_ast_metadata(self.parser_context, ast)

self.post.original = message
self.post.parsed = render_ast_to_html(self.parser_context, ast, metadata)
self.post.search_document = render_ast_to_plaintext(
self.parser_context,
ast,
metadata,
text_format=PlainTextFormat.SEARCH_DOCUMENT,
)

self.message_ast = ast
self.message_metadata = metadata
47 changes: 14 additions & 33 deletions misago/posting/states/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from ...parser.plaintext import render_ast_to_plaintext
from ...threads.checksums import update_post_checksum
from ...threads.models import Post, Thread
from .base import State

if TYPE_CHECKING:
from ...users.models import User


class StartState:
class StartThreadState(State):
request: HttpRequest
timestamp: datetime
category: Category
Expand All @@ -32,15 +33,13 @@ class StartState:
message_metadata: dict | None

def __init__(self, request: HttpRequest, category: Category):
self.request = request
self.timestamp = timezone.now()
self.category = category
self.user = request.user
super().__init__(request)

self.category = category
self.thread = self.initialize_thread()
self.post = self.initialize_post()

self.parser_context = self.initialize_parser_context()
self.store_model_state(category)

def initialize_thread(self) -> Thread:
return Thread(
Expand All @@ -65,46 +64,26 @@ def initialize_post(self) -> Post:
updated_on=self.timestamp,
)

def initialize_parser_context(self) -> ParserContext:
return create_parser_context(self.request, content_type=ContentType.POST)

def set_thread_title(self, title: str):
self.thread.title = title
self.thread.slug = slugify(title)

def set_post_message(self, message: str):
parser = create_parser(self.parser_context)
ast = parser(message)
metadata = create_ast_metadata(self.parser_context, ast)

self.post.original = message
self.post.parsed = render_ast_to_html(self.parser_context, ast, metadata)
self.post.search_document = render_ast_to_plaintext(
self.parser_context,
ast,
metadata,
text_format=PlainTextFormat.SEARCH_DOCUMENT,
)

self.message_ast = ast
self.message_metadata = metadata

@transaction.atomic()
def save(self):
self.thread.save()
self.post.save()

self.save_final_thread()
self.save_final_post()
self.save_thread()
self.save_post()

self.save_category()
self.save_user()

def save_final_thread(self):
def save_thread(self):
self.thread.first_post = self.thread.last_post = self.post
self.thread.save()

def save_final_post(self):
def save_post(self):
update_post_checksum(self.post)
self.post.update_search_vector()
self.post.save()
Expand All @@ -113,13 +92,15 @@ def save_category(self):
self.category.threads = models.F("threads") + 1
self.category.posts = models.F("posts") + 1
self.category.set_last_thread(self.thread)
self.category.save()

self.save_model_changes(self.category)

def save_user(self):
self.user.threads = models.F("threads") + 1
self.user.posts = models.F("posts") + 1
self.user.save()

self.save_model_changes(self.user)


class StartThreadState(StartState):
class StartPrivateThreadState(StartThreadState):
pass
45 changes: 19 additions & 26 deletions misago/posting/views/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...permissions.threads import check_start_thread_in_category_permission
from ...threads.models import Thread
from ..forms.start import ThreadStartForm
from ..states.start import StartState, StartThreadState
from ..states.start import StartPrivateThreadState, StartThreadState


def start_thread_login_required():
Expand All @@ -26,8 +26,8 @@ def start_thread_login_required():
)


class StartView(View):
template_name: str
class ThreadStartView(View):
template_name: str = "misago/posting/start.html"
form_class = ThreadStartForm
state_class = StartThreadState

Expand Down Expand Up @@ -69,27 +69,6 @@ def post(self, request: HttpRequest, **kwargs) -> HttpResponse:
thread_url = self.get_thread_url(request, state.thread)
return redirect(thread_url)

def get_form(self, request: HttpRequest, category: Category) -> Form:
if request.method == "POST":
return self.form_class(request.POST, request.FILES)

return self.form_class()

def get_state(self, request: HttpRequest, category: Category) -> StartState:
return self.state_class(request, category)

def get_context_data(
self, request: HttpRequest, category: Category, form: Form
) -> dict:
return {"category": category, "form": form}

def get_thread_url(self, request: HttpRequest, thread: Thread) -> str:
raise NotImplementedError()


class ThreadStartView(StartView):
template_name: str = "misago/posting/start.html"

def get_category(self, request: HttpRequest, category_id: int) -> Category:
try:
category = Category.objects.get(
Expand All @@ -106,15 +85,29 @@ def get_category(self, request: HttpRequest, category_id: int) -> Category:

return category

def get_form(self, request: HttpRequest, category: Category) -> Form:
if request.method == "POST":
return self.form_class(request.POST, request.FILES)

return self.form_class()

def get_state(self, request: HttpRequest, category: Category) -> StartThreadState:
return self.state_class(request, category)

def get_context_data(
self, request: HttpRequest, category: Category, form: Form
) -> dict:
return {"category": category, "form": form}

def get_thread_url(self, request: HttpRequest, thread: Thread) -> str:
return reverse(
"misago:thread",
kwargs={"pk": thread.id, "slug": thread.slug},
)


class PrivateThreadStartView(StartView):
pass
class PrivateThreadStartView(ThreadStartView):
state_class = StartPrivateThreadState


class ThreadStartSelectCategoryView(View):
Expand Down

0 comments on commit 90036ea

Please sign in to comment.