Skip to content

Commit

Permalink
tmp - use types in some of the examples importers
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Jan 3, 2025
1 parent cc708dd commit f4c3b0f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 29 deletions.
8 changes: 4 additions & 4 deletions beangulp/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def filename(self, filepath: str) -> Optional[str]:
"""
return None

def extract(self, filepath: str, existing: data.Entries) -> data.Entries:
def extract(self, filepath: str, existing: data.Directives) -> data.Directives:
"""Extract transactions and other directives from a document.
The existing entries list is loaded from the existing ledger
Expand All @@ -123,7 +123,7 @@ def extract(self, filepath: str, existing: data.Entries) -> data.Entries:

cmp = staticmethod(similar.heuristic_comparator())

def deduplicate(self, entries: data.Entries, existing: data.Entries) -> None:
def deduplicate(self, entries: data.Directives, existing: data.Directives) -> None:
"""Mark duplicates in extracted entries.
The default implementation uses the cmp() method to compare
Expand All @@ -144,7 +144,7 @@ def deduplicate(self, entries: data.Entries, existing: data.Entries) -> None:
window = datetime.timedelta(days=2)
extract.mark_duplicate_entries(entries, existing, window, self.cmp)

def sort(self, entries: data.Entries, reverse=False) -> None:
def sort(self, entries: data.Directives, reverse=False) -> None:
"""Sort the extracted directives.
The sort is in-place and stable. The reverse flag can be set
Expand Down Expand Up @@ -201,7 +201,7 @@ def file_date(self, file) -> Optional[datetime.date]:
def file_name(self, file) -> Optional[str]:
"""See Importer class filename() method."""

def extract(self, file, existing_entries: Optional[data.Entries] = None) -> data.Entries:
def extract(self, file, existing_entries: Optional[data.Directives] = None) -> data.Directives:
"""See Importer class extract() method."""
return []

Expand Down
17 changes: 11 additions & 6 deletions examples/importers/acme.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@
__copyright__ = "Copyright (C) 2016 Martin Blais"
__license__ = "GNU GPLv2"

import datetime
import re
import subprocess
from typing import Optional

from dateutil.parser import parse as parse_datetime

import beangulp
from beancount.core import data
from beangulp import mimetypes
from beangulp.cache import cache
from beangulp.testing import main


@cache
def pdf_to_text(filename):
def pdf_to_text(filename: str) -> str:
"""Convert a PDF document to a text equivalent."""
r = subprocess.run(['pdftotext', filename, '-'],
stdout=subprocess.PIPE, check=True)
Expand All @@ -31,10 +34,10 @@ def pdf_to_text(filename):
class Importer(beangulp.Importer):
"""An importer for ACME Bank PDF statements."""

def __init__(self, account_filing):
def __init__(self, account_filing: str) -> None:
self.account_filing = account_filing

def identify(self, filepath):
def identify(self, filepath: str) -> bool:
mimetype, encoding = mimetypes.guess_type(filepath)
if mimetype != 'application/pdf':
return False
Expand All @@ -44,20 +47,22 @@ def identify(self, filepath):
text = pdf_to_text(filepath)
if text:
return re.match('ACME Bank', text) is not None
return False

def filename(self, filepath):
def filename(self, filepath: str) -> str:
# Normalize the name to something meaningful.
return 'acmebank.pdf'

def account(self, filepath):
def account(self, filepath: str) -> data.Account:
return self.account_filing

def date(self, filepath):
def date(self, filepath: str) -> Optional[datetime.date]:
# Get the actual statement's date from the contents of the file.
text = pdf_to_text(filepath)
match = re.search('Date: ([^\n]*)', text)
if match:
return parse_datetime(match.group(1)).date()
return None


if __name__ == '__main__':
Expand Down
39 changes: 20 additions & 19 deletions examples/importers/utrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
class Importer(beangulp.Importer):
"""An importer for UTrade CSV files (an example investment bank)."""

