Skip to content

Commit

Permalink
feat(datafusion): add TimestampFromUNIX and subtract/add operations
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Nov 8, 2023
1 parent 0ab933a commit 2bffa5a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
28 changes: 23 additions & 5 deletions ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
parenthesize,
)
from ibis.backends.base.sqlglot.datatypes import PostgresType
from ibis.common.temporal import IntervalUnit
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType

Expand Down Expand Up @@ -108,7 +108,6 @@ def translate_val(op, **_):
ops.ArrayContains: "array_contains",
ops.ArrayLength: "array_length",
ops.ArrayRemove: "array_remove_all",
ops.StringLength: "length",
}

for _op, _name in _simple_ops.items():
Expand Down Expand Up @@ -148,6 +147,11 @@ def _fmt(_, _name: str = _name, **kw):
ops.DateAdd: operator.add,
ops.DateSub: operator.sub,
ops.DateDiff: operator.sub,
ops.TimestampDiff: operator.sub,
ops.TimestampSub: operator.sub,
ops.TimestampAdd: operator.add,
ops.IntervalAdd: operator.add,
ops.IntervalSubtract: operator.sub,
}


Expand Down Expand Up @@ -212,7 +216,7 @@ def _literal(op, *, value, dtype, **kw):
"DataFusion doesn't support subsecond interval resolutions"
)

return interval(value, unit=dtype.resolution.upper())
return interval(value, unit=dtype.unit.plural.lower())
elif dtype.is_timestamp():
return _to_timestamp(value, dtype, literal=True)
elif dtype.is_date():
Expand Down Expand Up @@ -780,10 +784,24 @@ def is_nan(op, *, arg, **_):


@translate_val.register(ops.ArrayStringJoin)
def array_string_join(op, *, sep, arg):
def array_string_join(op, *, sep, arg, **_):
return F.array_join(arg, sep)


@translate_val.register(ops.FindInSet)
def array_string_find(op, *, needle, values):
def array_string_find(op, *, needle, values, **_):
return F.coalesce(F.array_position(F.make_array(*values), needle), 0)


@translate_val.register(ops.TimestampFromUNIX)
def timestamp_from_unix(op, *, arg, unit, **_):
if unit == TimestampUnit.SECOND:
return F.from_unixtime(arg)
elif unit in (
TimestampUnit.MILLISECOND,
TimestampUnit.MICROSECOND,
TimestampUnit.NANOSECOND,
):
return F.arrow_cast(arg, f"Timestamp({unit.name.capitalize()}, None)")
else:
raise com.UnsupportedOperationError(f"Unsupported unit {unit}")
17 changes: 12 additions & 5 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,11 @@ def convert_to_offset(x):
"CalciteContextException: Cannot apply '-' to arguments of type '<TIMESTAMP(9)> - <TIMESTAMP(0)>'."
),
),
pytest.mark.broken(
["datafusion"],
raises=Exception,
reason="pyarrow.lib.ArrowInvalid: Casting from duration[us] to duration[s] would lose data",
),
],
),
param(
Expand Down Expand Up @@ -1186,12 +1191,16 @@ def convert_to_offset(x):
raises=com.UnsupportedOperationError,
reason="DATE_DIFF is not supported in Flink",
),
pytest.mark.broken(
["datafusion"],
raises=Exception,
reason="pyarrow.lib.ArrowNotImplementedError: Unsupported cast",
),
],
),
],
)
@pytest.mark.notimpl(["mssql", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(["datafusion"], raises=BaseException)
def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
expr = expr_fn(alltypes, backend).name("tmp")
expected = expected_fn(df, backend)
Expand Down Expand Up @@ -1377,9 +1386,7 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
),
],
)
@pytest.mark.notimpl(
["datafusion", "sqlite", "mssql", "oracle"], raises=com.OperationNotDefinedError
)
@pytest.mark.notimpl(["sqlite", "mssql", "oracle"], raises=com.OperationNotDefinedError)
def test_temporal_binop_pandas_timedelta(
backend, con, alltypes, df, timedelta, temporal_fn
):
Expand Down Expand Up @@ -1775,7 +1782,7 @@ def test_strftime(backend, alltypes, df, expr_fn, pandas_pattern):
],
)
@pytest.mark.notimpl(
["datafusion", "mysql", "postgres", "sqlite", "druid", "oracle"],
["mysql", "postgres", "sqlite", "druid", "oracle"],
raises=com.OperationNotDefinedError,
)
def test_integer_to_timestamp(backend, con, unit):
Expand Down

0 comments on commit 2bffa5a

Please sign in to comment.