Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Feb 7, 2024
1 parent 0c7209e commit ff9468e
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 124 deletions.
16 changes: 7 additions & 9 deletions atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from atom.utils.constants import CAT_TYPES, DEFAULT_MISSING, __version__
from atom.utils.types import (
Backend, Bins, Bool, CategoricalStrats, ColumnSelector, DataFrame,
DiscretizerStrats, Engine, Estimator, FeatureNamesOut,
DiscretizerStrats, Engine, EngineTuple, Estimator, FeatureNamesOut,
FeatureSelectionSolvers, FeatureSelectionStrats, FloatLargerEqualZero,
FloatLargerZero, FloatZeroToOneInc, Index, IndexSelector, Int,
IntLargerEqualZero, IntLargerTwo, IntLargerZero, MetricConstructor,
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
)

self._config = DataConfig(
index=index,
index=index is not False,
shuffle=shuffle,
stratify=stratify,
n_rows=n_rows,
Expand All @@ -139,7 +139,7 @@ def __init__(

# Initialize the branch system and fill with data
self._branches = BranchManager(memory=self.memory)
self._branches.fill(*self._get_data(arrays, y=y))
self._branches.fill(*self._get_data(arrays, y=y, index=index))

self.ignore = ignore # type: ignore[assignment]
self.sp = sp # type: ignore[assignment]
Expand All @@ -163,9 +163,9 @@ def __init__(
)
if "cpu" not in self.device.lower():
self._log(f"Device: {self.device}", 1)
if self.engine.data != "pandas":
if self.engine.data != EngineTuple().data:
self._log(f"Data engine: {self.engine.data}", 1)
if self.engine.estimator != "sklearn":
if self.engine.estimator != EngineTuple().estimator:
self._log(f"Estimator engine: {self.engine.estimator}", 1)
if self.backend == "ray" or self.n_jobs > 1:
self._log(f"Parallelization backend: {self.backend}", 1)
Expand Down Expand Up @@ -1232,11 +1232,9 @@ def _add_transformer(
"""
if callable(transformer):
est_class = make_sklearn(transformer, feature_names_out=feature_names_out)
transformer_c = self._inherit(est_class())
transformer_c = self._inherit(transformer(), feature_names_out=feature_names_out)
else:
make_sklearn(transformer.__class__, feature_names_out=feature_names_out)
transformer_c = transformer
transformer_c = make_sklearn(transformer, feature_names_out=feature_names_out)

if any(m.branch is self.branch for m in self._models):
raise PermissionError(
Expand Down
6 changes: 3 additions & 3 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@
ClassMap, DataConfig, Goal, PlotCallback, ShapExplanation, Task,
TrialsCallback, adjust_verbosity, bk, cache, check_dependency, check_empty,
check_scaling, composed, crash, estimator_has_attr, flt, get_cols,
get_custom_scorer, has_task, it, lst, make_sklearn, merge, method_to_log,
rnd, sign, time_to_str, to_pandas,
get_custom_scorer, has_task, it, lst, merge, method_to_log, rnd, sign,
time_to_str, to_pandas,
)


Expand Down Expand Up @@ -326,7 +326,7 @@ def _est_class(self) -> type[Predictor]:
except (ModuleNotFoundError, AttributeError, IndexError):
mod = import_module(module)

return make_sklearn(getattr(mod, est_name))
return getattr(mod, est_name)

@property
def _shap(self) -> ShapExplanation:
Expand Down
151 changes: 77 additions & 74 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
from atom.pipeline import Pipeline
from atom.utils.constants import DF_ATTRS
from atom.utils.types import (
Bool, DataFrame, FloatZeroToOneExc, HarmonicsSelector, Int, IntLargerOne,
MetricConstructor, Model, ModelSelector, ModelsSelector, Pandas,
RowSelector, Seasonality, Segment, Sequence, Series, SPDict, SPTuple,
TargetSelector, YSelector, bool_t, dataframe_t, int_t, segment_t,
Bool, DataFrame, FloatZeroToOneExc, HarmonicsSelector, IndexSelector, Int,
IntLargerOne, MetricConstructor, Model, ModelSelector, ModelsSelector,
Pandas, RowSelector, Seasonality, Segment, Sequence, Series, SPDict,
SPTuple, TargetSelector, YSelector, bool_t, dataframe_t, int_t, segment_t,
sequence_t,
)
from atom.utils.utils import (
Expand Down Expand Up @@ -376,64 +376,12 @@ def get_single_sp(sp: Int | str) -> int:
else:
return flt([get_single_sp(x) for x in lst(sp)])

def _set_index(self, df: DataFrame, y: Pandas | None) -> DataFrame:
"""Assign an index to the dataframe.
Parameters
----------
df: dataframe
Dataset.
y: series, dataframe or None
Target column(s). Used to check that the provided index
is not one of the target columns. If None, the check is
skipped.
Returns
-------
dataframe
Dataset with updated indices.
"""
if self._config.index is True: # True gets caught by isinstance(int)
pass
elif self._config.index is False:
df = df.reset_index(drop=True)
elif isinstance(self._config.index, int_t):
if -df.shape[1] <= self._config.index <= df.shape[1]:
df = df.set_index(df.columns[int(self._config.index)], drop=True)
else:
raise IndexError(
f"Invalid value for the index parameter. Value {self._config.index} "
f"is out of range for a dataset with {df.shape[1]} columns."
)
elif isinstance(self._config.index, str):
if self._config.index in df:
df = df.set_index(self._config.index, drop=True)
else:
raise ValueError(
"Invalid value for the index parameter. "
f"Column {self._config.index} not found in the dataset."
)

if y is not None and df.index.name in (c.name for c in get_cols(y)):
raise ValueError(
"Invalid value for the index parameter. The index column "
f"can not be the same as the target column, got {df.index.name}."
)

if df.index.duplicated().any():
raise ValueError(
"Invalid value for the index parameter. There are duplicate indices "
"in the dataset. Use index=False to reset the index to RangeIndex."
)

return df

def _get_data(
self,
arrays: tuple,
y: YSelector = -1,
*,
index: IndexSelector = False,
) -> tuple[DataContainer, DataFrame | None]:
"""Get data sets from a sequence of indexables.
Expand All @@ -448,6 +396,9 @@ def _get_data(
y: int, str or sequence, default=-1
Transformed target column.
index: bool, int, str or sequence, default=False
Index parameter as provided in constructor.
Returns
-------
DataContainer
Expand Down Expand Up @@ -488,6 +439,60 @@ def _subsample(df: DataFrame) -> DataFrame:
else:
return df.iloc[sorted(random.sample(range(len(df)), k=n_rows))]

def _set_index(df: DataFrame, y: Pandas | None) -> DataFrame:

Check notice on line 442 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'y' from outer scope
"""Assign an index to the dataframe.
Parameters
----------
df: dataframe
Dataset.
y: series, dataframe or None
Target column(s). Used to check that the provided index
is not one of the target columns. If None, the check is
skipped.
Returns
-------
dataframe
Dataset with updated indices.
"""
if index is True: # True gets caught by isinstance(int)
pass
elif index is False:
df = df.reset_index(drop=True)
elif isinstance(index, int_t):
if -df.shape[1] <= index <= df.shape[1]:
df = df.set_index(df.columns[int(index)], drop=True)
else:
raise IndexError(
f"Invalid value for the index parameter. Value {index} "
f"is out of range for a dataset with {df.shape[1]} columns."
)
elif isinstance(index, str):
if index in df:
df = df.set_index(index, drop=True)
else:
raise ValueError(
"Invalid value for the index parameter. "
f"Column {index} not found in the dataset."
)

if y is not None and df.index.name in (c.name for c in get_cols(y)):
raise ValueError(
"Invalid value for the index parameter. The index column "
f"can not be the same as the target column, got {df.index.name}."
)

if df.index.duplicated().any():
raise ValueError(
"Invalid value for the index parameter. There are duplicate indices "
"in the dataset. Use index=False to reset the index to RangeIndex."
)

return df

def _no_data_sets(
X: DataFrame,

Check notice on line 497 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: Pandas,

Check notice on line 498 in atom/baserunner.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'y' from outer scope
Expand Down Expand Up @@ -525,14 +530,13 @@ def _no_data_sets(
)
data = _subsample(data)

if isinstance(self._config.index, sequence_t):
if len(self._config.index) != len(data):
if isinstance(index, sequence_t):
if len(index) != len(data):
raise IndexError(
"Invalid value for the index parameter. Length of "
f"index ({len(self._config.index)}) doesn't match "
f"that of the dataset ({len(data)})."
"Invalid value for the index parameter. Length of index "
f"({len(index)}) doesn't match that of the dataset ({len(data)})."
)
data.index = self._config.index
data.index = index

if len(data) < 5:
raise ValueError(
Expand Down Expand Up @@ -585,7 +589,7 @@ def _no_data_sets(
stratify=self._config.get_stratify_columns(data, y),
)

complete_set = self._set_index(bk.concat([train, test, holdout]), y)
complete_set = _set_index(bk.concat([train, test, holdout]), y)

container = DataContainer(
data=(data := complete_set.iloc[: len(data)]),
Expand Down Expand Up @@ -682,23 +686,22 @@ def _has_data_sets(
)

# If the index is a sequence, assign it before shuffling
if isinstance(self._config.index, sequence_t):
if isinstance(index, sequence_t):
len_data = len(train) + len(test)
if holdout is not None:
len_data += len(holdout)

if len(self._config.index) != len_data:
if len(index) != len_data:
raise IndexError(
"Invalid value for the index parameter. Length of "
f"index ({len(self._config.index)}) doesn't match "
f"that of the data sets ({len_data})."
"Invalid value for the index parameter. Length of index "
f"({len(index)}) doesn't match that of the data sets ({len_data})."
)
train.index = self._config.index[: len(train)]
test.index = self._config.index[len(train): len(train) + len(test)]
train.index = index[: len(train)]
test.index = index[len(train): len(train) + len(test)]
if holdout is not None:
holdout.index = self._config.index[-len(holdout):]
holdout.index = index[-len(holdout):]

complete_set = self._set_index(bk.concat([train, test, holdout]), y_test)
complete_set = _set_index(bk.concat([train, test, holdout]), y_test)

container = DataContainer(
data=(data := complete_set.iloc[:len(train) + len(test)]),
Expand Down
25 changes: 20 additions & 5 deletions atom/basetransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@

from atom.utils.types import (
Backend, Bool, DataFrame, Engine, EngineDataOptions,
EngineEstimatorOptions, EngineTuple, Estimator, Int, IntLargerEqualZero,
Pandas, Sequence, Severity, Verbose, Warnings, XSelector, YSelector,
bool_t, dataframe_t, int_t, sequence_t,
EngineEstimatorOptions, EngineTuple, Estimator, FeatureNamesOut, Int,
IntLargerEqualZero, Pandas, Sequence, Severity, Verbose, Warnings,
XSelector, YSelector, bool_t, dataframe_t, int_t, sequence_t,
)
from atom.utils.utils import (
crash, flt, lst, make_sklearn, n_cols, to_df, to_pandas,
Expand Down Expand Up @@ -359,7 +359,11 @@ def _device_id(self) -> int:

# Methods ====================================================== >>

def _inherit(self, obj: T_Estimator, fixed: tuple[str, ...] = ()) -> T_Estimator:
def _inherit(
self,
obj: T_Estimator, fixed: tuple[str, ...] = (),
feature_names_out: FeatureNamesOut = "one-to-one",
) -> T_Estimator:
"""Inherit parameters from parent.
Utility method to set the sp (seasonal period), n_jobs and
Expand All @@ -375,6 +379,17 @@ def _inherit(self, obj: T_Estimator, fixed: tuple[str, ...] = ()) -> T_Estimator
fixed: tuple of str, default=()
Fixed parameters that should not be overriden.
feature_names_out: "one-to-one", callable or None, default="one-to-one"
Determines the list of feature names that will be returned
by the `get_feature_names_out` method.
- If None: The `get_feature_names_out` method is not defined.
- If "one-to-one": The output feature names will be equal to
the input feature names.
- If callable: Function that takes positional arguments self
and a sequence of input feature names. It must return a
sequence of output feature names.
Returns
-------
Estimator
Expand All @@ -392,7 +407,7 @@ def _inherit(self, obj: T_Estimator, fixed: tuple[str, ...] = ()) -> T_Estimator
else:
obj.set_params(**{p: lst(self._config.sp.sp)[0]})

return obj
return make_sklearn(obj, feature_names_out=feature_names_out)

def _get_est_class(self, name: str, module: str) -> type[Estimator]:
"""Import a class from a module.
Expand Down
Loading

0 comments on commit ff9468e

Please sign in to comment.