diff --git a/broker.py b/broker.py new file mode 100644 index 0000000..5f87ae0 --- /dev/null +++ b/broker.py @@ -0,0 +1,38 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +import utils + + +class Broker(ABC): + + @classmethod + @abstractmethod + def name(cls) -> str: + ... + + @classmethod + @abstractmethod + def isFileForBroker(cls, file: str) -> bool: + ... + + @classmethod + @abstractmethod + def parseFileToTxnList(cls, file: str, tax_year: Optional[int]) -> list[utils.Transaction]: + ... diff --git a/brokers.py b/brokers.py index 285ec61..c202fba 100644 --- a/brokers.py +++ b/brokers.py @@ -28,12 +28,17 @@ def isFileForBroker(cls, filename): 3) Add your class to the BROKERS map below. """ +from __future__ import annotations + +from typing import Optional, Type + +from broker import Broker from interactive_brokers import InteractiveBrokers from tdameritrade import TDAmeritrade from vanguard import Vanguard -BROKERS = { +BROKERS: dict[str, Type[Broker]] = { 'amtd': TDAmeritrade, 'ib': InteractiveBrokers, 'tdameritrade': TDAmeritrade, @@ -41,7 +46,7 @@ def isFileForBroker(cls, filename): } -def DetectBroker(filename): +def DetectBroker(filename: str) -> Optional[Type[Broker]]: for (broker_name, broker) in BROKERS.items(): if hasattr(broker, 'isFileForBroker'): if broker.isFileForBroker(filename): @@ -50,12 +55,11 @@ def DetectBroker(filename): return None -def GetBroker(broker_name, filename): - if not broker_name or broker_name not in BROKERS: - broker = DetectBroker(filename) - else: - broker = BROKERS[broker_name] +def GetBroker(broker_name: str, filename: str) -> Type[Broker]: + if broker_name in BROKERS: + return BROKERS[broker_name] + broker: Optional[Type[Broker]] = DetectBroker(filename) if not broker: raise Exception('Invalid broker name: %s' % broker_name) diff --git a/csv2txf.py b/csv2txf.py index 0b1cd2b..18e893a 100755 --- a/csv2txf.py +++ b/csv2txf.py @@ -23,19 +23,23 @@ * TXF standard: http://turbotax.intuit.com/txf/ """ +from __future__ import annotations + from decimal import Decimal from datetime import datetime import sys -from utils import txfDate +from typing import List + from brokers import GetBroker +import utils -def ConvertTxnListToTxf(txn_list, tax_year, date): +def ConvertTxnListToTxf(txn_list: list[utils.Transaction], tax_year: int, date: str) -> List[str]: lines = [] lines.append('V042') # Version lines.append('Acsv2txf') # Program name/version if date is None: - date = txfDate(datetime.today()) + date = utils.txfDate(datetime.today()) lines.append('D%s' % date) # Export date lines.append('^') for txn in txn_list: @@ -54,19 +58,21 @@ def ConvertTxnListToTxf(txn_list, tax_year, date): return lines -def RunConverter(broker_name, filename, tax_year, date): +def RunConverter(broker_name: str, filename: str, tax_year: int, date: str) -> List[str]: broker = GetBroker(broker_name, filename) txn_list = broker.parseFileToTxnList(filename, tax_year) return ConvertTxnListToTxf(txn_list, tax_year, date) -def GetSummary(broker_name, filename, tax_year): +def GetSummary(broker_name: str, filename: str, tax_year: int) -> str: broker = GetBroker(broker_name, filename) total_cost = Decimal(0) total_sales = Decimal(0) txn_list = broker.parseFileToTxnList(filename, tax_year) for txn in txn_list: + assert txn.costBasis is not None total_cost += txn.costBasis + assert txn.saleProceeds is not None total_sales += txn.saleProceeds return '\n'.join([ @@ -94,15 +100,17 @@ def main(argv): sys.stderr.write('Filename is required; specify with `--file` flag.\n') sys.exit(1) - if not options.year: - options.year = datetime.today().year - 1 + if options.year: + year = int(options.year) + else: + year = datetime.today().year - 1 + utils.Warning(f'Year not specified, defaulting to {year} (last year)\n') output = None if options.out_format == 'summary': - output = GetSummary(options.broker, options.filename, options.year) + output = GetSummary(options.broker, options.filename, year) else: - txf_lines = RunConverter(options.broker, options.filename, options.year, - options.date) + txf_lines = RunConverter(options.broker, options.filename, year, options.date) output = '\n'.join(txf_lines) if options.out_filename: diff --git a/decorators.py b/decorators.py new file mode 100644 index 0000000..d552ef8 --- /dev/null +++ b/decorators.py @@ -0,0 +1,33 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys + + +if (sys.version_info.major, sys.version_info.minor) >= (3, 12): + # This was added in Python 3.12: + # + # * https://docs.python.org/3/library/typing.html#typing.override + # * https://peps.python.org/pep-0698/ + from typing import override +else: + from typing import TypeVar + + _F = TypeVar('_F') + + def override(func: _F) -> _F: + """No-op @override for Python versions prior to 3.12.""" + return func diff --git a/interactive_brokers.py b/interactive_brokers.py index 9f3f717..43c0996 100644 --- a/interactive_brokers.py +++ b/interactive_brokers.py @@ -18,22 +18,30 @@ * dividends """ +from __future__ import annotations + import csv from datetime import datetime from decimal import Decimal +from typing import Optional + +from broker import Broker +from decorators import override import utils FIRST_LINE = 'Title,Worksheet for Form 8949,' -class InteractiveBrokers: +class InteractiveBrokers(Broker): + @classmethod - def name(cls): + @override + def name(cls) -> str: return 'Interactive Brokers' @classmethod - def DetermineEntryCode(cls, part, box): + def DetermineEntryCode(cls, part: int, box: str) -> Optional[int]: if part == 1: if box == 'A': return 321 @@ -51,34 +59,36 @@ def DetermineEntryCode(cls, part, box): return None @classmethod - def TryParseYear(cls, date_str): + def TryParseYear(cls, date_str: str) -> Optional[int]: try: return datetime.strptime(date_str, '%m/%d/%Y').year except ValueError: return None @classmethod - def ParseDollarValue(cls, value): + def ParseDollarValue(cls, value: str) -> Decimal: return Decimal(value.replace(',', '').replace('"', '')) @classmethod - def isFileForBroker(cls, filename): + @override + def isFileForBroker(cls, filename: str) -> bool: with open(filename) as f: first_line = f.readline() return first_line.find(FIRST_LINE) == 0 @classmethod - def parseFileToTxnList(cls, filename, tax_year): + @override + def parseFileToTxnList(cls, filename: str, tax_year: Optional[int]) -> list[utils.Transaction]: with open(filename) as f: # First 2 lines are headers. f.readline() f.readline() txns = csv.reader(f, delimiter=',', quotechar='"') - txn_list = [] - part = None - box = None - entry_code = None + txn_list: list[utils.Transaction] = [] + part: Optional[int] = None + box: Optional[str] = None + entry_code: Optional[int] = None for row in txns: if row[0] == 'Part' and len(row) == 3: diff --git a/tdameritrade.py b/tdameritrade.py index af072aa..62a3f1a 100644 --- a/tdameritrade.py +++ b/tdameritrade.py @@ -23,10 +23,16 @@ * partial lot sales """ +from __future__ import annotations + import csv from datetime import datetime from decimal import Decimal import re +from typing import Optional + +from broker import Broker +from decorators import override import utils @@ -38,63 +44,67 @@ TRANSACTION_TYPE = 'Trans type' -class TDAmeritrade: +class TDAmeritrade(Broker): + @classmethod - def name(cls): + @override + def name(cls) -> str: return "TD Ameritrade" @classmethod - def buyDate(cls, dict): + def buyDate(cls, txn: dict[str, str]) -> datetime: """Returns date of transaction as datetime object.""" # Our input date format is MM/DD/YYYY. - return datetime.strptime(dict['Open date'], '%m/%d/%Y') + return datetime.strptime(txn['Open date'], '%m/%d/%Y') @classmethod - def sellDate(cls, dict): + def sellDate(cls, txn: dict[str, str]) -> datetime: """Returns date of transaction as datetime object.""" # Our input date format is MM/DD/YYYY. - return datetime.strptime(dict['Close date'], '%m/%d/%Y') + return datetime.strptime(txn['Close date'], '%m/%d/%Y') @classmethod - def isShortTerm(cls, dict): - return dict['Term'] == 'Short-term' + def isShortTerm(cls, txn: dict[str, str]) -> bool: + return txn['Term'] == 'Short-term' @classmethod - def symbol(cls, dict): - match = re.match('^.*\((.*)\)$', dict['Security']) + def symbol(cls, txn: dict[str, str]) -> str: + match = re.match(r'^.*\((.*)\)$', txn['Security']) if match: return match.group(1) else: - raise Exception('Security symbol not found in: %s' % dict) + raise Exception('Security symbol not found in: %s' % txn) @classmethod - def numShares(cls, dict): - return Decimal(dict['Qty']) + def numShares(cls, txn: dict[str, str]) -> Decimal: + return Decimal(txn['Qty']) @classmethod - def costBasis(cls, dict): + def costBasis(cls, txn: dict[str, str]) -> Decimal: # Proceeds amount may include commas as thousand separators, which # Decimal does not handle. - return Decimal(dict['Adj cost'].replace(',', '')) + return Decimal(txn['Adj cost'].replace(',', '')) @classmethod - def saleProceeds(cls, dict): + def saleProceeds(cls, txn: dict[str, str]) -> Decimal: # Proceeds amount may include commas as thousand separators, which # Decimal does not handle. - return Decimal(dict['Adj proceeds'].replace(',', '')) + return Decimal(txn['Adj proceeds'].replace(',', '')) @classmethod - def isFileForBroker(cls, filename): + @override + def isFileForBroker(cls, filename: str) -> bool: with open(filename) as f: first_line = f.readline() return first_line == FIRST_LINE @classmethod - def parseFileToTxnList(cls, filename, tax_year): + @override + def parseFileToTxnList(cls, filename: str, tax_year: Optional[int]) -> list[utils.Transaction]: txns = csv.reader(open(filename), delimiter=',', quotechar='"') line_num = 0 - txn_list = [] - names = None + txn_list: list[utils.Transaction] = [] + names: list[str] = [] for row in txns: line_num = line_num + 1 if line_num == 1: @@ -113,14 +123,14 @@ def parseFileToTxnList(cls, filename, tax_year): curr_txn = utils.Transaction() curr_txn.desc = '%s shares %s' % ( cls.numShares(txn_dict), cls.symbol(txn_dict)) - curr_txn.buyDate = cls.buyDate(txn_dict) - curr_txn.buyDateStr = utils.txfDate(curr_txn.buyDate) + buyDate = cls.buyDate(txn_dict) + curr_txn.buyDateStr = utils.txfDate(buyDate) curr_txn.costBasis = cls.costBasis(txn_dict) - curr_txn.sellDate = cls.sellDate(txn_dict) - curr_txn.sellDateStr = utils.txfDate(curr_txn.sellDate) + sellDate = cls.sellDate(txn_dict) + curr_txn.sellDateStr = utils.txfDate(sellDate) curr_txn.saleProceeds = cls.saleProceeds(txn_dict) - assert curr_txn.sellDate >= curr_txn.buyDate + assert sellDate >= buyDate, f'Sell date ({sellDate}) must be on or after buy date ({buyDate})' if cls.isShortTerm(txn_dict): # TODO(mbrukman): assert here that (sellDate - buyDate) <= 1 year curr_txn.entryCode = 321 # "ST gain/loss - security" @@ -128,7 +138,7 @@ def parseFileToTxnList(cls, filename, tax_year): # TODO(mbrukman): assert here that (sellDate - buyDate) > 1 year curr_txn.entryCode = 323 # "LT gain/loss - security" - if tax_year and curr_txn.sellDate.year != tax_year: + if tax_year and sellDate.year != tax_year: utils.Warning('ignoring txn: "%s" (line %d) as the sale is not from %d\n' % (curr_txn.desc, line_num, tax_year)) continue diff --git a/update_testdata.py b/update_testdata.py index 6c22514..f9a820f 100755 --- a/update_testdata.py +++ b/update_testdata.py @@ -20,17 +20,21 @@ strings. """ +from __future__ import annotations + import glob import os import sys +from typing import Type +from broker import Broker from brokers import DetectBroker # If your broker parser does not support isFileForBroker, you'll need -# need to add an entry here. -# Example: -# 'vanguard.csv' : Vanguard -BROKER_CSV = {} +# need to add an entry here. Example entry: +# +# 'my_broker.csv' : MyBroker +BROKER_CSV: dict[str, Type[Broker]] = {} def main(argv): diff --git a/utils.py b/utils.py index 25b2eec..a2a971a 100644 --- a/utils.py +++ b/utils.py @@ -14,7 +14,12 @@ """Utility methods/classes.""" +from __future__ import annotations + +from datetime import datetime +from decimal import Decimal import sys +from typing import Optional class Error(Exception): @@ -25,7 +30,7 @@ class ValueError(Error): def __init__(self, msg): self.msg = msg - def __str__(self): + def __str__(self) -> str: return self.msg @@ -33,25 +38,27 @@ class UnimplementedError(Error): def __init__(self, msg): self.msg = msg - def __str__(self): + def __str__(self) -> str: return self.msg -def Warning(str): - sys.stderr.write('warning: %s' % str) +def Warning(msg: str): + sys.stderr.write('warning: %s' % msg) + +class Transaction: -class Transaction(object): - def __init__(self): - self.desc = None - self.buyDateStr = None - self.costBasis = None - self.sellDateStr = None - self.saleProceeds = None - self.adjustment = None - self.entryCode = None + desc: Optional[str] = None + buyDate: Optional[datetime] = None + buyDateStr: Optional[str] = None + costBasis: Optional[Decimal] = None + sellDate: Optional[datetime] = None + sellDateStr: Optional[str] = None + saleProceeds: Optional[Decimal] = None + adjustment: Optional[Decimal] = None + entryCode: Optional[int] = None - def __str__(self): + def __str__(self) -> str: data = [ ('desc:%s', self.desc), ('buyDateStr:%s', self.buyDateStr), @@ -65,12 +72,12 @@ def __str__(self): return ','.join(formatted_data) -def txfDate(date): +def txfDate(date: datetime) -> str: """Returns a date string in the TXF format, which is MM/DD/YYYY.""" return date.strftime('%m/%d/%Y') -def isLongTerm(buy_date, sell_date): +def isLongTerm(buy_date: datetime, sell_date: datetime) -> bool: # To handle leap years, cannot use a standard number of days, i.e.: # sell_date - buy_date > timedelta(days=365) # - doesn't work for leap years diff --git a/vanguard.py b/vanguard.py index 010a953..93ac15a 100644 --- a/vanguard.py +++ b/vanguard.py @@ -22,9 +22,15 @@ * partial lot sales """ +from __future__ import annotations + import csv from datetime import datetime from decimal import Decimal +from typing import Optional + +from broker import Broker +from decorators import override import utils @@ -33,64 +39,68 @@ '"Principal Amount"', '"Net Amount"\n']) -class Vanguard: +class Vanguard(Broker): + @classmethod - def name(cls): + @override + def name(cls) -> str: return 'Vanguard' @classmethod - def isBuy(cls, dict): - return dict['Transaction Type'] == 'Buy' + def isBuy(cls, txn: dict[str, str]) -> bool: + return txn['Transaction Type'] == 'Buy' @classmethod - def isSell(cls, dict): - return dict['Transaction Type'] == 'Sell' + def isSell(cls, txn: dict[str, str]) -> bool: + return txn['Transaction Type'] == 'Sell' @classmethod - def date(cls, dict): + def date(cls, txn: dict[str, str]) -> datetime: """Returns date of transaction as datetime object.""" # Our input date format is YYYY/MM/DD. - return datetime.strptime(dict['Trade Date'], '%Y-%m-%d') + return datetime.strptime(txn['Trade Date'], '%Y-%m-%d') @classmethod - def symbol(cls, dict): - return dict['Symbol'] + def symbol(cls, txn: dict[str, str]) -> str: + return txn['Symbol'] @classmethod - def investmentName(cls, dict): - return dict['Investment Name'] + def investmentName(cls, txn: dict[str, str]) -> str: + return txn['Investment Name'] @classmethod - def numShares(cls, dict): - shares = int(dict['Shares']) - if cls.isSell(dict): + def numShares(cls, txn: dict[str, str]) -> int: + shares = int(txn['Shares']) + if cls.isSell(txn): return shares * -1 else: return shares @classmethod - def netAmount(cls, dict): - amount = Decimal(dict['Net Amount']) - if cls.isBuy(dict): + def netAmount(cls, txn: dict[str, str]) -> Decimal: + amount = Decimal(txn['Net Amount']) + if cls.isBuy(txn): return amount * -1 else: return amount @classmethod - def isFileForBroker(cls, filename): + @override + def isFileForBroker(cls, filename: str) -> bool: with open(filename) as f: first_line = f.readline() return first_line == FIRST_LINE @classmethod - def parseFileToTxnList(cls, filename, tax_year): + @override + def parseFileToTxnList(cls, filename: str, tax_year: Optional[int]) -> list[utils.Transaction]: txns = csv.reader(open(filename), delimiter=',', quotechar='"') - row_num = 0 - txn_list = [] - names = None - curr_txn = None - buy = None - sell = None + row_num: int = 0 + txn_list: list[utils.Transaction] = [] + names: list[str] = [] + curr_txn: Optional[utils.Transaction] = None + buy: dict[str, str] = {} + sell: dict[str, str] = {} for row in txns: row_num = row_num + 1 if row_num == 1: @@ -113,21 +123,23 @@ def parseFileToTxnList(cls, filename, tax_year): sell = txn_dict # Assume that sells follow the buys, so we can attach this sale to the # current buy txn we are processing. + assert curr_txn is not None assert cls.numShares(buy) == cls.numShares(sell) assert cls.symbol(buy) == cls.symbol(sell) assert cls.investmentName(buy) == cls.investmentName(sell) - curr_txn.sellDate = cls.date(sell) - curr_txn.sellDateStr = utils.txfDate(curr_txn.sellDate) + buyDate: datetime = curr_txn.buyDate + sellDate: datetime = cls.date(sell) + curr_txn.sellDateStr = utils.txfDate(sellDate) curr_txn.saleProceeds = cls.netAmount(sell) - if utils.isLongTerm(curr_txn.buyDate, curr_txn.sellDate): + if utils.isLongTerm(buyDate, sellDate): curr_txn.entryCode = 323 # "LT gain/loss - security" else: curr_txn.entryCode = 321 # "ST gain/loss - security" - assert curr_txn.sellDate >= curr_txn.buyDate - if tax_year and curr_txn.sellDate.year != tax_year: + assert sellDate >= buyDate, f'Sell date ({sellDate}) must be on or after buy date ({buyDate})' + if tax_year and sellDate.year != tax_year: utils.Warning('ignoring txn: "%s" as the sale is not from %d\n' % (curr_txn.desc, tax_year)) continue @@ -135,8 +147,8 @@ def parseFileToTxnList(cls, filename, tax_year): txn_list.append(curr_txn) # Clear both the buy and the sell as we have matched them up. - buy = None - sell = None + buy = {} + sell = {} curr_txn = None return txn_list