From a1702488edc73e4b0fecab3e9619b7f16b64e874 Mon Sep 17 00:00:00 2001 From: Umair Ahmed Date: Fri, 27 Sep 2024 16:53:15 +0530 Subject: [PATCH] Resolved linting issue. Signed-off-by: Ahmed Umair --- crossfit/backend/torch/model.py | 4 +++- crossfit/backend/torch/op/base.py | 1 + examples/custom_ct2_model.py | 9 +++++---- tests/op/test_model_function.py | 4 ++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/crossfit/backend/torch/model.py b/crossfit/backend/torch/model.py index f9cc093..868f09b 100644 --- a/crossfit/backend/torch/model.py +++ b/crossfit/backend/torch/model.py @@ -15,6 +15,7 @@ import cudf import cupy as cp + from crossfit.backend.cudf.series import ( create_list_series_from_1d_or_2d_ar, create_nested_list_series_from_3d_ar, @@ -60,10 +61,11 @@ def max_seq_length(self) -> int: raise NotImplementedError() def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame: + # importing here to avoid cyclic import error from crossfit.backend.torch.loader import SortedSeqLoader out = cudf.DataFrame(index=index) - _index = loader.sort_column(index.values) if type(loader) == SortedSeqLoader else index + _index = loader.sort_column(index.values) if type(loader) is SortedSeqLoader else index if self.model_output_type == "string": all_outputs = [o for output in all_outputs_ls for o in output] diff --git a/crossfit/backend/torch/op/base.py b/crossfit/backend/torch/op/base.py index 0038ce1..960de95 100644 --- a/crossfit/backend/torch/op/base.py +++ b/crossfit/backend/torch/op/base.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional + import torch from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader diff --git a/examples/custom_ct2_model.py b/examples/custom_ct2_model.py index 37770e0..4ae04a2 100644 --- a/examples/custom_ct2_model.py +++ b/examples/custom_ct2_model.py @@ -16,14 +16,15 @@ import argparse from dataclasses import dataclass from functools import lru_cache -from crossfit import op -from crossfit.backend.torch.hf.model import HFModel -import dask_cudf -import crossfit as cf import ctranslate2 +import dask_cudf from transformers import AutoConfig, AutoTokenizer +import crossfit as cf +from crossfit import op +from crossfit.backend.torch.hf.model import HFModel + @dataclass class TranslationConfig: diff --git a/tests/op/test_model_function.py b/tests/op/test_model_function.py index 26fbb88..7ffd1eb 100644 --- a/tests/op/test_model_function.py +++ b/tests/op/test_model_function.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest.mock import patch +import pytest + cp = pytest.importorskip("cupy") cudf = pytest.importorskip("cudf") dask_cudf = pytest.importorskip("dask_cudf") @@ -25,7 +26,6 @@ import crossfit as cf # noqa: E402 - cf_loader = pytest.importorskip("crossfit.backend.torch.loader")