From d26e776a199d472f703bde5e7ea3775739ec0ab3 Mon Sep 17 00:00:00 2001 From: Julio Perez <37191411+jperez999@users.noreply.github.com> Date: Fri, 9 Jun 2023 14:37:00 -0400 Subject: [PATCH] working subgraphs in workflow (#1842) * working subgraphs in workflow * add in subgraph transform to compare * change to subworkflows in workflow instead of subgraph4 * remove reset_index calls in asserts of test --- nvtabular/workflow/workflow.py | 8 ++ .../unit/workflow/test_workflow_subgraphs.py | 99 +++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 tests/unit/workflow/test_workflow_subgraphs.py diff --git a/nvtabular/workflow/workflow.py b/nvtabular/workflow/workflow.py index 909c6bad05..9c5e944a37 100755 --- a/nvtabular/workflow/workflow.py +++ b/nvtabular/workflow/workflow.py @@ -142,6 +142,10 @@ def fit_schema(self, input_schema: Schema): self.graph.construct_schema(input_schema) return self + @property + def subworkflows(self): + return list(self.graph.subgraphs.keys()) + @property def input_dtypes(self): return self.graph.input_dtypes @@ -165,6 +169,10 @@ def output_node(self): def _input_columns(self): return self.graph._input_columns() + def get_subworkflow(self, subgraph_name): + subgraph = self.graph.subgraph(subgraph_name) + return Workflow(subgraph.output_node) + def remove_inputs(self, input_cols) -> "Workflow": """Removes input columns from the workflow. diff --git a/tests/unit/workflow/test_workflow_subgraphs.py b/tests/unit/workflow/test_workflow_subgraphs.py new file mode 100644 index 0000000000..047edac246 --- /dev/null +++ b/tests/unit/workflow/test_workflow_subgraphs.py @@ -0,0 +1,99 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import numpy as np +import pytest +from pandas.api.types import is_integer_dtype + +from merlin.core.utils import set_dask_client +from merlin.dag.ops.subgraph import Subgraph +from nvtabular import Workflow, ops +from tests.conftest import assert_eq + + +@pytest.mark.parametrize("gpu_memory_frac", [0.01, 0.1]) +@pytest.mark.parametrize("engine", ["parquet", "csv", "csv-no-header"]) +@pytest.mark.parametrize("dump", [True, False]) +@pytest.mark.parametrize("replace", [True, False]) +def test_workflow_subgraphs(tmpdir, client, df, dataset, gpu_memory_frac, engine, dump, replace): + cat_names = ["name-cat", "name-string"] if engine == "parquet" else ["name-string"] + cont_names = ["x", "y", "id"] + label_name = ["label"] + + norms = ops.Normalize() + cat_features = cat_names >> ops.Categorify() + if replace: + cont_features = cont_names >> ops.FillMissing() >> ops.LogOp >> norms + else: + fillmissing_logop = ( + cont_names + >> ops.FillMissing() + >> ops.LogOp + >> ops.Rename(postfix="_FillMissing_1_LogOp_1") + ) + cont_features = cont_names + fillmissing_logop >> norms + + set_dask_client(client=client) + wkflow_ops = Subgraph("cat_graph", cat_features) + Subgraph("cont_graph", cont_features) + workflow = Workflow(wkflow_ops + label_name) + + workflow.fit(dataset) + + if dump: + workflow_dir = os.path.join(tmpdir, "workflow") + workflow.save(workflow_dir) + workflow = None + + workflow = Workflow.load(workflow_dir) + + def get_norms(tar): + ser_median = tar.dropna().quantile(0.5, interpolation="linear") + gdf = tar.fillna(ser_median) + gdf = np.log(gdf + 1) + return gdf + + concat_ops = "_FillMissing_1_LogOp_1" + if replace: + concat_ops = "" + + df_pp = workflow.transform(dataset).to_ddf().compute() + + if engine == "parquet": + assert is_integer_dtype(df_pp["name-cat"].dtype) + assert is_integer_dtype(df_pp["name-string"].dtype) + + subgraph_cat = workflow.get_subworkflow("cat_graph") + subgraph_cont = workflow.get_subworkflow("cont_graph") + assert isinstance(subgraph_cat, Workflow) + assert isinstance(subgraph_cont, Workflow) + # will not be the same nodes of saved out and loaded back + if not dump: + assert subgraph_cat.output_node == cat_features + assert subgraph_cont.output_node == cont_features + # check failure path works as expected + with pytest.raises(ValueError) as exc: + workflow.get_subworkflow("not_exist") + assert "No subgraph named" in str(exc.value) + + # test transform results from subgraph + sub_cat_df = subgraph_cat.transform(dataset).to_ddf().compute() + assert_eq(sub_cat_df, df_pp[cat_names]) + + cont_names = [name + concat_ops for name in cont_names] + sub_cont_df = subgraph_cont.transform(dataset).to_ddf().compute() + assert_eq(sub_cont_df[cont_names], df_pp[cont_names])