From f1a502842b0a6df594dc59097baea33a24992618 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Bredeho=CC=88ft?= Date: Tue, 30 Jun 2015 14:56:08 +0200 Subject: [PATCH] ADDED: initial version --- .gitignore | 59 ++ LICENSE.md | 22 + README.md | 33 ++ drf_tools/__init__.py | 1 + drf_tools/exceptions.py | 50 ++ drf_tools/fields.py | 9 + drf_tools/filters.py | 89 +++ drf_tools/renderers.py | 69 +++ drf_tools/routers.py | 57 ++ drf_tools/serializers.py | 128 +++++ drf_tools/test/__init__.py | 1 + drf_tools/test/base.py | 553 +++++++++++++++++++ drf_tools/test/utils.py | 26 + drf_tools/utils.py | 29 + drf_tools/views.py | 250 +++++++++ requirements.txt | 9 + setup.py | 36 ++ tests/__init__.py | 1 + tests/manage.py | 10 + tests/testproject/__init__.py | 1 + tests/testproject/migrations/0001_initial.py | 47 ++ tests/testproject/migrations/__init__.py | 1 + tests/testproject/models.py | 18 + tests/testproject/settings.py | 67 +++ tests/testproject/tests.py | 139 +++++ tests/testproject/urls.py | 18 + tests/testproject/views.py | 16 + 27 files changed, 1739 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE.md create mode 100644 README.md create mode 100644 drf_tools/__init__.py create mode 100644 drf_tools/exceptions.py create mode 100644 drf_tools/fields.py create mode 100644 drf_tools/filters.py create mode 100644 drf_tools/renderers.py create mode 100644 drf_tools/routers.py create mode 100644 drf_tools/serializers.py create mode 100644 drf_tools/test/__init__.py create mode 100644 drf_tools/test/base.py create mode 100644 drf_tools/test/utils.py create mode 100644 drf_tools/utils.py create mode 100644 drf_tools/views.py create mode 100644 requirements.txt create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100755 tests/manage.py create mode 100644 tests/testproject/__init__.py create mode 100644 tests/testproject/migrations/0001_initial.py create mode 100644 tests/testproject/migrations/__init__.py create mode 100644 tests/testproject/models.py create mode 100644 tests/testproject/settings.py create mode 100644 tests/testproject/tests.py create mode 100644 tests/testproject/urls.py create mode 100644 tests/testproject/views.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4373a72 --- /dev/null +++ b/.gitignore @@ -0,0 +1,59 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +**/.idea/* diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..deafd1a --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2015 Sebastian Bredehöft + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..8bd99fd --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +drf-tools +================= +Multiple extensions and test utilities for Django REST Framework 3. + +## Setup ## + + pip install drf-tools + +## Requirement ## + +* Python 2.7+ +* Django 1.6+ +* Django REST Framework 3 +* drf-nested-fields 0.9+ +* drf-hal-json 0.9+ +* drf-enum-field 0.9+ +* drf-nested-routing 0.9+ +* django-filter 0.10+ +* openpyxl 2.0+ +* chardet 2.3+ + +## Features ## + +* Combination of the following libs: + * https://github.com/seebass/drf-nested-fields + * https://github.com/seebass/drf-hal-json + * https://github.com/seebass/drf-enum-field + * https://github.com/seebass/drf-nested-routing +* Additional renderers + * CsvRenderer + * ZipFileRenderer + * XlsxRenderer +* Test utitilities diff --git a/drf_tools/__init__.py b/drf_tools/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/drf_tools/__init__.py @@ -0,0 +1 @@ + diff --git a/drf_tools/exceptions.py b/drf_tools/exceptions.py new file mode 100644 index 0000000..70d805e --- /dev/null +++ b/drf_tools/exceptions.py @@ -0,0 +1,50 @@ +import logging + +from django.core.exceptions import ValidationError, ObjectDoesNotExist, PermissionDenied +from django.http import Http404 +from rest_framework.exceptions import APIException +from rest_framework import status +from rest_framework.response import Response + + +logger = logging.getLogger(__name__) + + +def exception_handler(exc): + headers = {} + if isinstance(exc, APIException): + if getattr(exc, 'auth_header', None): + headers['WWW-Authenticate'] = exc.auth_header + if getattr(exc, 'wait', None): + headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + headers['Retry-After'] = '%d' % exc.wait + status_code = exc.status_code + elif isinstance(exc, (ValueError, ValidationError)): + status_code = status.HTTP_400_BAD_REQUEST + elif isinstance(exc, PermissionDenied): + status_code = status.HTTP_403_FORBIDDEN + elif isinstance(exc, (ObjectDoesNotExist, Http404)): + status_code = status.HTTP_404_NOT_FOUND + else: + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + + if status_code == status.HTTP_500_INTERNAL_SERVER_ERROR or logger.isEnabledFor(logging.DEBUG): + logger.exception(str(exc)) + + return Response(__create_error_response_by_exception(exc), status=status_code, headers=headers) + + +def __create_error_response_by_exception(exc): + if hasattr(exc, 'messages'): + messages = exc.messages + else: + messages = [str(exc)] + return __create_error_response(exc.__class__.__name__, messages) + + +def __create_error_response(error_type, messages, code=0): + error = dict() + error['type'] = error_type + error['messages'] = messages + error['code'] = code + return {'error': error} diff --git a/drf_tools/fields.py b/drf_tools/fields.py new file mode 100644 index 0000000..90cdcc5 --- /dev/null +++ b/drf_tools/fields.py @@ -0,0 +1,9 @@ +from rest_framework.fields import CharField + + +class FilenameField(CharField): + def to_representation(self, value): + value = super(FilenameField, self).to_representation(value) + if value: + value = value.split("/")[-1] + return value diff --git a/drf_tools/filters.py b/drf_tools/filters.py new file mode 100644 index 0000000..983f2b2 --- /dev/null +++ b/drf_tools/filters.py @@ -0,0 +1,89 @@ +from django.core.validators import EMPTY_VALUES +from django_filters import FilterSet, BooleanFilter +from django_filters.filters import Filter +from django import forms +from django.utils import six + + +class ListFilterSet(FilterSet): + """ + The filterset handles a list of values as filter, that are connected using the OR-Operator + """ + + @property + def qs(self): + if not hasattr(self, '_qs'): + valid = self.is_bound and self.form.is_valid() + + if self.strict and self.is_bound and not valid: + self._qs = self.queryset.none() + return self._qs + + # start with all the results and filter from there + qs = self.queryset.all() + for name, filter_ in six.iteritems(self.filters): + # CUSTOM:START + value_list = None + + if valid and self.data: + value_list = self.data.getlist(name) + + if value_list: # valid & clean data + filtered_qs = None + for value in value_list: + if isinstance(filter_, BooleanFilter): + value = self._str_to_boolean(value) + + if not filtered_qs: + filtered_qs = filter_.filter(qs, value) + else: + filtered_qs |= filter_.filter(qs, value) + qs = filtered_qs + # CUSTOM:END + + if self._meta.order_by: + order_field = self.form.fields[self.order_by_field] + data = self.form[self.order_by_field].data + ordered_value = None + try: + ordered_value = order_field.clean(data) + except forms.ValidationError: + pass + + if ordered_value in EMPTY_VALUES and self.strict: + ordered_value = self.form.fields[self.order_by_field].choices[0][0] + + if ordered_value: + qs = qs.order_by(*self.get_order_by(ordered_value)) + + self._qs = qs + + return self._qs + + @staticmethod + def _str_to_boolean(value): + if value.lower() == "true": + return True + if value.lower() == "false": + return False + return value + + +class EnumFilter(Filter): + def __init__(self, enum_type, *args, **kwargs): + super(EnumFilter, self).__init__(*args, **kwargs) + self.enum_type = enum_type + + field_class = forms.CharField + + def filter(self, qs, value): + if value in ([], (), {}, None, ''): + return qs + enum_value = None + for choice in self.enum_type: + if choice.name == value or choice.value == value: + enum_value = choice + break + if enum_value is None: + raise ValueError("'{value}' is not a valid value for '{enum}'".format(value=value, enum=self.enum_type.__name__)) + return super(EnumFilter, self).filter(qs, enum_value) diff --git a/drf_tools/renderers.py b/drf_tools/renderers.py new file mode 100644 index 0000000..4eb6070 --- /dev/null +++ b/drf_tools/renderers.py @@ -0,0 +1,69 @@ +from rest_framework.renderers import BaseRenderer as OriginalBaseRenderer + +from drf_tools.serializers import ZipSerializer, CsvSerializer +from drf_tools.serializers import XlsxSerializer + + +class BaseFileRenderer(OriginalBaseRenderer): + KWARGS_KEY_FILENAME = "filename" + + def _add_filename_to_response(self, renderer_context): + filename = self.__get_filename(renderer_context) + if filename: + renderer_context['response']['Content-Disposition'] = 'attachment; filename="{}"'.format(filename) + + def __get_filename(self, renderer_context): + filename = renderer_context['kwargs'].get(self.KWARGS_KEY_FILENAME) + if filename and self.format and not filename.endswith('.' + self.format): + filename += "." + self.format + return filename + + +class CsvRenderer(BaseFileRenderer): + media_type = "text/csv" + format = "txt" + + def render(self, data, accepted_media_type=None, renderer_context=None): + self._add_filename_to_response(renderer_context) + return CsvSerializer.serialize(data) + + +class XlsxRenderer(BaseFileRenderer): + media_type = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' + format = 'xlsx' + charset = None + render_style = 'binary' + + def render(self, data, accepted_media_type=None, renderer_context=None): + self._add_filename_to_response(renderer_context) + if not isinstance(data, list): + return data + + return XlsxSerializer.serialize(data) + + +class ZipFileRenderer(BaseFileRenderer): + """ + A zip file is created containing the given dict with filename->bytes + """ + media_type = 'application/x-zip-compressed' + format = 'zip' + charset = None + render_style = 'binary' + + def render(self, data, accepted_media_type=None, renderer_context=None): + self._add_filename_to_response(renderer_context) + return ZipSerializer.serialize(data) + + +class AnyFileFromSystemRenderer(BaseFileRenderer): + """ + Given the full file path, the file is opened, read and returned + """ + media_type = '*/*' + render_style = 'binary' + + def render(self, data, accepted_media_type=None, renderer_context=None): + self._add_filename_to_response(renderer_context) + with open(data, "rb") as file: + return file.read() diff --git a/drf_tools/routers.py b/drf_tools/routers.py new file mode 100644 index 0000000..ac2be09 --- /dev/null +++ b/drf_tools/routers.py @@ -0,0 +1,57 @@ +from collections import OrderedDict + +from django.core.urlresolvers import NoReverseMatch +from rest_framework.permissions import AllowAny +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.routers import DefaultRouter +from rest_framework.views import APIView + +from drf_nested_routing.routers import NestedRouterMixin + + +class NestedRouterWithExtendedRootView(NestedRouterMixin, DefaultRouter): + """ + Router that handles nested routes and additionally adds given api_view_urls to the ApiRootView (the api entrypoint) + """ + def __init__(self, api_view_urls): + self.__api_view_urls = api_view_urls + super(NestedRouterWithExtendedRootView, self).__init__() + + def get_api_root_view(self): + api_root_routes = {} + list_name = self.routes[0].name + for prefix, viewset, basename in self.registry: + api_root_routes[prefix] = list_name.format(basename=basename) + + api_view_urls = self.__api_view_urls + + class ApiRootView(APIView): + + permission_classes = (AllowAny,) + + def get(self, request, *args, **kwargs): + links = OrderedDict() + links['viewsets'] = OrderedDict() + for key, url_name in api_root_routes.items(): + try: + links['viewsets'][key] = reverse(url_name, request=request, format=kwargs.get('format', None)) + except NoReverseMatch: + continue + + links['views'] = OrderedDict() + for api_view_url in api_view_urls: + url_name = api_view_url.name + try: + if '' in api_view_url._regex: + links['views'][url_name] = reverse(url_name, request=request, + format=kwargs.get('format', None), args=(0,)) + else: + links['views'][url_name] = reverse(url_name, request=request, + format=kwargs.get('format', None)) + except NoReverseMatch as e: + continue + + return Response({"_links": links}) + + return ApiRootView().as_view() diff --git a/drf_tools/serializers.py b/drf_tools/serializers.py new file mode 100644 index 0000000..68c875d --- /dev/null +++ b/drf_tools/serializers.py @@ -0,0 +1,128 @@ +import csv +from io import BytesIO +import zipfile + +from chardet.universaldetector import UniversalDetector +from openpyxl import Workbook, load_workbook +from openpyxl.cell import Cell +from rest_framework.serializers import HyperlinkedModelSerializer +from drf_enum_field.serializers import EnumFieldSerializerMixin +from drf_hal_json.serializers import HalModelSerializer, HalEmbeddedSerializer +from drf_nested_routing.serializers import NestedRoutingSerializerMixin + + +class HalNestedRoutingEmbeddedSerializer(NestedRoutingSerializerMixin, EnumFieldSerializerMixin, HalEmbeddedSerializer): + pass + + +class HalNestedRoutingLinksSerializer(NestedRoutingSerializerMixin, EnumFieldSerializerMixin, HyperlinkedModelSerializer): + pass + + +class HalNestedFieldsModelSerializer(NestedRoutingSerializerMixin, EnumFieldSerializerMixin, HalModelSerializer): + links_serializer_class = HalNestedRoutingLinksSerializer + embedded_serializer_class = HalNestedRoutingEmbeddedSerializer + + +class CsvSerializer(object): + @staticmethod + def serialize(data): + if isinstance(data, bytes): + return data + + if not isinstance(data, list): + data = [str(data)] + + csv_buffer = BytesIO() + for row in data: + if not isinstance(row, (list, tuple)): + row = [row] + csv_buffer.write(('\t'.join(CsvSerializer.__validate_cell(cell) for cell in row) + '\n').encode('utf-8')) + + return csv_buffer.getvalue() + + @staticmethod + def deserialize(file_bytes): + try: + file_string = file_bytes.decode('utf-8') + except UnicodeDecodeError as ude: + detector = UniversalDetector() + for line in BytesIO(file_bytes): + detector.feed(line) + if detector.done: + break + detector.close() + if detector.result['confidence'] < 0.5: + raise ValueError("Failed to guess the encoding of the file (it's not utf-8). Use utf-8 encoded files.") + try: + file_string = file_bytes.decode(detector.result['encoding']) + except UnicodeDecodeError: + raise ValueError("Failed to guess the encoding of the file (it's not utf-8). Use utf-8 encoded files. " + "(The invalid character is '{char:#x}' at {pos})".format(pos=ude.start, + char=file_bytes[ude.start])) + csv_lines = file_string.splitlines() + first_line = csv_lines[:1] + first_row_tab = next(csv.reader(first_line, delimiter="\t")) + first_row_semicolon = next(csv.reader(first_line, delimiter=";")) + if len(first_row_tab) > 1: + rows = csv.reader(csv_lines, delimiter="\t") + elif len(first_row_semicolon) > 1: + rows = csv.reader(csv_lines, delimiter=";") + else: + raise ValueError("Csv file is not delimited by ';' or 'tab'") + + return rows + + @staticmethod + def __validate_cell(cell): + cell = str(cell) if cell is not None else '' + if "\t" in cell or "\n" in cell: + cell = '"{}"'.format(cell) + return cell + + +class ZipSerializer(object): + @staticmethod + def serialize(data): + byte_buffer = BytesIO() + zip_file = zipfile.ZipFile(byte_buffer, "w") + for filename, data_bytes in data.items(): + if not isinstance(data_bytes, bytes): + return data + zip_file.writestr(filename, data_bytes) + zip_file.close() + return byte_buffer.getvalue() + + +class XlsxSerializer(object): + @staticmethod + def serialize(data): + workbook = Workbook() + sheet = workbook.active + for row_index, row in enumerate(data): + for column_index, value in enumerate(row): + data_type = Cell.TYPE_STRING + if isinstance(value, (int, float)): + data_type = Cell.TYPE_NUMERIC + if data_type == Cell.TYPE_STRING: + value = str(value) + sheet.cell(column=column_index + 1, row=row_index + 1).set_explicit_value(value, data_type=data_type) + xlsx_file = BytesIO() + workbook.save(xlsx_file) + return xlsx_file.getvalue() + + @staticmethod + def deserialize(file_bytes, sheet_name): + workbook = load_workbook(filename=BytesIO(file_bytes), data_only=True) + if len(workbook.worksheets) > 1: + if not sheet_name: + raise ValueError("The uploaded file contains several sheets. The name of the sheet to be imported " + "needs to be specified with the 'sheetName' parameter.") + worksheet = workbook.get_sheet_by_name(sheet_name) + else: + worksheet = workbook.get_active_sheet() + + if not worksheet: + raise ValueError("No worksheet found.") + + return [[cell.value for cell in row] for row in worksheet.rows] diff --git a/drf_tools/test/__init__.py b/drf_tools/test/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/drf_tools/test/__init__.py @@ -0,0 +1 @@ + diff --git a/drf_tools/test/base.py b/drf_tools/test/base.py new file mode 100644 index 0000000..fb96e54 --- /dev/null +++ b/drf_tools/test/base.py @@ -0,0 +1,553 @@ +from collections import defaultdict +from datetime import datetime, date +import json +import random + +from six.moves.urllib.parse import urlparse, urlencode, unquote, parse_qs +from six.moves.urllib.request import urlopen, Request +from six.moves.urllib.error import HTTPError + +import logging + +from django.core.urlresolvers import reverse +from django.db.models import Model +from django.test import TestCase +from enumfields import Enum + +from drf_hal_json import LINKS_FIELD_NAME, EMBEDDED_FIELD_NAME, HAL_JSON_MEDIA_TYPE +import drf_nested_routing +from rest_framework.settings import api_settings + +from drf_tools.test.utils import skip_abstract_test +from drf_tools.utils import DATETIME_FORMAT_ISO + + +class BaseRestTest(TestCase): + _TESTSERVER_NAME = "testserver" + _TESTSERVER_BASE_URL = "http://" + _TESTSERVER_NAME + _COUNT_FIELD_NAME = "count" + _PAGE_SIZE_FIELD_NAME = "page_size" + _SELF_FIELD_NAME = api_settings.URL_FIELD_NAME + _CONTENT_TYPE_HEADER_NAME = "Content-Type" + _ALLOW_HEADER_NAME = "Allow" + _LOCATION_HEADER_NAME = "Location" + _QUERY_PARAM_FIELDS = "fields" + _PARENT_LOOKUPS_MODEL_FIELD = "parent_lookups" + + def setUp(self): + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + self.client.get("/") # hack: has to be called initial for router registration + + def _assertDatetimesEqual(self, datetime1, datetime2): + if datetime1 and isinstance(datetime1, datetime): + datetime1 = datetime1.strftime(DATETIME_FORMAT_ISO) + if datetime2 and isinstance(datetime2, datetime): + datetime2 = datetime2.strftime(DATETIME_FORMAT_ISO) + self.assertEqual(datetime1, datetime2) + + def _assertDatesEqual(self, date1, date2): + if date1 and isinstance(date1, date): + date1 = date1.strftime('%Y-%m-%d') + if date2 and isinstance(date2, date): + date2 = date2.strftime('%Y-%m-%d') + self.assertEqual(date1, date2) + + def _doGETDetails(self, modelObj, queryParams=None, **headers): + resp = self.client.get(self._getRelativeDetailURI(modelObj=modelObj), queryParams, **headers) + self.assertEqual(resp[self._CONTENT_TYPE_HEADER_NAME], HAL_JSON_MEDIA_TYPE) + return resp + + def _doGETList(self, modelClass, queryParams=None, parentLookups=None, **headers): + resp = self.client.get(self._getRelativeListURI(modelClass, parentLookups), queryParams, **headers) + self.assertEqual(resp[self._CONTENT_TYPE_HEADER_NAME], HAL_JSON_MEDIA_TYPE) + return resp + + def _doPOST(self, modelClass, content, parentLookups=None, **headers): + return self.client.post(self._getRelativeListURI(modelClass, parentLookups), + self.__contentToJson(content), HAL_JSON_MEDIA_TYPE, **headers) + + def _doPUT(self, modelObj, content, **headers): + resp = self.client.put(self._getRelativeDetailURI(modelObj), self.__contentToJson(content), + HAL_JSON_MEDIA_TYPE, **headers) + self.assertEqual(resp[self._CONTENT_TYPE_HEADER_NAME], HAL_JSON_MEDIA_TYPE) + return resp + + def _doPATCH(self, modelObj, content, **headers): + resp = self.client.patch(self._getRelativeDetailURI(modelObj), self.__contentToJson(content), + HAL_JSON_MEDIA_TYPE, **headers) + self.assertEqual(resp[self._CONTENT_TYPE_HEADER_NAME], HAL_JSON_MEDIA_TYPE) + return resp + + def _doDELETE(self, modelObj, **headers): + return self.client.delete(self._getRelativeDetailURI(modelObj), **headers) + + def _doOPTIONSList(self, modelClass, parentLookups=None, **headers): + return self.client.options(self._getRelativeListURI(modelClass, parentLookups), **headers) + + def _doOPTIONSDetails(self, modelObj=None, **headers): + return self.client.options(self._getRelativeDetailURI(modelObj), **headers) + + def _doHEADList(self, modelClass, parentLookups=None, **headers): + return self.client.head(self._getRelativeListURI(modelClass, parentLookups), **headers) + + def _doHEADDetails(self, modelObj=None, **headers): + return self.client.head(self._getRelativeDetailURI(modelObj), **headers) + + def _extractIdFromLocationHeader(self, locationHeader): + return locationHeader.split("/")[-2] + + def _getAbsoluteDetailURI(self, modelObj): + if not modelObj: + return None + return self._TESTSERVER_BASE_URL + self._getRelativeDetailURI(modelObj) + + def _getRelativeDetailURI(self, modelObj): + if not modelObj: + return None + + lookup_field = getattr(modelObj, "pk", None) + kwargs = {"pk": lookup_field} + parent_lookups = drf_nested_routing.get_parent_query_lookups_by_class(modelObj.__class__) + if parent_lookups: + for lookup in parent_lookups: + lookup_path = lookup.split('__') + parent_lookup = modelObj + for part in lookup_path: + parent_lookup = getattr(parent_lookup, part) + parentLookupId = parent_lookup.id if isinstance(parent_lookup, Model) else parent_lookup + kwargs[drf_nested_routing.PARENT_LOOKUP_NAME_PREFIX + lookup] = parentLookupId + + # Handle unsaved object case + if lookup_field is None: + return None + + return reverse(modelObj.__class__.__name__.lower() + '-detail', kwargs=kwargs) + + def _getAbsoluteListURI(self, modelClass, parentLookups=None): + return self._TESTSERVER_BASE_URL + self._getRelativeListURI(modelClass, parentLookups) + + def _getRelativeListURI(self, modelClass, parentLookups=None): + parent_lookups = drf_nested_routing.get_parent_query_lookups_by_class(modelClass) + if parent_lookups and not parentLookups: + raise ValueError("Please specify parent lookups for '{}'".format(modelClass.__name__)) + + composedParentLookups = dict() + if parent_lookups and parentLookups: + for lookup in parent_lookups: + lookupId = parentLookups.get(lookup) + if not lookupId: + continue + composedParentLookups[drf_nested_routing.PARENT_LOOKUP_NAME_PREFIX + lookup] = lookupId + + return reverse(modelClass.__name__.lower() + '-list', kwargs=composedParentLookups) + + def _assertLinksAndModelListEqual(self, linksList, modelList): + if linksList is not None: + self.assertEqual(len(linksList), len(modelList)) + for model in modelList: + self.assertTrue(self._getAbsoluteDetailURI(model) in linksList) + else: + self.assertEqual(0, len(modelList)) + + def __contentToJson(self, content): + cleanedContent = dict() + for attr, value in content.items(): + cleanedValue = value + if not isinstance(value, dict) and value is not None: + cleanedValue = str(value) + cleanedContent[attr] = cleanedValue + return json.dumps(cleanedContent) + + def _buildContent(self, stateAttrs, linkAttrs=None, embeddedAttrs=None): + content = dict(stateAttrs) + if linkAttrs: + content[LINKS_FIELD_NAME] = linkAttrs + if embeddedAttrs: + content[EMBEDDED_FIELD_NAME] = embeddedAttrs + return content + + def _splitContent(self, content): + stateAttrs = {key: value for key, value in content.items() if + key not in (LINKS_FIELD_NAME, EMBEDDED_FIELD_NAME)} + linkAttrs = content.get(LINKS_FIELD_NAME, dict()) + embeddedAttrs = content.get(EMBEDDED_FIELD_NAME, dict()) + return stateAttrs, linkAttrs, embeddedAttrs + + +class BaseModelViewSetTest(BaseRestTest): + def _getModelClass(self): + raise NotImplementedError() + + def _getOrCreateModelInstance(self): + raise NotImplementedError() + + def _assertModelEqual(self, content, modelObj): + raise NotImplementedError() + + def _getOrCreateModelList(self, minCount=5, maxCount=15): + modelList = list() + for i in range(random.randint(minCount, maxCount)): + modelList.append(self._getOrCreateModelInstance()) + return modelList + + def _getAllowedListMethods(self): + return ["OPTIONS"] + + def _getAllowedDetailsMethods(self): + return ["OPTIONS"] + + @skip_abstract_test + def testOPTIONSList(self): + resp = self._doOPTIONSList(self._getModelClass(), self._getWildcardedParentLookups(self._getModelClass())) + self.assertEqual(200, resp.status_code, resp.content) + + self.assertEqual(resp[self._CONTENT_TYPE_HEADER_NAME], HAL_JSON_MEDIA_TYPE) + allowedMethods = [allowedMethod.strip() for allowedMethod in resp[self._ALLOW_HEADER_NAME].split(",")] + for expectedAllowedMethod in self._getAllowedListMethods(): + self.assertTrue(expectedAllowedMethod in allowedMethods) + + self.assertTrue(HAL_JSON_MEDIA_TYPE in resp.data['renders']) + self.assertTrue(HAL_JSON_MEDIA_TYPE in resp.data['parses']) + + @skip_abstract_test + def testOPTIONSDetails(self): + resp = self._doOPTIONSDetails(self._getOrCreateModelInstance()) + self.assertEqual(200, resp.status_code, resp.content) + + self.assertEqual(resp[self._CONTENT_TYPE_HEADER_NAME], HAL_JSON_MEDIA_TYPE) + allowedMethods = [allowedMethod.strip() for allowedMethod in resp[self._ALLOW_HEADER_NAME].split(",")] + for expectedAllowedMethod in self._getAllowedDetailsMethods(): + self.assertTrue(expectedAllowedMethod in allowedMethods) + + self.assertTrue(HAL_JSON_MEDIA_TYPE in resp.data['renders']) + self.assertTrue(HAL_JSON_MEDIA_TYPE in resp.data['parses']) + + @skip_abstract_test + def testHEADList(self): + resp = self._doHEADList(self._getModelClass(), self._getWildcardedParentLookups(self._getModelClass())) + self.assertEqual(405, resp.status_code, resp.content) + + @skip_abstract_test + def testHEADDetails(self): + resp = self._doHEADDetails(self._getOrCreateModelInstance()) + self.assertEqual(405, resp.status_code, resp.content) + + @skip_abstract_test + def testGETList(self): + resp = self._doGETList(self._getModelClass(), self._getWildcardedParentLookups(self._getModelClass())) + self.assertEqual(405, resp.status_code, resp.content) + + @skip_abstract_test + def testGETDetails(self): + resp = self._doGETDetails(self._getOrCreateModelInstance()) + self.assertEqual(405, resp.status_code, resp.content) + + @skip_abstract_test + def testPUT(self): + resp = self._doPUT(self._getOrCreateModelInstance(), {}) + self.assertEqual(405, resp.status_code, resp.content) + + @skip_abstract_test + def testPATCH(self): + resp = self._doPATCH(self._getOrCreateModelInstance(), {}) + self.assertEqual(405, resp.status_code, resp.content) + + @skip_abstract_test + def testPOST(self): + resp = self._doPOST(self._getModelClass(), {}, parentLookups=self._getWildcardedParentLookups(self._getModelClass())) + self.assertEqual(405, resp.status_code, resp.content) + + @skip_abstract_test + def testDELETE(self): + resp = self._doDELETE(self._getOrCreateModelInstance()) + self.assertEqual(405, resp.status_code, resp.content) + + def _getWildcardedParentLookups(self, modelClass): + parentLookups = drf_nested_routing.get_parent_query_lookups_by_class(modelClass) + if not parentLookups: + return None + return {parentLookup: "*" for parentLookup in parentLookups} + + +class CreateModelViewSetTest(BaseModelViewSetTest): + def _getAllowedListMethods(self): + return super(CreateModelViewSetTest, self)._getAllowedListMethods() + ["POST"] + + def _createModelAsJson(self): + raise NotImplementedError() + + @skip_abstract_test + def testPOST(self): + content, parentLookups = self._createModelAsJson() + resp = self._doPOST(self._getModelClass(), content, parentLookups) + self.assertEqual(201, resp.status_code, resp.content) + objectFromDb = self._getModelClass().objects.get( + id=self._extractIdFromLocationHeader(resp[self._LOCATION_HEADER_NAME])) + self._assertModelEqual(content, objectFromDb) + + +class ReadModelViewSetTest(BaseModelViewSetTest): + + def _getAllowedListMethods(self): + return super(ReadModelViewSetTest, self)._getAllowedListMethods() + ["GET", "HEAD"] + + def _getAllowedDetailsMethods(self): + return super(ReadModelViewSetTest, self)._getAllowedDetailsMethods() + ["GET", "HEAD"] + + @skip_abstract_test + def testHEADList(self): + resp = self._doHEADList(self._getModelClass(), self._getWildcardedParentLookups(self._getModelClass())) + self.assertEqual(200, resp.status_code, resp.content) + + @skip_abstract_test + def testHEADDetails(self): + resp = self._doHEADDetails(self._getOrCreateModelInstance()) + self.assertEqual(200, resp.status_code, resp.content) + + @skip_abstract_test + def testGETList(self): + modelList = self._getOrCreateModelList() + modelCount = len(modelList) + queryParams = {self._PAGE_SIZE_FIELD_NAME: modelCount} + modelsByUrl = {self._getAbsoluteDetailURI(model): model for model in modelList} + wildCardedParentLookups = self._getWildcardedParentLookups(self._getModelClass()) + resp = self._doGETList(self._getModelClass(), queryParams, wildCardedParentLookups) + self.assertEqual(200, resp.status_code, resp.content) + stateAttrs, linkAttrs, embeddedAttrs = self._splitContent(resp.data) + self.assertEqual(stateAttrs[self._COUNT_FIELD_NAME], modelCount) + self.assertEqual(stateAttrs[self._PAGE_SIZE_FIELD_NAME], modelCount) + self.assertEqual(linkAttrs[self._SELF_FIELD_NAME], "{}?{}={}".format( + unquote(self._getAbsoluteListURI(self._getModelClass(), wildCardedParentLookups)), + self._PAGE_SIZE_FIELD_NAME, modelCount)) + self.assertEqual(len(embeddedAttrs), modelCount) + for embeddedAttr in embeddedAttrs: + model = modelsByUrl[embeddedAttr[LINKS_FIELD_NAME][self._SELF_FIELD_NAME]] + self._assertModelEqual(embeddedAttr, model) + + @skip_abstract_test + def testGETDetails(self): + modelObj = self._getOrCreateModelInstance() + resp = self._doGETDetails(modelObj) + self.assertEqual(200, resp.status_code, resp.content) + self._assertModelEqual(resp.data, modelObj) + + +class IncludeFields: + def __init__(self, stateFields=None, linkFields=None, embeddedFields=None): + self.__stateFields = stateFields or [] + self.__linkFields = linkFields or [] + self.__embeddedFields = embeddedFields or {} + + def buildQueryParamValue(self): + fields = self.__stateFields + self.__linkFields + for embeddedAttr, includeFields in self.__embeddedFields.items(): + fields.append("{}.fields({})".format(embeddedAttr, includeFields.buildQueryParamValue())) + return ",".join(fields) + + def getResultingStateFields(self): + return self.__stateFields + ['id'] + + def getResultingLinkFields(self): + return self.__linkFields + ['self'] + + def getResultingEmbeddedFields(self): + return self.__embeddedFields + + +class AdvancedReadModelViewSetTestMixin(ReadModelViewSetTest): + def _getIncludeFields(self): + raise NotImplementedError() + + @skip_abstract_test + def testGETListPaginated(self): + def __getPageParams(url): + pr = urlparse(url) + params = {k: v[0] for k, v in parse_qs(pr.query).items()} + location = "{scheme}://{netloc}{path}".format(scheme=pr.scheme, netloc=pr.netloc, path=pr.path) + pageNumber = int(params['page']) if 'page' in params else 1 + return location, int(params[self._PAGE_SIZE_FIELD_NAME]), pageNumber + + specifiedPageSize = 5 + modelList = self._getOrCreateModelList() + + pageCount = len(modelList) // specifiedPageSize + rest = len(modelList) % specifiedPageSize + if rest != 0: + pageCount += 1 + + for i in range(pageCount): + queryParams = {'page': i + 1, self._PAGE_SIZE_FIELD_NAME: specifiedPageSize} + wildcardedParentLookups = self._getWildcardedParentLookups(self._getModelClass()) + resp = self._doGETList(self._getModelClass(), queryParams, wildcardedParentLookups) + path, pageSize, page = __getPageParams(resp.data[LINKS_FIELD_NAME][self._SELF_FIELD_NAME]) + expectedPath = "{}{}".format(self._TESTSERVER_BASE_URL, unquote( + self._getRelativeListURI(self._getModelClass(), wildcardedParentLookups))) + self.assertEqual(expectedPath, path) + self.assertEqual(page, i + 1) + self.assertEqual(pageSize, 5) + self.assertEqual(len(modelList), resp.data[self._COUNT_FIELD_NAME]) + self.assertEqual(5, resp.data[self._PAGE_SIZE_FIELD_NAME]) + embeddedCount = pageSize + if rest != 0 and i == pageCount - 1: + embeddedCount = rest + self.assertEqual(embeddedCount, len(resp.data[EMBEDDED_FIELD_NAME])) + + if i != pageCount - 1: + path, pageSize, page = __getPageParams(resp.data[LINKS_FIELD_NAME]['next']) + self.assertEqual(expectedPath, path) + self.assertEqual(5, pageSize) + self.assertEqual(i + 2, page) + if i > 0: + path, pageSize, page = __getPageParams(resp.data[LINKS_FIELD_NAME]['previous']) + self.assertEqual(expectedPath, path) + self.assertEqual(5, pageSize) + self.assertEqual(i, page) + + @skip_abstract_test + def testGETListIncludeCertainFields(self): + modelList = self._getOrCreateModelList() + modelsByUrl = {self._getAbsoluteDetailURI(model): model for model in modelList} + includeFields = self._getIncludeFields() + fieldsQueryParamValue = includeFields.buildQueryParamValue() + + modelCount = len(modelList) + queryParams = {self._PAGE_SIZE_FIELD_NAME: modelCount, self._QUERY_PARAM_FIELDS: fieldsQueryParamValue} + wildcardedParentLookups = self._getWildcardedParentLookups(self._getModelClass()) + resp = self._doGETList(self._getModelClass(), queryParams, wildcardedParentLookups) + self.assertEqual(200, resp.status_code, resp.content) + stateAttrs, linkAttrs, embeddedAttrs = self._splitContent(resp.data) + + self.assertEqual(stateAttrs[self._COUNT_FIELD_NAME], modelCount) + self.assertEqual(stateAttrs[self._PAGE_SIZE_FIELD_NAME], modelCount) + selfUrl = unquote(linkAttrs[self._SELF_FIELD_NAME]) + self.assertTrue( + selfUrl.startswith(unquote(self._getAbsoluteListURI(self._getModelClass(), wildcardedParentLookups)))) + self.assertTrue("{}={}".format(self._QUERY_PARAM_FIELDS, fieldsQueryParamValue) in selfUrl) + self.assertTrue("{}={}".format(self._PAGE_SIZE_FIELD_NAME, modelCount) in selfUrl) + self.assertEqual(modelCount, len(embeddedAttrs)) + + for embeddedObjectAttrs in embeddedAttrs: + modelObj = modelsByUrl[embeddedObjectAttrs[LINKS_FIELD_NAME][self._SELF_FIELD_NAME]] + self.assertIsNotNone(modelObj) + self.__assertIncludeFieldsContentEqual(includeFields, modelObj, embeddedObjectAttrs) + + @skip_abstract_test + def testGETDetailsIncludeCertainFields(self): + modelObj = self._getOrCreateModelInstance() + includeFields = self._getIncludeFields() + fieldsQueryParamValue = includeFields.buildQueryParamValue() + + resp = self._doGETDetails(modelObj, {self._QUERY_PARAM_FIELDS: fieldsQueryParamValue}) + self.assertEqual(200, resp.status_code, resp.data) + self.__assertIncludeFieldsContentEqual(includeFields, modelObj, resp.data) + + def __assertIncludeFieldsContentEqual(self, includeFields, modelObj, content): + stateAttrs, linkAttrs, embeddedAttrs = self._splitContent(content) + + self.assertEqual(len(includeFields.getResultingStateFields()), len(stateAttrs)) + self.assertEqual(modelObj.id, stateAttrs['id']) + self.assertEqual(len(includeFields.getResultingLinkFields()), len(linkAttrs)) + self.assertEqual(len(includeFields.getResultingEmbeddedFields()), len(embeddedAttrs)) + + for stateField in includeFields.getResultingStateFields(): + if stateField in stateAttrs: + modelValue = getattr(modelObj, stateField) + if isinstance(modelValue, Enum): + modelValue = modelValue.value + if isinstance(modelValue, datetime): + self._assertDatetimesEqual(modelValue, stateAttrs[stateField]) + continue + if isinstance(modelValue, date): + self._assertDatesEqual(modelValue, stateAttrs[stateField]) + continue + self.assertEqual(modelValue, stateAttrs[stateField]) + + for linkField in includeFields.getResultingLinkFields(): + if linkField in linkAttrs: + if linkField == self._SELF_FIELD_NAME: + ref = modelObj + else: + ref = getattr(modelObj, linkField) + self.assertEqual(self._getAbsoluteDetailURI(ref), linkAttrs[linkField]) + + for embeddedField, embeddedIncludeFields in includeFields.getResultingEmbeddedFields().items(): + if embeddedField in embeddedAttrs: + if isinstance(embeddedAttrs[embeddedField], dict): + embeddedObj = getattr(modelObj, embeddedField) + self.__assertIncludeFieldsContentEqual(embeddedIncludeFields, embeddedObj, + embeddedAttrs[embeddedField]) + else: + embeddedList = getattr(modelObj, embeddedField).all() + modelsByUrl = {self._getAbsoluteDetailURI(model): model for model in embeddedList} + for embeddedAttr in embeddedAttrs[embeddedField]: + embeddedObj = modelsByUrl[embeddedAttr[LINKS_FIELD_NAME][self._SELF_FIELD_NAME]] + self.__assertIncludeFieldsContentEqual(embeddedIncludeFields, embeddedObj, embeddedAttr) + + +class UpdateModelViewSetTest(BaseModelViewSetTest): + def _getUpdateAttributes(self): + raise NotImplementedError() + + def _getAllowedDetailsMethods(self): + return super(UpdateModelViewSetTest, self)._getAllowedDetailsMethods() + ["PATCH", "PUT"] + + @skip_abstract_test + def testPUT(self): + modelObj = self._getOrCreateModelInstance() + resp = self._doGETDetails(modelObj) + self.assertEqual(200, resp.status_code, resp.data) + content = dict(resp.data) + + for changeAttr, value in self._getUpdateAttributes().items(): + if changeAttr in content: + content[changeAttr] = value + elif changeAttr in content[LINKS_FIELD_NAME]: + content[LINKS_FIELD_NAME][changeAttr] = value + else: + raise ValueError("Attribute '{}' not available.".format(changeAttr)) + + resp = self._doPUT(modelObj, content) + self.assertEqual(200, resp.status_code, resp.data) + objectFromDb = modelObj.__class__.objects.get(id=modelObj.id) + self._assertModelEqual(content, objectFromDb) + + @skip_abstract_test + def testPATCH(self): + modelObj = self._getOrCreateModelInstance() + resp = self._doGETDetails(modelObj) + self.assertEqual(200, resp.status_code, resp.data) + content = dict(resp.data) + patchData = defaultdict(dict) + for changeAttr, value in self._getUpdateAttributes().items(): + if changeAttr in content: + content[changeAttr] = value + patchData[changeAttr] = value + elif changeAttr in content[LINKS_FIELD_NAME]: + content[LINKS_FIELD_NAME][changeAttr] = value + patchData[LINKS_FIELD_NAME][changeAttr] = value + else: + raise ValueError("Attribute '{}' not available.".format(changeAttr)) + + resp = self._doPATCH(modelObj, patchData) + self.assertEqual(200, resp.status_code, resp.content) + objectFromDb = modelObj.__class__.objects.get(id=modelObj.id) + self._assertModelEqual(content, objectFromDb) + + +class DeleteModelViewSetTest(BaseModelViewSetTest): + def _getAllowedDetailsMethods(self): + return super(DeleteModelViewSetTest, self)._getAllowedDetailsMethods() + ["DELETE"] + + @skip_abstract_test + def testDELETE(self): + modelObj = self._getOrCreateModelInstance() + self.assertTrue(modelObj.__class__.objects.filter(id=modelObj.id).exists()) + resp = self._doDELETE(modelObj) + self.assertEqual(204, resp.status_code, resp.content) + self.assertFalse(modelObj.__class__.objects.filter(id=modelObj.id).exists()) + + +class ModelViewSetTest(CreateModelViewSetTest, ReadModelViewSetTest, UpdateModelViewSetTest, DeleteModelViewSetTest, + BaseModelViewSetTest): + pass diff --git a/drf_tools/test/utils.py b/drf_tools/test/utils.py new file mode 100644 index 0000000..f629ce9 --- /dev/null +++ b/drf_tools/test/utils.py @@ -0,0 +1,26 @@ +from functools import wraps + + +def withDebug(func): + """Switch on DEBUG during a test (disabled by default). Useful for query logging.""" + + @wraps(func) + def wrapper(*args, **kwargs): + from django.conf import settings + from django.db import connection + + settings.DEBUG = True + connection.queries = [] + result = func(*args, **kwargs) + settings.DEBUG = False + return result + + return wrapper + +def skip_abstract_test(func): + def func_wrapper(self): + if self.__class__.__subclasses__(): + return + return func(self) + + return func_wrapper diff --git a/drf_tools/utils.py b/drf_tools/utils.py new file mode 100644 index 0000000..d86af32 --- /dev/null +++ b/drf_tools/utils.py @@ -0,0 +1,29 @@ +from django.core.exceptions import ValidationError +from django.core.validators import URLValidator + +DATETIME_FORMAT = '%d.%m.%Y %H:%M:%S' +DATETIME_FORMAT_ISO = '%Y-%m-%dT%H:%M:%S' + +def get_id_from_detail_uri(uri): + return int(uri.split('/')[-2]) + + +def is_detail_uri(uri): + try: + get_id_from_detail_uri(uri) + return True + except ValueError: + return False + + +def get_valid_uri(uri): + if not uri: + return None, True + url = uri.strip() + if not url.startswith("http"): + url = 'http://' + url + try: + URLValidator()(url) + except ValidationError: + return None, False + return url, True diff --git a/drf_tools/views.py b/drf_tools/views.py new file mode 100644 index 0000000..553ea55 --- /dev/null +++ b/drf_tools/views.py @@ -0,0 +1,250 @@ +from datetime import datetime +import logging + +from rest_framework import status +from rest_framework.mixins import RetrieveModelMixin, ListModelMixin, DestroyModelMixin +from rest_framework.parsers import MultiPartParser +from rest_framework.relations import PrimaryKeyRelatedField +from rest_framework.renderers import JSONRenderer +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.views import APIView +from rest_framework.viewsets import GenericViewSet +from rest_framework.exceptions import ParseError +import drf_hal_json +from drf_hal_json.views import HalCreateModelMixin +from drf_nested_fields.views import CustomFieldsMixin, copy_meta_attributes + +from drf_nested_routing.views import CreateNestedModelMixin, UpdateNestedModelMixin + +from drf_tools import utils +from drf_tools.serializers import HalNestedFieldsModelSerializer, CsvSerializer, XlsxSerializer + +logger = logging.getLogger(__name__) + + +def _add_parent_to_hal_request_data(request, parentKey): + if not drf_hal_json.is_hal_content_type(request.content_type): + return + links = request.data.get('_links') + if links and parentKey in links: + return + + if not links: + links = {} + request.data['_links'] = links + + uriSplit = request.build_absolute_uri().split('/') + if request.method == 'PUT': + uriSplit = uriSplit[:-3] # in case of PUT the id must be removed as well + else: + uriSplit = uriSplit[:-2] + + links[parentKey] = '/'.join(uriSplit) + '/' + + +class RestLoggingMixin(object): + """Provides full logging of requests and responses""" + + def finalize_response(self, request, response, *args, **kwargs): + if logger.isEnabledFor(logging.DEBUG): + logger.debug("{} {}".format(response.status_code, response.data)) + return super(RestLoggingMixin, self).finalize_response(request, response, *args, **kwargs) + + def initial(self, request, *args, **kwargs): + if logger.isEnabledFor(logging.DEBUG): + logger.debug("{} {} {} {}".format(request.method, request.path, request.query_params, request.data)) + super(RestLoggingMixin, self).initial(request, *args, **kwargs) + + +class DefaultSerializerMixin(object): + """ + If a view has no serializer_class specified, this mixin takes care of creating a default serializer_class that inherits + HalNestedFieldsModelSerializer + """ + + def get_serializer_class(self): + if not self.serializer_class: + class DefaultSerializer(HalNestedFieldsModelSerializer): + class Meta: + model = self.queryset.model + + self.serializer_class = DefaultSerializer + + return self.serializer_class + + +class HalNoLinksMixin(ListModelMixin): + """ + For responses with a high amount of data, link generation can be switched of via query-param 'no_links'. Instead of links, + simple ids are returned + """ + + def get_serializer_class(self): + no_links = extract_boolean_from_query_params(self.get_serializer_context().get('request'), "no_links") + if not no_links: + return super(HalNoLinksMixin, self).get_serializer_class() + + self.always_included_fields = ["id"] + serializer_class = super(HalNoLinksMixin, self).get_serializer_class() + + class HalNoLinksSerializer(serializer_class): + serializer_related_field = PrimaryKeyRelatedField + + class Meta: + pass + + copy_meta_attributes(serializer_class.Meta, Meta) + + @staticmethod + def _is_link_field(field): + return False + + @staticmethod + def _get_links_serializer(model_cls, link_field_names): + return None + + return HalNoLinksSerializer + + +class CreateModelMixin(CreateNestedModelMixin, HalCreateModelMixin): + """ + Parents of nested resources are automatically added to the request content, so that they don't have to be defined twice + (url and request content) + """ + + def _add_parent_to_request_data(self, request, parentKey, parentId): + _add_parent_to_hal_request_data(request, parentKey) + + +class ReadModelMixin(HalNoLinksMixin, CustomFieldsMixin, RetrieveModelMixin, ListModelMixin): + always_included_fields = ["id", api_settings.URL_FIELD_NAME] + + +class UpdateModelMixin(UpdateNestedModelMixin): + """ + Additionally to the django-method it is checked if the resource exists and 404 is returned if not, + instead of creating that resource + + Parents of nested resources are automatically added to the request content, so that they don't have to be defined twice + (url and request content) + """ + + def update(self, request, *args, **kwargs): + instance = self.get_object() + if instance is None: + return Response("Resource with the given id/pk does not exist.", status=status.HTTP_404_NOT_FOUND) + + return super(UpdateModelMixin, self).update(request, *args, **kwargs) + + def _add_parent_to_request_data(self, request, parentKey, parentId): + _add_parent_to_hal_request_data(request, parentKey) + + +class PartialUpdateOnlyMixin(object): + def partial_update(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=True) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + return Response(serializer.data) + + @staticmethod + def perform_update(serializer): + serializer.save() + + +class BaseViewSet(RestLoggingMixin, DefaultSerializerMixin, GenericViewSet): + pass + + +class ModelViewSet(CreateModelMixin, ReadModelMixin, UpdateModelMixin, DestroyModelMixin, BaseViewSet): + pass + + +class FileUploadView(RestLoggingMixin, APIView): + parser_classes = (MultiPartParser,) + renderer_classes = (JSONRenderer,) + + def _get_file_and_name(self, request): + file = self._get_file_from_request(request) + return file.file, file.name + + def _get_file_bytes_and_name(self, request): + file = self._get_file_from_request(request) + return file.read(), file.name + + @staticmethod + def _get_file_from_request(request): + in_memory_upload_file = request.data.get('file') + if not in_memory_upload_file or not in_memory_upload_file.file: + raise ValueError("Mulitpart content must contain file.") + return in_memory_upload_file + + +class XlsxImportView(FileUploadView): + default_sheet_name = None + media_type = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' + + def _get_xlsx_content_as_list_and_file_info(self, request): + file_bytes, filename = self._get_file_bytes_and_name(request) + sheetName = request.query_params.get('sheetName') or self.default_sheet_name + return XlsxSerializer.deserialize(file_bytes, sheetName), filename, file_bytes + + +class CsvImportView(FileUploadView): + media_type = 'text/csv' + + def _get_csv_content_as_list_and_file_info(self, request): + file_bytes, filename = self._get_file_bytes_and_name(request) + return CsvSerializer.deserialize(file_bytes), filename, file_bytes + + +def extract_int_from_query_params(request, key): + value = request.query_params.get(key) + if value: + try: + value = int(value) + except ValueError: + raise ParseError("Type of parameter '{}' must be 'int'".format(key)) + return value + + +def extract_datetime_from_query_params(request, key): + value = request.query_params.get(key) + if value: + try: + value = datetime.strptime(value, utils.DATETIME_FORMAT_ISO) + except ValueError: + raise ParseError( + "Value of parameter '{}' has wrong format. Use '{}' instead".format(key, utils.DATETIME_FORMAT_ISO)) + return value + + +def extract_enum_from_query_params(request, key, enum_type): + value = request.query_params.get(key) + choices = [context.value for context in enum_type] + if value: + if value.upper() not in choices: + raise ParseError("Value of query-parameter '{}' must be one out of {}".format(key, choices)) + return enum_type[value.upper()] + + return value + + +def extract_boolean_from_query_params(request, key): + value = request.query_params.get(key) + if not value: + return None + return value == 'true' + + +def get_instance_from_params(request, key, model_cls, optional=False): + value = extract_int_from_query_params(request, key) + if not value: + if not optional: + raise ParseError("Query-parameter '{}' must be set".format(key)) + else: + return None + + return model_cls.objects.get(id=value) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0330705 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +Django==1.8.2 +djangorestframework==3.1.3 +drf-nested-fields==0.9.2 +drf-hal-json==0.9.0 +drf-enum-field==0.9.1 +drf-nested-routing==0.9.0 +django-filter==0.10.0 +openpyxl==2.2.5 +chardet==2.3.0 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7c4bc22 --- /dev/null +++ b/setup.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +from setuptools import setup, find_packages + +setup( + name='drf-tools', + version="0.9.0", + url='https://github.com/seebass/drf-tools', + license='MIT', + description='Multiple extensions and test utilities for Django REST Framework 3', + author='Sebastian Bredehöft', + author_email='bredehoeft.sebastian@gmail.com', + packages=find_packages(exclude=['tests*']), + install_requires=[ + 'django>=1.8', + 'djangorestframework>=3.0.0', + 'drf-nested-fields>=0.9.0', + 'drf-hal-json>=0.9.0', + 'drf-enum-field>=0.9.0', + 'drf-nested-routin>=0.9.0', + 'django-filter>=0.10.0', + 'openpyxl>=2.2.5', + 'chardet>=2.3.0' + ], + zip_safe=False, + classifiers=[ + 'Development Status :: 4 - Beta', + 'Environment :: Web Environment', + 'Framework :: Django', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Topic :: Internet :: WWW/HTTP', + ] +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/manage.py b/tests/manage.py new file mode 100755 index 0000000..3a77e56 --- /dev/null +++ b/tests/manage.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 +import os +import sys + +if __name__ == "__main__": + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "dimm.settings") + + from django.core.management import execute_from_command_line + + execute_from_command_line(sys.argv) diff --git a/tests/testproject/__init__.py b/tests/testproject/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/testproject/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/testproject/migrations/0001_initial.py b/tests/testproject/migrations/0001_initial.py new file mode 100644 index 0000000..5b2d585 --- /dev/null +++ b/tests/testproject/migrations/0001_initial.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import models, migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='RelatedResource1', + fields=[ + ('id', models.AutoField(primary_key=True, verbose_name='ID', serialize=False, auto_created=True)), + ('name', models.CharField(max_length=255)), + ('active', models.BooleanField(default=True)), + ], + ), + migrations.CreateModel( + name='RelatedResource2', + fields=[ + ('id', models.AutoField(primary_key=True, verbose_name='ID', serialize=False, auto_created=True)), + ('name', models.CharField(max_length=255)), + ('active', models.BooleanField(default=True)), + ('related_resources_1', models.ManyToManyField(to='testproject.RelatedResource1')), + ], + ), + migrations.CreateModel( + name='TestResource', + fields=[ + ('id', models.AutoField(primary_key=True, verbose_name='ID', serialize=False, auto_created=True)), + ('name', models.CharField(max_length=255)), + ], + ), + migrations.AddField( + model_name='relatedresource2', + name='resource', + field=models.ForeignKey(to='testproject.TestResource'), + ), + migrations.AddField( + model_name='relatedresource1', + name='resource', + field=models.OneToOneField(to='testproject.TestResource'), + ), + ] diff --git a/tests/testproject/migrations/__init__.py b/tests/testproject/migrations/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/testproject/migrations/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/testproject/models.py b/tests/testproject/models.py new file mode 100644 index 0000000..43cf78b --- /dev/null +++ b/tests/testproject/models.py @@ -0,0 +1,18 @@ +from django.db import models + + +class TestResource(models.Model): + name = models.CharField(max_length=255) + + +class RelatedResource1(models.Model): + name = models.CharField(max_length=255) + active = models.BooleanField(default=True) + resource = models.OneToOneField(TestResource) + + +class RelatedResource2(models.Model): + name = models.CharField(max_length=255) + active = models.BooleanField(default=True) + related_resources_1 = models.ManyToManyField(RelatedResource1) + resource = models.ForeignKey(TestResource) diff --git a/tests/testproject/settings.py b/tests/testproject/settings.py new file mode 100644 index 0000000..9dffc59 --- /dev/null +++ b/tests/testproject/settings.py @@ -0,0 +1,67 @@ +import os + +DEBUG = True +TEMPLATE_DEBUG = True + +TIME_ZONE = 'Europe/Berlin' + +INSTALLED_APPS = ( + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'django.contrib.admindocs', + 'rest_framework', + 'testproject' +) + +ROOT_URLCONF = 'testproject.urls' + +SECRET_KEY = '9q7324#45RWtw843q$%&/' + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': os.path.join(os.path.dirname(os.path.dirname(__file__)), 'db.sqlite3'), + } +} + +MIDDLEWARE_CLASSES = ( + 'django.middleware.common.CommonMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', +) + +REST_FRAMEWORK = { + 'DEFAULT_PERMISSION_CLASSES': ['rest_framework.permissions.AllowAny'], + 'DEFAULT_PAGINATION_CLASS': 'drf_hal_json.pagination.HalPageNumberPagination', + 'DEFAULT_PARSER_CLASSES': ('drf_hal_json.parsers.JsonHalParser',), + 'DEFAULT_RENDERER_CLASSES': ('drf_hal_json.renderers.JsonHalRenderer',), +} + +STATIC_URL = '/static/' + +LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'handlers': { + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + } + }, + 'loggers': { + 'django.db.backends': { + 'handlers': ['console'], + 'level': 'INFO', + 'propagate': False + } + }, + 'root': { + 'handlers': ['console'], + 'level': 'INFO' + } +} diff --git a/tests/testproject/tests.py b/tests/testproject/tests.py new file mode 100644 index 0000000..35a56ac --- /dev/null +++ b/tests/testproject/tests.py @@ -0,0 +1,139 @@ +from drf_tools.test.base import IncludeFields, ModelViewSetTest, AdvancedReadModelViewSetTestMixin + +from .models import TestResource, RelatedResource1, RelatedResource2 + + +class TestResourceViewSetTest(AdvancedReadModelViewSetTestMixin, ModelViewSetTest): + def setUp(self): + super(TestResourceViewSetTest, self).setUp() + self.nameCt = 0 + + def _getModelClass(self): + return TestResource + + def _getOrCreateModelInstance(self): + self.nameCt += 1 + return TestResource.objects.create(name="resource_{}".format(self.nameCt)) + + def _createModelAsJson(self): + stateAttrs = { + "name": "test-resource", + } + linkAttrs = {} + return self._buildContent(stateAttrs, linkAttrs), None + + def _assertModelEqual(self, content, testResource): + stateAttrs, linkAttrs, embeddedAttrs = self._splitContent(content) + + self.assertEqual(stateAttrs.get('name'), testResource.name) + + if self._SELF_FIELD_NAME in content: + self.assertEqual(linkAttrs[self._SELF_FIELD_NAME], self._getAbsoluteDetailURI(testResource)) + + def _getUpdateAttributes(self): + return {"name": "new_name"} + + def _getIncludeFields(self): + return IncludeFields(["name"]) + + +class RelatedResource1ViewSetTest(AdvancedReadModelViewSetTestMixin, ModelViewSetTest): + def setUp(self): + super(RelatedResource1ViewSetTest, self).setUp() + self.nameCt = 0 + + def _getModelClass(self): + return RelatedResource1 + + def _getOrCreateModelInstance(self): + self.nameCt += 1 + resource = TestResource.objects.create(name="resource_{}".format(self.nameCt)) + return RelatedResource1.objects.create(name="relatedresource1_{}".format(self.nameCt), resource=resource) + + def _createModelAsJson(self): + testResource = TestResource.objects.create(name="test-resource") + stateAttrs = { + "name": "related-resource1", + "active": True, + } + linkAttrs = { + "resource": self._getAbsoluteDetailURI(testResource) + } + parentLookups = { + "resource": testResource.id + } + return self._buildContent(stateAttrs, linkAttrs), parentLookups + + def _assertModelEqual(self, content, relatedResource1): + stateAttrs, linkAttrs, embeddedAttrs = self._splitContent(content) + + self.assertEqual(stateAttrs.get('name'), relatedResource1.name) + self.assertEqual(stateAttrs.get('active'), relatedResource1.active) + + if self._SELF_FIELD_NAME in content: + self.assertEqual(linkAttrs[self._SELF_FIELD_NAME], self._getAbsoluteDetailURI(relatedResource1)) + self.assertEqual(linkAttrs["resource"], self._getAbsoluteDetailURI(relatedResource1.resource)) + + def _getUpdateAttributes(self): + return {"name": "new_name"} + + def _getIncludeFields(self): + return IncludeFields(["name"], [], {"resource": IncludeFields(["name"])}) + + +class RelatedResource2ViewSetTest(AdvancedReadModelViewSetTestMixin, ModelViewSetTest): + def setUp(self): + super(RelatedResource2ViewSetTest, self).setUp() + self.nameCt = 0 + + def _getModelClass(self): + return RelatedResource2 + + def _getOrCreateModelInstance(self): + self.nameCt += 1 + resource1 = TestResource.objects.create(name="resource1_{}".format(self.nameCt)) + resource2 = TestResource.objects.create(name="resource2_{}".format(self.nameCt)) + nestedRelatedResource11 = RelatedResource1.objects.create(name="nestedrelatedresource11_{}".format(self.nameCt), + resource=resource2) + resource3 = TestResource.objects.create(name="resource3_{}".format(self.nameCt)) + nestedRelatedResource12 = RelatedResource1.objects.create(name="nestedrelatedresource12_{}".format(self.nameCt), + resource=resource3) + relatedResource2 = RelatedResource2.objects.create(name="relatedresource2_{}".format(self.nameCt), resource=resource1) + relatedResource2.related_resources_1.add(nestedRelatedResource11, nestedRelatedResource12) + return relatedResource2 + + def _createModelAsJson(self): + resource1 = TestResource.objects.create(name="resource1") + resource2 = TestResource.objects.create(name="resource2") + nestedRelatedResource11 = RelatedResource1.objects.create(name="nestedrelatedresource11", resource=resource2) + resource3 = TestResource.objects.create(name="resource3") + nestedRelatedResource12 = RelatedResource1.objects.create(name="nestedrelatedresource12", resource=resource3) + stateAttrs = { + "name": "related-resource2", + "active": False, + } + linkAttrs = { + "related_resources_1": [self._getAbsoluteDetailURI(nestedRelatedResource11), + self._getAbsoluteDetailURI(nestedRelatedResource12)] + } + parentLookups = { + "resource": resource1.id + } + return self._buildContent(stateAttrs, linkAttrs), parentLookups + + def _assertModelEqual(self, content, relatedResource2): + stateAttrs, linkAttrs, embeddedAttrs = self._splitContent(content) + + self.assertEqual(stateAttrs.get('name'), relatedResource2.name) + self.assertEqual(stateAttrs.get('active'), relatedResource2.active) + + if self._SELF_FIELD_NAME in content: + self.assertEqual(linkAttrs[self._SELF_FIELD_NAME], self._getAbsoluteDetailURI(relatedResource2)) + self._assertLinksAndModelListEqual(linkAttrs["related_resources_1"], relatedResource2.related_resources_1) + self.assertEqual(linkAttrs["resource"], self._getAbsoluteDetailURI(relatedResource2.resource)) + + def _getUpdateAttributes(self): + return {"name": "new_name"} + + def _getIncludeFields(self): + return IncludeFields(["name"], ["resource"], {"related_resources_1": IncludeFields(["name"])}) diff --git a/tests/testproject/urls.py b/tests/testproject/urls.py new file mode 100644 index 0000000..64c69b3 --- /dev/null +++ b/tests/testproject/urls.py @@ -0,0 +1,18 @@ +from django.conf.urls import patterns, url, include + +from django.contrib import admin + +from drf_tools.routers import NestedRouterWithExtendedRootView +from .views import TestResourceViewSet, RelatedResource1ViewSet, RelatedResource2ViewSet + +admin.autodiscover() + +router = NestedRouterWithExtendedRootView(list()) +test_resource_route = router.register(r'test-resources', TestResourceViewSet) +test_resource_route.register(r'related-1', RelatedResource1ViewSet, ['resource']) +test_resource_route.register(r'related-2', RelatedResource2ViewSet, ['resource']) + +urlpatterns = patterns( + '', + url(r'', include(router.urls)), +) diff --git a/tests/testproject/views.py b/tests/testproject/views.py new file mode 100644 index 0000000..b10d070 --- /dev/null +++ b/tests/testproject/views.py @@ -0,0 +1,16 @@ +from drf_nested_routing.views import NestedViewSetMixin + +from drf_tools.views import ModelViewSet +from .models import TestResource, RelatedResource2, RelatedResource1 + + +class TestResourceViewSet(ModelViewSet): + queryset = TestResource.objects.all() + + +class RelatedResource1ViewSet(NestedViewSetMixin, ModelViewSet): + queryset = RelatedResource1.objects.all() + + +class RelatedResource2ViewSet(NestedViewSetMixin, ModelViewSet): + queryset = RelatedResource2.objects.all()