Skip to content

Commit

Permalink
add mypy checks
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Sep 23, 2023
1 parent 414c040 commit 44191d8
Show file tree
Hide file tree
Showing 19 changed files with 154 additions and 105 deletions.
15 changes: 7 additions & 8 deletions .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ Do you have a question? Before you do, please read the following guidelines.

### Question or problem?

For quick questions there's no need to open an issue. Check first if the
question isn't already answered on the [FAQ](../faq) section. If not, reach
us through the [discussions](https://github.com/tvdboom/ATOM/discussions)
page or on the [slack](https://join.slack.com/t/atom-alm7229/shared_invite/zt-upd8uc0z-LL63MzBWxFf5tVWOGCBY5g)
channel.
For quick questions, there's no need to open an issue. Check first if the
question isn't already answered in the FAQ section. If not, reach us
through the [discussions](https://github.com/tvdboom/ATOM/discussions) page or on the [slack](https://join.slack.com/t/atom-alm7229/shared_invite/zt-upd8uc0z-LL63MzBWxFf5tVWOGCBY5g) channel.


### Report a bug?
Expand Down Expand Up @@ -99,17 +97,18 @@ review and accept your changes.
* Make sure that your code is properly commented with docstrings and
comments explaining your rationale behind non-obvious coding practices.
* Run [isort](https://pycqa.github.io/isort/): `isort atom tests`.
* Run [flake8](https://github.com/john-hen/Flake8-pyproject): `flake8 --show-source --statistics atom tests`.
* Run [flake8](https://github.com/pycqa/flake8): `flake8 --show-source --statistics atom tests`.
* Run [mypy](https://www.mypy-lang.org/): `mypy atom tests`.

If your contribution requires a new library dependency:

* Double-check that the new dependency is easy to install via pip and Anaconda.
* The library should support Python 3.8 and higher.
* The library should support Python 3.9 and higher.
* Make sure the code works with the latest version of the library.
* Update the dependencies in the documentation.
* Add the library with the minimum required version to `pyproject.toml`.

After submitting your pull request, GitHub will automatically run the tests
on your changes and make sure that the updated code builds successfully.
The checks are run on Python 3.8, 3.9, 3.10 and 3.11, on Ubuntu and Windows.
The checks are run on Python 3.9, 3.10 and 3.11, on Ubuntu and Windows.
We also use services that automatically check code style and test coverage.
16 changes: 15 additions & 1 deletion .github/workflows/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ jobs:
- name: Apply linting
run: flake8 --show-source --statistics atom tests

mypy:
runs-on: ubuntu-latest
steps:
- name: Check out source repository
uses: actions/checkout@v3
- name: Set up Python environment
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: pip install mypy types-requests pandas-stubs
- name: Check type hints
run: mypy atom tests

code-quality-codeql:
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -72,7 +86,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11"]
steps:
- name: Check out source repository
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
--- | ---
**Repository** | [![Project Status: Active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active) [![Conda Recipe](https://img.shields.io/badge/recipe-atom--ml-green.svg)](https://anaconda.org/conda-forge/atom-ml) [![License: MIT](https://img.shields.io/github/license/tvdboom/ATOM)](https://opensource.org/licenses/MIT) [![Downloads](https://static.pepy.tech/badge/atom-ml)](https://pepy.tech/project/atom-ml)
**Release** | [![pdm-managed](https://img.shields.io/badge/pdm-managed-blueviolet)](https://pdm.fming.dev) [![PyPI version](https://img.shields.io/pypi/v/atom-ml)](https://pypi.org/project/atom-ml/) [![Conda Version](https://img.shields.io/conda/vn/conda-forge/atom-ml.svg)](https://anaconda.org/conda-forge/atom-ml) [![DOI](https://zenodo.org/badge/195069958.svg)](https://zenodo.org/badge/latestdoi/195069958)
**Compatibility** | [![Python 3.8\|3.9\|3.10\|3.11](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11-blue?logo=python)](https://www.python.org) [![Conda Platforms](https://img.shields.io/conda/pn/conda-forge/atom-ml.svg)](https://anaconda.org/conda-forge/atom-ml)
**Compatibility** | [![Python 3.9\|3.10\|3.11](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11-blue?logo=python)](https://www.python.org) [![Conda Platforms](https://img.shields.io/conda/pn/conda-forge/atom-ml.svg)](https://anaconda.org/conda-forge/atom-ml)
**Build status** | [![Build Status](https://github.com/tvdboom/ATOM/workflows/ATOM/badge.svg)](https://github.com/tvdboom/ATOM/actions) [![Azure Pipelines](https://dev.azure.com/conda-forge/feedstock-builds/_apis/build/status/atom-ml-feedstock?branchName=master)](https://dev.azure.com/conda-forge/feedstock-builds/_build/latest?definitionId=10822&branchName=master) [![codecov](https://codecov.io/gh/tvdboom/ATOM/branch/master/graph/badge.svg)](https://codecov.io/gh/tvdboom/ATOM)
**Code analysis** | [![PEP8](https://img.shields.io/badge/code%20style-pep8-orange.svg)](https://www.python.org/dev/peps/pep-0008/) [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)

Expand Down
4 changes: 2 additions & 2 deletions atom/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def __init__(
)


@beartype
class ATOMForecaster(BaseTransformer, ATOM):
"""Main class for forecasting tasks.
Expand Down Expand Up @@ -566,7 +567,6 @@ class ATOMForecaster(BaseTransformer, ATOM):
"""

@beartype
def __init__(
self,
*arrays,
Expand Down Expand Up @@ -612,6 +612,7 @@ def __init__(
)


@beartype
class ATOMRegressor(BaseTransformer, ATOM):
"""Main class for regression tasks.
Expand Down Expand Up @@ -812,7 +813,6 @@ class ATOMRegressor(BaseTransformer, ATOM):
"""

@beartype
def __init__(
self,
*arrays,
Expand Down
2 changes: 1 addition & 1 deletion atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def get_data(new_t: str) -> Series:
"""
if pd.api.types.is_sparse(column):
# If already sparse array, cast directly to new sparse type
# If already sparse array, cast directly to a new sparse type
return column.astype(pd.SparseDtype(new_t, column.dtype.fill_value))
else:
if dense2sparse and name not in lst(self.target): # Skip target cols
Expand Down
56 changes: 27 additions & 29 deletions atom/data_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@
TomekLinks,
)
from scipy.stats import zscore
from sklearn.base import BaseEstimator, clone
from sklearn.base import BaseEstimator
from sklearn.impute import KNNImputer

from atom.basetransformer import BaseTransformer
from atom.utils.types import (
Bool, DataFrame, DataFrameTypes, DiscretizerStrats, Engine, Estimator,
Features, Float, Int, NJobs, NumericalStrats, Pandas, PrunerStrats, Scalar,
ScalerStrats, Sequence, SequenceTypes, Series, SeriesTypes, Target,
Verbose,
Transformer, Verbose,
)
from atom.utils.utils import (
CustomDict, bk, check_is_fitted, composed, crash, get_cols, it, lst, merge,
Expand Down Expand Up @@ -832,7 +832,7 @@ def transform(
X = X.drop(name, axis=1)

Check notice on line 832 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
continue

elif any(t in dtype for t in ("object", "category", "string")):
elif dtype in ("object", "category", "string"):
if self.strip_categorical:
# Strip strings from blank spaces
X[name] = column.apply(
Expand Down Expand Up @@ -1429,7 +1429,7 @@ class Encoder(BaseEstimator, TransformerMixin, BaseTransformer):

def __init__(
self,
strategy: str | Estimator = "Target",
strategy: str | Transformer = "Target",
*,
max_onehot: Int | None = 10,
ordinal: dict[str, Sequence] | None = None,
Expand Down Expand Up @@ -1482,7 +1482,7 @@ def fit(self, X: Features, y: Target | None = None) -> Encoder:
self._check_feature_names(X, reset=True)
self._check_n_features(X, reset=True)

self.mapping_ = {}
self.mapping_ = defaultdict(dict)

Check notice on line 1485 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute mapping_ defined outside __init__
self._to_value = defaultdict(list)

Check notice on line 1486 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _to_value defined outside __init__
self._categories = {}

Check notice on line 1487 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _categories defined outside __init__
self._encoders = {}

Check notice on line 1488 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _encoders defined outside __init__
Expand Down Expand Up @@ -1511,20 +1511,14 @@ def fit(self, X: Features, y: Target | None = None) -> Encoder:
f"Invalid value for the strategy parameter, got {self.strategy}. "
f"Choose from: {', '.join(strategies)}."
)
estimator = strategies[self.strategy](
handle_missing="return_nan",
handle_unknown="value",
**self.kwargs,
)
elif not all(hasattr(self.strategy, attr) for attr in ("fit", "transform")):
raise TypeError(
"Invalid type for the strategy parameter. A custom"
"estimator must have a fit and transform method."
)
estimator = strategies[self.strategy]
elif callable(self.strategy):
estimator = self.strategy(**self.kwargs)
else:
estimator = self.strategy
else:
raise ValueError(
f"Invalid value for the strategy parameter, got {self.strategy}. "
"For customs estimators, a class is expected, but got an instance."
)

if self.max_onehot is None:
max_onehot = 0
Expand Down Expand Up @@ -1558,7 +1552,7 @@ def fit(self, X: Features, y: Target | None = None) -> Encoder:
X[name] = column.replace(category, self.value)

# Get the unique categories before fitting
self._categories[name] = column.sort_values().unique().tolist()
self._categories[name] = column.dropna().sort_values().unique().tolist()

# Perform encoding type dependent on number of unique values
ordinal = self.ordinal or {}
Expand Down Expand Up @@ -1598,7 +1592,13 @@ def fit(self, X: Features, y: Target | None = None) -> Encoder:
args = [X[[name]]]
if "y" in sign(estimator.fit):
args.append(bk.DataFrame(y).iloc[:, 0])
self._encoders[name] = clone(estimator).fit(*args)

self._encoders[name] = estimator(
cols=[name],
handle_missing="return_nan",
handle_unknown="value",
**self.kwargs,
).fit(*args)

# Create encoding of unique values for mapping
data = self._encoders[name].transform(
Expand All @@ -1612,7 +1612,6 @@ def fit(self, X: Features, y: Target | None = None) -> Encoder:

# Only mapping 1 - 1 column
if data.shape[1] == 1:
self.mapping_[name] = {}
for idx, value in data[name].items():
self.mapping_[name][idx] = value

Expand Down Expand Up @@ -1642,30 +1641,29 @@ def transform(self, X: Features, y: Target | None = None) -> DataFrame:
self._log("Encoding categorical columns...", 1)

for name, column in X[self._cat_cols].items():
# Convert uncommon classes to "other"
# Convert infrequent classes to value
if self._to_value[name]:
X[name] = column.replace(self._to_value[name], self.value)

n_classes = len(column.unique())
self._log(
f" --> {self._encoders[name].__class__.__name__[:-7]}-encoding "
f"feature {name}. Contains {n_classes} classes.", 2
f"feature {name}. Contains {column.nunique()} classes.", 2
)

# Count the propagated missing values
n_nans = column.isna().sum()
if n_nans:
# Count the propagated missingX[[name]] values
if n_nans := column.isna().sum():
self._log(f" --> Propagating {n_nans} missing values.", 2)

# Get the new encoded columns
# TODO: category_encoders can't handle pd.NA
# https://github.com/scikit-learn-contrib/category_encoders/issues/424
new_cols = self._encoders[name].transform(X[[name]])

# Drop _nan columns (since missing values are propagated)
new_cols = new_cols[[col for col in new_cols if not col.endswith("_nan")]]
new_cols = new_cols.loc[:, ~new_cols.columns.str.endswith("_nan")]

# Check for unknown classes
uc = len([i for i in column.unique() if i not in self._categories[name]])
if uc:
if uc := len(column.dropna()[~column.isin(self._categories[name])]):
self._log(f" --> Handling {uc} unknown classes.", 2)

# Insert the new columns at old location
Expand Down
2 changes: 1 addition & 1 deletion atom/plots/predictionplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def plot_errors(
# Fit the points using linear regression
from atom.models import OrdinaryLeastSquares
model = OrdinaryLeastSquares(goal=self.goal, branches=self._branches)
estimator = model._get_est().fit(pd.DataFrame(y_true), y_pred)
estimator = model._get_est().fit(bk.DataFrame(y_true), y_pred)

Check warning on line 683 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Incorrect call arguments

Unexpected argument

Check warning on line 683 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Incorrect call arguments

Unexpected argument

Check notice on line 683 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_est of a class

fig.add_trace(
self._draw_line(
Expand Down
Loading

0 comments on commit 44191d8

Please sign in to comment.