diff --git a/tests/conftest.py b/tests/conftest.py index 3c3ae4373b..a5f5e1f6bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -279,7 +279,7 @@ def get_cats(workflow, col, stat_name="categories", cpu=False): # figure out the categorify node from the workflow graph cats = [ cg.op - for cg in iter_nodes([workflow.output_node]) + for cg in iter_nodes([workflow.output_node], flatten_subgraphs=True) if isinstance(cg.op, nvtabular.ops.Categorify) ] if len(cats) != 1: