Skip to content

Commit

Permalink
Merge pull request #18 from climatepolicyradar/feature/enable-mps-inf…
Browse files Browse the repository at this point in the history
…erence

enable inference on M1 GPUs; outside of docker
  • Loading branch information
kdutia authored Feb 13, 2024
2 parents a45cbeb + 2f10bf8 commit 6f4bacd
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 51 deletions.
8 changes: 4 additions & 4 deletions cli/text2embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
)
@click.option(
"--device",
type=click.Choice(["cuda", "cpu"]),
type=click.Choice(["cuda", "mps", "cpu"]),
help="Device to use for embeddings generation",
required=True,
default="cpu",
Expand All @@ -92,15 +92,15 @@ def run_as_cli(
Each embeddings file is called {id}.json where {id} is the document ID of the
input. Its first line is the description embedding and all other lines are
embeddings of each of the text blocks in the document in order. Encoding will
automatically run on the GPU if one is available.
run CPU unless device is set to 'cuda' or 'mps'.
Args: input_dir: Directory containing JSON files output_dir: Directory to save
embeddings to s3: Whether we are reading from and writing to S3. redo: Redo
encoding for files that have already been parsed. By default, files with IDs that
already exist in the output directory are skipped. limit (Optional[int]):
Optionally limit the number of text samples to process. Useful for debugging.
device (str): Device to use for embeddings generation. Must be either "cuda" or
"cpu".
device (str): Device to use for embeddings generation. Must be either "cuda", "mps",
or "cpu".
"""
# FIXME: This solution assumes that we have a json document with language = en (
# supported target language) for every document in the parser output. This isn't
Expand Down
92 changes: 45 additions & 47 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ aws-error-utils = "^2.7.0"
moto = "^4.1.11"
torch = "2.0.0"
cpr-data-access = {git = "https://github.com/climatepolicyradar/data-access.git", tag = "0.4.0"}
python-dotenv = "^1.0.1"

[tool.poetry.dev-dependencies]
black = "^22.1.0"
Expand Down
3 changes: 3 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
from typing import Set
import re
from dotenv import load_dotenv, find_dotenv

load_dotenv(find_dotenv())

SBERT_MODEL: str = os.getenv("SBERT_MODEL", "msmarco-distilbert-dot-v5")
INDEX_ENCODER_CACHE_FOLDER: str = os.getenv("INDEX_ENCODER_CACHE_FOLDER", "/models")
Expand Down

0 comments on commit 6f4bacd

Please sign in to comment.