Skip to content

Commit

Permalink
migrate DefaultDictWithKey
Browse files Browse the repository at this point in the history
  • Loading branch information
trim21 authored and blais committed Dec 23, 2024
1 parent 97ffca4 commit f12c168
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
4 changes: 2 additions & 2 deletions beangulp/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import chardet

from beancount.utils import defdict
from beangulp import mimetypes
from beangulp import utils

# NOTE: See get_file() at the end of this file to create instances of FileMemo.

Expand Down Expand Up @@ -155,7 +155,7 @@ def get_file(filename):
return _CACHE[filename]


_CACHE = defdict.DefaultDictWithKey(_FileMemo)
_CACHE = utils.DefaultDictWithKey(_FileMemo)


def cache(func=None, *, key=None):
Expand Down
12 changes: 12 additions & 0 deletions beangulp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from os import path
from typing import Iterator, Sequence, Union, Set, Optional, Dict
import datetime
import collections
import decimal
import hashlib
import logging
Expand All @@ -13,6 +14,17 @@
from beangulp import mimetypes


class DefaultDictWithKey(collections.defaultdict):
"""A version of defaultdict whose factory accepts the key as an argument.
Note: collections.defaultdict would be improved by supporting this directly,
this is a common occurrence.
"""

def __missing__(self, key):
self[key] = value = self.default_factory(key)
return value


def getmdate(filepath: str) -> datetime.date:
"""Return file modification date."""
mtime = path.getmtime(filepath)
Expand Down
13 changes: 12 additions & 1 deletion beangulp/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import types
import unittest

from unittest import mock
from shutil import rmtree
from tempfile import mkdtemp

Expand Down Expand Up @@ -111,3 +111,14 @@ def test_idify(self):
)
self.assertEqual("A____B.pdf", utils.idify("A____B_._pdf"))


class TestDefDictWithKey(unittest.TestCase):
def test_defdict_with_key(self):
factory = mock.MagicMock()
testdict = utils.DefaultDictWithKey(factory)

testdict["a"]
testdict["b"]
self.assertEqual(2, len(factory.mock_calls))
self.assertEqual(("a",), factory.mock_calls[0][1])
self.assertEqual(("b",), factory.mock_calls[1][1])

0 comments on commit f12c168

Please sign in to comment.