Skip to content

Commit

Permalink
Added test_tokenize.py to test get_model_output from Model class.
Browse files Browse the repository at this point in the history
Handle some review comments.

Signed-off-by: Ahmed Umair <[email protected]>
  • Loading branch information
Umair Ahmed committed Sep 23, 2024
1 parent f81a61d commit e885474
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 6 deletions.
9 changes: 4 additions & 5 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import cudf
import cupy as cp
from crossfit.backend.cudf.series import (
Expand All @@ -24,7 +26,7 @@ class Model:
def __init__(self, path_or_name: str, max_mem_gb: int = 16, model_output_type: str = "numeric"):
self.path_or_name = path_or_name
self.max_mem_gb = max_mem_gb
if model_output_type == "numeric" or model_output_type == "string":
if model_output_type in ["numeric", "string"]:
self.model_output_type = model_output_type
else:
raise ValueError(
Expand Down Expand Up @@ -66,10 +68,7 @@ def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cu
)

if self.model_output_type == "string":
all_outputs = []
for output in all_outputs_ls:
for o in output:
all_outputs.append(o)
all_outputs = [o for output in all_outputs_ls for o in output]
out[pred_output_col] = cudf.Series(data=all_outputs, index=_index)
del all_outputs_ls
del loader
Expand Down
17 changes: 16 additions & 1 deletion examples/custom_ct2_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
from dataclasses import dataclass
from functools import lru_cache
Expand Down Expand Up @@ -141,7 +156,7 @@ def main():
model = ModelForSeq2SeqModel(Config)
pipe = op.Sequential(
op.Tokenizer(
model, cols=[args.input_column], tokenizer_type="sentencepiece", max_length=255
model, cols=[args.input_column], tokenizer_type="default", max_length=255
),
op.Predictor(model, sorted_data_loader=True, batch_size=args.batch_size),
repartition=args.partitions,
Expand Down
67 changes: 67 additions & 0 deletions tests/op/test_model_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from unittest.mock import patch

cp = pytest.importorskip("cupy")
cudf = pytest.importorskip("cudf")
dask_cudf = pytest.importorskip("dask_cudf")
dd = pytest.importorskip("dask.dataframe")
pd = pytest.importorskip("pandas")
transformers = pytest.importorskip("transformers")
torch = pytest.importorskip("torch")

import crossfit as cf # noqa: E402


cf_loader = pytest.importorskip("crossfit.backend.torch.loader")


@pytest.mark.parametrize("trust_remote_code", ["y"])
def test_model_output_int(trust_remote_code, model_name="ai4bharat/indictrans2-en-indic-1B"):
with patch("builtins.input", return_value=trust_remote_code):
tokens_data = cudf.DataFrame({"input_ids": [[11, 12, 13], [14, 15, 16], [17, 18, 19]]})
index = tokens_data.index.copy()
model = cf.HFModel(model_name)
data = [[4], [7], [10]]
all_outputs_ls = torch.tensor(data)
loader = cf_loader.SortedSeqLoader(
tokens_data,
model,
)
pred_output_col = "translation"
out = model.get_model_output(all_outputs_ls, index, loader, pred_output_col)
assert isinstance(out, cudf.DataFrame)
assert isinstance(out["translation"][0][0], int)


@pytest.mark.parametrize("trust_remote_code", ["y"])
def test_model_output_str(trust_remote_code, model_name="ai4bharat/indictrans2-en-indic-1B"):
with patch("builtins.input", return_value=trust_remote_code):
tokens_data = cudf.DataFrame(
{"input_ids": [[18264, 7728, 8], [123, 99, 2258], [3115, 125, 123]]}
)
index = tokens_data.index.copy()
model = cf.HFModel(model_name, model_output_type="string")
data = [["▁हमारे▁परीक्षण▁डेटा"], ["▁पर▁हमारे▁दो"], ["▁दूरी▁कार्यों▁की"]]
all_outputs_ls = data
loader = cf_loader.SortedSeqLoader(
tokens_data,
model,
)
pred_output_col = "translation"
out = model.get_model_output(all_outputs_ls, index, loader, pred_output_col)
assert isinstance(out, cudf.DataFrame)
assert isinstance(out["translation"][0][0], str)

0 comments on commit e885474

Please sign in to comment.