Skip to content

Commit

Permalink
add in subgraph transform to compare
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 committed Jun 9, 2023
1 parent 0d26d97 commit 96cd7c0
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions tests/unit/workflow/test_workflow_subgraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,18 @@
# limitations under the License.
#

import glob
import math
import os

import numpy as np
import pytest
from pandas.api.types import is_integer_dtype

import nvtabular as nvt
from merlin.core import dispatch
from merlin.core.dispatch import HAS_GPU
from merlin.core.utils import set_dask_client
from merlin.dag.ops.subgraph import Subgraph
from nvtabular import Dataset, Workflow, ops
from tests.conftest import get_cats
from nvtabular import Workflow, ops
from tests.conftest import assert_eq, get_cats


@pytest.mark.parametrize("gpu_memory_frac", [0.01, 0.1])
Expand Down Expand Up @@ -98,23 +95,12 @@ def get_norms(tar):
assert len(cats1.tolist()) == len(cats_expected1.tolist())

# Write to new "shuffled" and "processed" dataset
workflow.transform(dataset).to_parquet(
tmpdir,
out_files_per_proc=10,
shuffle=nvt.io.Shuffle.PER_PARTITION,
)

dataset_2 = Dataset(glob.glob(str(tmpdir) + "/*.parquet"), part_mem_fraction=gpu_memory_frac)

df_pp = dispatch.concat(list(dataset_2.to_iter()), axis=0)
df_pp = workflow.transform(dataset).to_ddf().compute()

if engine == "parquet":
assert is_integer_dtype(df_pp["name-cat"].dtype)
assert is_integer_dtype(df_pp["name-string"].dtype)

num_rows, num_row_groups, col_names = dispatch.read_parquet_metadata(str(tmpdir) + "/_metadata")
assert num_rows == len(df_pp)

subgraph_cat = workflow.get_subgraph("cat_graph")
subgraph_cont = workflow.get_subgraph("cont_graph")
assert isinstance(subgraph_cat, Workflow)
Expand All @@ -127,3 +113,12 @@ def get_norms(tar):
with pytest.raises(ValueError) as exc:
workflow.get_subgraph("not_exist")
assert "No subgraph named" in str(exc.value)

sub_cat_df = subgraph_cat.transform(dataset).to_ddf().compute()
assert assert_eq(sub_cat_df.reset_index(drop=True), df_pp[cat_names].reset_index(drop=True))

cont_names = [name + concat_ops for name in cont_names]
sub_cont_df = subgraph_cont.transform(dataset).to_ddf().compute()
assert assert_eq(
sub_cont_df[cont_names].reset_index(drop=True), df_pp[cont_names].reset_index(drop=True)
)

0 comments on commit 96cd7c0

Please sign in to comment.