Skip to content

Commit

Permalink
set_output=pandas
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Nov 25, 2023
1 parent 35687c1 commit dfbb0d9
Show file tree
Hide file tree
Showing 37 changed files with 771 additions and 748 deletions.
9 changes: 8 additions & 1 deletion atom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
"""
Automated Tool for Optimized Modeling (ATOM)
Author: Mavs
Description: Import API and version.
Description: Import API and version, and set configuration.
"""

import pandas as pd
import sklearn

from atom.api import ATOMClassifier, ATOMForecaster, ATOMModel, ATOMRegressor
from atom.utils.constants import __version__


pd.options.mode.copy_on_write = True
sklearn.set_config(transform_output="pandas")
2 changes: 1 addition & 1 deletion atom/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

from logging import Logger
from pathlib import Path
from typing import TypeVar

from beartype import beartype
from beartype.typing import TypeVar
from joblib.memory import Memory
from sklearn.base import clone

Expand Down
17 changes: 11 additions & 6 deletions atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
import os
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from collections.abc import Callable, Iterator
from copy import deepcopy
from logging import Logger
from pathlib import Path
from platform import machine, platform, python_build, python_version
from types import MappingProxyType
from typing import Any, Literal, TypeVar

import dill as pickle
import numpy as np
import pandas as pd
from beartype import beartype
from beartype.typing import Any, Callable, Iterator, Literal, Sequence, TypeVar
from joblib.memory import Memory
from pandas._typing import DtypeObj

