diff --git a/boto3_helpers/dynamodb.py b/boto3_helpers/dynamodb.py index ff1ecc7..aa2b0be 100644 --- a/boto3_helpers/dynamodb.py +++ b/boto3_helpers/dynamodb.py @@ -1,14 +1,24 @@ +from base64 import b64decode +from json import loads + from boto3 import resource as boto3_resource from boto3.dynamodb.types import TypeDeserializer, TypeSerializer from time import sleep -class CustomTypeDeserializer(TypeDeserializer): - def __init__(self, *args, use_decimal=False, **kwargs): +class _CustomTypeDeserializer(TypeDeserializer): + def __init__(self, *args, use_decimal=False, decode_binary=False, **kwargs): self.use_decimal = use_decimal + self.decode_binary = decode_binary super().__init__(*args, **kwargs) + def _deserialize_b(self, value): + if self.decode_binary: + return b64decode(value) + + return super()._deserialize_b(value) + def _deserialize_n(self, value): if self.use_decimal: return super()._deserialize_n(value) @@ -219,7 +229,7 @@ def batch_yield_items( def fix_numbers(item): - """``boto3`` DB infamously deserializes numeric types from DynamoDB to + """``boto3`` infamously deserializes numeric types from DynamoDB to Python ``Decimal`` objects. This function changes these objects into ``int`` objects and ``float`` objects. @@ -238,6 +248,40 @@ def fix_numbers(item): so think about what your application needs before using this function. """ s = TypeSerializer().serialize - d = CustomTypeDeserializer().deserialize + d = _CustomTypeDeserializer().deserialize wire_format = {k: s(v) for k, v in item.items()} return {k: d(v) for k, v in wire_format.items()} + + +def load_dynamodb_json(text, use_decimal=False): + """The DynamoDB API returns JSON data with typing information. This function + deserializes this JSON format into standard Python types. + + .. code-block:: python + + from boto3 import resource as load_dynamodb_json + + text = '{"Item": {"some_number": {"N": "100"}}}' + info = load_dynamodb_json(text) + assert info['Item']['some_number'] == 100 + + JSON from the ``GetItem``, ``Query``, and ``Scan`` API endpoints is supported. + + If ``use_decimal`` is ``True``, numeric types will be deserialized to + ``decimal.Decimal`` objects. This matches the ``boto3`` client behavior, but + is often inconvenient. + """ + d = _CustomTypeDeserializer(use_decimal=use_decimal, decode_binary=True).deserialize + ret = {} + for key, value in loads(text).items(): + if key == 'Item': + ret['Item'] = {k: d(v) for k, v in value.items()} + elif key == 'Items': + all_items = [] + for item in value: + all_items.append({k: d(v) for k, v in item.items()}) + ret['Items'] = all_items + else: + ret[key] = value + + return ret diff --git a/tests/test_dynamodb.py b/tests/test_dynamodb.py index 2d64ab7..757f3e1 100644 --- a/tests/test_dynamodb.py +++ b/tests/test_dynamodb.py @@ -9,11 +9,79 @@ from boto3_helpers.dynamodb import ( batch_yield_items, fix_numbers, + load_dynamodb_json, query_table, scan_table, update_attributes, ) +SCAN_RESPONSE = """\ +{ + "Items": [ + { + "bin_set": { + "BS": [ + "dGhpcyB0ZXh0IGlzIGJhc2U2NC1lbmNvZGVk" + ] + }, + "string_set": { + "SS": [ + "ss_1", + "ss_2" + ] + }, + "number_int": { + "N": "1" + }, + "number_set": { + "NS": [ + "1.1", + "1" + ] + }, + "string_literal": { + "S": "s" + }, + "list_value": { + "L": [ + { + "S": "sl_1" + }, + { + "N": "1" + } + ] + }, + "bin_value": { + "B": "dGhpcyB0ZXh0IGlzIGJhc2U2NC1lbmNvZGVk" + }, + "bool_value": { + "BOOL": true + }, + "null_value": { + "NULL": true + }, + "number_float": { + "N": "1.1" + }, + "map_value": { + "M": { + "n_key": { + "N": "1.1" + }, + "s_key": { + "S": "s_value" + } + } + } + } + ], + "Count": 1, + "ScannedCount": 1, + "ConsumedCapacity": null +} +""" + class DynamoDBTests(TestCase): def test_query_table(self): @@ -276,3 +344,51 @@ def test_fix_numbers(self): 'map_value': {'n_key': 1.1, 's_key': 's_value'}, } self.assertEqual(actual, expected) + + def test_load_dynamodb_json_scan(self): + actual = load_dynamodb_json(SCAN_RESPONSE) + expected = { + 'Items': [ + { + 'bin_set': {b'this text is base64-encoded'}, + 'string_set': {'ss_1', 'ss_2'}, + 'number_int': 1, + 'number_set': {1.1, 1}, + 'string_literal': 's', + 'list_value': ['sl_1', 1], + 'bin_value': b'this text is base64-encoded', + 'bool_value': True, + 'null_value': None, + 'number_float': 1.1, + 'map_value': {'n_key': 1.1, 's_key': 's_value'}, + } + ], + 'Count': 1, + 'ScannedCount': 1, + 'ConsumedCapacity': None, + } + self.assertEqual(actual, expected) + + def test_load_dynamodb_json_get(self): + i = 0 + for text, use_decimal, expected in ( + ( + '{"Item": {"some_number": {"N": "100"}}}', + False, + {'Item': {'some_number': 100}}, + ), + ( + '{"Item": {"some_number": {"N": "100.1"}}}', + False, + {'Item': {'some_number': 100.1}}, + ), + ( + '{"Item": {"some_number": {"N": "100.1"}}}', + True, + {'Item': {'some_number': Decimal('100.1')}}, + ), + ): + i += 1 + with self.subTest(i=i): + actual = load_dynamodb_json(text, use_decimal=use_decimal) + self.assertEqual(actual, expected)