Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

86 get subset kwargs do not work if column name has spaces #96

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,15 @@ def create_index(self):
with pytest.raises(RuntimeError):
_ = TestDataset().index

def test_warning_index_no_valid_attribute(self):
a-mosquito marked this conversation as resolved.
Show resolved Hide resolved
class TestDataset(Dataset):
def create_index(self):
# Return index with some invalid column names
return pd.DataFrame(np.zeros((10, 3)), columns=["a", "b ", "else"])

with pytest.warns(RuntimeWarning):
_ = TestDataset().index

@pytest.mark.parametrize(
("level", "what_to_expect"),
[
Expand Down
13 changes: 13 additions & 0 deletions tpcp/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base class for all datasets."""
import warnings
from keyword import iskeyword
from typing import Dict, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union, cast, overload

import numpy as np
Expand Down Expand Up @@ -64,6 +65,8 @@ def _create_check_index(self):
While we can not catch all related issues (i.e. determinism across different machines), this should catch the
most obvious ones.

Furthermore, we check if all columns of the index are valid Python attribute names and throw a warning if not.

In case, creating the index twice is too expensive, users can overwrite this method.
But better to catch errors early.
"""
Expand All @@ -83,6 +86,16 @@ def _create_check_index(self):
"explicitly using `sort_values`."
)

invalid_elements = [s for s in index_1.columns if not s.isidentifier() or iskeyword(s)]
if invalid_elements:
warnings.warn(
f"Some of your index columns are not valid Python attribute names: {invalid_elements}. "
a-mosquito marked this conversation as resolved.
Show resolved Hide resolved
f"This will cause issues when using further methods such as `get_subset`, `group_label`, "
a-mosquito marked this conversation as resolved.
Show resolved Hide resolved
f"`group_labels`, and `datapoint_label`.",
RuntimeWarning,
stacklevel=1,
)

return index_1

@property
Expand Down
Loading