Check notice on line 28 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _typing of a class
from scipy import stats
Expand All @@ -34,7 +35,7 @@
from atom.branch import Branch, BranchManager
from atom.data_cleaning import (
Balancer, Cleaner, Discretizer, Encoder, Imputer, Normalizer, Pruner,
Scaler,
Scaler, TransformerMixin,
)
from atom.feature_engineering import (
FeatureExtractor, FeatureGenerator, FeatureGrouper, FeatureSelector,
Expand All @@ -55,8 +56,9 @@
FloatZeroToOneInc, Index, IndexSelector, Int, IntLargerEqualZero,
IntLargerTwo, IntLargerZero, MetricConstructor, ModelsConstructor, NItems,
NJobs, NormalizerStrats, NumericalStrats, Operators, Pandas, PrunerStrats,
RowSelector, Scalar, ScalerStrats, Series, TargetSelector, Transformer,
TSIndex, VectorizerStarts, Verbose, Warnings, XSelector, YSelector,
RowSelector, Scalar, ScalerStrats, Sequence, Series, TargetSelector,
Transformer, VectorizerStarts, Verbose, Warnings, XSelector, YSelector,
sequence_t, tsindex_t,
)
from atom.utils.utils import (
ClassMap, DataConfig, DataContainer, Goal, adjust_verbosity, bk,
Expand Down Expand Up @@ -540,7 +542,7 @@ def eda(

if isinstance(rows, str):
rows_c = [(self.branch._get_rows(rows), rows)]

Check notice on line 544 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class
elif isinstance(rows, Sequence):
elif isinstance(rows, sequence_t):
rows_c = [(self.branch._get_rows(r), r) for r in rows]

Check notice on line 546 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class
elif isinstance(rows, dict):
rows_c = [(self.branch._get_rows(v), k) for k, v in rows.items()]

Check notice on line 548 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class
Expand Down Expand Up @@ -937,7 +939,7 @@ def stats(self, _vb: Int = -2, /):
for set_ in ("train", "test", "holdout"):
if (data := getattr(self, set_)) is not None:
self._log(f"{set_.capitalize()} set size: {len(data)}", _vb)
if isinstance(self.branch.train.index, TSIndex):
if isinstance(self.branch.train.index, tsindex_t):
self._log(f" --> From: {min(data.index)} To: {max(data.index)}", _vb)

self._log("-" * 37, _vb)
Expand Down Expand Up @@ -1147,6 +1149,9 @@ def _add_transformer(
)
transformer_c._cols = inc

# Add custom cloning method to keep internal attrs
transformer_c.__class__.__sklearn_clone__ = TransformerMixin.__sklearn_clone__

if hasattr(transformer_c, "fit"):
if not transformer_c.__module__.startswith("atom"):
self._log(f"Fitting {transformer_c.__class__.__name__}...", 1)
Expand Down
20 changes: 10 additions & 10 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from importlib import import_module
from logging import Logger
from pathlib import Path
from typing import overload
from typing import Any, Literal, overload
from unittest.mock import patch

import dill as pickle
Expand All @@ -30,7 +30,6 @@
from beartype.roar import (
BeartypeCallHintParamViolation, BeartypeCallHintReturnViolation,
)
from beartype.typing import Any, Literal
from joblib.memory import Memory
from joblib.parallel import Parallel, delayed
from mlflow.data import from_pandas
Expand Down Expand Up @@ -70,7 +69,8 @@
HT, Backend, Bool, DataFrame, Engine, FHSelector, Float, FloatZeroToOneExc,
Int, IntLargerEqualZero, MetricConstructor, NJobs, Pandas,
PredictionMethod, Predictor, RowSelector, Scalar, Scorer, Sequence, Stages,
TargetSelector, Verbose, Warnings, XSelector, YSelector,
TargetSelector, Verbose, Warnings, XSelector, YSelector, dataframe_t,
float_t, int_t,
)
from atom.utils.utils import (
ClassMap, DataConfig, Goal, PlotCallback, ShapExplanation, Task,
Expand Down Expand Up @@ -281,7 +281,7 @@ def __contains__(self, item: str) -> bool:
return item in self.dataset

def __getitem__(self, item: Int | str | list) -> Pandas:
if isinstance(item, Int):
if isinstance(item, int_t):
return self.dataset[self.columns[item]]
else:
return self.dataset[item] # Get a subset of the dataset
Expand Down Expand Up @@ -431,7 +431,7 @@ def _get_est(self, params: dict[str, Any]) -> Predictor:
Estimator instance.
"""
# Separate the parameters for the estimator from those in sub-estimators
# Separate the params for the estimator from those in sub-estimators
base_params, sub_params = {}, {}
for name, value in params.items():
if "__" not in name:
Expand Down Expand Up @@ -818,7 +818,7 @@ def _get_score(
else:
if threshold and self.task.is_binary and hasattr(self, "predict_proba"):
y_true, y_pred = self._get_pred(rows, attr="predict_proba")
if isinstance(y_pred, DataFrame):
if isinstance(y_pred, dataframe_t):
# Update every target column with its corresponding threshold
for i, value in enumerate(threshold):
y_pred.iloc[:, i] = (y_pred.iloc[:, i] > value).astype("int")
Expand Down Expand Up @@ -974,13 +974,13 @@ def fit_model(
# Follow the same stratification strategy as atom
cols = self._config.get_stratify_columns(self.og.train, self.og.y_train)

if isinstance(cv := self._ht["cv"], Int):
if isinstance(cv := self._ht["cv"], int_t):
if self.task.is_forecast:
if cv == 1:
splitter = SingleWindowSplitter(range(1, len(self.og.test)))
else:
splitter = TimeSeriesSplit(n_splits=cv)
elif isinstance(self._ht["cv"], Int):
elif isinstance(self._ht["cv"], int_t):
# We use ShuffleSplit instead of K-fold because it
# works with n_splits=1 and multioutput stratification
if cols is None:
Expand Down Expand Up @@ -1805,7 +1805,7 @@ def inference(*X) -> Scalar | str | list[Scalar | str]:
conv = lambda elem: elem.item() if hasattr(elem, "item") else elem

y_pred = self.inverse_transform(y=self.predict([X], verbose=0), verbose=0)
if isinstance(y_pred, DataFrame):
if isinstance(y_pred, dataframe_t):
return [conv(elem) for elem in y_pred.iloc[0, :]]
else:
return conv(y_pred[0])
Expand Down Expand Up @@ -2028,7 +2028,7 @@ def evaluate(
Scores of the model.
"""
if isinstance(threshold, Float):
if isinstance(threshold, float_t):
threshold_c = [threshold] * self.branch._data.n_cols # Length=n_targets

Check notice on line 2032 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
elif len(threshold) != self.branch._data.n_cols:

Check notice on line 2033 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
raise ValueError(
Expand Down
29 changes: 15 additions & 14 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
import random
import re
from abc import ABCMeta
from collections.abc import Hashable
from copy import deepcopy
from functools import cached_property
from pathlib import Path
from typing import Any

import dill as pickle
import pandas as pd
from beartype import beartype
from beartype.typing import Any, Hashable, Sequence
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.metaestimators import available_if
Expand All @@ -34,7 +35,7 @@
from atom.utils.types import (
Bool, DataFrame, FloatZeroToOneExc, Int, MetricConstructor, Model,
ModelSelector, ModelsSelector, Pandas, RowSelector, Scalar, Segment,
Series, YSelector,
Sequence, Series, YSelector, dataframe_t, int_t, segment_t, sequence_t,
)
from atom.utils.utils import (
ClassMap, DataContainer, Task, bk, check_is_fitted, composed, crash,
Expand Down Expand Up @@ -108,7 +109,7 @@ def __getitem__(self, item: Int | str | list) -> Any:
"This instance has no dataset annexed to it. "
"Use the run method before calling __getitem__."
)
elif isinstance(item, Int):
elif isinstance(item, int_t):
return self.dataset[self.columns[item]]
elif isinstance(item, str):
if item in self._branches:
Expand Down Expand Up @@ -288,7 +289,7 @@ def _set_index(self, df: DataFrame, y: Pandas | None) -> DataFrame:
pass
elif self._config.index is False:
df = df.reset_index(drop=True)
elif isinstance(self._config.index, Int):
elif isinstance(self._config.index, int_t):
if -df.shape[1] <= self._config.index <= df.shape[1]:
df = df.set_index(df.columns[int(self._config.index)], drop=True)
else:
Expand Down Expand Up @@ -414,7 +415,7 @@ def _no_data_sets(
)
data = _subsample(data)

if isinstance(self._config.index, Sequence):
if isinstance(self._config.index, sequence_t):
if len(self._config.index) != len(data):
raise IndexError(
"Invalid value for the index parameter. Length of "
Expand Down Expand Up @@ -485,7 +486,7 @@ def _no_data_sets(

except ValueError as ex:
# Clarify common error with stratification for multioutput tasks
if "least populated class" in str(ex) and isinstance(y, DataFrame):
if "least populated class" in str(ex) and isinstance(y, dataframe_t):
raise ValueError(
"Stratification for multioutput tasks is applied over all target "
"columns, which results in a least populated class that has only "
Expand Down Expand Up @@ -571,7 +572,7 @@ def _has_data_sets(
)

# If the index is a sequence, assign it before shuffling
if isinstance(self._config.index, Sequence):
if isinstance(self._config.index, sequence_t):
len_data = len(train) + len(test)
if holdout is not None:
len_data += len(holdout)
Expand Down Expand Up @@ -604,7 +605,7 @@ def _has_data_sets(
# Process input arrays ===================================== >>

if len(arrays) == 0:
if self._goal.name == "forecast" and not isinstance(y, Int | str):
if self._goal.name == "forecast" and not isinstance(y, (*int_t, str)):
# arrays=() and y=y for forecasting
sets = _no_data_sets(*self._check_input(y=y))
elif not self.branch._container:

Check notice on line 611 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _container of a class
Expand All @@ -625,7 +626,7 @@ def _has_data_sets(
X_train, y_train = self._check_input(arrays[0][0], arrays[0][1])

Check notice on line 626 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
X_test, y_test = self._check_input(arrays[1][0], arrays[1][1])

Check notice on line 627 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
sets = _has_data_sets(X_train, y_train, X_test, y_test)
elif isinstance(arrays[1], Int | str) or n_cols(arrays[1]) == 1:
elif isinstance(arrays[1], (*int_t, str)) or n_cols(arrays[1]) == 1:
if not self._goal.name == "forecast":
# arrays=(X, y)
sets = _no_data_sets(*self._check_input(arrays[0], arrays[1]))
Expand Down Expand Up @@ -729,11 +730,11 @@ def _get_models(
exc: list[Model] = []
if models is None:
inc = self._models.values()
elif isinstance(models, Segment):
elif isinstance(models, segment_t):
inc = get_segment(self._models, models)
else:
for model in lst(models):
if isinstance(model, Int):
if isinstance(model, int_t):
try:
inc.append(self._models[model])
except KeyError:
Expand Down Expand Up @@ -788,7 +789,7 @@ def _get_models(

return list(dict.fromkeys(inc)) # Avoid duplicates

def _delete_models(self, models: str | Sequence):
def _delete_models(self, models: str | Model | Sequence[str | Model]):
"""Delete models.
Remove models from the instance. All attributes are deleted
Expand All @@ -797,7 +798,7 @@ def _delete_models(self, models: str | Sequence):
Parameters
----------
models: str or sequence
models: str, Model or sequence
Model(s) to delete.
"""
Expand Down Expand Up @@ -1239,7 +1240,7 @@ def stacking(
f"{model.fullname} can not perform {self.task} tasks."
)

kwargs["final_estimator"] = model._get_est()
kwargs["final_estimator"] = model._get_est({})

Check notice on line 1243 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_est of a class

self._models.append(Stacking(models=models_c, name=name, **kw_model, **kwargs))

Expand Down
6 changes: 3 additions & 3 deletions atom/basetrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
import traceback
from abc import ABCMeta
from datetime import datetime as dt
from typing import Any

import joblib
import mlflow
import numpy as np
import ray
from beartype.typing import Any
from joblib import Parallel, delayed
from optuna import Study, create_study

Expand All @@ -26,7 +26,7 @@
from atom.data_cleaning import BaseTransformer
from atom.models import MODELS, CustomModel
from atom.plots import RunnerPlot
from atom.utils.types import Model, Sequence
from atom.utils.types import Model, sequence_t
from atom.utils.utils import (
ClassMap, DataConfig, Goal, Task, check_dependency, get_custom_scorer, lst,
sign, time_to_str,
Expand Down Expand Up @@ -104,7 +104,7 @@ def _check_param(self, param: str, value: Any) -> dict:
Parameter with model names as keys.
"""
if isinstance(value, Sequence):
if isinstance(value, sequence_t):
if len(value) != len(self._models):
raise ValueError(
f"Invalid value for the {param} parameter. The length "
Expand Down
Loading

0 comments on commit dfbb0d9

Please sign in to comment.