Skip to content

Commit

Permalink
Merge branch 'main' into vjawa/fix_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa authored Jul 24, 2024
2 parents f7e86c5 + 0ecc5d3 commit d2e1eae
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 23 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ Install this library using `pip`:

pip install crossfit

### Installation from source (for cuda 12.x)

```
git clone https://github.com/rapidsai/crossfit.git
cd crossfit
pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"
```

## Usage

Usage instructions go here.
Expand Down
77 changes: 56 additions & 21 deletions crossfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,41 @@
from crossfit.metric import *
from crossfit.op import *


class LazyLoader:
def __init__(self, name):
self._name = name
self._module = None
self._error = None

def _load(self):
if self._module is None and self._error is None:
try:
parts = self._name.split(".")
module_name = ".".join(parts[:-1])
attribute_name = parts[-1]
module = __import__(module_name, fromlist=[attribute_name])
self._module = getattr(module, attribute_name)
except ImportError as e:
self._error = e
except AttributeError as e:
self._error = AttributeError(
f"Module '{module_name}' has no attribute '{attribute_name}'"
)

def __getattr__(self, item):
self._load()
if self._error is not None:
raise ImportError(f"Failed to import {self._name}: {self._error}")
return getattr(self._module, item)

def __call__(self, *args, **kwargs):
self._load()
if self._error is not None:
raise ImportError(f"Failed to import {self._name}: {self._error}")
return self._module(*args, **kwargs)


__all__ = [
"Aggregator",
"backend",
Expand All @@ -40,25 +75,25 @@
"Serial",
]

# Using the lazy import function
HFModel = LazyLoader("crossfit.backend.torch.HFModel")
SentenceTransformerModel = LazyLoader("crossfit.backend.torch.SentenceTransformerModel")
TorchExactSearch = LazyLoader("crossfit.backend.torch.TorchExactSearch")
IRDataset = LazyLoader("crossfit.dataset.base.IRDataset")
MultiDataset = LazyLoader("crossfit.dataset.base.MultiDataset")
load_dataset = LazyLoader("crossfit.dataset.load.load_dataset")
embed = LazyLoader("crossfit.report.beir.embed.embed")
beir_report = LazyLoader("crossfit.report.beir.report.beir_report")

try:
from crossfit.backend.torch import HFModel, SentenceTransformerModel, TorchExactSearch
from crossfit.dataset.base import IRDataset, MultiDataset
from crossfit.dataset.load import load_dataset
from crossfit.report.beir.embed import embed
from crossfit.report.beir.report import beir_report

__all__.extend(
[
"embed",
"beir_report",
"load_dataset",
"TorchExactSearch",
"SentenceTransformerModel",
"HFModel",
"MultiDataset",
"IRDataset",
]
)
except ImportError as e:
pass
__all__.extend(
[
"embed",
"beir_report",
"load_dataset",
"TorchExactSearch",
"SentenceTransformerModel",
"HFModel",
"MultiDataset",
"IRDataset",
]
)
13 changes: 13 additions & 0 deletions requirements/cuda12x.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
cudf-cu12>=24.4
dask-cudf-cu12>=24.4
cuml-cu12>=24.4
pylibraft-cu12>=24.4
raft-dask-cu12>=24.4
cuvs-cu12>=24.4
dask-cuda>=24.6
torch>=2.0
transformers>=4.0
curated-transformers>=1.0
bitsandbytes>=0.30
sentence-transformers>=2.0
sentencepiece
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_long_description():

def read_requirements(filename):
base = os.path.abspath(os.path.dirname(__file__))
with codecs.open(os.path.join(base, filename), "rb", "utf-8") as f:
with codecs.open(os.path.join(base, filename), "r", "utf-8") as f:
lineiter = (line.strip() for line in f)
return [line for line in lineiter if line and not line.startswith("#")]

Expand All @@ -40,12 +40,15 @@ def read_requirements(filename):

requirements = {
"base": read_requirements("requirements/base.txt"),
"cuda12x": read_requirements("requirements/cuda12x.txt"),
"dev": _dev,
"tensorflow": read_requirements("requirements/tensorflow.txt"),
"pytorch": read_requirements("requirements/pytorch.txt"),
"jax": read_requirements("requirements/jax.txt"),
}

dev_requirements = {
"cuda12x-dev": requirements["cuda12x"] + _dev,
"tensorflow-dev": requirements["tensorflow"] + _dev,
"pytorch-dev": requirements["pytorch"] + _dev,
"jax-dev": requirements["jax"] + _dev,
Expand Down Expand Up @@ -75,6 +78,6 @@ def read_requirements(filename):
**dev_requirements,
"all": list(itertools.chain(*list(requirements.values()))),
},
python_requires=">=3.7",
python_requires=">=3.7, <3.12",
test_suite="tests",
)

0 comments on commit d2e1eae

Please sign in to comment.