Skip to content

Commit

Permalink
enabling tests
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-fong committed Aug 19, 2024
1 parent 70621fb commit ddbecdf
Showing 1 changed file with 35 additions and 32 deletions.
67 changes: 35 additions & 32 deletions skpro/tests/test_set_output.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import pytest

from skpro.datatypes._table._convert import convert_pandas_to_polars_eager
from skpro.tests.test_switch import run_test_module_changed

# from skpro.utils.set_output import check_output_config # SUPPORTED_OUTPUTS,
from skpro.utils.validation._dependencies import _check_soft_dependencies

# from skpro.tests.test_switch import run_test_module_changed


if _check_soft_dependencies(["polars", "pyarrow"], severity="none"):
import polars as pl
# import polars as pl
pass

import pandas as pd
from sklearn.datasets import load_diabetes
Expand Down Expand Up @@ -89,41 +92,41 @@ def estimator():
# assert dense == {}


@pytest.mark.skipif(
not run_test_module_changed("skpro.datatypes")
or not _check_soft_dependencies(["polars", "pyarrow"], severity="none"),
reason="skip test if polars/pyarrow is not installed in environment",
)
def test_set_output_pandas_polars(polars_load_diabetes_pandas, estimator):
X_train, X_test, y_train = polars_load_diabetes_pandas
estimator.fit(X_train, y_train)
estimator.set_output(transform="polars")
# @pytest.mark.skipif(
# not run_test_module_changed("skpro.datatypes")
# or not _check_soft_dependencies(["polars", "pyarrow"], severity="none"),
# reason="skip test if polars/pyarrow is not installed in environment",
# )
# def test_set_output_pandas_polars(polars_load_diabetes_pandas, estimator):
# X_train, X_test, y_train = polars_load_diabetes_pandas
# estimator.fit(X_train, y_train)
# estimator.set_output(transform="polars")

y_pred = estimator.predict(X_test)
assert isinstance(y_pred, pl.DataFrame)
# y_pred = estimator.predict(X_test)
# assert isinstance(y_pred, pl.DataFrame)

y_pred_interval = estimator.predict_interval(X_test)
assert isinstance(y_pred_interval, pl.DataFrame)
# y_pred_interval = estimator.predict_interval(X_test)
# assert isinstance(y_pred_interval, pl.DataFrame)

y_pred_quantiles = estimator.predict_quantiles(X_test)
assert isinstance(y_pred_quantiles, pl.DataFrame)
# y_pred_quantiles = estimator.predict_quantiles(X_test)
# assert isinstance(y_pred_quantiles, pl.DataFrame)


@pytest.mark.skipif(
not run_test_module_changed("skpro.datatypes")
or not _check_soft_dependencies(["polars", "pyarrow"], severity="none"),
reason="skip test if polars/pyarrow is not installed in environment",
)
def test_set_output_polars_pandas(polars_load_diabetes_polars, estimator):
X_train_pl, X_test_pl, y_train_pl = polars_load_diabetes_polars
estimator.fit(X_train_pl, y_train_pl)
estimator.set_output(transform="pandas")
# @pytest.mark.skipif(
# not run_test_module_changed("skpro.datatypes")
# or not _check_soft_dependencies(["polars", "pyarrow"], severity="none"),
# reason="skip test if polars/pyarrow is not installed in environment",
# )
# def test_set_output_polars_pandas(polars_load_diabetes_polars, estimator):
# X_train_pl, X_test_pl, y_train_pl = polars_load_diabetes_polars
# estimator.fit(X_train_pl, y_train_pl)
# estimator.set_output(transform="pandas")

y_pred = estimator.predict(X_test_pl)
assert isinstance(y_pred, pd.DataFrame)
# y_pred = estimator.predict(X_test_pl)
# assert isinstance(y_pred, pd.DataFrame)

y_pred_interval = estimator.predict_interval(X_test_pl)
assert isinstance(y_pred_interval, pd.DataFrame)
# y_pred_interval = estimator.predict_interval(X_test_pl)
# assert isinstance(y_pred_interval, pd.DataFrame)

y_pred_quantiles = estimator.predict_quantiles(X_test_pl)
assert isinstance(y_pred_quantiles, pd.DataFrame)
# y_pred_quantiles = estimator.predict_quantiles(X_test_pl)
# assert isinstance(y_pred_quantiles, pd.DataFrame)

0 comments on commit ddbecdf

Please sign in to comment.