Skip to content

Commit

Permalink
Update test_polars.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Sep 7, 2024
1 parent 698774b commit 2a9b2b5
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions skpro/datatypes/tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
if _check_soft_dependencies(["polars", "pyarrow"], severity="none"):
import polars as pl

from skpro.datatypes._table._check import check_polars_table
from skpro.datatypes._table._convert import convert_pandas_to_polars_eager
from skpro.datatypes import check_is_mtype, convert

TEST_ALPHAS = [0.05, 0.1, 0.25]

Expand Down Expand Up @@ -43,12 +42,15 @@ def estimator():
return _estimator


def _pd_to_pl(df):
return convert(df, from_type="pd_Series_Table", to_type="polars_eager_table")

@pytest.fixture
def polars_load_diabetes_polars(polars_load_diabetes_pandas):
X_train, X_test, y_train = polars_load_diabetes_pandas
X_train_pl = convert_pandas_to_polars_eager(X_train)
X_test_pl = convert_pandas_to_polars_eager(X_test)
y_train_pl = convert_pandas_to_polars_eager(y_train)
X_train_pl = _pd_to_pl(X_train)
X_test_pl = _pd_to_pl(X_test)
y_train_pl = _pd_to_pl(y_train)

# drop the index in the polars frame
X_train_pl = X_train_pl.drop(["__index__"])
Expand All @@ -60,9 +62,9 @@ def polars_load_diabetes_polars(polars_load_diabetes_pandas):

def polars_load_diabetes_polars_with_index(polars_load_diabetes_pandas):
X_train, X_test, y_train = polars_load_diabetes_pandas
X_train_pl = convert_pandas_to_polars_eager(X_train)
X_test_pl = convert_pandas_to_polars_eager(X_test)
y_train_pl = convert_pandas_to_polars_eager(y_train)
X_train_pl = _pd_to_pl(X_train)
X_test_pl = _pd_to_pl(X_test)
y_train_pl = _pd_to_pl(y_train)

return [X_train_pl, X_test_pl, y_train_pl]

Expand All @@ -83,9 +85,9 @@ def test_polars_eager_conversion_methods(
X_train, X_test, y_train = polars_load_diabetes_pandas
X_train_pl, X_test_pl, y_train_pl = polars_load_diabetes_polars

assert check_polars_table(X_train_pl)
assert check_polars_table(X_test_pl)
assert check_polars_table(y_train_pl)
assert check_is_mtype(X_train_pl, "polars_eager_table")
assert check_is_mtype(X_test_pl, "polars_eager_table")
assert check_is_mtype(y_train_pl, "polars_eager_table")

assert (X_train.values == X_train_pl.to_numpy()).all()
assert (X_test.values == X_test_pl.to_numpy()).all()
Expand Down

0 comments on commit 2a9b2b5

Please sign in to comment.