Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Dec 8, 2023
1 parent b194a56 commit d13cba2
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -U pip setuptools
pip install -U pytest pytest-cov
pip install -U pytest pytest-mock pytest-cov
pip install -e .[full]
- name: Run tests w/ coverage
run: pytest --cov=atom --cov-report=xml tests/
Expand Down
15 changes: 6 additions & 9 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def _get_parameters(self, trial: Trial) -> dict[str, Any]:
"""Get the trial's hyperparameters.
This method fetches the suggestions from the trial and
rounds floats to the 4th digit.
rounds floats to the fourth digit.
Parameters
----------
Expand Down Expand Up @@ -656,14 +656,14 @@ def _get_pred(
attr = attribute
break

df = self.branch._get_rows(rows)
X, y = self.branch._get_rows(rows, return_X_y=True)

Check notice on line 659 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

Check notice on line 659 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class

# Filter for indices in dataset (required for sh and ts)
X = df.iloc[df.index.isin(self._all.index), :self.n_features]
y_true = df.iloc[df.index.isin(self._all.index), self.n_features:]
X = X.loc[X.index.isin(self._all.index)]

Check notice on line 662 in atom/basemodel.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
y_true = y.loc[y.index.isin(self._all.index)]

if self.task.is_forecast:
y_pred = self._prediction(X.index, X=check_empty(X), verbose=0, method=attr)
y_pred = self._prediction(fh=X.index, X=check_empty(X), verbose=0, method=attr)
else:
y_pred = self._prediction(X.index, verbose=0, method=attr)

Expand Down Expand Up @@ -757,10 +757,7 @@ def _score_from_pred(

# Forecasting models can have first prediction NaN
if self.task.is_forecast and all(x.isna()[0] for x in get_cols(y_pred)):
(
y_true,
y_pred,
) = y_true.iloc[1:], y_pred.iloc[1:]
y_true, y_pred = y_true.iloc[1:], y_pred.iloc[1:]

if self.task is Task.multiclass_multioutput_classification:
# Get the mean of the scores over the target columns
Expand Down
16 changes: 8 additions & 8 deletions atom/branch/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,8 @@ def _get_columns(
inc.append(df.columns[int(col)])
else:
raise IndexError(
f"Invalid value for the columns parameter. Value {col} "
f"is out of range for data with {df.shape[1]} columns."
f"Invalid column selection. Value {col} is out "
f"of range for data with {df.shape[1]} columns."
)
elif isinstance(col, str):
for c in col.split("+"):
Expand All @@ -619,19 +619,19 @@ def _get_columns(
array.extend(df.select_dtypes(c).columns)
except TypeError:
raise ValueError(
"Invalid value for the columns parameter. "
f"Could not find any column that matches {c}."
"Invalid column selection. Could "
f"not find any column that matches {c}."
) from None

if len(inc) + len(exc) == 0:
raise ValueError(
"Invalid value for the columns parameter, got "
f"{columns}. At least one column has to be selected."
f"Invalid column selection, got {columns}. "
f"At least one column has to be selected."
)
elif inc and exc:
raise ValueError(
"Invalid value for the columns parameter. You can either "
"include or exclude columns, not combinations of these."
"Invalid column selection. You can either include "
"or exclude columns, not combinations of these."
)
elif exc:
# If columns were excluded with `!`, select all but those
Expand Down
4 changes: 2 additions & 2 deletions tests/test_atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_reset():
atom.scale()
atom.branch = "2"
atom.encode()
atom.run("LR", errors="raise")
atom.run("LR")
atom.reset(hard=True)
assert not atom.models
assert len(atom._branches) == 1
Expand Down Expand Up @@ -606,7 +606,7 @@ def test_ignore_columns():
"""Assert that columns can be ignored from transformations."""
atom = ATOMRegressor(X_reg, y_reg, ignore="age", random_state=1)
atom.scale()
atom.run("OLS", errors="raise")
atom.run("OLS")
assert "age" in atom
assert "age" not in atom.pipeline.named_steps["scaler"].feature_names_in_
assert "age" not in atom.ols.estimator.feature_names_in_
Expand Down
4 changes: 2 additions & 2 deletions tests/test_basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def test_continued_hyperparameter_tuning():
def test_continued_bootstrapping():
"""Assert that the bootstrapping method can be recalled."""
atom = ATOMClassifier(X_bin, y_bin, random_state=1)
atom.run("LGB", est_params={"n_estimators": 5}, errors="raise")
atom.run("LGB", est_params={"n_estimators": 5})
assert not hasattr(atom.lgb, "bootstrap")
atom.lgb.bootstrapping(3)
assert len(atom.lgb.bootstrap) == 3
Expand Down Expand Up @@ -605,7 +605,7 @@ def test_shape_property():
def test_columns_property():
"""Assert that the columns property returns the columns of the dataset."""
atom = ATOMClassifier(X_bin, y_bin, ignore=(0, 1), random_state=1)
atom.run("MNB", errors="raise")
atom.run("MNB")
assert len(atom.mnb.columns) == len(atom.columns) - 2


Expand Down
2 changes: 1 addition & 1 deletion tests/test_baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def test_input_is_train_test_with_parameter_y():

def test_input_is_train_test_for_forecast():
"""Assert that input train, test works for forecast tasks."""
trainer = DirectForecaster("ES", errors="raise", random_state=1)
trainer = DirectForecaster("ES", random_state=1)
trainer.run(fc_train, fc_test)
assert_series_equal(trainer.y, pd.concat([fc_train, fc_test]))

Expand Down
1 change: 0 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def test_pruning_non_sklearn(model):
atom.run(
models=model,
n_trials=7,
errors="raise",
est_params={"n_estimators": 10, "max_depth": 2},
ht_params={"pruner": PatientPruner(None, patience=1)},
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_get_plot_models_max_one():
def test_custom_title_and_legend(func):
"""Assert that title and legend can be customized."""
atom = ATOMClassifier(X_bin, y_bin, random_state=1)
atom.run("Tree", errors="raise")
atom.run("Tree")
atom.plot_roc(title={"text": "test", "x": 0}, legend={"font_color": "red"})
func.assert_called_once()

Expand Down

0 comments on commit d13cba2

Please sign in to comment.