From 5bc01039ef989f985a21d2e595ad8ad0c3eafcc1 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 5 Jun 2023 16:27:47 +0100 Subject: [PATCH 01/10] Use Distributed helper for client fixture in conftest.py (#1830) * Use Distributed helper for client fixture * reduce rtol in test_multihot_empty_rows --- tests/conftest.py | 7 ++++--- tests/unit/framework_utils/test_tf_layers.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2598ced6408..0ca0ad87489 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,10 +52,10 @@ def assert_eq(a, b, *args, **kwargs): import pytest from asvdb import ASVDb, BenchmarkInfo, utils -from dask.distributed import Client, LocalCluster from numba import cuda import nvtabular +from merlin.core.utils import Distributed from merlin.dag.node import iter_nodes REPO_ROOT = Path(__file__).parent.parent @@ -97,8 +97,9 @@ def assert_eq(a, b, *args, **kwargs): @pytest.fixture(scope="module") def client(): - cluster = LocalCluster(n_workers=2) - client = Client(cluster) + distributed = Distributed(n_workers=2) + cluster = distributed.cluster + client = distributed.client yield client client.close() cluster.close() diff --git a/tests/unit/framework_utils/test_tf_layers.py b/tests/unit/framework_utils/test_tf_layers.py index 106be0fa457..38e2778cab7 100644 --- a/tests/unit/framework_utils/test_tf_layers.py +++ b/tests/unit/framework_utils/test_tf_layers.py @@ -318,4 +318,4 @@ def test_multihot_empty_rows(): ) y_hat = model(x).numpy() - np.testing.assert_allclose(y_hat, multi_hot_embedding_rows, rtol=1e-06) + np.testing.assert_allclose(y_hat, multi_hot_embedding_rows, rtol=1e-05) From 669dc50db74b997e2465c58c1b965055048123ae Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 5 Jun 2023 19:17:13 +0100 Subject: [PATCH 02/10] Get visible devices from env var if set (#1831) --- ...3-Running-on-multiple-GPUs-or-on-CPU.ipynb | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/examples/03-Running-on-multiple-GPUs-or-on-CPU.ipynb b/examples/03-Running-on-multiple-GPUs-or-on-CPU.ipynb index aba2647567d..3c90574ff5f 100644 --- a/examples/03-Running-on-multiple-GPUs-or-on-CPU.ipynb +++ b/examples/03-Running-on-multiple-GPUs-or-on-CPU.ipynb @@ -27,6 +27,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "77464844", "metadata": {}, @@ -53,6 +54,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "1c5598ae", "metadata": {}, @@ -92,6 +94,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "63ac0cf2", "metadata": {}, @@ -100,6 +103,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "4def0005", "metadata": {}, @@ -123,6 +127,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "d7c3f9ea", "metadata": {}, @@ -148,6 +153,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "728c3009", "metadata": {}, @@ -176,11 +182,15 @@ "\n", "# Deploy a Single-Machine Multi-GPU Cluster\n", "protocol = \"tcp\" # \"tcp\" or \"ucx\"\n", + "\n", "if numba.cuda.is_available():\n", " NUM_GPUS = list(range(len(numba.cuda.gpus)))\n", "else:\n", " NUM_GPUS = []\n", - "visible_devices = \",\".join([str(n) for n in NUM_GPUS]) # Delect devices to place workers\n", + "try:\n", + " visible_devices = os.environ[\"CUDA_VISIBLE_DEVICES\"]\n", + "except KeyError:\n", + " visible_devices = \",\".join([str(n) for n in NUM_GPUS]) # Delect devices to place workers\n", "device_limit_frac = 0.7 # Spill GPU-Worker memory to host at this limit.\n", "device_pool_frac = 0.8\n", "part_mem_frac = 0.15\n", @@ -206,6 +216,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "d14dc098", "metadata": {}, @@ -242,6 +253,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "0576affe", "metadata": {}, @@ -589,6 +601,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "94ef0024", "metadata": {}, @@ -599,6 +612,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "768fc24e", "metadata": {}, @@ -622,6 +636,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "61785127", "metadata": {}, @@ -678,6 +693,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "01ea40bb", "metadata": {}, @@ -686,6 +702,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "987f3274", "metadata": {}, @@ -714,6 +731,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "b06c962e", "metadata": {}, @@ -745,6 +763,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "d28ae761", "metadata": {}, @@ -755,6 +774,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "4e07864d", "metadata": {}, @@ -763,6 +783,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "8f971a22", "metadata": {}, From 2ca6df0d95401348303e314f43102009eb2a1fe3 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 5 Jun 2023 20:06:13 +0100 Subject: [PATCH 03/10] Add ops to NVT import (#1834) --- nvtabular/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nvtabular/__init__.py b/nvtabular/__init__.py index 71d597b6cc2..1759cde1833 100644 --- a/nvtabular/__init__.py +++ b/nvtabular/__init__.py @@ -21,8 +21,7 @@ from merlin.core import dispatch, utils # noqa from merlin.dag import ColumnSelector from merlin.schema import ColumnSchema, Schema -from nvtabular import workflow # noqa -from nvtabular import _version +from nvtabular import _version, ops, workflow # noqa # suppress some warnings with cudf warning about column ordering with dlpack # and numba warning about deprecated environment variables From 33479810c53d111f1658653f1fcb531d58b987de Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 5 Jun 2023 20:55:30 +0100 Subject: [PATCH 04/10] Pass keyword argument for axis in dataframe any method (#1833) --- nvtabular/ops/categorify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index 2f4285738f8..7fbfa1de774 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -1025,7 +1025,7 @@ def _top_level_groupby(df, options: FitOptions = None, spill=True): del df_gb # Extract null groups into gb_null - isnull = gb.isnull().any(1) + isnull = gb.isnull().any(axis=1) gb_null = gb[~isnull] gb = gb[isnull] if not len(gb_null): From 69fbd57808530417121030ea8aff5b1bae654c68 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 5 Jun 2023 21:39:01 +0100 Subject: [PATCH 05/10] Use tmpdir for Categorify out_path in test_tf4rec (#1832) Co-authored-by: Adam Laiacano <108741458+nv-alaiacano@users.noreply.github.com> --- tests/unit/test_tf4rec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_tf4rec.py b/tests/unit/test_tf4rec.py index 46dd66054a5..737515e9286 100644 --- a/tests/unit/test_tf4rec.py +++ b/tests/unit/test_tf4rec.py @@ -14,7 +14,7 @@ NUM_ROWS = 10000 -def test_tf4rec(): +def test_tf4rec(tmpdir): inputs = { "user_session": np.random.randint(1, 10000, NUM_ROWS), "product_id": np.random.randint(1, 51996, NUM_ROWS), @@ -29,7 +29,7 @@ def test_tf4rec(): cat_feats = ( ["user_session", "product_id", "category_id"] - >> nvt.ops.Categorify() + >> nvt.ops.Categorify(out_path=str(tmpdir)) >> nvt.ops.LambdaOp(lambda col: col + 1) ) From f1165d8c023b4cf86a9fa176c51565496792d348 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 7 Jun 2023 13:55:19 +0100 Subject: [PATCH 06/10] Convert index to array so that assignment works later on (#1836) --- nvtabular/ops/categorify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index 7fbfa1de774..1cfdbf34a65 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -1693,7 +1693,7 @@ def _encode( expr = df[selection_l.names[0]].isna() for _name in selection_l.names[1:]: expr = expr & df[_name].isna() - nulls = df[expr].index + nulls = df[expr].index.values if use_collection or not search_sorted: if list_col: From 66c6e3a1b240b6b3addf96af5021a1c4f57e9d5b Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 7 Jun 2023 14:23:28 +0100 Subject: [PATCH 07/10] Remove n_workers=2 from Distributed in client fixture (#1835) --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0ca0ad87489..3c3ae4373bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -97,7 +97,7 @@ def assert_eq(a, b, *args, **kwargs): @pytest.fixture(scope="module") def client(): - distributed = Distributed(n_workers=2) + distributed = Distributed() cluster = distributed.cluster client = distributed.client yield client From 4b7957a9566dec057251885ef1a5728e54ff7f07 Mon Sep 17 00:00:00 2001 From: Adam Laiacano <108741458+nv-alaiacano@users.noreply.github.com> Date: Wed, 7 Jun 2023 12:34:15 -0400 Subject: [PATCH 08/10] handle copying of partitioned stat files when saving workflow (#1838) * handle copying of partitioned stat files when saving workflow * undo changes to tox --- nvtabular/ops/categorify.py | 16 ++++++++++-- nvtabular/workflow/workflow.py | 15 +++++++---- tests/unit/workflow/test_workflow.py | 37 ++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index 1cfdbf34a65..556e2a005a9 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -1861,12 +1861,24 @@ def _copy_storage(existing_stats, existing_path, new_path, copy): existing_fs = get_fs_token_paths(existing_path)[0] new_fs = get_fs_token_paths(new_path)[0] new_locations = {} + for column, existing_file in existing_stats.items(): new_file = existing_file.replace(str(existing_path), str(new_path)) if copy and new_file != existing_file: new_fs.makedirs(os.path.dirname(new_file), exist_ok=True) - with new_fs.open(new_file, "wb") as output: - output.write(existing_fs.open(existing_file, "rb").read()) + + # For some ops, the existing "file" is a directory containing `part.N.parquet` files. + # In that case, new_file is actually a directory and we will iterate through the "part" + # files and copy them individually + if os.path.isdir(existing_file): + new_fs.makedirs(new_file, exist_ok=True) + for existing_file_part in existing_fs.ls(existing_file): + new_file_part = os.path.join(new_file, os.path.basename(existing_file_part)) + with new_fs.open(new_file_part, "wb") as output: + output.write(existing_fs.open(existing_file_part, "rb").read()) + else: + with new_fs.open(new_file, "wb") as output: + output.write(existing_fs.open(existing_file, "rb").read()) new_locations[column] = new_file diff --git a/nvtabular/workflow/workflow.py b/nvtabular/workflow/workflow.py index 41f8b104778..909c6bad05c 100755 --- a/nvtabular/workflow/workflow.py +++ b/nvtabular/workflow/workflow.py @@ -17,12 +17,13 @@ import inspect import json import logging +import os import sys import time import types import warnings from functools import singledispatchmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import cloudpickle import fsspec @@ -295,12 +296,12 @@ def _getmodules(cls, fs): return [mod for mod in result if mod.__name__ not in exclusions] - def save(self, path, modules_byvalue=None): + def save(self, path: Union[str, os.PathLike], modules_byvalue=None): """Save this workflow to disk Parameters ---------- - path: str + path: Union[str, os.PathLike] The path to save the workflow to modules_byvalue: A list of modules that should be serialized by value. This @@ -314,6 +315,8 @@ def save(self, path, modules_byvalue=None): # avoid a circular import getting the version from nvtabular import __version__ as nvt_version + path = str(path) + fs = fsspec.get_fs_token_paths(path)[0] fs.makedirs(path, exist_ok=True) @@ -385,12 +388,12 @@ def save(self, path, modules_byvalue=None): cloudpickle.unregister_pickle_by_value(sys.modules[m]) @classmethod - def load(cls, path, client=None) -> "Workflow": + def load(cls, path: Union[str, os.PathLike], client=None) -> "Workflow": """Load up a saved workflow object from disk Parameters ---------- - path: str + path: Union[str, os.PathLike] The path to load the workflow from client: distributed.Client, optional The Dask distributed client to use for multi-gpu processing and multi-node processing @@ -403,6 +406,8 @@ def load(cls, path, client=None) -> "Workflow": # avoid a circular import getting the version from nvtabular import __version__ as nvt_version + path = str(path) + fs = fsspec.get_fs_token_paths(path)[0] # check version information from the metadata blob, and warn if we have a mismatch diff --git a/tests/unit/workflow/test_workflow.py b/tests/unit/workflow/test_workflow.py index ff6b57a4103..a5e2688ba7e 100755 --- a/tests/unit/workflow/test_workflow.py +++ b/tests/unit/workflow/test_workflow.py @@ -671,6 +671,43 @@ def test_workflow_saved_schema(tmpdir): assert node.output_schema is not None +def test_stat_op_workflow_roundtrip(tmpdir): + """ + Categorify and TargetEncoding produce intermediate stats files that must be properly + saved and re-loaded. + """ + N = 100 + + df = Dataset( + make_df( + { + "a": np.random.randint(0, 100000, N), + "item_id": np.random.randint(0, 100, N), + "user_id": np.random.randint(0, 100, N), + "click": np.random.randint(0, 2, N), + } + ), + ) + + outputs = ["a"] >> nvt.ops.Categorify() + + continuous = ( + ["user_id", "item_id"] + >> nvt.ops.TargetEncoding(["click"], kfold=1, p_smooth=20) + >> nvt.ops.Normalize() + ) + outputs += continuous + wf = nvt.Workflow(outputs) + + wf.fit(df) + expected = wf.transform(df).compute() + wf.save(tmpdir) + + wf2 = nvt.Workflow.load(tmpdir) + transformed = wf2.transform(df).compute() + assert_eq(transformed, expected) + + def test_workflow_infer_modules_byvalue(tmp_path): module_fn = tmp_path / "not_a_real_module.py" sys.path.append(str(tmp_path)) From 25151f73f1f388f596a086dc9e271008486c0c88 Mon Sep 17 00:00:00 2001 From: Julio Perez <37191411+jperez999@users.noreply.github.com> Date: Thu, 8 Jun 2023 15:23:22 -0400 Subject: [PATCH 09/10] How to for embedding op with categorify from scratch (#1827) * embedding op test from start * setup asserts to verify logic * use make series to handle cpu-gpu env --------- Co-authored-by: Karl Higley Co-authored-by: rnyak <16246900+rnyak@users.noreply.github.com> --- tests/unit/workflow/test_workflow.py | 58 +++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/tests/unit/workflow/test_workflow.py b/tests/unit/workflow/test_workflow.py index a5e2688ba7e..77009da1354 100755 --- a/tests/unit/workflow/test_workflow.py +++ b/tests/unit/workflow/test_workflow.py @@ -27,9 +27,11 @@ import nvtabular as nvt from merlin.core import dispatch from merlin.core.compat import cudf, dask_cudf -from merlin.core.dispatch import HAS_GPU, make_df +from merlin.core.dispatch import HAS_GPU, create_multihot_col, make_df, make_series from merlin.core.utils import set_dask_client from merlin.dag import ColumnSelector, postorder_iter_nodes +from merlin.dataloader.loader_base import LoaderBase as Loader +from merlin.dataloader.ops.embeddings import EmbeddingOperator from merlin.schema import Tags from nvtabular import Dataset, Workflow, ops from tests.conftest import assert_eq, get_cats, mycols_csv @@ -774,3 +776,57 @@ def test_workflow_auto_infer_modules_byvalue(tmp_path): os.unlink(str(tmp_path / "not_a_real_module.py")) Workflow.load(str(tmp_path / "identity-workflow")) + + +@pytest.mark.parametrize("cpu", [None, "cpu"] if HAS_GPU else ["cpu"]) +def test_embedding_cat_export_import(tmpdir, cpu): + string_ids = ["alpha", "bravo", "charlie", "delta", "foxtrot"] + training_data = make_df( + { + "string_id": string_ids, + } + ) + training_data["embeddings"] = create_multihot_col( + [0, 5, 10, 15, 20, 25], make_series(np.random.rand(25)) + ) + + cat_op = nvt.ops.Categorify() + + # first workflow that categorifies all data + graph1 = ["string_id"] >> cat_op + emb_res = Workflow(graph1 + ["embeddings"]).fit_transform( + Dataset(training_data, cpu=(cpu is not None)) + ) + npy_path = str(tmpdir / "embeddings.npy") + emb_res.to_npy(npy_path) + + embeddings = np.load(npy_path) + # second workflow that categorifies the embedding table data + df = make_df({"string_id": np.random.choice(string_ids, 30)}) + graph2 = ["string_id"] >> cat_op + train_res = Workflow(graph2).transform(Dataset(df, cpu=(cpu is not None))) + + data_loader = Loader( + train_res, + batch_size=1, + transforms=[ + EmbeddingOperator( + embeddings[:, 1:], + id_lookup_table=embeddings[:, 0].astype(int), + lookup_key="string_id", + ) + ], + shuffle=False, + device=cpu, + ) + origin_df = train_res.to_ddf().merge(emb_res.to_ddf(), on="string_id", how="left").compute() + for idx, batch in enumerate(data_loader): + batch + b_df = batch[0].to_df() + org_df = origin_df.iloc[idx] + if not cpu: + assert (b_df["string_id"].to_numpy() == org_df["string_id"].to_numpy()).all() + assert (b_df["embeddings"].list.leaves == org_df["embeddings"].list.leaves).all() + else: + assert (b_df["string_id"].values == org_df["string_id"]).all() + assert b_df["embeddings"].values[0] == org_df["embeddings"].tolist() From d26e776a199d472f703bde5e7ea3775739ec0ab3 Mon Sep 17 00:00:00 2001 From: Julio Perez <37191411+jperez999@users.noreply.github.com> Date: Fri, 9 Jun 2023 14:37:00 -0400 Subject: [PATCH 10/10] working subgraphs in workflow (#1842) * working subgraphs in workflow * add in subgraph transform to compare * change to subworkflows in workflow instead of subgraph4 * remove reset_index calls in asserts of test --- nvtabular/workflow/workflow.py | 8 ++ .../unit/workflow/test_workflow_subgraphs.py | 99 +++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 tests/unit/workflow/test_workflow_subgraphs.py diff --git a/nvtabular/workflow/workflow.py b/nvtabular/workflow/workflow.py index 909c6bad05c..9c5e944a37c 100755 --- a/nvtabular/workflow/workflow.py +++ b/nvtabular/workflow/workflow.py @@ -142,6 +142,10 @@ def fit_schema(self, input_schema: Schema): self.graph.construct_schema(input_schema) return self + @property + def subworkflows(self): + return list(self.graph.subgraphs.keys()) + @property def input_dtypes(self): return self.graph.input_dtypes @@ -165,6 +169,10 @@ def output_node(self): def _input_columns(self): return self.graph._input_columns() + def get_subworkflow(self, subgraph_name): + subgraph = self.graph.subgraph(subgraph_name) + return Workflow(subgraph.output_node) + def remove_inputs(self, input_cols) -> "Workflow": """Removes input columns from the workflow. diff --git a/tests/unit/workflow/test_workflow_subgraphs.py b/tests/unit/workflow/test_workflow_subgraphs.py new file mode 100644 index 00000000000..047edac2462 --- /dev/null +++ b/tests/unit/workflow/test_workflow_subgraphs.py @@ -0,0 +1,99 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import numpy as np +import pytest +from pandas.api.types import is_integer_dtype + +from merlin.core.utils import set_dask_client +from merlin.dag.ops.subgraph import Subgraph +from nvtabular import Workflow, ops +from tests.conftest import assert_eq + + +@pytest.mark.parametrize("gpu_memory_frac", [0.01, 0.1]) +@pytest.mark.parametrize("engine", ["parquet", "csv", "csv-no-header"]) +@pytest.mark.parametrize("dump", [True, False]) +@pytest.mark.parametrize("replace", [True, False]) +def test_workflow_subgraphs(tmpdir, client, df, dataset, gpu_memory_frac, engine, dump, replace): + cat_names = ["name-cat", "name-string"] if engine == "parquet" else ["name-string"] + cont_names = ["x", "y", "id"] + label_name = ["label"] + + norms = ops.Normalize() + cat_features = cat_names >> ops.Categorify() + if replace: + cont_features = cont_names >> ops.FillMissing() >> ops.LogOp >> norms + else: + fillmissing_logop = ( + cont_names + >> ops.FillMissing() + >> ops.LogOp + >> ops.Rename(postfix="_FillMissing_1_LogOp_1") + ) + cont_features = cont_names + fillmissing_logop >> norms + + set_dask_client(client=client) + wkflow_ops = Subgraph("cat_graph", cat_features) + Subgraph("cont_graph", cont_features) + workflow = Workflow(wkflow_ops + label_name) + + workflow.fit(dataset) + + if dump: + workflow_dir = os.path.join(tmpdir, "workflow") + workflow.save(workflow_dir) + workflow = None + + workflow = Workflow.load(workflow_dir) + + def get_norms(tar): + ser_median = tar.dropna().quantile(0.5, interpolation="linear") + gdf = tar.fillna(ser_median) + gdf = np.log(gdf + 1) + return gdf + + concat_ops = "_FillMissing_1_LogOp_1" + if replace: + concat_ops = "" + + df_pp = workflow.transform(dataset).to_ddf().compute() + + if engine == "parquet": + assert is_integer_dtype(df_pp["name-cat"].dtype) + assert is_integer_dtype(df_pp["name-string"].dtype) + + subgraph_cat = workflow.get_subworkflow("cat_graph") + subgraph_cont = workflow.get_subworkflow("cont_graph") + assert isinstance(subgraph_cat, Workflow) + assert isinstance(subgraph_cont, Workflow) + # will not be the same nodes of saved out and loaded back + if not dump: + assert subgraph_cat.output_node == cat_features + assert subgraph_cont.output_node == cont_features + # check failure path works as expected + with pytest.raises(ValueError) as exc: + workflow.get_subworkflow("not_exist") + assert "No subgraph named" in str(exc.value) + + # test transform results from subgraph + sub_cat_df = subgraph_cat.transform(dataset).to_ddf().compute() + assert_eq(sub_cat_df, df_pp[cat_names]) + + cont_names = [name + concat_ops for name in cont_names] + sub_cont_df = subgraph_cont.transform(dataset).to_ddf().compute() + assert_eq(sub_cont_df[cont_names], df_pp[cont_names])