diff --git a/cpp/nvtabular/inference/categorify.cc b/cpp/nvtabular/inference/categorify.cc index e9b50c0cdd..603a39e33b 100644 --- a/cpp/nvtabular/inference/categorify.cc +++ b/cpp/nvtabular/inference/categorify.cc @@ -337,12 +337,12 @@ namespace nvtabular // this operator currently only supports CPU arrays .def_property_readonly("supports", [](py::object self) { - py::object supports = py::module_::import("nvtabular").attr("graph").attr("base_operator").attr("Supports"); + py::object supports = py::module_::import("nvtabular").attr("graph").attr("operator").attr("Supports"); return supports.attr("CPU_DICT_ARRAY"); }) .def_property_readonly("supported_formats", [](py::object self) { - py::object supported = py::module_::import("nvtabular").attr("graph").attr("base_operator").attr("DataFormats"); + py::object supported = py::module_::import("nvtabular").attr("graph").attr("operator").attr("DataFormats"); return supported.attr("NUMPY_DICT_ARRAY"); }); } diff --git a/nvtabular/ops/operator.py b/nvtabular/ops/operator.py index 0757557b12..a0aa99ab6a 100644 --- a/nvtabular/ops/operator.py +++ b/nvtabular/ops/operator.py @@ -13,6 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from merlin.dag import BaseOperator, ColumnSelector # noqa pylint: disable=unused-import +from merlin.dag import ( # noqa pylint: disable=unused-import + BaseOperator, + ColumnSelector, + DataFormats, +) -Operator = BaseOperator + +# Avoid TENSOR_TABLE by default (for now) +class Operator(BaseOperator): + @property + def supported_formats(self): + return DataFormats.PANDAS_DATAFRAME | DataFormats.CUDF_DATAFRAME diff --git a/tests/unit/ops/test_categorify.py b/tests/unit/ops/test_categorify.py index 41a69ef346..28c98854a2 100644 --- a/tests/unit/ops/test_categorify.py +++ b/tests/unit/ops/test_categorify.py @@ -734,3 +734,8 @@ def test_categorify_inference(): output_tensors = inference_op.transform(cats.input_columns, input_tensors) for key in input_tensors: assert output_tensors[key].dtype == np.dtype("int64") + + # Check results are consistent with python code path + expect = workflow.transform(df) + got = pd.DataFrame(output_tensors) + assert_eq(expect, got)