diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index 1cfdbf34a6..556e2a005a 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -1861,12 +1861,24 @@ def _copy_storage(existing_stats, existing_path, new_path, copy): existing_fs = get_fs_token_paths(existing_path)[0] new_fs = get_fs_token_paths(new_path)[0] new_locations = {} + for column, existing_file in existing_stats.items(): new_file = existing_file.replace(str(existing_path), str(new_path)) if copy and new_file != existing_file: new_fs.makedirs(os.path.dirname(new_file), exist_ok=True) - with new_fs.open(new_file, "wb") as output: - output.write(existing_fs.open(existing_file, "rb").read()) + + # For some ops, the existing "file" is a directory containing `part.N.parquet` files. + # In that case, new_file is actually a directory and we will iterate through the "part" + # files and copy them individually + if os.path.isdir(existing_file): + new_fs.makedirs(new_file, exist_ok=True) + for existing_file_part in existing_fs.ls(existing_file): + new_file_part = os.path.join(new_file, os.path.basename(existing_file_part)) + with new_fs.open(new_file_part, "wb") as output: + output.write(existing_fs.open(existing_file_part, "rb").read()) + else: + with new_fs.open(new_file, "wb") as output: + output.write(existing_fs.open(existing_file, "rb").read()) new_locations[column] = new_file diff --git a/nvtabular/workflow/workflow.py b/nvtabular/workflow/workflow.py index 41f8b10477..909c6bad05 100755 --- a/nvtabular/workflow/workflow.py +++ b/nvtabular/workflow/workflow.py @@ -17,12 +17,13 @@ import inspect import json import logging +import os import sys import time import types import warnings from functools import singledispatchmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import cloudpickle import fsspec @@ -295,12 +296,12 @@ def _getmodules(cls, fs): return [mod for mod in result if mod.__name__ not in exclusions] - def save(self, path, modules_byvalue=None): + def save(self, path: Union[str, os.PathLike], modules_byvalue=None): """Save this workflow to disk Parameters ---------- - path: str + path: Union[str, os.PathLike] The path to save the workflow to modules_byvalue: A list of modules that should be serialized by value. This @@ -314,6 +315,8 @@ def save(self, path, modules_byvalue=None): # avoid a circular import getting the version from nvtabular import __version__ as nvt_version + path = str(path) + fs = fsspec.get_fs_token_paths(path)[0] fs.makedirs(path, exist_ok=True) @@ -385,12 +388,12 @@ def save(self, path, modules_byvalue=None): cloudpickle.unregister_pickle_by_value(sys.modules[m]) @classmethod - def load(cls, path, client=None) -> "Workflow": + def load(cls, path: Union[str, os.PathLike], client=None) -> "Workflow": """Load up a saved workflow object from disk Parameters ---------- - path: str + path: Union[str, os.PathLike] The path to load the workflow from client: distributed.Client, optional The Dask distributed client to use for multi-gpu processing and multi-node processing @@ -403,6 +406,8 @@ def load(cls, path, client=None) -> "Workflow": # avoid a circular import getting the version from nvtabular import __version__ as nvt_version + path = str(path) + fs = fsspec.get_fs_token_paths(path)[0] # check version information from the metadata blob, and warn if we have a mismatch diff --git a/tests/unit/workflow/test_workflow.py b/tests/unit/workflow/test_workflow.py index ff6b57a410..a5e2688ba7 100755 --- a/tests/unit/workflow/test_workflow.py +++ b/tests/unit/workflow/test_workflow.py @@ -671,6 +671,43 @@ def test_workflow_saved_schema(tmpdir): assert node.output_schema is not None +def test_stat_op_workflow_roundtrip(tmpdir): + """ + Categorify and TargetEncoding produce intermediate stats files that must be properly + saved and re-loaded. + """ + N = 100 + + df = Dataset( + make_df( + { + "a": np.random.randint(0, 100000, N), + "item_id": np.random.randint(0, 100, N), + "user_id": np.random.randint(0, 100, N), + "click": np.random.randint(0, 2, N), + } + ), + ) + + outputs = ["a"] >> nvt.ops.Categorify() + + continuous = ( + ["user_id", "item_id"] + >> nvt.ops.TargetEncoding(["click"], kfold=1, p_smooth=20) + >> nvt.ops.Normalize() + ) + outputs += continuous + wf = nvt.Workflow(outputs) + + wf.fit(df) + expected = wf.transform(df).compute() + wf.save(tmpdir) + + wf2 = nvt.Workflow.load(tmpdir) + transformed = wf2.transform(df).compute() + assert_eq(transformed, expected) + + def test_workflow_infer_modules_byvalue(tmp_path): module_fn = tmp_path / "not_a_real_module.py" sys.path.append(str(tmp_path))