diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 6138bb5..178c2be 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -14,3 +14,14 @@ jobs: python-version: '3.11' - run: pip install ruff - run: ruff check beangulp/ examples/ + mypy: + runs-on: ubuntu-latest + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.9' + - run: pip install . mypy types-python-dateutil types-beautifulsoup4 beancount@git+https://github.com/beancount/beancount.git + - run: mypy beangulp examples diff --git a/beangulp/__init__.py b/beangulp/__init__.py index 00b0216..8b1408d 100644 --- a/beangulp/__init__.py +++ b/beangulp/__init__.py @@ -14,6 +14,7 @@ import sys import warnings import click +from typing import Callable, Optional, Union, Sequence from beancount import loader @@ -257,7 +258,9 @@ def _importer(importer): class Ingest: - def __init__(self, importers, hooks=None): + def __init__(self, + importers: Sequence[Union[Importer, ImporterProtocol]], + hooks: Optional[Sequence[Callable]] = None) -> None: self.importers = [_importer(i) for i in importers] self.hooks = list(hooks) if hooks is not None else [] diff --git a/beangulp/cache.py b/beangulp/cache.py index ab2494c..9c72b45 100644 --- a/beangulp/cache.py +++ b/beangulp/cache.py @@ -155,7 +155,7 @@ def get_file(filename): return _CACHE[filename] -_CACHE = utils.DefaultDictWithKey(_FileMemo) +_CACHE = utils.DefaultDictWithKey(_FileMemo) # type: ignore def cache(func=None, *, key=None): diff --git a/beangulp/file_type.py b/beangulp/file_type.py index 7c5f705..e8fac8b 100644 --- a/beangulp/file_type.py +++ b/beangulp/file_type.py @@ -12,6 +12,7 @@ __license__ = "GNU GPLv2" import warnings +from typing import Optional from beangulp import mimetypes @@ -19,10 +20,10 @@ try: import magic except (ImportError, OSError): - magic = None + magic = None # type: ignore -def guess_file_type(filename): +def guess_file_type(filename: str) -> Optional[str]: """Attempt to guess the type of the input file. Args: diff --git a/beangulp/importer.py b/beangulp/importer.py index 39281e4..09ca684 100644 --- a/beangulp/importer.py +++ b/beangulp/importer.py @@ -177,7 +177,7 @@ class ImporterProtocol: # you prefer to create your imported transactions with a different flag. FLAG = flags.FLAG_OKAY - def name(self): + def name(self) -> str: """See Importer class name property.""" return f"{self.__class__.__module__}.{self.__class__.__name__}" @@ -185,9 +185,11 @@ def name(self): def identify(self, file) -> bool: """See Importer class identify() method.""" + raise NotImplementedError def file_account(self, file) -> data.Account: """See Importer class account() method.""" + raise NotImplementedError def file_date(self, file) -> Optional[datetime.date]: """See Importer class date() method.""" @@ -195,19 +197,20 @@ def file_date(self, file) -> Optional[datetime.date]: def file_name(self, file) -> Optional[str]: """See Importer class filename() method.""" - def extract(self, file, existing_entries: data.Entries = None) -> data.Entries: + def extract(self, file, existing_entries: Optional[data.Entries] = None) -> data.Entries: """See Importer class extract() method.""" + return [] class Adapter(Importer): """Adapter from ImporterProtocol to Importer ABC interface.""" - def __init__(self, importer): + def __init__(self, importer: ImporterProtocol) -> None: assert isinstance(importer, ImporterProtocol) self.importer = importer @property - def name(self): + def name(self) -> str: return self.importer.name() def identify(self, filepath): diff --git a/beangulp/importer_test.py b/beangulp/importer_test.py index c9db9de..266e1f2 100644 --- a/beangulp/importer_test.py +++ b/beangulp/importer_test.py @@ -14,8 +14,10 @@ def test_importer_methods(self): memo = cache._FileMemo('/tmp/test') imp = importer.ImporterProtocol() self.assertIsInstance(imp.FLAG, str) - self.assertFalse(imp.identify(memo)) + with self.assertRaises(NotImplementedError): + self.assertFalse(imp.identify(memo)) self.assertFalse(imp.extract(memo)) - self.assertFalse(imp.file_account(memo)) + with self.assertRaises(NotImplementedError): + self.assertFalse(imp.file_account(memo)) self.assertFalse(imp.file_date(memo)) self.assertFalse(imp.file_name(memo)) diff --git a/beangulp/importers/csv.py b/beangulp/importers/csv.py index bc13e76..a371048 100644 --- a/beangulp/importers/csv.py +++ b/beangulp/importers/csv.py @@ -11,7 +11,7 @@ from inspect import signature from os import path -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union, Tuple import collections import csv import datetime @@ -159,7 +159,7 @@ def normalize_config(config, head, dialect='excel', skip_lines: int = 0): def prepare_for_identifier(regexps: Union[str, List[str]], - matchers: Optional[List[str]]) -> Dict[str, str]: + matchers: Optional[List[Tuple[str, str]]]) -> Dict[str, List[Tuple[str, str]]]: """Prepare data for identifier mixin.""" if isinstance(regexps, str): regexps = [regexps] diff --git a/beangulp/importers/csvbase.py b/beangulp/importers/csvbase.py index 2c79c4b..110957e 100644 --- a/beangulp/importers/csvbase.py +++ b/beangulp/importers/csvbase.py @@ -7,11 +7,12 @@ from collections import defaultdict from itertools import islice +from typing import Any, Dict, FrozenSet from beancount.core import data import beangulp -EMPTY = frozenset() +EMPTY: FrozenSet[str] = frozenset() def _resolve(spec, names): @@ -191,7 +192,7 @@ class CSVReader(metaclass=CSVMeta): """Order of entries in the CSV file. If None the order will be inferred from the file content.""" # This is populated by the CSVMeta metaclass. - columns = {} + columns: Dict[Any, Any] = {} def read(self, filepath): """Read CSV file according to class defined columns specification. diff --git a/beangulp/petl_utils.py b/beangulp/petl_utils.py index 4d45932..54052b2 100644 --- a/beangulp/petl_utils.py +++ b/beangulp/petl_utils.py @@ -1,11 +1,11 @@ """Utilities using petl. """ -from typing import Optional +from typing import Optional, Set import datetime import re -import petl +import petl # type: ignore from beancount.core import data from beancount.core import amount @@ -47,32 +47,33 @@ def table_to_directives( metas.append((column, match.group(1))) # Create transactions. - entries = [] + entries: data.Entries = [] filename = filename or f"<{__file__}>" for index, rec in enumerate(table.records()): meta = data.new_metadata(filename, index) units = amount.Amount(rec.amount, currency) - tags, links = set(), set() + tags: Set[str] = set() + links: Set[str] = set() + link = getattr(rec, "link", None) + if link: + links.add(link) + tag = getattr(rec, "tag", None) + if tag: + tags.add(tag) txn = data.Transaction( meta, rec.date, flags.FLAG_OKAY, getattr(rec, "payee", None), getattr(rec, "narration", ""), - tags, - links, + frozenset(tags), + frozenset(links), [data.Posting(rec.account, units, None, None, None, None)], ) if hasattr(rec, "other_account") and rec.other_account: txn.postings.append( data.Posting(rec.other_account, None, None, None, None, None) ) - link = getattr(rec, "link", None) - if link: - links.add(link) - tag = getattr(rec, "tag", None) - if tag: - tags.add(tag) for column, key in metas: value = getattr(rec, column, None) diff --git a/beangulp/petl_utils_test.py b/beangulp/petl_utils_test.py index ad1c25b..9951c37 100644 --- a/beangulp/petl_utils_test.py +++ b/beangulp/petl_utils_test.py @@ -1,7 +1,7 @@ import datetime import decimal import unittest -import petl +import petl # type: ignore from beancount.parser import cmptest from beangulp import petl_utils diff --git a/beangulp/similar.py b/beangulp/similar.py index f2c6819..03dcf2c 100644 --- a/beangulp/similar.py +++ b/beangulp/similar.py @@ -7,7 +7,7 @@ __license__ = "GNU GPLv2" from decimal import Decimal -from typing import Callable, Optional +from typing import Callable, Optional, FrozenSet, Set, Union import collections import datetime import re @@ -196,8 +196,8 @@ def cmp(entry1: data.Directive, entry2: data.Directive) -> bool: ): return False - links1 = entry1.links - links2 = entry2.links + links1: Union[FrozenSet[str], Set[str]] = entry1.links + links2: Union[FrozenSet[str], Set[str]] = entry2.links if regex: links1 = {link for link in links1 if re.match(regex, link)} links2 = {link for link in links2 if re.match(regex, link)} diff --git a/beangulp/testing.py b/beangulp/testing.py index 9da6cd2..e81a1e8 100644 --- a/beangulp/testing.py +++ b/beangulp/testing.py @@ -34,10 +34,10 @@ def write_expected(outfile: TextIO, name: The filename for filing, produced by the importer. entries: The list of entries extracted by the importer. """ - date = date.isoformat() if date else '' + formatted_date = date.isoformat() if date else '' name = name or '' print(f';; Account: {account}', file=outfile) - print(f';; Date: {date}', file=outfile) + print(f';; Date: {formatted_date}', file=outfile) print(f';; Name: {name}', file=outfile) printer.print_entries(entries, file=outfile) @@ -46,7 +46,7 @@ def write_expected_file(filepath: str, *data, force: bool = False): """Writes out the expected file.""" mode = 'w' if force else 'x' with open(filepath, mode) as expfile: - write_expected(expfile, *data) + write_expected(expfile, *data) # type: ignore def compare_expected(filepath: str, *data) -> List[str]: diff --git a/examples/importers/acme.py b/examples/importers/acme.py index 67c18fb..b3d24c1 100644 --- a/examples/importers/acme.py +++ b/examples/importers/acme.py @@ -9,19 +9,22 @@ __copyright__ = "Copyright (C) 2016 Martin Blais" __license__ = "GNU GPLv2" +import datetime import re import subprocess +from typing import Optional from dateutil.parser import parse as parse_datetime import beangulp +from beancount.core import data from beangulp import mimetypes from beangulp.cache import cache from beangulp.testing import main @cache -def pdf_to_text(filename): +def pdf_to_text(filename: str) -> str: """Convert a PDF document to a text equivalent.""" r = subprocess.run(['pdftotext', filename, '-'], stdout=subprocess.PIPE, check=True) @@ -31,10 +34,10 @@ def pdf_to_text(filename): class Importer(beangulp.Importer): """An importer for ACME Bank PDF statements.""" - def __init__(self, account_filing): + def __init__(self, account_filing: str) -> None: self.account_filing = account_filing - def identify(self, filepath): + def identify(self, filepath: str) -> bool: mimetype, encoding = mimetypes.guess_type(filepath) if mimetype != 'application/pdf': return False @@ -44,20 +47,22 @@ def identify(self, filepath): text = pdf_to_text(filepath) if text: return re.match('ACME Bank', text) is not None + return False - def filename(self, filepath): + def filename(self, filepath: str) -> str: # Normalize the name to something meaningful. return 'acmebank.pdf' - def account(self, filepath): + def account(self, filepath: str) -> data.Account: return self.account_filing - def date(self, filepath): + def date(self, filepath: str) -> Optional[datetime.date]: # Get the actual statement's date from the contents of the file. text = pdf_to_text(filepath) match = re.search('Date: ([^\n]*)', text) if match: return parse_datetime(match.group(1)).date() + return None if __name__ == '__main__': diff --git a/examples/importers/csvbank.py b/examples/importers/csvbank.py index 7cc47d7..30d44b9 100644 --- a/examples/importers/csvbank.py +++ b/examples/importers/csvbank.py @@ -5,7 +5,7 @@ class Importer(csvbase.Importer): - date = csvbase.Date('Posting Date', '%m/%d/%Y') + date = csvbase.Date('Posting Date', '%m/%d/%Y') # type: ignore narration = csvbase.Columns('Description', 'Check or Slip #', sep='; ') amount = csvbase.Amount('Amount') balance = csvbase.Amount('Balance') diff --git a/examples/importers/utrade.py b/examples/importers/utrade.py index 2f0c742..26111d1 100644 --- a/examples/importers/utrade.py +++ b/examples/importers/utrade.py @@ -26,13 +26,13 @@ class Importer(beangulp.Importer): """An importer for UTrade CSV files (an example investment bank).""" - def __init__(self, currency, - account_root, - account_cash, - account_dividends, - account_gains, - account_fees, - account_external): + def __init__(self, currency: str, + account_root: data.Account, + account_cash: data.Account, + account_dividends: data.Account, + account_gains: data.Account, + account_fees: data.Account, + account_external: data.Account) -> None: self.currency = currency self.account_root = account_root self.account_cash = account_cash @@ -41,7 +41,7 @@ def __init__(self, currency, self.account_fees = account_fees self.account_external = account_external - def identify(self, filepath): + def identify(self, filepath: str) -> bool: # Match if the filename is as downloaded and the header has the unique # fields combination we're looking for. if not re.match(r"UTrade\d\d\d\d\d\d\d\d\.csv", path.basename(filepath)): @@ -52,27 +52,27 @@ def identify(self, filepath): return False return True - def filename(self, filepath): + def filename(self, filepath: str) -> str: return 'utrade.{}'.format(path.basename(filepath)) - def account(self, filepath): + def account(self, filepath: str) -> data.Account: return self.account_root - def date(self, filepath): + def date(self, filepath: str) -> datetime.date: # Extract the statement date from the filename. return datetime.datetime.strptime(path.basename(filepath), 'UTrade%Y%m%d.csv').date() - def extract(self, filepath, existing): + def extract(self, filepath: str, existing: data.Entries) -> data.Entries: # Open the CSV file and create directives. - entries = [] + entries: data.Entries = [] index = 0 with open(filepath) as infile: for index, row in enumerate(csv.DictReader(infile)): meta = data.new_metadata(filepath, index) date = parse(row['DATE']).date() rtype = row['TYPE'] - link = f"ut{row['REF #']}" + links = frozenset([f"ut{row['REF #']}"]) desc = f"({row['TYPE']}) {row['DESCRIPTION']}" units = amount.Amount(D(row['AMOUNT']), self.currency) fees = amount.Amount(D(row['FEES']), self.currency) @@ -81,7 +81,7 @@ def extract(self, filepath, existing): if rtype == 'XFER': assert fees.number == ZERO txn = data.Transaction( - meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [ + meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [ data.Posting(self.account_cash, units, None, None, None, None), data.Posting(self.account_external, -other, None, None, None, @@ -100,7 +100,7 @@ def extract(self, filepath, existing): account_dividends = self.account_dividends.format(instrument) txn = data.Transaction( - meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [ + meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [ data.Posting(self.account_cash, units, None, None, None, None), data.Posting(account_dividends, -other, None, None, None, None), ]) @@ -122,9 +122,9 @@ def extract(self, filepath, existing): rate = D(match.group(3)) if rtype == 'BUY': - cost = position.Cost(rate, self.currency, None, None) + cost = position.CostSpec(rate, None, self.currency, None, None, None) txn = data.Transaction( - meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [ + meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [ data.Posting(self.account_cash, units, None, None, None, None), data.Posting(self.account_fees, fees, None, None, None, @@ -143,11 +143,11 @@ def extract(self, filepath, existing): logging.error("Missing cost basis in '%s'", row['DESCRIPTION']) continue cost_number = D(match.group(1)) - cost = position.Cost(cost_number, self.currency, None, None) + cost = position.CostSpec(cost_number, None, self.currency, None, None, None) price = amount.Amount(rate, self.currency) account_gains = self.account_gains.format(instrument) txn = data.Transaction( - meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [ + meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [ data.Posting(self.account_cash, units, None, None, None, None), data.Posting(self.account_fees, fees, None, None, None,