Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle copying of partitioned stat files when saving workflow #1838

Merged
merged 3 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions nvtabular/ops/categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,12 +1861,24 @@ def _copy_storage(existing_stats, existing_path, new_path, copy):
existing_fs = get_fs_token_paths(existing_path)[0]
new_fs = get_fs_token_paths(new_path)[0]
new_locations = {}

for column, existing_file in existing_stats.items():
new_file = existing_file.replace(str(existing_path), str(new_path))
if copy and new_file != existing_file:
new_fs.makedirs(os.path.dirname(new_file), exist_ok=True)
with new_fs.open(new_file, "wb") as output:
output.write(existing_fs.open(existing_file, "rb").read())

# For some ops, the existing "file" is a directory containing `part.N.parquet` files.
# In that case, new_file is actually a directory and we will iterate through the "part"
# files and copy them individually
if os.path.isdir(existing_file):
new_fs.makedirs(new_file, exist_ok=True)
for existing_file_part in existing_fs.ls(existing_file):
new_file_part = os.path.join(new_file, os.path.basename(existing_file_part))
with new_fs.open(new_file_part, "wb") as output:
output.write(existing_fs.open(existing_file_part, "rb").read())
else:
with new_fs.open(new_file, "wb") as output:
output.write(existing_fs.open(existing_file, "rb").read())

new_locations[column] = new_file

Expand Down
15 changes: 10 additions & 5 deletions nvtabular/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import inspect
import json
import logging
import os
import sys
import time
import types
import warnings
from functools import singledispatchmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

import cloudpickle
import fsspec
Expand Down Expand Up @@ -295,12 +296,12 @@ def _getmodules(cls, fs):

return [mod for mod in result if mod.__name__ not in exclusions]

def save(self, path, modules_byvalue=None):
def save(self, path: Union[str, os.PathLike], modules_byvalue=None):
"""Save this workflow to disk

Parameters
----------
path: str
path: Union[str, os.PathLike]
The path to save the workflow to
modules_byvalue:
A list of modules that should be serialized by value. This
Expand All @@ -314,6 +315,8 @@ def save(self, path, modules_byvalue=None):
# avoid a circular import getting the version
from nvtabular import __version__ as nvt_version

path = str(path)

fs = fsspec.get_fs_token_paths(path)[0]

fs.makedirs(path, exist_ok=True)
Expand Down Expand Up @@ -385,12 +388,12 @@ def save(self, path, modules_byvalue=None):
cloudpickle.unregister_pickle_by_value(sys.modules[m])

@classmethod
def load(cls, path, client=None) -> "Workflow":
def load(cls, path: Union[str, os.PathLike], client=None) -> "Workflow":
"""Load up a saved workflow object from disk

Parameters
----------
path: str
path: Union[str, os.PathLike]
The path to load the workflow from
client: distributed.Client, optional
The Dask distributed client to use for multi-gpu processing and multi-node processing
Expand All @@ -403,6 +406,8 @@ def load(cls, path, client=None) -> "Workflow":
# avoid a circular import getting the version
from nvtabular import __version__ as nvt_version

path = str(path)

fs = fsspec.get_fs_token_paths(path)[0]

# check version information from the metadata blob, and warn if we have a mismatch
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,43 @@ def test_workflow_saved_schema(tmpdir):
assert node.output_schema is not None


def test_stat_op_workflow_roundtrip(tmpdir):
"""
Categorify and TargetEncoding produce intermediate stats files that must be properly
saved and re-loaded.
"""
N = 100

df = Dataset(
make_df(
{
"a": np.random.randint(0, 100000, N),
"item_id": np.random.randint(0, 100, N),
"user_id": np.random.randint(0, 100, N),
"click": np.random.randint(0, 2, N),
}
),
)

outputs = ["a"] >> nvt.ops.Categorify()

continuous = (
["user_id", "item_id"]
>> nvt.ops.TargetEncoding(["click"], kfold=1, p_smooth=20)
>> nvt.ops.Normalize()
)
outputs += continuous
wf = nvt.Workflow(outputs)

wf.fit(df)
expected = wf.transform(df).compute()
wf.save(tmpdir)

wf2 = nvt.Workflow.load(tmpdir)
transformed = wf2.transform(df).compute()
assert_eq(transformed, expected)


def test_workflow_infer_modules_byvalue(tmp_path):
module_fn = tmp_path / "not_a_real_module.py"
sys.path.append(str(tmp_path))
Expand Down
5 changes: 2 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ deps =
pytest
pytest-cov
commands =
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/dataloader.git
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/models.git
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git@{posargs:main}
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/dataloader.git@{posargs:main}
jperez999 marked this conversation as resolved.
Show resolved Hide resolved
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/models.git@{posargs:main}
python -m pytest --cov-report term --cov merlin -rxs tests/unit

[testenv:test-merlin]
Expand Down
Loading