Skip to content

Commit

Permalink
add sp
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Dec 15, 2023
1 parent d13cba2 commit eabeea8
Show file tree
Hide file tree
Showing 85 changed files with 1,801 additions and 712 deletions.
2 changes: 1 addition & 1 deletion .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ and accept your changes.
* Update the documentation so all of your changes are reflected there.
* Adhere to [PEP 8](https://peps.python.org/pep-0008/) standards.
* Use a maximum of 99 characters per line. Try to keep docstrings below
74 characters.
80 characters.
* Update the project unit tests to test your code changes as thoroughly
as possible.
* Make sure that your code is properly commented with docstrings and
Expand Down
37 changes: 29 additions & 8 deletions atom/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from atom.atom import ATOM
from atom.utils.types import (
Backend, Bool, ColumnSelector, Engine, IndexSelector, IntLargerEqualZero,
NJobs, Predictor, Scalar, Verbose, Warnings, YSelector,
Backend, Bool, ColumnSelector, Engine, IndexSelector, Int,
IntLargerEqualZero, NJobs, Predictor, Scalar, Sequence, Verbose, Warnings,
YSelector,
)
from atom.utils.utils import Goal

Expand All @@ -35,13 +36,13 @@ def ATOMModel(
needs_scaling: Bool = False,
native_multilabel: Bool = False,
native_multioutput: Bool = False,
has_validation: str | None = None,
validation: str | None = None,
) -> T_Predictor:
"""Convert an estimator to a model that can be ingested by atom.
This function adds the relevant attributes to the estimator so
that they can be used by atom. Note that only estimators that
follow [sklearn's API][api] are compatible.
This function adds the relevant tags to the estimator so that they
can be used by `atom`. Note that only estimators that follow
[sklearn's API][api] are compatible.
Read more about custom models in the [user guide][custom-models].
Expand Down Expand Up @@ -75,7 +76,7 @@ def ATOMModel(
If False and the task is multioutput, a multioutput
meta-estimator is wrapped around the estimator.
has_validation: str or None, default=None
validation: str or None, default=None
Whether the model allows [in-training validation][].
- If None: No support for in-training validation.
Expand Down Expand Up @@ -121,7 +122,7 @@ def ATOMModel(
estimator_c.needs_scaling = needs_scaling
estimator_c.native_multioutput = native_multioutput
estimator_c.native_multilabel = native_multilabel
estimator_c.has_validation = has_validation
estimator_c.validation = validation

return estimator_c

Expand Down Expand Up @@ -453,6 +454,24 @@ class ATOMForecaster(ATOM):
and model training. The features are still used in the remaining
methods.
sp: int, str, sequence or None, default=None
[Seasonal period][seasonality] of the time series.
- If None: No seasonal period.
- If int: Seasonal period, e.g., 7 for weekly data, and 12 for
monthly data.
- If str:
- Seasonal period provided as [PeriodAlias][], e.g., "M" for
12 or "H" for 24.
- "index": The frequency of the data index is mapped to a
seasonal period.
- "infer": Automatically infer the seasonal period from the
data (calls [get_seasonal_period][self-get_seasonal_period]
under the hood, using default parameters).
- If sequence: Multiple seasonal periods provided as int or str.
test_size: int or float, default=0.2
- If <=1: Fraction of the dataset to include in the test set.
- If >1: Number of rows to include in the test set.
Expand Down Expand Up @@ -592,6 +611,7 @@ def __init__(
*arrays,
y: YSelector = -1,
ignore: ColumnSelector | None = None,
sp: Int | str | Sequence[Int | str] | None = None,
n_rows: Scalar = 1,
test_size: Scalar = 0.2,
holdout_size: Scalar | None = None,
Expand All @@ -611,6 +631,7 @@ def __init__(
y=y,
index=True,
ignore=ignore,
sp=sp,
test_size=test_size,
holdout_size=holdout_size,
shuffle=False,
Expand Down
24 changes: 14 additions & 10 deletions atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
NJobs, NormalizerStrats, NumericalStrats, Operators, Pandas, PrunerStrats,
RowSelector, Scalar, ScalerStrats, Sequence, Series, TargetSelector,
Transformer, VectorizerStarts, Verbose, Warnings, XSelector, YSelector,
sequence_t, tsindex_t,
sequence_t,
)
from atom.utils.utils import (
ClassMap, DataConfig, DataContainer, Goal, adjust_verbosity, bk,
Expand Down Expand Up @@ -95,6 +95,7 @@ def __init__(
y: YSelector = -1,
index: IndexSelector = False,
ignore: ColumnSelector | None = None,
sp: Int | str | Sequence[Int | str] | None = None,
shuffle: Bool = True,
stratify: IndexSelector = True,
n_rows: Scalar = 1,
Expand Down Expand Up @@ -133,18 +134,19 @@ def __init__(
holdout_size=holdout_size,
)

self._log("<< ================== ATOM ================== >>", 1)

# Initialize the branch system and fill with data
self._branches = BranchManager(memory=self.memory)
self._branches.fill(*self._get_data(arrays, y=y))

Check warning on line 139 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Incorrect call arguments

Expected an iterable, got DataContainer

self.ignore = ignore # type: ignore[assignment]
self.sp = sp # type: ignore[assignment]

self.missing = DEFAULT_MISSING

self._models = ClassMap()
self._metric = ClassMap()

self._log("<< ================== ATOM ================== >>", 1)
self._log("\nConfiguration ==================== >>", 1)
self._log(f"Algorithm task: {self.task}.", 1)
if self.n_jobs > 1:
Expand Down Expand Up @@ -747,8 +749,8 @@ def load(cls, filename: str | Path, data: tuple[Any, ...] | None = None) -> ATOM
if atom._config.index is False:

Check notice on line 749 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _config of a class
branch._container = DataContainer(
data=(dataset := branch._container.data.reset_index(drop=True)),
train_idx=dataset.index[: len(branch._container.train_idx)],
test_idx=dataset.index[-len(branch._container.test_idx) :],
train_idx=dataset.index[:len(branch._container.train_idx)],
test_idx=dataset.index[-len(branch._container.test_idx):],
n_cols=branch._container.n_cols,
)

Expand Down Expand Up @@ -956,11 +958,13 @@ def stats(self, _vb: Int = -2, /):
"""
self._log("Dataset stats " + "=" * 20 + " >>", _vb)
self._log(f"Shape: {self.shape}", _vb)
if self.task.is_forecast and self.sp:
self._log(f"Seasonal period: {self.sp}", _vb)

for set_ in ("train", "test", "holdout"):
if (data := getattr(self, set_)) is not None:
self._log(f"{set_.capitalize()} set size: {len(data)}", _vb)
if isinstance(self.branch.train.index, tsindex_t):
for ds in ("train", "test", "holdout"):
if (data := getattr(self, ds)) is not None:
self._log(f"{ds.capitalize()} set size: {len(data)}", _vb)
if self.task.is_forecast:
self._log(f" --> From: {min(data.index)} To: {max(data.index)}", _vb)

self._log("-" * 37, _vb)
Expand Down Expand Up @@ -1231,7 +1235,7 @@ def _add_transformer(
self.branch._container = DataContainer(
data=(data := self.dataset.reset_index(drop=True)),
train_idx=data.index[: len(self.branch._data.train_idx)],

Check notice on line 1237 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
test_idx=data.index[-len(self.branch._data.test_idx) :],
test_idx=data.index[-len(self.branch._data.test_idx):],

Check notice on line 1238 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
n_cols=self.branch._data.n_cols,

Check notice on line 1239 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
)
if self.branch._holdout is not None:

Check notice on line 1241 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _holdout of a class
Expand Down
82 changes: 71 additions & 11 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,23 @@ def __init__(
self._branch = branches.current
self._train_idx = len(self.branch._data.train_idx) # Can change for sh and ts

Check notice on line 254 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class

if self.needs_scaling and not check_scaling(self.X, pipeline=self.pipeline):
self.scaler = Scaler().fit(self.X_train)
if hasattr(self, "needs_scaling"):
if self.needs_scaling and not check_scaling(self.X, pipeline=self.pipeline):
self.scaler = Scaler().fit(self.X_train)

def __repr__(self) -> str:
"""Display class name."""
return f"{self.__class__.__name__}()"

def __dir__(self) -> list[str]:
"""Add additional attrs from __getattr__ to the dir."""
attrs = list(super().__dir__())
if "_branch" in self.__dict__:
attrs += [x for x in dir(self.branch) if not x.startswith("_")]
attrs += list(DF_ATTRS)
attrs += list(self.columns)
return attrs

def __getattr__(self, item: str) -> Any:
"""Get attributes from branch or data."""
if "_branch" in self.__dict__:
Expand Down Expand Up @@ -449,9 +459,10 @@ def _get_est(self, params: dict[str, Any]) -> Predictor:
estimator = MultiOutputClassifier(estimator)
elif self.task.is_regression:
estimator = MultiOutputRegressor(estimator)
elif hasattr(self, "_estimators") and self._goal.name not in self._estimators:
# Forecasting task with a regressor
estimator = make_reduction(estimator)
elif self.task.is_forecast:
if hasattr(self, "_estimators") and self._goal.name not in self._estimators:
# Forecasting task with a regressor
estimator = make_reduction(estimator)

return self._inherit(estimator)

Expand Down Expand Up @@ -494,13 +505,13 @@ def _fit_estimator(
Fitted instance.
"""
if self.has_validation and hasattr(estimator, "partial_fit") and validation:
if getattr(self, "validation", False) and hasattr(estimator, "partial_fit") and validation:
# Loop over first parameter in estimator
try:
steps = estimator.get_params()[self.has_validation]
steps = estimator.get_params()[self.validation]
except KeyError:
# For meta-estimators like multioutput
steps = estimator.get_params()[f"estimator__{self.has_validation}"]
steps = estimator.get_params()[f"estimator__{self.validation}"]

for step in range(steps):
kwargs = {}
Expand Down Expand Up @@ -533,8 +544,8 @@ def _fit_estimator(

if trial.should_prune():
# Hacky solution to add the pruned step to the output
if self.has_validation in trial.params:
trial.params[self.has_validation] = f"{step}/{steps}"
if self.validation in trial.params:
trial.params[self.validation] = f"{step}/{steps}"

trial.set_user_attr("estimator", estimator)
raise TrialPruned
Expand Down Expand Up @@ -1308,7 +1319,7 @@ def name(self, value: str):
"""Change the model's name."""
# Drop the acronym if provided by the user
if re.match(f"{self.acronym}_", value, re.I):
value = value[len(self.acronym) + 1 :]
value = value[len(self.acronym) + 1:]

# Add the acronym in front (with right capitalization)
self._name = f"{self.acronym}{f'_{value}' if value else ''}"
Expand Down Expand Up @@ -2437,6 +2448,32 @@ def transform(
class ClassRegModel(BaseModel):
"""Classification and regression models."""

def get_tags(self) -> dict[str, Any]:
"""Get the model's tags.
Return class parameters that provide general information about
the estimator's characteristics.
Returns
-------
dict
Model's tags.
"""
return {
"acronym": self.acronym,
"fullname": self.fullname,
"estimator": self._est_class,
"module": self._est_class.__module__.split(".")[0] + self._module,
"handles_missing": self.handles_missing,
"needs_scaling": self.needs_scaling,
"accepts_sparse": self.accepts_sparse,
"native_multilabel": self.native_multilabel,
"native_multioutput": self.native_multioutput,
"validation": self.validation,
"supports_engines": ", ".join(self.supports_engines),
}

@overload
def _prediction(
self,
Expand Down Expand Up @@ -2845,6 +2882,29 @@ def score(
class ForecastModel(BaseModel):
"""Forecasting models."""

def get_tags(self) -> dict[str, Any]:
"""Get the model's tags.
Return class parameters that provide general information about
the estimator's characteristics.
Returns
-------
dict
Model's tags.
"""
return {
"acronym": self.acronym,
"fullname": self.fullname,
"estimator": self._est_class.__name__,
"module": self._est_class.__module__.split(".")[0] + self._module,
"handles_missing": self.handles_missing,
"in_sample_prediction": self.in_sample_prediction,
"native_multivariate": self.native_multivariate,
"supports_engines": ", ".join(self.supports_engines),
}

@overload
def _prediction(
self,
Expand Down
Loading

0 comments on commit eabeea8

Please sign in to comment.