def __init__(self, currency,
account_root,
account_cash,
account_dividends,
account_gains,
account_fees,
account_external):
def __init__(self, currency: str,
account_root: data.Account,
account_cash: data.Account,
account_dividends: data.Account,
account_gains: data.Account,
account_fees: data.Account,
account_external: data.Account) -> None:
self.currency = currency
self.account_root = account_root
self.account_cash = account_cash
Expand All @@ -41,7 +41,7 @@ def __init__(self, currency,
self.account_fees = account_fees
self.account_external = account_external

def identify(self, filepath):
def identify(self, filepath: str) -> bool:
# Match if the filename is as downloaded and the header has the unique
# fields combination we're looking for.
if not re.match(r"UTrade\d\d\d\d\d\d\d\d\.csv", path.basename(filepath)):
Expand All @@ -52,27 +52,28 @@ def identify(self, filepath):
return False
return True

def filename(self, filepath):
def filename(self, filepath: str) -> str:
return 'utrade.{}'.format(path.basename(filepath))

def account(self, filepath):
def account(self, filepath: str) -> data.Account:
return self.account_root

def date(self, filepath):
def date(self, filepath: str) -> datetime.date:
# Extract the statement date from the filename.
return datetime.datetime.strptime(path.basename(filepath),
'UTrade%Y%m%d.csv').date()

def extract(self, filepath, existing):
def extract(self, filepath: str, existing: data.Directives) -> data.Directives:
# Open the CSV file and create directives.
entries = []
entries: data.Directives = []
index = 0
with open(filepath) as infile:
for index, row in enumerate(csv.DictReader(infile)):
meta = data.new_metadata(filepath, index)
date = parse(row['DATE']).date()
rtype = row['TYPE']
link = f"ut{row['REF #']}"
links = frozenset([link])
desc = f"({row['TYPE']}) {row['DESCRIPTION']}"
units = amount.Amount(D(row['AMOUNT']), self.currency)
fees = amount.Amount(D(row['FEES']), self.currency)
Expand All @@ -81,7 +82,7 @@ def extract(self, filepath, existing):
if rtype == 'XFER':
assert fees.number == ZERO
txn = data.Transaction(
meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [
meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [
data.Posting(self.account_cash, units, None, None, None,
None),
data.Posting(self.account_external, -other, None, None, None,
Expand All @@ -100,7 +101,7 @@ def extract(self, filepath, existing):
account_dividends = self.account_dividends.format(instrument)

txn = data.Transaction(
meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [
meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [
data.Posting(self.account_cash, units, None, None, None, None),
data.Posting(account_dividends, -other, None, None, None, None),
])
Expand All @@ -122,9 +123,9 @@ def extract(self, filepath, existing):
rate = D(match.group(3))

if rtype == 'BUY':
cost = position.Cost(rate, self.currency, None, None)
cost = position.CostSpec(rate, None, self.currency, None, None, None)
txn = data.Transaction(
meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [
meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [
data.Posting(self.account_cash, units, None, None, None,
None),
data.Posting(self.account_fees, fees, None, None, None,
Expand All @@ -143,11 +144,11 @@ def extract(self, filepath, existing):
logging.error("Missing cost basis in '%s'", row['DESCRIPTION'])
continue
cost_number = D(match.group(1))
cost = position.Cost(cost_number, self.currency, None, None)
cost = position.CostSpec(cost_number, None, self.currency, None, None, None)
price = amount.Amount(rate, self.currency)
account_gains = self.account_gains.format(instrument)
txn = data.Transaction(
meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, {link}, [
meta, date, flags.FLAG_OKAY, None, desc, data.EMPTY_SET, links, [
data.Posting(self.account_cash, units, None, None, None,
None),
data.Posting(self.account_fees, fees, None, None, None,
Expand Down

0 comments on commit f4c3b0f

Please sign in to comment.