diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py index 3d4d15be1ce..e6a93e3f656 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py @@ -9,6 +9,15 @@ import pyarrow as pa import pylibcudf as plc +from pylibcudf.strings.convert.convert_floats import from_floats, is_float, to_floats +from pylibcudf.strings.convert.convert_integers import ( + from_integers, + is_integer, + to_integers, +) +from pylibcudf.traits import is_floating_point + +from polars.exceptions import InvalidOperationError from cudf_polars.containers import Column from cudf_polars.dsl.expressions.base import AggInfo, ExecutionContext, Expr @@ -35,7 +44,7 @@ def __init__(self, dtype: plc.DataType, value: Expr) -> None: self.children = (value,) if not dtypes.can_cast(value.dtype, self.dtype): raise NotImplementedError( - f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}" + f"Can't cast {value.dtype.id().name} to {self.dtype.id().name}" ) def do_evaluate( @@ -48,7 +57,45 @@ def do_evaluate( """Evaluate this expression given a dataframe for context.""" (child,) = self.children column = child.evaluate(df, context=context, mapping=mapping) - return Column(plc.unary.cast(column.obj, self.dtype)).sorted_like(column) + if ( + self.dtype.id() == plc.TypeId.STRING + or column.obj.type().id() == plc.TypeId.STRING + ): + result = self._handle_string_cast(column) + else: + result = plc.unary.cast(column.obj, self.dtype) + return Column(result).sorted_like(column) + + def _handle_string_cast(self, column: Column) -> plc.Column: + if self.dtype.id() == plc.TypeId.STRING: + if is_floating_point(column.obj.type()): + result = from_floats(column.obj) + else: + result = from_integers(column.obj) + else: + if is_floating_point(self.dtype): + floats = is_float(column.obj) + if not plc.interop.to_arrow( + plc.reduce.reduce( + floats, + plc.aggregation.all(), + plc.DataType(plc.TypeId.BOOL8), + ) + ).as_py(): + raise InvalidOperationError("Conversion from `str` failed.") + result = to_floats(column.obj, self.dtype) + else: + integers = is_integer(column.obj) + if not plc.interop.to_arrow( + plc.reduce.reduce( + integers, + plc.aggregation.all(), + plc.DataType(plc.TypeId.BOOL8), + ) + ).as_py(): + raise InvalidOperationError("Conversion from `str` failed.") + result = to_integers(column.obj, self.dtype) + return result def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 4154a404e98..fb7ed0aaf2b 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -9,6 +9,7 @@ import pyarrow as pa import pylibcudf as plc +from pylibcudf.traits import is_floating_point, is_integral_not_bool from typing_extensions import assert_never import polars as pl @@ -45,6 +46,10 @@ def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType: return typ +def _is_int_or_float(dtype: plc.DataType) -> bool: + return is_integral_not_bool(dtype) or is_floating_point(dtype) + + def can_cast(from_: plc.DataType, to: plc.DataType) -> bool: """ Can we cast (via :func:`~.pylibcudf.unary.cast`) between two datatypes. @@ -61,9 +66,13 @@ def can_cast(from_: plc.DataType, to: plc.DataType) -> bool: True if casting is supported, False otherwise """ return ( - plc.traits.is_fixed_width(to) - and plc.traits.is_fixed_width(from_) - and plc.unary.is_supported_cast(from_, to) + ( + plc.traits.is_fixed_width(to) + and plc.traits.is_fixed_width(from_) + and plc.unary.is_supported_cast(from_, to) + ) + or (from_.id() == plc.TypeId.STRING and _is_int_or_float(to)) + or (_is_int_or_float(from_) and to.id() == plc.TypeId.STRING) ) diff --git a/python/cudf_polars/tests/expressions/test_casting.py b/python/cudf_polars/tests/expressions/test_casting.py index 3e003054338..0722a0f198a 100644 --- a/python/cudf_polars/tests/expressions/test_casting.py +++ b/python/cudf_polars/tests/expressions/test_casting.py @@ -14,7 +14,7 @@ _supported_dtypes = [(pl.Int8(), pl.Int64())] _unsupported_dtypes = [ - (pl.String(), pl.Int64()), + (pl.Datetime("ns"), pl.Int64()), ] diff --git a/python/cudf_polars/tests/expressions/test_numeric_binops.py b/python/cudf_polars/tests/expressions/test_numeric_binops.py index 8f68bbc460c..fa1ec3c19e4 100644 --- a/python/cudf_polars/tests/expressions/test_numeric_binops.py +++ b/python/cudf_polars/tests/expressions/test_numeric_binops.py @@ -8,7 +8,6 @@ from cudf_polars.testing.asserts import ( assert_gpu_result_equal, - assert_ir_translation_raises, ) dtypes = [ @@ -114,12 +113,3 @@ def test_binop_with_scalar(left_scalar, right_scalar): q = df.select(lop / rop) assert_gpu_result_equal(q) - - -def test_numeric_to_string_cast_fails(): - df = pl.DataFrame( - {"a": [1, 1, 2, 3, 3, 4, 1], "b": [None, 2, 3, 4, 5, 6, 7]} - ).lazy() - q = df.select(pl.col("a").cast(pl.String)) - - assert_ir_translation_raises(q, NotImplementedError) diff --git a/python/cudf_polars/tests/expressions/test_stringfunction.py b/python/cudf_polars/tests/expressions/test_stringfunction.py index 4f6850ac977..d1655e5e466 100644 --- a/python/cudf_polars/tests/expressions/test_stringfunction.py +++ b/python/cudf_polars/tests/expressions/test_stringfunction.py @@ -40,6 +40,34 @@ def ldf(with_nulls): ) +@pytest.fixture(params=[pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64]) +def numeric_type(request): + return request.param + + +@pytest.fixture +def str_to_numeric_data(with_nulls): + a = ["1", "2", "3", "4", "5", "6"] + if with_nulls: + a[4] = None + return pl.LazyFrame({"a": a}) + + +@pytest.fixture +def str_from_numeric_data(with_nulls, numeric_type): + a = [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + if with_nulls: + a[4] = None + return pl.LazyFrame({"a": pl.Series(a, dtype=numeric_type)}) + + slice_cases = [ (1, 3), (0, 3), @@ -337,3 +365,23 @@ def test_unsupported_regex_raises(pattern): q = df.select(pl.col("a").str.contains(pattern, strict=True)) assert_ir_translation_raises(q, NotImplementedError) + + +def test_string_to_numeric(str_to_numeric_data, numeric_type): + query = str_to_numeric_data.select(pl.col("a").cast(numeric_type)) + assert_gpu_result_equal(query) + + +def test_string_from_numeric(str_from_numeric_data): + query = str_from_numeric_data.select(pl.col("a").cast(pl.String)) + assert_gpu_result_equal(query) + + +def test_string_to_numeric_invalid(numeric_type): + df = pl.LazyFrame({"a": ["a", "b", "c"]}) + q = df.select(pl.col("a").cast(numeric_type)) + assert_collect_raises( + q, + polars_except=pl.exceptions.InvalidOperationError, + cudf_except=pl.exceptions.ComputeError, + )