diff --git a/cgt_calc/currency_converter.py b/cgt_calc/currency_converter.py index 3d7cd01e..92f0a70b 100644 --- a/cgt_calc/currency_converter.py +++ b/cgt_calc/currency_converter.py @@ -1,6 +1,7 @@ """Convert currencies to GBP using rate history.""" from __future__ import annotations +from collections import defaultdict import csv import datetime from decimal import Decimal @@ -23,18 +24,22 @@ class CurrencyConverter: def __init__( self, exchange_rates_file: str | None = None, - initial_data: dict[str, dict[str, Decimal]] | None = None, + initial_data: dict[datetime.date, dict[str, Decimal]] | None = None, ): """Load data from exchange_rates_file and optionally from initial_data.""" self.exchange_rates_file = exchange_rates_file read_data = self._read_exchange_rates_file(exchange_rates_file) - self.cache = {**read_data, **(initial_data or {})} + self.cache: dict[datetime.date, dict[str, Decimal]] = { + **read_data, + **(initial_data or {}), + } + self.session = requests.Session() @staticmethod def _read_exchange_rates_file( exchange_rates_file: str | None, - ) -> dict[str, dict[str, Decimal]]: - cache: dict[str, dict[str, Decimal]] = {} + ) -> defaultdict[datetime.date, dict[str, Decimal]]: + cache: defaultdict[datetime.date, dict[str, Decimal]] = defaultdict(dict) if exchange_rates_file is None: return cache path = Path(exchange_rates_file) @@ -48,17 +53,16 @@ def _read_exchange_rates_file( if sorted(EXCHANGE_RATES_HEADER) != sorted(line.keys()): raise ParsingError( exchange_rates_file, - f"invalid columns {line.keys()}," + f"invalid columns {line.keys()}, " f"they should be {EXCHANGE_RATES_HEADER}", ) - if line["month"] not in cache: - cache[line["month"]] = {} - cache[line["month"]][line["currency"]] = Decimal(line["rate"]) + date = datetime.date.fromisoformat(line["month"]) + cache[date][line["currency"]] = Decimal(line["rate"]) return cache @staticmethod def _write_exchange_rates_file( - exchange_rates_file: str | None, data: dict[str, dict[str, Decimal]] + exchange_rates_file: str | None, data: dict[datetime.date, dict[str, Decimal]] ) -> None: if exchange_rates_file is None: return @@ -69,18 +73,30 @@ def _write_exchange_rates_file( for symbol, rate in rates.items() ] writer = csv.writer(fout) - writer.writerows([EXCHANGE_RATES_HEADER] + data_rows) - - def _query_hmrc_api(self, month_str: str) -> None: - url = ( - "http://www.hmrc.gov.uk/softwaredevelopers/rates/" - f"exrates-monthly-{month_str}.xml" - ) - response = requests.get(url, timeout=10) + writer.writerows([EXCHANGE_RATES_HEADER, *data_rows]) + + def _query_hmrc_api(self, date: datetime.date) -> None: + # Pre 2021 we need to use the old HMRC endpoint + if date.year < 2021: + month_str = date.strftime("%m%y") + url = ( + "http://www.hmrc.gov.uk/softwaredevelopers/rates/" + f"exrates-monthly-{month_str}.xml" + ) + else: + month_str = date.strftime("%Y-%m") + url = ( + "https://www.trade-tariff.service.gov.uk/api/v2/" + f"exchange_rates/files/monthly_xml_{month_str}.xml" + ) + + response = self.session.get(url, timeout=10) + if not response.ok: raise ParsingError( url, f"HMRC API returned a {response.status_code} response" ) + tree = ElementTree.fromstring(response.text) rates = { str(getattr(row.find("currencyCode"), "text", None)).upper(): Decimal( @@ -90,18 +106,17 @@ def _query_hmrc_api(self, month_str: str) -> None: } if None in rates or None in rates.values(): raise ParsingError(url, "HMRC API produced invalid/unknown data") - self.cache[month_str] = rates + self.cache[date] = rates self._write_exchange_rates_file(self.exchange_rates_file, self.cache) def currency_to_gbp_rate(self, currency: str, date: datetime.date) -> Decimal: """Get GBP/currency rate at given date.""" assert is_date(date) - month_str = date.strftime("%m%y") - if month_str not in self.cache: - self._query_hmrc_api(month_str) - if currency not in self.cache[month_str]: + if date not in self.cache: + self._query_hmrc_api(date) + if currency not in self.cache[date]: raise ExchangeRateMissingError(currency, date) - return self.cache[month_str][currency] + return self.cache[date][currency] def to_gbp(self, amount: Decimal, currency: str, date: datetime.date) -> Decimal: """Convert amount from given currency to GBP.""" diff --git a/tests/test_calc.py b/tests/test_calc.py index b01ade2c..3ed9066a 100644 --- a/tests/test_calc.py +++ b/tests/test_calc.py @@ -772,14 +772,12 @@ def test_basic( tax_year: int, broker_transactions: list[BrokerTransaction], expected: float, - gbp_prices: dict[str, dict[str, Decimal]] | None, + gbp_prices: dict[datetime.date, dict[str, Decimal]] | None, calculation_log: CalculationLog | None, ) -> None: """Generate basic tests for test data.""" if gbp_prices is None: - gbp_prices = { - t.date.strftime("%m%y"): {"USD": Decimal(1)} for t in broker_transactions - } + gbp_prices = {t.date: {"USD": Decimal(1)} for t in broker_transactions} converter = CurrencyConverter(None, gbp_prices) initial_prices = InitialPrices({}) calculator = CapitalGainsCalculator(tax_year, converter, initial_prices)