Skip to content

Commit

Permalink
Fix apply
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Apr 27, 2022
1 parent a057995 commit 7666f99
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 13 deletions.
6 changes: 4 additions & 2 deletions mars/dataframe/base/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _infer_df_func_returns(self, df, dtypes, dtype=None, name=None, index=None):
if self.output_types is not None and (
dtypes is not None or dtype is not None
):
ret_dtypes = dtypes if dtypes is not None else (dtype, name)
ret_dtypes = dtypes if dtypes is not None else (name, dtype)
ret_index_value = parse_index(index) if index is not None else None
self._elementwise = False
return ret_dtypes, ret_index_value
Expand Down Expand Up @@ -473,7 +473,9 @@ def __call__(self, df_or_series, dtypes=None, dtype=None, name=None, index=None)
self._axis = validate_axis(axis, df_or_series)

if df_or_series.op.output_types[0] == OutputType.dataframe:
return self._call_dataframe(df_or_series, dtypes=dtypes, index=index)
return self._call_dataframe(
df_or_series, dtypes=dtypes, dtype=dtype, name=name, index=index
)
else:
return self._call_series(
df_or_series, dtypes=dtypes, dtype=dtype, name=name, index=index
Expand Down
15 changes: 14 additions & 1 deletion mars/dataframe/base/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_rechunk():
assert series2.nsplits == series.nsplits


def test_data_frame_apply():
def test_dataframe_apply():
cols = [chr(ord("A") + i) for i in range(10)]
df_raw = pd.DataFrame(dict((c, [i**2 for i in range(20)]) for c in cols))

Expand All @@ -231,6 +231,10 @@ def df_func_with_err(v):
assert len(v) > 2
return v.sort_values()

def df_series_func_with_err(v):
assert len(v) > 2
return 0

with pytest.raises(TypeError):
df.apply(df_func_with_err)

Expand All @@ -240,6 +244,15 @@ def df_func_with_err(v):
assert r.op.output_types[0] == OutputType.dataframe
assert r.op.elementwise is False

r = df.apply(
df_series_func_with_err, output_type="series", dtype=object, name="output"
)
assert r.dtype == np.dtype("O")
assert r.shape == (df.shape[-1],)
assert r.op._op_type_ == opcodes.APPLY
assert r.op.output_types[0] == OutputType.series
assert r.op.elementwise is False

r = df.apply("ffill")
assert r.op._op_type_ == opcodes.FILL_NA

Expand Down
4 changes: 3 additions & 1 deletion mars/dataframe/groupby/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def get(self):
"skew": lambda x, bias=False: x.skew(bias=bias),
"kurt": lambda x, bias=False: x.kurt(bias=bias),
"kurtosis": lambda x, bias=False: x.kurtosis(bias=bias),
"nunique": lambda x: x.nunique(),
}
_series_col_name = "col_name"

Expand Down Expand Up @@ -720,7 +721,8 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
result = (result,)

if out.ndim == 2:
result = tuple(r.to_frame().T for r in result)
if result[0].ndim == 1:
result = tuple(r.to_frame().T for r in result)
if op.stage == OperandStage.agg:
result = tuple(r.astype(out.dtypes) for r in result)
else:
Expand Down
1 change: 1 addition & 0 deletions mars/dataframe/reduction/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def where_function(cond, var1, var2):
"skew": lambda x, skipna=True, bias=False: x.skew(skipna=skipna, bias=bias),
"kurt": lambda x, skipna=True, bias=False: x.kurt(skipna=skipna, bias=bias),
"kurtosis": lambda x, skipna=True, bias=False: x.kurtosis(skipna=skipna, bias=bias),
"nunique": lambda x: x.nunique(),
}


Expand Down
4 changes: 3 additions & 1 deletion mars/dataframe/reduction/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,13 +972,15 @@ def _compile_function(self, func, func_name=None, ndim=1) -> ReductionSteps:
else:
map_func_name, agg_func_name = step_func_name, step_func_name

op_custom_reduction = getattr(t.op, "custom_reduction", None)

# build agg description
agg_funcs.append(
ReductionAggStep(
agg_input_key,
map_func_name,
agg_func_name,
custom_reduction,
op_custom_reduction or custom_reduction,
t.key,
output_limit,
t.op.get_reduction_args(axis=self._axis),
Expand Down
9 changes: 1 addition & 8 deletions mars/dataframe/reduction/custom_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,7 @@ class DataFrameCustomReduction(DataFrameReductionOperand, DataFrameReductionMixi
_op_type_ = OperandDef.CUSTOM_REDUCTION
_func_name = "custom_reduction"

_custom_reduction = AnyField("custom_reduction")

def __init__(self, custom_reduction=None, **kw):
super().__init__(_custom_reduction=custom_reduction, **kw)

@property
def custom_reduction(self):
return self._custom_reduction
custom_reduction = AnyField("custom_reduction")

@property
def is_atomic(self):
Expand Down
6 changes: 6 additions & 0 deletions mars/dataframe/reduction/tests/test_reduction_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,12 @@ def test_nunique(setup, check_ref_counts):
expected = data1.nunique(axis=1)
pd.testing.assert_series_equal(result, expected)

# test with agg func
df = md.DataFrame(data1, chunk_size=3)
result = df.agg("nunique").execute().fetch()
expected = data1.agg("nunique")
pd.testing.assert_series_equal(result, expected)


@pytest.mark.skipif(pa is None, reason="pyarrow not installed")
def test_use_arrow_dtype_n_unique(setup, check_ref_counts):
Expand Down

0 comments on commit 7666f99

Please sign in to comment.