From 414867322079cce732d3fbaf2396c02b7abea50a Mon Sep 17 00:00:00 2001 From: edknv <109497216+edknv@users.noreply.github.com> Date: Mon, 15 Jan 2024 14:19:55 -0800 Subject: [PATCH] upgrade base container to nemo:23.10 (#42) * require cudf 23.12 * use the new ci image * add nvidia pip index to tests * lint * separate gpu requirements * require pytrec_eval * install beir for testing * move beir installtion to dockerfile * drop cudf 23.12 and use nemo contianer * dummy commit to trigger ci using newly updated 23.10 container --- .github/workflows/gpu-ci.yml | 4 ++-- crossfit/backend/torch/model.py | 2 ++ crossfit/backend/torch/op/base.py | 5 +---- crossfit/op/base.py | 30 +++--------------------------- crossfit/op/tokenize.py | 21 +++++++++++++++------ crossfit/report/beir/embed.py | 2 -- docker/ci/Dockerfile | 5 +++-- docker/ci/build_and_push.sh | 2 +- requirements/base.txt | 2 +- requirements/pytorch.txt | 1 + 10 files changed, 29 insertions(+), 45 deletions(-) diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index abf74ed6..7b02c043 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -13,7 +13,7 @@ jobs: gpu-ci: runs-on: linux-amd64-gpu-p100-latest-1 container: - image: nvcr.io/nvidian/crossfit-ci:23.09 + image: nvcr.io/nvidian/crossfit-ci:23.10 env: NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }} options: --shm-size=1G @@ -32,7 +32,7 @@ jobs: benchmark: runs-on: linux-amd64-gpu-p100-latest-1 container: - image: nvcr.io/nvidian/crossfit-ci:23.09 + image: nvcr.io/nvidian/crossfit-ci:23.10 env: NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }} options: --shm-size=1G diff --git a/crossfit/backend/torch/model.py b/crossfit/backend/torch/model.py index 1745bd29..5e5aa36b 100644 --- a/crossfit/backend/torch/model.py +++ b/crossfit/backend/torch/model.py @@ -34,6 +34,8 @@ def call_on_worker(self, worker, *args, **kwargs): return worker.torch_model(*args, **kwargs) def get_model(self, worker): + if not hasattr(worker, "torch_model"): + self.load_on_worker(worker) return worker.torch_model def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int: diff --git a/crossfit/backend/torch/op/base.py b/crossfit/backend/torch/op/base.py index 829208d2..6ba60e95 100644 --- a/crossfit/backend/torch/op/base.py +++ b/crossfit/backend/torch/op/base.py @@ -50,9 +50,6 @@ def __init__( self.model_output_col = model_output_col self.pred_output_col = pred_output_col - def setup(self): - self.model.load_on_worker(self) - @torch.no_grad() def call(self, data, partition_info=None): index = data.index @@ -72,7 +69,7 @@ def call(self, data, partition_info=None): ) all_outputs_ls = [] - for output in loader.map(self.model.get_model(self)): + for output in loader.map(self.model.get_model(self.get_worker())): if isinstance(output, dict): if self.model_output_col not in output: raise ValueError(f"Column '{self.model_outupt_col}' not found in model output.") diff --git a/crossfit/op/base.py b/crossfit/op/base.py index abf9d3e5..14e213f3 100644 --- a/crossfit/op/base.py +++ b/crossfit/op/base.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import uuid import dask.dataframe as dd from dask.distributed import get_worker, wait @@ -27,7 +26,7 @@ def __init__(self, pre=None, cols=False, keep_cols=None): self.pre = pre self.cols = cols self.keep_cols = keep_cols or [] - self.id = str(uuid.uuid4()) + self.worker_name = getattr(self.get_worker(), "name", 0) def setup(self): pass @@ -46,29 +45,6 @@ def get_worker(self): return worker - def _get_init_name(self): - init_name = f"setup_done_{self.id}" - return init_name - - def setup_worker(self): - worker = self.get_worker() - - self.worker_name = getattr(worker, "name", 0) - init_name = self._get_init_name() - - if not hasattr(worker, init_name): - self.setup() - setattr(worker, init_name, True) - - def teardown_worker(self): - worker = self.get_worker() - - init_name = self._get_init_name() - - if hasattr(worker, init_name): - delattr(worker, init_name) - self.teardown() - def call_dask(self, data: dd.DataFrame): output = data.map_partitions(self, meta=self._build_dask_meta(data)) @@ -101,10 +77,10 @@ def add_keep_cols(self, data, output): def __call__(self, data, *args, partition_info=None, **kwargs): if isinstance(data, dd.DataFrame): output = self.call_dask(data, *args, **kwargs) - self.teardown_worker() + self.teardown() return output - self.setup_worker() + self.setup() if self.pre is not None: params = inspect.signature(self.pre).parameters diff --git a/crossfit/op/tokenize.py b/crossfit/op/tokenize.py index ec24a0f3..686ece64 100644 --- a/crossfit/op/tokenize.py +++ b/crossfit/op/tokenize.py @@ -39,14 +39,18 @@ def __init__( self.model = model self.max_length = max_length or model.max_seq_length() - # Make sure we download the tokenizer just once - GPUTokenizer.from_pretrained(self.model) - - def setup(self): - self.tokenizer = GPUTokenizer.from_pretrained(self.model) + self.setup() def tokenize_strings(self, sentences, max_length=None): - return self.tokenizer( + worker = self.get_worker() + + if hasattr(worker, "tokenizer"): + tokenizer = worker.tokenizer + else: + tokenizer = GPUTokenizer.from_pretrained(self.model) + worker.tokenizer = tokenizer + + return worker.tokenizer( sentences, max_length=max_length or self.max_length, max_num_rows=len(sentences), @@ -56,6 +60,11 @@ def tokenize_strings(self, sentences, max_length=None): add_special_tokens=True, ) + def teardown(self): + worker = self.get_worker() + if hasattr(worker, "tokenizer"): + delattr(worker, "tokenizer") + def call_column(self, data): if isinstance(data, cudf.DataFrame): raise ValueError( diff --git a/crossfit/report/beir/embed.py b/crossfit/report/beir/embed.py index 2af24550..98dcc3f0 100644 --- a/crossfit/report/beir/embed.py +++ b/crossfit/report/beir/embed.py @@ -52,7 +52,6 @@ def embed( else: return EmbeddingDatataset.from_dir(emb_dir, data=dataset) - dfs = [] for dtype in ["query", "item"]: if os.path.exists(os.path.join(emb_dir, dtype)): continue @@ -76,7 +75,6 @@ def embed( embeddings = pipe(df) embeddings.to_parquet(os.path.join(emb_dir, dtype)) - dfs.append(df) output: EmbeddingDatataset = EmbeddingDatataset.from_dir(emb_dir, data=dataset) pred_path = os.path.join(emb_dir, "predictions") diff --git a/docker/ci/Dockerfile b/docker/ci/Dockerfile index 98600c43..b92cc2e1 100644 --- a/docker/ci/Dockerfile +++ b/docker/ci/Dockerfile @@ -1,8 +1,9 @@ -FROM nvcr.io/nvidia/pytorch:23.09-py3 +FROM nvcr.io/nvidia/nemo:23.10 COPY . /tmp/crossfit/ RUN cd /tmp/crossfit && \ - pip install .[pytorch-dev] && \ + python3 -m pip install .[pytorch-dev] && \ + python3 -m pip install beir && \ rm -r /tmp/crossfit ENV CF_HOME /root/.cf diff --git a/docker/ci/build_and_push.sh b/docker/ci/build_and_push.sh index 948b5dd4..ae7b9ae3 100755 --- a/docker/ci/build_and_push.sh +++ b/docker/ci/build_and_push.sh @@ -6,7 +6,7 @@ set -e IMAGE_NAME=nvcr.io/nvidian/crossfit-ci -IMAGE_TAG=23.09 +IMAGE_TAG=23.10 docker build -t ${IMAGE_NAME}:${IMAGE_TAG} -f docker/ci/Dockerfile . diff --git a/requirements/base.txt b/requirements/base.txt index 9a34f039..68a2fbef 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -11,4 +11,4 @@ typing_extensions typing_utils tqdm rich -pynvml>=11.0.0,<11.5 \ No newline at end of file +pynvml>=11.0.0,<11.5 diff --git a/requirements/pytorch.txt b/requirements/pytorch.txt index 1ca20faf..1419089c 100644 --- a/requirements/pytorch.txt +++ b/requirements/pytorch.txt @@ -2,3 +2,4 @@ torch>=1.0 transformers curated-transformers bitsandbytes +sentence-transformers