diff --git a/django_coreapi/client.py b/django_coreapi/client.py index ac83f23..cd73429 100644 --- a/django_coreapi/client.py +++ b/django_coreapi/client.py @@ -34,7 +34,7 @@ def reload(self, document): transport = determine_transport(link.url, transports=self.transports) return transport.transition(link, decoders=self.decoders) - def action(self, document, keys, params=None, action=None, transform=None): + def action(self, document, keys, params=None, action=None, encoding=None, transform=None): if isinstance(keys, string_types): keys = [keys] @@ -42,13 +42,12 @@ def action(self, document, keys, params=None, action=None, transform=None): link, link_ancestors = _lookup_link(document, keys) url = _make_absolute(link.url) - if (action is not None) or (transform is not None): + if (action is not None) or (encoding is not None) or (transform is not None): # Handle any explicit overrides. action = link.action if (action is None) else action + encoding = link.encoding if (encoding is None) else encoding transform = link.transform if (transform is None) else transform - link = Link(url, action, transform, link.fields) - else: - link = Link(url, link.action, link.transform, link.fields) + link = Link(url, action=action, encoding=encoding, transform=transform, fields=link.fields) # Perform the action, and return a new document. transport = determine_transport(url, transports=self.transports) diff --git a/django_coreapi/mock.py b/django_coreapi/mock.py new file mode 100644 index 0000000..e71c1c9 --- /dev/null +++ b/django_coreapi/mock.py @@ -0,0 +1,66 @@ +import functools +import coreapi +import django_coreapi.client + + +_responses = [] + + +def get_match(keys): + """ + Find the first matching response in the current set + :param keys: the key path to be matched + :return: the matching response or None + """ + for response in _responses: + if response[0] == keys: + return response[1] + + +class Mock(object): + """ + Can be used as a context manager. Takes handler functions as arguments, which are evaluated in order in place of + """ + def __init__(self): + super(Mock, self).__init__() + + def __enter__(self): + # prepare and inject the mock methods to coreapi's session + self._real_action = coreapi.Client.action + self._real_django_client_action = django_coreapi.client.DjangoCoreAPIClient.action + + def fake_action(client, document, keys, *args, **kwargs): + res = get_match(keys) + if res is not None: + return res + raise + + coreapi.Client.action = fake_action + django_coreapi.client.DjangoCoreAPIClient.action = fake_action + + def __exit__(self, exc_type, value, tb): + global _responses + # replace the real methods + coreapi.Client.action = self._real_action + django_coreapi.client.DjangoCoreAPIClient.action = self._real_django_client_action + # clear out the match list + _responses = [] + + +def activate(f): + """ + A decorator which mocks the coreapi and django_coreapi clients, allowing use of `add` + :param f: the function to be wrapped + :return: the wrapped function + """ + @functools.wraps(f) + def decorated(*args, **kwargs): + with Mock(): + return f(*args, **kwargs) + + return decorated + + +def add(keys, response): + global _responses + _responses.append((keys, response)) diff --git a/tests.py b/tests.py index 935bca3..50c5ff3 100644 --- a/tests.py +++ b/tests.py @@ -4,6 +4,7 @@ settings.configure(ROOT_URLCONF='test_app', DEBUG=True, REST_FRAMEWORK={'UNAUTHENTICATED_USER': None}) import django django.setup() +from django_coreapi import mock from django_coreapi.client import DjangoCoreAPIClient from django_coreapi.transports import DjangoTestHTTPTransport @@ -35,3 +36,20 @@ def test_post_data(self): 'test': 'cat' }}) self.assertIsNotNone(doc) + + @mock.activate + def test_mocking(self): + mock.add(['test', 'post_data'], {"a": 1}) + content = { + 'test': { + 'post_data': Link(url='/post_data/', action='post', fields=[ + Field('data', location='body') + ]), + } + } + schema = Document(title='test', content=content) + client = DjangoCoreAPIClient() + doc = client.action(schema, ['test', 'post_data'], params={'data': { + 'test': 'cat' + }}) + self.assertEqual(doc, {"a": 1})