Skip to content

Commit

Permalink
correctly implement check_for_invalid_characters and test it
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jul 18, 2024
1 parent 84774ef commit 5e1edfe
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 5 deletions.
11 changes: 6 additions & 5 deletions bfabric/scripts/bfabric_save_csv2dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ def check_for_invalid_characters(data: pl.DataFrame, invalid_characters: str) ->
if not invalid_characters:
return
invalid_columns_df = data.select(pl.col(pl.String).str.contains_any(list(invalid_characters)).any())
if invalid_columns_df.is_empty():
return
invalid_columns = (
invalid_columns_df.transpose(include_header=True, header_name="column")
.filter(pl.col("column_0"))
.select("column")
.to_numpy()
.filter(pl.col("column_0"))["column"]
.to_list()
)
if len(invalid_columns) > 0:
raise RuntimeError(f"Invalid characters found in columns: {invalid_columns[0]}")
if invalid_columns:
raise RuntimeError(f"Invalid characters found in columns: {invalid_columns}")


def main() -> None:
Expand Down
60 changes: 60 additions & 0 deletions bfabric/tests/unit/scripts/test_save_csv2dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest
import polars as pl

from bfabric.scripts.bfabric_save_csv2dataset import check_for_invalid_characters


def test_check_for_invalid_characters_no_invalid_chars():
data = pl.DataFrame({"col1": ["abc", "def", "ghi"], "col2": ["123", "456", "789"]})
invalid_characters = "!@#"

# Should not raise an exception
check_for_invalid_characters(data, invalid_characters)


def test_check_for_invalid_characters_with_invalid_chars():
data = pl.DataFrame({"col1": ["abc", "d!ef", "ghi"], "col2": ["123", "456", "789"]})
invalid_characters = "!@#"

with pytest.raises(RuntimeError) as excinfo:
check_for_invalid_characters(data, invalid_characters)

assert "Invalid characters found in columns: ['col1']" in str(excinfo.value)


def test_check_for_invalid_characters_multiple_columns():
data = pl.DataFrame({"col1": ["abc", "d!ef", "ghi"], "col2": ["123", "45@6", "789"], "col3": ["xyz", "uvw", "rst"]})
invalid_characters = "!@#"

with pytest.raises(RuntimeError) as excinfo:
check_for_invalid_characters(data, invalid_characters)

assert "Invalid characters found in columns: ['col1', 'col2']" in str(excinfo.value)


def test_check_for_invalid_characters_empty_invalid_chars():
data = pl.DataFrame({"col1": ["abc", "def", "ghi"], "col2": ["123", "456", "789"]})
invalid_characters = ""

# Should not raise an exception
check_for_invalid_characters(data, invalid_characters)


def test_check_for_invalid_characters_empty_dataframe():
data = pl.DataFrame()
invalid_characters = "!@#"

# Should not raise an exception
check_for_invalid_characters(data, invalid_characters)


def test_check_for_invalid_characters_non_string_columns():
data = pl.DataFrame({"col1": [1, 2, 3], "col2": [4.5, 5.6, 6.7], "col3": ["abc", "def", "ghi"]})
invalid_characters = "!@#"

# Should not raise an exception
check_for_invalid_characters(data, invalid_characters)


if __name__ == "__main__":
pytest.main()

0 comments on commit 5e1edfe

Please sign in to comment.