diff --git a/tests/unit/test_tf4rec.py b/tests/unit/test_tf4rec.py index 46dd66054a..737515e928 100644 --- a/tests/unit/test_tf4rec.py +++ b/tests/unit/test_tf4rec.py @@ -14,7 +14,7 @@ NUM_ROWS = 10000 -def test_tf4rec(): +def test_tf4rec(tmpdir): inputs = { "user_session": np.random.randint(1, 10000, NUM_ROWS), "product_id": np.random.randint(1, 51996, NUM_ROWS), @@ -29,7 +29,7 @@ def test_tf4rec(): cat_feats = ( ["user_session", "product_id", "category_id"] - >> nvt.ops.Categorify() + >> nvt.ops.Categorify(out_path=str(tmpdir)) >> nvt.ops.LambdaOp(lambda col: col + 1) )