From 8c918c25b17b5a8e651b1fe0c08831c60fc80752 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 5 Oct 2017 21:45:33 +0200 Subject: [PATCH] #893: moved missing validation to serializers in posting endpoints --- misago/threads/api/postingendpoint/__init__.py | 3 +++ misago/threads/api/postingendpoint/close.py | 18 +++++++++++------- misago/threads/api/postingendpoint/hide.py | 18 +++++++++++------- misago/threads/api/postingendpoint/pin.py | 16 +++++++++++----- misago/threads/api/postingendpoint/protect.py | 13 +++++++++++-- .../threads/tests/test_thread_editreply_api.py | 17 +++++++++++++---- misago/threads/tests/test_thread_reply_api.py | 18 ++++++++++++++---- misago/threads/tests/test_thread_start_api.py | 12 ++++++++++++ 8 files changed, 86 insertions(+), 29 deletions(-) diff --git a/misago/threads/api/postingendpoint/__init__.py b/misago/threads/api/postingendpoint/__init__.py index 3ad09ccbe8..46f78174cd 100755 --- a/misago/threads/api/postingendpoint/__init__.py +++ b/misago/threads/api/postingendpoint/__init__.py @@ -1,4 +1,7 @@ +from rest_framework import serializers + from django.core.exceptions import PermissionDenied +from django.http import QueryDict from django.utils import timezone from django.utils.module_loading import import_string diff --git a/misago/threads/api/postingendpoint/close.py b/misago/threads/api/postingendpoint/close.py index 5e7a3702d5..c84cbdf225 100755 --- a/misago/threads/api/postingendpoint/close.py +++ b/misago/threads/api/postingendpoint/close.py @@ -1,3 +1,5 @@ +from rest_framework import serializers + from misago.threads import moderation from . import PostingEndpoint, PostingMiddleware @@ -5,14 +7,16 @@ class CloseMiddleware(PostingMiddleware): def use_this_middleware(self): - return self.mode == PostingEndpoint.START and 'close' in self.request.data + return self.mode == PostingEndpoint.START + + def get_serializer(self): + return CloseSerializer(data=self.request.data) def post_save(self, serializer): if self.thread.category.acl['can_close_threads']: - try: - close = bool(self.request.data['close']) - except (TypeError, ValueError): - close = False - - if close: + if serializer.validated_data.get('close'): moderation.close_thread(self.request, self.thread) + + +class CloseSerializer(serializers.Serializer): + close = serializers.BooleanField(required=False, default=False) diff --git a/misago/threads/api/postingendpoint/hide.py b/misago/threads/api/postingendpoint/hide.py index 4ba03e3322..688058d8af 100644 --- a/misago/threads/api/postingendpoint/hide.py +++ b/misago/threads/api/postingendpoint/hide.py @@ -1,3 +1,5 @@ +from rest_framework import serializers + from misago.threads import moderation from . import PostingEndpoint, PostingMiddleware @@ -5,19 +7,21 @@ class HideMiddleware(PostingMiddleware): def use_this_middleware(self): - return self.mode == PostingEndpoint.START and 'hide' in self.request.data + return self.mode == PostingEndpoint.START + + def get_serializer(self): + return HideSerializer(data=self.request.data) def post_save(self, serializer): if self.thread.category.acl['can_hide_threads']: - try: - hide = bool(self.request.data['hide']) - except (TypeError, ValueError): - hide = False - - if hide: + if serializer.validated_data.get('hide'): moderation.hide_thread(self.request, self.thread) self.thread.update_all = True self.thread.save(update_fields=['is_hidden']) self.thread.category.synchronize() self.thread.category.update_all = True + + +class HideSerializer(serializers.Serializer): + hide = serializers.BooleanField(required=False, default=False) diff --git a/misago/threads/api/postingendpoint/pin.py b/misago/threads/api/postingendpoint/pin.py index f94605d97e..76de3eba9d 100755 --- a/misago/threads/api/postingendpoint/pin.py +++ b/misago/threads/api/postingendpoint/pin.py @@ -1,3 +1,5 @@ +from rest_framework import serializers + from misago.threads import moderation from misago.threads.models import Thread @@ -6,18 +8,22 @@ class PinMiddleware(PostingMiddleware): def use_this_middleware(self): - return self.mode == PostingEndpoint.START and 'pin' in self.request.data + return self.mode == PostingEndpoint.START + + def get_serializer(self): + return PinSerializer(data=self.request.data) def post_save(self, serializer): allowed_pin = self.thread.category.acl['can_pin_threads'] if allowed_pin > 0: - try: - pin = int(self.request.data['pin']) - except (TypeError, ValueError): - pin = 0 + pin = serializer.validated_data['pin'] if pin <= allowed_pin: if pin == Thread.WEIGHT_GLOBAL: moderation.pin_thread_globally(self.request, self.thread) elif pin == Thread.WEIGHT_PINNED: moderation.pin_thread_locally(self.request, self.thread) + + +class PinSerializer(serializers.Serializer): + pin = serializers.IntegerField(required=False, default=0) diff --git a/misago/threads/api/postingendpoint/protect.py b/misago/threads/api/postingendpoint/protect.py index 6d6ac8e1d5..3cd75cf4ff 100644 --- a/misago/threads/api/postingendpoint/protect.py +++ b/misago/threads/api/postingendpoint/protect.py @@ -1,14 +1,23 @@ +from rest_framework import serializers + from . import PostingEndpoint, PostingMiddleware class ProtectMiddleware(PostingMiddleware): def use_this_middleware(self): - return self.mode == PostingEndpoint.EDIT and 'protect' in self.request.data + return self.mode == PostingEndpoint.EDIT + + def get_serializer(self): + return ProtectSerializer(data=self.request.data) def post_save(self, serializer): if self.thread.category.acl['can_protect_posts']: try: - self.post.is_protected = bool(self.request.data['protect']) + self.post.is_protected = serializer.validated_data.get('protect', False) self.post.update_fields.append('is_protected') except (TypeError, ValueError): pass + + +class ProtectSerializer(serializers.Serializer): + protect = serializers.BooleanField(required=False, default=False) diff --git a/misago/threads/tests/test_thread_editreply_api.py b/misago/threads/tests/test_thread_editreply_api.py index 879932a0f5..f51dcc7b72 100644 --- a/misago/threads/tests/test_thread_editreply_api.py +++ b/misago/threads/tests/test_thread_editreply_api.py @@ -164,10 +164,19 @@ def test_empty_data(self): response = self.put(self.api_link, data={}) - self.assertEqual(response.status_code, 400) - self.assertEqual(response.json(), { - 'post': ["You have to enter a message."], - }) + self.assertContains(response, "You have to enter a message.", status_code=400) + + def test_invalid_data(self): + """api errors for invalid request data""" + self.override_acl() + + response = self.client.put( + self.api_link, + 'false', + content_type="application/json", + ) + + self.assertContains(response, "Invalid data.", status_code=400) def test_edit_event(self): """events can't be edited""" diff --git a/misago/threads/tests/test_thread_reply_api.py b/misago/threads/tests/test_thread_reply_api.py index d25dabfc62..4969ca2b89 100644 --- a/misago/threads/tests/test_thread_reply_api.py +++ b/misago/threads/tests/test_thread_reply_api.py @@ -110,10 +110,20 @@ def test_empty_data(self): self.override_acl() response = self.client.post(self.api_link, data={}) - self.assertEqual(response.status_code, 400) - self.assertEqual(response.json(), { - 'post': ["You have to enter a message."], - }) + + self.assertContains(response, "You have to enter a message.", status_code=400) + + def test_invalid_data(self): + """api errors for invalid request data""" + self.override_acl() + + response = self.client.post( + self.api_link, + 'false', + content_type="application/json", + ) + + self.assertContains(response, "Invalid data.", status_code=400) def test_post_is_validated(self): """post is validated""" diff --git a/misago/threads/tests/test_thread_start_api.py b/misago/threads/tests/test_thread_start_api.py index 32763c738c..26d86e6574 100644 --- a/misago/threads/tests/test_thread_start_api.py +++ b/misago/threads/tests/test_thread_start_api.py @@ -116,6 +116,18 @@ def test_empty_data(self): } ) + def test_invalid_data(self): + """api errors for invalid request data""" + self.override_acl() + + response = self.client.post( + self.api_link, + 'false', + content_type="application/json", + ) + + self.assertContains(response, "Invalid data.", status_code=400) + def test_title_is_validated(self): """title is validated""" self.override_acl()