Skip to content

Commit

Permalink
Resolved linting issue.
Browse files Browse the repository at this point in the history
Signed-off-by: Ahmed Umair <[email protected]>
  • Loading branch information
Umair Ahmed committed Sep 27, 2024
1 parent 4f82e86 commit a170248
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
4 changes: 3 additions & 1 deletion crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import cudf
import cupy as cp

from crossfit.backend.cudf.series import (
create_list_series_from_1d_or_2d_ar,
create_nested_list_series_from_3d_ar,
Expand Down Expand Up @@ -60,10 +61,11 @@ def max_seq_length(self) -> int:
raise NotImplementedError()

def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame:
# importing here to avoid cyclic import error
from crossfit.backend.torch.loader import SortedSeqLoader

out = cudf.DataFrame(index=index)
_index = loader.sort_column(index.values) if type(loader) == SortedSeqLoader else index
_index = loader.sort_column(index.values) if type(loader) is SortedSeqLoader else index

if self.model_output_type == "string":
all_outputs = [o for output in all_outputs_ls for o in output]
Expand Down
1 change: 1 addition & 0 deletions crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Optional

import torch

from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader
Expand Down
9 changes: 5 additions & 4 deletions examples/custom_ct2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
import argparse
from dataclasses import dataclass
from functools import lru_cache
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel
import dask_cudf
import crossfit as cf

import ctranslate2
import dask_cudf
from transformers import AutoConfig, AutoTokenizer

import crossfit as cf
from crossfit import op
from crossfit.backend.torch.hf.model import HFModel


@dataclass
class TranslationConfig:
Expand Down
4 changes: 2 additions & 2 deletions tests/op/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from unittest.mock import patch

import pytest

cp = pytest.importorskip("cupy")
cudf = pytest.importorskip("cudf")
dask_cudf = pytest.importorskip("dask_cudf")
Expand All @@ -25,7 +26,6 @@

import crossfit as cf # noqa: E402


cf_loader = pytest.importorskip("crossfit.backend.torch.loader")


Expand Down

0 comments on commit a170248

Please sign in to comment.