Skip to content

Commit

Permalink
Merge pull request #75 from Promptly-Technologies-LLC/71-lint-with-mypy
Browse files Browse the repository at this point in the history
71 lint with mypy
  • Loading branch information
chriscarrollsmith authored Feb 8, 2025
2 parents 5554261 + e61940c commit 8880816
Show file tree
Hide file tree
Showing 12 changed files with 330 additions and 250 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,9 @@ cython_debug/

# Quarto
/.quarto/

# AI chat histories
.specstory

# Mac
.DS_Store
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ We welcome contributions to improve `imfp`! Here's how you can help:
- Fork and clone the repository and open a terminal in the repository directory
- Install [uv](https://astral.sh/setup-uv/) with `curl -LsSf https://astral.sh/uv/install.sh | sh`
- Install the dependencies with `uv sync`
- Install a git hook to enforce conventional commits with `curl -o- https://raw.githubusercontent.com/tapsellorg/conventional-commits-git-hook/master/scripts/install.sh | sh`
- Install a git hook to enforce conventional commits with `curl -o- https://raw.githubusercontent.com/chriscarrollsmith/conventional-commits-git-hook/master/scripts/install.sh | sh`
- Create a fix, commit it with an ["Angular-style Conventional Commit"](https://www.conventionalcommits.org/en/v1.0.0-beta.4/) message, and push it to your fork
- Open a pull request to our `main` branch

Expand Down
5 changes: 3 additions & 2 deletions experimental/download_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# recommend trying this.

import imfp
from pandas import DataFrame

# Set a custom wait time
_imf_wait_time = 10

# Attempt to download databases sequentially
databases = imfp.imf_databases()
datasets = {"database_names": [], "dataframes": []}
databases: DataFrame = imfp.imf_databases()
datasets: dict[str, list[DataFrame | None]] = {"database_names": [], "dataframes": []}
for database_id in databases["database_id"]:
datasets["database_names"].append(database_id)
try:
Expand Down
5 changes: 3 additions & 2 deletions experimental/download_indicators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This script tries to download whole databases indicator by indicator

import imfp
from pandas import DataFrame

# Examine database list
databases = imfp.imf_databases()
Expand All @@ -11,8 +12,8 @@

# Try to download the database indicator by indicator (Note that some databases don't
# use the 'indicator' parameter, so this won't work with every database)
indicators = imfp.imf_parameters("IFS")["indicator"]
datasets = {"indicator_names": [], "dataframes": []}
indicators: DataFrame = imfp.imf_parameters("IFS")["indicator"]
datasets: dict[str, list[DataFrame | None]] = {"indicator_names": [], "dataframes": []}
for indicator in indicators["input_code"]:
datasets["indicator_names"].append(indicator)
try:
Expand Down
13 changes: 5 additions & 8 deletions imfp/admin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from os import environ
from warnings import warn
from typing import Union
import type_enforced


@type_enforced.Enforcer
def set_imf_app_name(name: str = "imfp") -> None:
"""
Set the IMF Application Name.
Expand All @@ -25,11 +27,8 @@ def set_imf_app_name(name: str = "imfp") -> None:
imf_app_name("my_custom_app_name")
"""

if not isinstance(name, str) or len(name) > 255:
raise ValueError(
"Please provide a valid string as the application "
"name (max length: 255 characters)."
)
if len(name) > 255:
raise ValueError("Application name must be no longer than 255 characters.")

if name == "imfp" or name == "":
warn(
Expand All @@ -52,6 +51,7 @@ def set_imf_app_name(name: str = "imfp") -> None:
return None


@type_enforced.Enforcer
def set_imf_wait_time(wait_time: Union[int, float] = 1.5) -> None:
"""
Set the IMF wait time as an environment variable.
Expand All @@ -63,9 +63,6 @@ def set_imf_wait_time(wait_time: Union[int, float] = 1.5) -> None:
TypeError: If the provided wait_time is not a numeric value (int or float).
ValueError: If the provided wait_time is not greater than 0.
"""
if not isinstance(wait_time, (int, float)):
raise TypeError("Rate limit wait time must be a numeric value (int or float).")

if wait_time >= 0:
environ["IMF_WAIT_TIME"] = str(wait_time)
else:
Expand Down
63 changes: 44 additions & 19 deletions imfp/data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from pandas import DataFrame, Series, concat
from typing import overload, Literal
from warnings import warn
from .utils import _download_parse, _imf_dimensions, _imf_metadata
from urllib.parse import urlencode
from pandas import DataFrame, Series, concat
import type_enforced

from .utils import _download_parse, _imf_dimensions, _imf_metadata

logger = logging.getLogger(__name__)


@type_enforced.Enforcer
def imf_databases(times: int = 3) -> DataFrame:
"""
List IMF database IDs and descriptions
Expand Down Expand Up @@ -45,9 +49,11 @@ def imf_databases(times: int = 3) -> DataFrame:
return database_list


@type_enforced.Enforcer
def imf_parameters(database_id: str, times: int = 2) -> dict[str, DataFrame]:
"""
List input parameters and available parameter values for use in
making API requests from a given IMF database.
Parameters
Expand All @@ -74,9 +80,6 @@ def imf_parameters(database_id: str, times: int = 2) -> dict[str, DataFrame]:
# Commodity Price System database
params = imf_parameters(database_id='PCPS')
"""
if not database_id:
raise ValueError("Must supply database_id. Use imf_databases to find.")

url = "http://dataservices.imf.org/REST/SDMX_JSON.svc/CodeList/"
try:
codelist = _imf_dimensions(database_id, times)
Expand Down Expand Up @@ -124,6 +127,7 @@ def fetch_parameter_data(k, url, times):
return parameter_list


@type_enforced.Enforcer
def imf_parameter_defs(
database_id: str, times: int = 3, inputs_only: bool = True
) -> DataFrame:
Expand Down Expand Up @@ -156,9 +160,6 @@ def imf_parameter_defs(
# the Primary Commodity Price System database
param_defs = imf_parameter_defs(database_id='PCPS')
"""
if not database_id:
raise ValueError("Must supply database_id. Use imf_databases to find.")

try:
parameterlist = _imf_dimensions(database_id, times, inputs_only)[
["parameter", "description"]
Expand All @@ -175,17 +176,48 @@ def imf_parameter_defs(
return parameterlist


@overload
def imf_dataset(
database_id: str,
parameters: dict | None = None,
start_year: int | str | None = None,
end_year: int | str | None = None,
return_raw: bool = False,
print_url: bool = False,
times: int = 3,
include_metadata: Literal[False] = False,
**kwargs,
) -> DataFrame:
...


@overload
def imf_dataset(
database_id: str,
parameters: dict | None = None,
start_year: int | str | None = None,
end_year: int | str | None = None,
return_raw: bool = False,
print_url: bool = False,
times: int = 3,
include_metadata: Literal[True] = True,
**kwargs,
) -> tuple[dict, DataFrame]:
...


@type_enforced.Enforcer
def imf_dataset(
database_id: str,
parameters: dict = None,
start_year: int = None,
end_year: int = None,
parameters: dict | None = None,
start_year: int | str | None = None,
end_year: int | str | None = None,
return_raw: bool = False,
print_url: bool = False,
times: int = 3,
include_metadata: bool = False,
**kwargs,
) -> DataFrame | tuple[DataFrame, DataFrame]:
) -> DataFrame | tuple[dict, DataFrame]:
"""
Download a data series from the IMF.
Expand Down Expand Up @@ -223,13 +255,6 @@ def imf_dataset(
database header, and whose second item is the pandas DataFrame. If
return_raw == True, returns the raw JSON fetched from the API endpoint.
"""

if database_id is None:
raise ValueError("Missing required database_id argument.")

if not isinstance(database_id, str):
raise ValueError("database_id must be a string.")

years = {}
if start_year is not None:
try:
Expand Down
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[mypy]

[mypy-type_enforced.*]
ignore_missing_imports = True
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"requests == 2.32.3",
"statsmodels == 0.14.4",
"tabulate == 0.9.0",
"type-enforced>=1.10.1",
]

[project.urls]
Expand Down Expand Up @@ -112,5 +113,6 @@ dev = [
"jupyter == 1.1.1",
"seaborn == 0.13.2",
"quarto == 0.1.0",
"scikit-learn == 1.5.2"
"scikit-learn == 1.5.2",
"mypy>=1.14.1",
]
13 changes: 5 additions & 8 deletions tests/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ def test_set_imf_app_name():
with pytest.warns(UserWarning):
set_imf_app_name("imfp")

with pytest.raises(ValueError):
with pytest.raises(TypeError):
set_imf_app_name(None)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
set_imf_app_name(float("nan"))
with pytest.raises(TypeError):
set_imf_app_name(["z", "z"])

with pytest.raises(ValueError):
set_imf_app_name("z" * 256)
with pytest.raises(ValueError):
set_imf_app_name(["z", "z"])

set_imf_app_name("imfr_admin_functions_tester")
assert os.getenv("IMF_APP_NAME") == "imfr_admin_functions_tester"
Expand Down Expand Up @@ -57,7 +58,3 @@ def test_set_imf_wait_time(env_setup_teardown):
# Test with invalid input (negative value)
with pytest.raises(ValueError):
set_imf_wait_time(-1)


if __name__ == "__main__":
pytest.main()
4 changes: 0 additions & 4 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,3 @@ def test_imf_dataset_include_metadata(set_options):
assert isinstance(output[0], dict)
assert isinstance(output[1], pd.core.frame.DataFrame)
assert all([not pd.isna(value) for value in output[0].values()])


if __name__ == "__main__":
pytest.main()
5 changes: 1 addition & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
set_imf_wait_time,
)
from imfp.utils import _imf_save_response, _imf_use_cache
from typing import List

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -222,7 +223,3 @@ def test_bad_request(set_options):
_download_parse(URL)

assert "too large" in str(excinfo.value)


if __name__ == "__main__":
pytest.main()
Loading

0 comments on commit 8880816

Please sign in to comment.