From 7666f99a484ef5aa1b04000459e83971c0b7943b Mon Sep 17 00:00:00 2001 From: Wenjun Si Date: Tue, 26 Apr 2022 16:47:16 +0800 Subject: [PATCH] Fix apply --- mars/dataframe/base/apply.py | 6 ++++-- mars/dataframe/base/tests/test_base.py | 15 ++++++++++++++- mars/dataframe/groupby/aggregation.py | 4 +++- mars/dataframe/reduction/aggregation.py | 1 + mars/dataframe/reduction/core.py | 4 +++- mars/dataframe/reduction/custom_reduction.py | 9 +-------- .../reduction/tests/test_reduction_execution.py | 6 ++++++ 7 files changed, 32 insertions(+), 13 deletions(-) diff --git a/mars/dataframe/base/apply.py b/mars/dataframe/base/apply.py index 0c0a00ed25..c4c675369f 100644 --- a/mars/dataframe/base/apply.py +++ b/mars/dataframe/base/apply.py @@ -295,7 +295,7 @@ def _infer_df_func_returns(self, df, dtypes, dtype=None, name=None, index=None): if self.output_types is not None and ( dtypes is not None or dtype is not None ): - ret_dtypes = dtypes if dtypes is not None else (dtype, name) + ret_dtypes = dtypes if dtypes is not None else (name, dtype) ret_index_value = parse_index(index) if index is not None else None self._elementwise = False return ret_dtypes, ret_index_value @@ -473,7 +473,9 @@ def __call__(self, df_or_series, dtypes=None, dtype=None, name=None, index=None) self._axis = validate_axis(axis, df_or_series) if df_or_series.op.output_types[0] == OutputType.dataframe: - return self._call_dataframe(df_or_series, dtypes=dtypes, index=index) + return self._call_dataframe( + df_or_series, dtypes=dtypes, dtype=dtype, name=name, index=index + ) else: return self._call_series( df_or_series, dtypes=dtypes, dtype=dtype, name=name, index=index diff --git a/mars/dataframe/base/tests/test_base.py b/mars/dataframe/base/tests/test_base.py index 4de93d5c97..faadf2a0a7 100644 --- a/mars/dataframe/base/tests/test_base.py +++ b/mars/dataframe/base/tests/test_base.py @@ -217,7 +217,7 @@ def test_rechunk(): assert series2.nsplits == series.nsplits -def test_data_frame_apply(): +def test_dataframe_apply(): cols = [chr(ord("A") + i) for i in range(10)] df_raw = pd.DataFrame(dict((c, [i**2 for i in range(20)]) for c in cols)) @@ -231,6 +231,10 @@ def df_func_with_err(v): assert len(v) > 2 return v.sort_values() + def df_series_func_with_err(v): + assert len(v) > 2 + return 0 + with pytest.raises(TypeError): df.apply(df_func_with_err) @@ -240,6 +244,15 @@ def df_func_with_err(v): assert r.op.output_types[0] == OutputType.dataframe assert r.op.elementwise is False + r = df.apply( + df_series_func_with_err, output_type="series", dtype=object, name="output" + ) + assert r.dtype == np.dtype("O") + assert r.shape == (df.shape[-1],) + assert r.op._op_type_ == opcodes.APPLY + assert r.op.output_types[0] == OutputType.series + assert r.op.elementwise is False + r = df.apply("ffill") assert r.op._op_type_ == opcodes.FILL_NA diff --git a/mars/dataframe/groupby/aggregation.py b/mars/dataframe/groupby/aggregation.py index 7d593814c5..7039e517b8 100644 --- a/mars/dataframe/groupby/aggregation.py +++ b/mars/dataframe/groupby/aggregation.py @@ -94,6 +94,7 @@ def get(self): "skew": lambda x, bias=False: x.skew(bias=bias), "kurt": lambda x, bias=False: x.kurt(bias=bias), "kurtosis": lambda x, bias=False: x.kurtosis(bias=bias), + "nunique": lambda x: x.nunique(), } _series_col_name = "col_name" @@ -720,7 +721,8 @@ def _do_custom_agg(op, custom_reduction, *input_objs): result = (result,) if out.ndim == 2: - result = tuple(r.to_frame().T for r in result) + if result[0].ndim == 1: + result = tuple(r.to_frame().T for r in result) if op.stage == OperandStage.agg: result = tuple(r.astype(out.dtypes) for r in result) else: diff --git a/mars/dataframe/reduction/aggregation.py b/mars/dataframe/reduction/aggregation.py index 6945ecec9a..4248dcdcd9 100644 --- a/mars/dataframe/reduction/aggregation.py +++ b/mars/dataframe/reduction/aggregation.py @@ -78,6 +78,7 @@ def where_function(cond, var1, var2): "skew": lambda x, skipna=True, bias=False: x.skew(skipna=skipna, bias=bias), "kurt": lambda x, skipna=True, bias=False: x.kurt(skipna=skipna, bias=bias), "kurtosis": lambda x, skipna=True, bias=False: x.kurtosis(skipna=skipna, bias=bias), + "nunique": lambda x: x.nunique(), } diff --git a/mars/dataframe/reduction/core.py b/mars/dataframe/reduction/core.py index 14d4050038..5ba42af40f 100644 --- a/mars/dataframe/reduction/core.py +++ b/mars/dataframe/reduction/core.py @@ -972,13 +972,15 @@ def _compile_function(self, func, func_name=None, ndim=1) -> ReductionSteps: else: map_func_name, agg_func_name = step_func_name, step_func_name + op_custom_reduction = getattr(t.op, "custom_reduction", None) + # build agg description agg_funcs.append( ReductionAggStep( agg_input_key, map_func_name, agg_func_name, - custom_reduction, + op_custom_reduction or custom_reduction, t.key, output_limit, t.op.get_reduction_args(axis=self._axis), diff --git a/mars/dataframe/reduction/custom_reduction.py b/mars/dataframe/reduction/custom_reduction.py index 6995e2c51c..d80554e067 100644 --- a/mars/dataframe/reduction/custom_reduction.py +++ b/mars/dataframe/reduction/custom_reduction.py @@ -23,14 +23,7 @@ class DataFrameCustomReduction(DataFrameReductionOperand, DataFrameReductionMixi _op_type_ = OperandDef.CUSTOM_REDUCTION _func_name = "custom_reduction" - _custom_reduction = AnyField("custom_reduction") - - def __init__(self, custom_reduction=None, **kw): - super().__init__(_custom_reduction=custom_reduction, **kw) - - @property - def custom_reduction(self): - return self._custom_reduction + custom_reduction = AnyField("custom_reduction") @property def is_atomic(self): diff --git a/mars/dataframe/reduction/tests/test_reduction_execution.py b/mars/dataframe/reduction/tests/test_reduction_execution.py index 4d4d28cad9..b976475dad 100644 --- a/mars/dataframe/reduction/tests/test_reduction_execution.py +++ b/mars/dataframe/reduction/tests/test_reduction_execution.py @@ -671,6 +671,12 @@ def test_nunique(setup, check_ref_counts): expected = data1.nunique(axis=1) pd.testing.assert_series_equal(result, expected) + # test with agg func + df = md.DataFrame(data1, chunk_size=3) + result = df.agg("nunique").execute().fetch() + expected = data1.agg("nunique") + pd.testing.assert_series_equal(result, expected) + @pytest.mark.skipif(pa is None, reason="pyarrow not installed") def test_use_arrow_dtype_n_unique(setup, check_ref_counts):