From f4c3b0f841c5b15fde9ff9582096fd750bfa9b09 Mon Sep 17 00:00:00 2001 From: Jakob Schnitzer Date: Fri, 3 Jan 2025 09:01:52 +0100 Subject: [PATCH] tmp - use types in some of the examples importers --- beangulp/importer.py | 8 ++++---- examples/importers/acme.py | 17 ++++++++++------ examples/importers/utrade.py | 39 ++++++++++++++++++------------------ 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/beangulp/importer.py b/beangulp/importer.py index 91746cc..48947c5 100644 --- a/beangulp/importer.py +++ b/beangulp/importer.py @@ -102,7 +102,7 @@ def filename(self, filepath: str) -> Optional[str]: """ return None - def extract(self, filepath: str, existing: data.Entries) -> data.Entries: + def extract(self, filepath: str, existing: data.Directives) -> data.Directives: """Extract transactions and other directives from a document. The existing entries list is loaded from the existing ledger @@ -123,7 +123,7 @@ def extract(self, filepath: str, existing: data.Entries) -> data.Entries: cmp = staticmethod(similar.heuristic_comparator()) - def deduplicate(self, entries: data.Entries, existing: data.Entries) -> None: + def deduplicate(self, entries: data.Directives, existing: data.Directives) -> None: """Mark duplicates in extracted entries. The default implementation uses the cmp() method to compare @@ -144,7 +144,7 @@ def deduplicate(self, entries: data.Entries, existing: data.Entries) -> None: window = datetime.timedelta(days=2) extract.mark_duplicate_entries(entries, existing, window, self.cmp) - def sort(self, entries: data.Entries, reverse=False) -> None: + def sort(self, entries: data.Directives, reverse=False) -> None: """Sort the extracted directives. The sort is in-place and stable. The reverse flag can be set @@ -201,7 +201,7 @@ def file_date(self, file) -> Optional[datetime.date]: def file_name(self, file) -> Optional[str]: """See Importer class filename() method.""" - def extract(self, file, existing_entries: Optional[data.Entries] = None) -> data.Entries: + def extract(self, file, existing_entries: Optional[data.Directives] = None) -> data.Directives: """See Importer class extract() method.""" return [] diff --git a/examples/importers/acme.py b/examples/importers/acme.py index 67c18fb..b3d24c1 100644 --- a/examples/importers/acme.py +++ b/examples/importers/acme.py @@ -9,19 +9,22 @@ __copyright__ = "Copyright (C) 2016 Martin Blais" __license__ = "GNU GPLv2" +import datetime import re import subprocess +from typing import Optional from dateutil.parser import parse as parse_datetime import beangulp +from beancount.core import data from beangulp import mimetypes from beangulp.cache import cache from beangulp.testing import main @cache -def pdf_to_text(filename): +def pdf_to_text(filename: str) -> str: """Convert a PDF document to a text equivalent.""" r = subprocess.run(['pdftotext', filename, '-'], stdout=subprocess.PIPE, check=True) @@ -31,10 +34,10 @@ def pdf_to_text(filename): class Importer(beangulp.Importer): """An importer for ACME Bank PDF statements.""" - def __init__(self, account_filing): + def __init__(self, account_filing: str) -> None: self.account_filing = account_filing - def identify(self, filepath): + def identify(self, filepath: str) -> bool: mimetype, encoding = mimetypes.guess_type(filepath) if mimetype != 'application/pdf': return False @@ -44,20 +47,22 @@ def identify(self, filepath): text = pdf_to_text(filepath) if text: return re.match('ACME Bank', text) is not None + return False - def filename(self, filepath): + def filename(self, filepath: str) -> str: # Normalize the name to something meaningful. return 'acmebank.pdf' - def account(self, filepath): + def account(self, filepath: str) -> data.Account: return self.account_filing - def date(self, filepath): + def date(self, filepath: str) -> Optional[datetime.date]: # Get the actual statement's date from the contents of the file. text = pdf_to_text(filepath) match = re.search('Date: ([^\n]*)', text) if match: return parse_datetime(match.group(1)).date() + return None if __name__ == '__main__': diff --git a/examples/importers/utrade.py b/examples/importers/utrade.py index 2f0c742..ec1d6b9 100644 --- a/examples/importers/utrade.py +++ b/examples/importers/utrade.py @@ -26,13 +26,13 @@ class Importer(beangulp.Importer): """An importer for UTrade CSV files (an example investment bank).""" - def __init__(self, currency, - account_root, - account_cash, - account_dividends, - account_gains, - account_fees, - account_external): + def __init__(self, currency: str, + account_root: data.Account, + account_cash: data.Account, + account_dividends: data.Account, + account_gains: data.Account, + account_fees: data.Account, + account_external: data.Account) -> None: self.currency = currency self.account_root = account_root self.account_cash = account_cash @@ -41,7 +41,7 @@ def __init__(self, currency, self.account_fees = account_fees self.account_external = account_external - def identify(self, filepath): + def identify(self, filepath: str) -> bool: # Match if the filename is as downloaded and the header has the unique # fields combination we're looking for. if not re.match(r"UTrade\d\d\d\d\d\d\d\d\.csv", path.basename(filepath)): @@ -52,20 +52,20 @@ def identify(self, filepath): return False return True - def filename(self, filepath): + def filename(self, filepath: str) -> str: return 'utrade.{}'.format(path.basename(filepath)) - def account(self, filepath): + def account(self, filepath: str) -> data.Account: return self.account_root - def date(self, filepath): + def date(self, filepath: str) -> datetime.date: # Extract the statement date from the filename. return datetime.datetime.strptime(path.basename(filepath), 'UTrade%Y%m%d.csv').date() - def extract(self, filepath, existing): + def extract(self, filepath: str, existing: data.Directives) -> data.Directives: # Open the CSV file and create directives. - entries = [] + entries: data.Directives = [] index = 0 with open(filepath) as infile: for index, row in enumerate(csv.DictReader(infile)): @@ -73,6 +73,7 @@ def extract(self, filepath, existing): date = parse(row['DATE']).date() rtype = row['TYPE'] link = f"ut{row['REF #']}" + links = frozenset([link]) desc = f"({row['TYPE']}) {row['DESCRIPTION']}" units = amount.Amount(D(row['AMOUNT']), self.currency) fees = amount.Amount(D(row['FEES']), self.currency) @@ -81,7 +82,7 @@ def extract(self, filepath, existing): if rtype == 'XFER': assert fees.number == ZERO txn = data.Transaction( - meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [ + meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [ data.Posting(self.account_cash, units, None, None, None, None), data.Posting(self.account_external, -other, None, None, None, @@ -100,7 +101,7 @@ def extract(self, filepath, existing): account_dividends = self.account_dividends.format(instrument) txn = data.Transaction( - meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [ + meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [ data.Posting(self.account_cash, units, None, None, None, None), data.Posting(account_dividends, -other, None, None, None, None), ]) @@ -122,9 +123,9 @@ def extract(self, filepath, existing): rate = D(match.group(3)) if rtype == 'BUY': - cost = position.Cost(rate, self.currency, None, None) + cost = position.CostSpec(rate, None, self.currency, None, None, None) txn = data.Transaction( - meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [ + meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [ data.Posting(self.account_cash, units, None, None, None, None), data.Posting(self.account_fees, fees, None, None, None, @@ -143,11 +144,11 @@ def extract(self, filepath, existing): logging.error("Missing cost basis in '%s'", row['DESCRIPTION']) continue cost_number = D(match.group(1)) - cost = position.Cost(cost_number, self.currency, None, None) + cost = position.CostSpec(cost_number, None, self.currency, None, None, None) price = amount.Amount(rate, self.currency) account_gains = self.account_gains.format(instrument) txn = data.Transaction( - meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [ + meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [ data.Posting(self.account_cash, units, None, None, None, None), data.Posting(self.account_fees, fees, None, None, None,