diff --git a/beangulp/cache.py b/beangulp/cache.py index 155cdcf..ab2494c 100644 --- a/beangulp/cache.py +++ b/beangulp/cache.py @@ -19,8 +19,8 @@ import chardet -from beancount.utils import defdict from beangulp import mimetypes +from beangulp import utils # NOTE: See get_file() at the end of this file to create instances of FileMemo. @@ -155,7 +155,7 @@ def get_file(filename): return _CACHE[filename] -_CACHE = defdict.DefaultDictWithKey(_FileMemo) +_CACHE = utils.DefaultDictWithKey(_FileMemo) def cache(func=None, *, key=None): diff --git a/beangulp/utils.py b/beangulp/utils.py index fa6c022..9567ee3 100644 --- a/beangulp/utils.py +++ b/beangulp/utils.py @@ -2,6 +2,7 @@ from os import path from typing import Iterator, Sequence, Union, Set, Optional, Dict import datetime +import collections import decimal import hashlib import logging @@ -13,6 +14,17 @@ from beangulp import mimetypes +class DefaultDictWithKey(collections.defaultdict): + """A version of defaultdict whose factory accepts the key as an argument. + Note: collections.defaultdict would be improved by supporting this directly, + this is a common occurrence. + """ + + def __missing__(self, key): + self[key] = value = self.default_factory(key) + return value + + def getmdate(filepath: str) -> datetime.date: """Return file modification date.""" mtime = path.getmtime(filepath) diff --git a/beangulp/utils_test.py b/beangulp/utils_test.py index 418c96c..1f323a7 100644 --- a/beangulp/utils_test.py +++ b/beangulp/utils_test.py @@ -5,7 +5,7 @@ import os import types import unittest - +from unittest import mock from shutil import rmtree from tempfile import mkdtemp @@ -111,3 +111,14 @@ def test_idify(self): ) self.assertEqual("A____B.pdf", utils.idify("A____B_._pdf")) + +class TestDefDictWithKey(unittest.TestCase): + def test_defdict_with_key(self): + factory = mock.MagicMock() + testdict = utils.DefaultDictWithKey(factory) + + testdict["a"] + testdict["b"] + self.assertEqual(2, len(factory.mock_calls)) + self.assertEqual(("a",), factory.mock_calls[0][1]) + self.assertEqual(("b",), factory.mock_calls[1][1])