diff --git a/cli/text2embeddings.py b/cli/text2embeddings.py index 3d9bd85..bd0a68f 100644 --- a/cli/text2embeddings.py +++ b/cli/text2embeddings.py @@ -67,7 +67,7 @@ ) @click.option( "--device", - type=click.Choice(["cuda", "cpu"]), + type=click.Choice(["cuda", "mps", "cpu"]), help="Device to use for embeddings generation", required=True, default="cpu", @@ -92,15 +92,15 @@ def run_as_cli( Each embeddings file is called {id}.json where {id} is the document ID of the input. Its first line is the description embedding and all other lines are embeddings of each of the text blocks in the document in order. Encoding will - automatically run on the GPU if one is available. + run CPU unless device is set to 'cuda' or 'mps'. Args: input_dir: Directory containing JSON files output_dir: Directory to save embeddings to s3: Whether we are reading from and writing to S3. redo: Redo encoding for files that have already been parsed. By default, files with IDs that already exist in the output directory are skipped. limit (Optional[int]): Optionally limit the number of text samples to process. Useful for debugging. - device (str): Device to use for embeddings generation. Must be either "cuda" or - "cpu". + device (str): Device to use for embeddings generation. Must be either "cuda", "mps", + or "cpu". """ # FIXME: This solution assumes that we have a json document with language = en ( # supported target language) for every document in the parser output. This isn't diff --git a/poetry.lock b/poetry.lock index 82082c3..97ba4e1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -983,6 +983,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -1768,7 +1778,7 @@ name = "pycparser" version = "2.21" description = "C parser in Python" optional = false -python-versions = "*" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, @@ -1958,6 +1968,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.0.1" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, + {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "python-json-logger" version = "2.0.7" @@ -1992,6 +2016,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1999,8 +2024,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2017,6 +2050,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2024,6 +2058,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -2407,25 +2442,26 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo [[package]] name = "sentence-transformers" -version = "2.2.2" +version = "2.3.1" description = "Multilingual text embeddings" optional = false -python-versions = ">=3.6.0" +python-versions = ">=3.8.0" files = [ - {file = "sentence-transformers-2.2.2.tar.gz", hash = "sha256:dbc60163b27de21076c9a30d24b5b7b6fa05141d68cf2553fa9a77bf79a29136"}, + {file = "sentence-transformers-2.3.1.tar.gz", hash = "sha256:d589d85a464f45338cdbdf99ea715f8068e1fb01c582e0bcdbf60bcf3eade6d0"}, + {file = "sentence_transformers-2.3.1-py3-none-any.whl", hash = "sha256:285d6637726c3b002186aa4b8bcace1101364b32671fb605297c4c2636b8190e"}, ] [package.dependencies] -huggingface-hub = ">=0.4.0" +huggingface-hub = ">=0.15.1" nltk = "*" numpy = "*" +Pillow = "*" scikit-learn = "*" scipy = "*" sentencepiece = "*" -torch = ">=1.6.0" -torchvision = "*" +torch = ">=1.11.0" tqdm = "*" -transformers = ">=4.6.0,<5.0.0" +transformers = ">=4.32.0,<5.0.0" [[package]] name = "sentencepiece" @@ -2714,44 +2750,6 @@ typing-extensions = "*" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -[[package]] -name = "torchvision" -version = "0.15.1" -description = "image and video datasets and models for torch deep learning" -optional = false -python-versions = ">=3.8" -files = [ - {file = "torchvision-0.15.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc10d48e9a60d006d0c1b48dea87f1ec9b63d856737d592f7c5c44cd87f3f4b7"}, - {file = "torchvision-0.15.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3708d3410fdcaf6280e358cda9de2a4ab06cc0b4c0fd9aeeac550ec2563a887e"}, - {file = "torchvision-0.15.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:d4de10c837f1493c1c54344388e300a06c96914c6cc55fcb2527c21f2f010bbd"}, - {file = "torchvision-0.15.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:b82fcc5abc9b5c96495c76596a1573025cc1e09d97d2d6fda717c44b9ca45881"}, - {file = "torchvision-0.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:c84e97d8cc4fe167d87adad0a2a6424cff90544365545b20669bc50e6ea46875"}, - {file = "torchvision-0.15.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:97b90eb3b7333a31d049c4ccfd1064361e8491874959d38f466af64d67418cef"}, - {file = "torchvision-0.15.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6b60e1c839ae2a071befbba69b17468d67feafdf576e90ff9645bfbee998de17"}, - {file = "torchvision-0.15.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:13f71a3372d9168b01481a754ebaa171207f3dc455bf2fd86906c69222443738"}, - {file = "torchvision-0.15.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b2e8394726009090b40f6cc3a95cc878cc011dfac3d8e7a6060c79213d360880"}, - {file = "torchvision-0.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:2852f501189483187ce9eb0ccd01b3f4f0918d29057e4a18b3cce8dad9a8a964"}, - {file = "torchvision-0.15.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e5861baaeea87d19b6fd7d131e11a4a6bd17be14234c490a259bb360775e9520"}, - {file = "torchvision-0.15.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e714f362b9d8217cf4d68509b679ebc9ddf128cfe80f6c1def8e3f8a18466e75"}, - {file = "torchvision-0.15.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:43624accad1e47f16824be4db37ad678dd89326ad90b69c9c6363eeb22b9467e"}, - {file = "torchvision-0.15.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:7fe9b0cd3311b0db9e6d45ffab594ced06418fa4e2aa15eb2e60d55e5c51135c"}, - {file = "torchvision-0.15.1-cp38-cp38-win_amd64.whl", hash = "sha256:b45324ea4911a23a4b00b5a15cdbe36d47f93137206dab9f8c606d81b69dd3a7"}, - {file = "torchvision-0.15.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1dfdec7c7df967330bba3341a781e0c047d4e0163e67164a9918500362bf7d91"}, - {file = "torchvision-0.15.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c153710186cec0338d4fff411459a57ddbc8504436123ca73b3f0bdc26ff918c"}, - {file = "torchvision-0.15.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ff4e650aa601f32ab97bce06704868dd2baad69ca4d454fa1f0012a51199f2bc"}, - {file = "torchvision-0.15.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e9b4bb2a15849391df0415d2f76dd36e6528e4253f7b69322b7a0d682535544b"}, - {file = "torchvision-0.15.1-cp39-cp39-win_amd64.whl", hash = "sha256:21e6beb69e77ef6575c4fdd0ab332b96e8a7f144eee0d333acff469c827a4b5e"}, -] - -[package.dependencies] -numpy = "*" -pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" -requests = "*" -torch = "2.0.0" - -[package.extras] -scipy = ["scipy"] - [[package]] name = "tqdm" version = "4.66.1" @@ -3198,4 +3196,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "~3.9" -content-hash = "a24553fd7a42ac559e92918bcbd66d82ef33c612bab0b82fd3e7d80028de20d1" +content-hash = "f2e49221bcfc547489177a844a2baea7775f0cffb3ced19b18a65b36ac9f0f50" diff --git a/pyproject.toml b/pyproject.toml index c63e68d..186fd73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ aws-error-utils = "^2.7.0" moto = "^4.1.11" torch = "2.0.0" cpr-data-access = {git = "https://github.com/climatepolicyradar/data-access.git", tag = "0.4.0"} +python-dotenv = "^1.0.1" [tool.poetry.dev-dependencies] black = "^22.1.0" diff --git a/src/config.py b/src/config.py index 0b7d497..73dc82f 100644 --- a/src/config.py +++ b/src/config.py @@ -3,6 +3,9 @@ import os from typing import Set import re +from dotenv import load_dotenv, find_dotenv + +load_dotenv(find_dotenv()) SBERT_MODEL: str = os.getenv("SBERT_MODEL", "msmarco-distilbert-dot-v5") INDEX_ENCODER_CACHE_FOLDER: str = os.getenv("INDEX_ENCODER_CACHE_FOLDER", "/models")