Skip to content

Commit

Permalink
fixing type hints 2
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Feb 25, 2024
1 parent 63ac60d commit c3d9c6b
Show file tree
Hide file tree
Showing 12 changed files with 34 additions and 30 deletions.
7 changes: 3 additions & 4 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,12 +626,11 @@ def _no_data_sets(

except ValueError as ex:
# Clarify common error with stratification for multioutput tasks
if "least populated class" in str(ex) and isinstance(y, pd.DataFrame):
if isinstance(y, pd.DataFrame):
raise ValueError(
"Stratification for multioutput tasks is applied over all target "
"columns, which results in a least populated class that has only "
"one member. Either select only one column to stratify over, or "
"set the parameter stratify=False."
"columns. Either select only one column to stratify over, or set "
"the parameter stratify=False."
) from ex
else:
raise ex
Expand Down
5 changes: 3 additions & 2 deletions atom/basetransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from logging import DEBUG, FileHandler, Formatter, Logger, getLogger
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any, Literal, TypeVar, overload
from typing import Literal, TypeVar, overload

import joblib
import mlflow
Expand All @@ -33,7 +33,8 @@
from atom.utils.types import (
Backend, Bool, Engine, EngineDataOptions, EngineEstimatorOptions,
EngineTuple, Estimator, FeatureNamesOut, Int, IntLargerEqualZero, Pandas,
Severity, Verbose, Warnings, XSelector, YSelector, bool_t, int_t, YReturn, XReturn
Severity, Verbose, Warnings, XReturn, XSelector, YReturn, YSelector,
bool_t, int_t,
)
from atom.utils.utils import (
check_dependency, crash, lst, make_sklearn, to_df, to_tabular,
Expand Down
2 changes: 1 addition & 1 deletion atom/data/dataengines.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import modin.pandas as md
import polars as pl
import pyarrow as pa
import pyspark.sql as psql
import pyspark.pandas as ps
import pyspark.sql as psql


class DataEngine(metaclass=ABCMeta):
Expand Down
8 changes: 4 additions & 4 deletions atom/data_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
from collections import defaultdict
from collections.abc import Hashable
from typing import Any, Literal, TypeVar, cast, overload
from typing import Any, Literal, TypeVar, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -52,9 +52,9 @@
Bins, Bool, CategoricalStrats, DiscretizerStrats, Engine,
EngineDataOptions, EngineTuple, Estimator, FloatLargerZero, Int,
IntLargerEqualZero, IntLargerTwo, IntLargerZero, NJobs, NormalizerStrats,
NumericalStrats, Pandas, Predictor, PrunerStrats, Scalar, ScalerStrats,
SeasonalityModels, Sequence, Transformer, Verbose, XConstructor,
YConstructor, sequence_t, XReturn, YReturn,
NumericalStrats, Predictor, PrunerStrats, Scalar, ScalerStrats,
SeasonalityModels, Sequence, Transformer, Verbose, XConstructor, XReturn,
YConstructor, YReturn, sequence_t,
)
from atom.utils.utils import (
Goal, check_is_fitted, get_col_names, get_col_order, get_cols, it, lst,
Expand Down
10 changes: 5 additions & 5 deletions atom/feature_engineering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from collections.abc import Hashable
from random import sample
from typing import Any, Literal
from typing import Any, Literal, cast

import featuretools as ft
import numpy as np
Expand All @@ -36,7 +36,7 @@
Bool, Engine, FeatureSelectionSolvers, FeatureSelectionStrats,
FloatLargerEqualZero, FloatLargerZero, FloatZeroToOneInc,
IntLargerEqualZero, IntLargerZero, NJobs, Operators, Scalar, Sequence,
Verbose, XConstructor, YConstructor, XReturn
Verbose, XConstructor, XReturn, YConstructor,
)
from atom.utils.utils import (
Goal, Task, check_is_fitted, check_scaling, get_custom_scorer, is_sparse,
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def objective_function(model, X_train, y_train, X_valid, y_valid, scoring):
Xt = to_df(X)
yt = to_tabular(y, index=Xt.index)

if yt is None and self.strategy != "pca":
if yt is None and self.strategy not in ("pca", "sfm", None):
raise ValueError(
"Invalid value for the y parameter. Value cannot "
f"be None for strategy='{self.strategy}'."
Expand Down Expand Up @@ -1248,7 +1248,7 @@ def objective_function(model, X_train, y_train, X_valid, y_valid, scoring):
# PCA requires the features to be scaled
if not check_scaling(Xt):
self.scaler_ = Scaler(device=self.device, engine=self.engine)
Xt = self.scaler_.fit_transform(Xt)
Xt = cast(pd.DataFrame, self.scaler_.fit_transform(Xt))

estimator = self._get_est_class("PCA", "decomposition")
solver_param = "svd_solver"
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def transform(self, X: XConstructor, y: YConstructor | None = None) -> XReturn:

if self.scaler_:
self._log(" --> Scaling features...", 2)
Xt = self.scaler_.transform(Xt)
Xt = cast(pd.DataFrame, self.scaler_.transform(Xt))

Xt = self._estimator.transform(Xt).iloc[:, :self._estimator._comps]

Expand Down
2 changes: 1 addition & 1 deletion atom/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from atom.data_cleaning import TransformerMixin
from atom.utils.types import (
Bool, Engine, FloatLargerZero, Sequence, VectorizerStarts, Verbose,
XConstructor, YConstructor, bool_t, XReturn
XConstructor, XReturn, YConstructor, bool_t,
)
from atom.utils.utils import (
check_is_fitted, check_nltk_module, get_corpus, is_sparse, merge, to_df,
Expand Down
7 changes: 4 additions & 3 deletions atom/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@

from atom.utils.types import (
Bool, EngineDataOptions, EngineTuple, Estimator, FHConstructor, Float,
Pandas, Scalar, Sequence, Verbose, XConstructor, YConstructor, YReturn, XReturn
Pandas, Scalar, Sequence, Verbose, XConstructor, XReturn, YConstructor,
YReturn,
)
from atom.utils.utils import (
NotFittedError, adjust, check_is_fitted, fit_one, fit_transform_one,
transform_one, variable_return, to_df, to_tabular
NotFittedError, adjust, check_is_fitted, fit_one, fit_transform_one, to_df,
to_tabular, transform_one, variable_return,
)


Expand Down
2 changes: 1 addition & 1 deletion atom/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

import os
from collections.abc import Callable, Hashable, Iterable, Iterator
from collections.abc import Callable, Hashable, Iterator
from importlib.util import find_spec
from typing import (
TYPE_CHECKING, Annotated, Any, Literal, NamedTuple, SupportsIndex,
Expand Down
12 changes: 7 additions & 5 deletions atom/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from inspect import Parameter, signature
from itertools import cycle
from types import GeneratorType, MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload, cast
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload

import numpy as np
import pandas as pd
Expand All @@ -32,6 +32,7 @@
from pandas._libs.missing import NAType
from pandas._typing import Axes, Dtype
from pandas.api.types import is_numeric_dtype
from pandas.core.generic import NDFrame
from sklearn.base import BaseEstimator
from sklearn.base import OneToOneFeatureMixin as FMixin
from sklearn.metrics import (
Expand All @@ -46,9 +47,9 @@
Bool, EngineDataOptions, EngineTuple, Estimator, FeatureNamesOut, Float,
IndexSelector, Int, IntLargerEqualZero, MetricFunction, Model, Pandas,
Predictor, Scalar, Scorer, Segment, Sequence, SPTuple, Transformer,
Verbose, XConstructor, YConstructor, int_t, segment_t, sequence_t, XReturn, YReturn
Verbose, XConstructor, XReturn, YConstructor, YReturn, int_t, segment_t,
sequence_t,
)
from pandas.core.generic import NDFrame


if TYPE_CHECKING:
Expand Down Expand Up @@ -2180,7 +2181,7 @@ def name_cols(

# If columns were added or removed
temp_cols = []
for i, (name, column) in enumerate(df.items()):
for i, column in enumerate(get_cols(df)):
# equal_nan=True fails for non-numeric dtypes
mask = original_df.apply( # type: ignore[type-var]
lambda c: np.array_equal(
Expand Down Expand Up @@ -2483,7 +2484,8 @@ def prepare_df(out: XConstructor, og: pd.DataFrame) -> pd.DataFrame:
elif "X" not in params:
return X, y # If y is None and no X in transformer, skip the transformer

out: YConstructor | tuple[XConstructor, YConstructor] = getattr(transformer, method)(**kwargs, **transform_params)
caller = getattr(transformer, method)
out: YConstructor | tuple[XConstructor, YConstructor] = caller(**kwargs, **transform_params)

# Transform can return X, y or both
X_new: pd.DataFrame | None
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

from _pytest.monkeypatch import MonkeyPatch

from atom.utils.types import DataFrame, Pandas, Sequence, XSelector, XConstructor
from atom.utils.types import (
DataFrame, Pandas, Sequence, XConstructor,
)


class DummyTransformer(TransformerMixin, BaseEstimator):
Expand Down
1 change: 0 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from unittest.mock import MagicMock, patch

import dask.dataframe as dd
import modin.pandas as md
import numpy as np
import pandas as pd
import polars as pl
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ def test_missing_values_are_propagated():
def test_unknown_classes_are_imputed():
"""Assert that unknown classes are imputed."""
encoder = Encoder()
encoder.fit(["a", "b", "b", "a"])
assert encoder.transform(["c"]).iloc[0, 0] == -1.0
encoder.fit([["a"], ["b"], ["b"], ["a"]])
assert encoder.transform([["c"]]).iloc[0, 0] == -1.0


def test_ordinal_encoder():
Expand Down

0 comments on commit c3d9c6b

Please sign in to comment.