diff --git a/tests/unit/workflow/test_workflow_subgraphs.py b/tests/unit/workflow/test_workflow_subgraphs.py index d1521dc211..72fae8ef47 100644 --- a/tests/unit/workflow/test_workflow_subgraphs.py +++ b/tests/unit/workflow/test_workflow_subgraphs.py @@ -14,7 +14,6 @@ # limitations under the License. # -import glob import math import os @@ -22,13 +21,11 @@ import pytest from pandas.api.types import is_integer_dtype -import nvtabular as nvt -from merlin.core import dispatch from merlin.core.dispatch import HAS_GPU from merlin.core.utils import set_dask_client from merlin.dag.ops.subgraph import Subgraph -from nvtabular import Dataset, Workflow, ops -from tests.conftest import get_cats +from nvtabular import Workflow, ops +from tests.conftest import assert_eq, get_cats @pytest.mark.parametrize("gpu_memory_frac", [0.01, 0.1]) @@ -98,23 +95,12 @@ def get_norms(tar): assert len(cats1.tolist()) == len(cats_expected1.tolist()) # Write to new "shuffled" and "processed" dataset - workflow.transform(dataset).to_parquet( - tmpdir, - out_files_per_proc=10, - shuffle=nvt.io.Shuffle.PER_PARTITION, - ) - - dataset_2 = Dataset(glob.glob(str(tmpdir) + "/*.parquet"), part_mem_fraction=gpu_memory_frac) - - df_pp = dispatch.concat(list(dataset_2.to_iter()), axis=0) + df_pp = workflow.transform(dataset).to_ddf().compute() if engine == "parquet": assert is_integer_dtype(df_pp["name-cat"].dtype) assert is_integer_dtype(df_pp["name-string"].dtype) - num_rows, num_row_groups, col_names = dispatch.read_parquet_metadata(str(tmpdir) + "/_metadata") - assert num_rows == len(df_pp) - subgraph_cat = workflow.get_subgraph("cat_graph") subgraph_cont = workflow.get_subgraph("cont_graph") assert isinstance(subgraph_cat, Workflow) @@ -127,3 +113,12 @@ def get_norms(tar): with pytest.raises(ValueError) as exc: workflow.get_subgraph("not_exist") assert "No subgraph named" in str(exc.value) + + sub_cat_df = subgraph_cat.transform(dataset).to_ddf().compute() + assert assert_eq(sub_cat_df.reset_index(drop=True), df_pp[cat_names].reset_index(drop=True)) + + cont_names = [name + concat_ops for name in cont_names] + sub_cont_df = subgraph_cont.transform(dataset).to_ddf().compute() + assert assert_eq( + sub_cont_df[cont_names].reset_index(drop=True), df_pp[cont_names].reset_index(drop=True) + )