diff --git a/docs/docs/metrics/write-your-own.md b/docs/docs/metrics/write-your-own.md index 65f96b2be..0b87cbb52 100644 --- a/docs/docs/metrics/write-your-own.md +++ b/docs/docs/metrics/write-your-own.md @@ -28,8 +28,7 @@ Your implementation should call `writer.write(, )` for eve from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import AnnotationType, DataType, Metric, MetricType from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/docs/docs/sdk/initialize-a-project.mdx b/docs/docs/sdk/initialize-a-project.mdx index 834a9bee4..45fb3e956 100644 --- a/docs/docs/sdk/initialize-a-project.mdx +++ b/docs/docs/sdk/initialize-a-project.mdx @@ -23,7 +23,7 @@ from pathlib import Path from typing import List from encord_active.lib.metrics.execute import run_metrics -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import EmbeddingType from encord_active.lib.project.local import ProjectExistsError, init_local_project # 1. Choose images to import @@ -88,7 +88,7 @@ from pathlib import Path from typing import List from encord_active.lib.metrics.execute import run_metrics -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import EmbeddingType from encord_active.lib.project.local import ProjectExistsError, init_local_project # 1. Choose images and label files to import diff --git a/docs/docs/sdk/run-metrics.md b/docs/docs/sdk/run-metrics.md index 218ba51e7..259a24b21 100644 --- a/docs/docs/sdk/run-metrics.md +++ b/docs/docs/sdk/run-metrics.md @@ -45,7 +45,7 @@ There is a utility function you can use to run targeted subsets of metrics: from encord_active.lib.metrics.execute import ( run_metrics_by_embedding_type, ) -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import EmbeddingType run_metrics_by_embedding_type(EmbeddingType.IMAGE, **options) run_metrics_by_embedding_type(EmbeddingType.OBJECT, **options) diff --git a/examples/building-a-custom-metric-function.ipynb b/examples/building-a-custom-metric-function.ipynb index 8a1494376..288bf80ca 100644 --- a/examples/building-a-custom-metric-function.ipynb +++ b/examples/building-a-custom-metric-function.ipynb @@ -104,8 +104,7 @@ "from typing import List, Optional, Union\n", "\n", "from encord_active.lib.common.iterator import Iterator\n", - "from encord_active.lib.metrics.metric import Metric\n", - "from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType\n", + "from encord_active.lib.metrics.metric import AnnotationType, DataType, MetricType, Metric\n", "from encord_active.lib.metrics.writer import CSVMetricWriter\n", "\n", "class ExampleMetric(Metric):\n", @@ -443,8 +442,7 @@ "import numpy as np\n", "from encord_active.lib.common import utils\n", "from encord_active.lib.common.iterator import Iterator\n", - "from encord_active.lib.metrics.metric import Metric\n", - "from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType\n", + "from encord_active.lib.metrics.metric import AnnotationType, DataType, Metric, MetricType\n", "from encord_active.lib.metrics.writer import CSVMetricWriter\n", "from loguru import logger\n", "\n", diff --git a/poetry.lock b/poetry.lock index dddae515e..c431c7045 100644 --- a/poetry.lock +++ b/poetry.lock @@ -292,31 +292,6 @@ python-versions = ">=3.7" [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} -[[package]] -name = "clip" -version = "1.0" -description = "" -category = "main" -optional = false -python-versions = "*" -develop = false - -[package.dependencies] -ftfy = "*" -regex = "*" -torch = "*" -torchvision = "*" -tqdm = "*" - -[package.extras] -dev = ["pytest"] - -[package.source] -type = "git" -url = "https://github.com/openai/CLIP.git" -reference = "HEAD" -resolved_reference = "a9b1bf5920416aaeaec965c25dd9e8f98c864f16" - [[package]] name = "colorama" version = "0.4.6" @@ -525,7 +500,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.12.0" description = "A platform independent file lock." -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -563,17 +538,6 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" -[[package]] -name = "ftfy" -version = "6.1.1" -description = "Fixes mojibake and other problems with Unicode, after the fact" -category = "main" -optional = false -python-versions = ">=3.7,<4" - -[package.dependencies] -wcwidth = ">=0.2.5" - [[package]] name = "gitdb" version = "4.0.10" @@ -1287,20 +1251,6 @@ category = "dev" optional = false python-versions = ">=3.7" -[[package]] -name = "mpmath" -version = "1.3.0" -description = "Python library for arbitrary-precision floating-point arithmetic" -category = "main" -optional = false -python-versions = "*" - -[package.extras] -develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] -docs = ["sphinx"] -gmpy = ["gmpy2 (>=2.1.0a4)"] -tests = ["pytest (>=4.6)"] - [[package]] name = "mypy" version = "0.981" @@ -1451,21 +1401,6 @@ category = "main" optional = false python-versions = ">=3.5" -[[package]] -name = "networkx" -version = "3.1" -description = "Python package for creating and manipulating graphs and networks" -category = "main" -optional = false -python-versions = ">=3.8" - -[package.extras] -default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] -developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] -doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] -extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] -test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] - [[package]] name = "nodeenv" version = "1.7.0" @@ -2170,14 +2105,6 @@ packaging = "*" [package.extras] test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] -[[package]] -name = "regex" -version = "2023.5.5" -description = "Alternative regular expression module, to replace re." -category = "main" -optional = false -python-versions = ">=3.6" - [[package]] name = "requests" version = "2.29.0" @@ -2467,17 +2394,6 @@ python-versions = ">=3.6" plotly = ">=4.14.3" streamlit = ">=0.63" -[[package]] -name = "sympy" -version = "1.12" -description = "Computer algebra system (CAS) in Python" -category = "main" -optional = false -python-versions = ">=3.8" - -[package.dependencies] -mpmath = ">=0.19" - [[package]] name = "tabulate" version = "0.9.0" @@ -2585,35 +2501,29 @@ python-versions = ">=3.5" [[package]] name = "torch" -version = "2.0.1" +version = "1.12.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" category = "main" optional = false -python-versions = ">=3.8.0" +python-versions = ">=3.7.0" [package.dependencies] -filelock = "*" -jinja2 = "*" -networkx = "*" -sympy = "*" typing-extensions = "*" -[package.extras] -opt-einsum = ["opt-einsum (>=3.3)"] - [[package]] name = "torchvision" -version = "0.15.2" +version = "0.13.1" description = "image and video datasets and models for torch deep learning" category = "main" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" [package.dependencies] numpy = "*" pillow = ">=5.3.0,<8.3.0 || >=8.4.0" requests = "*" -torch = "2.0.1" +torch = "1.12.1" +typing-extensions = "*" [package.extras] scipy = ["scipy"] @@ -3025,7 +2935,7 @@ notebooks = ["jupyterlab", "ipywidgets"] [metadata] lock-version = "1.1" python-versions = ">=3.9,<3.9.7 || >3.9.7,<3.11" -content-hash = "ede337f6b5df598e4a61e704dd37d2cdfd160571836a78c17dd400ffa90ba82f" +content-hash = "9bf7547aa0f400b04b14f3f80026d7fad7f8843ea9591a54a2a24633b5b9d1cf" [metadata.files] aiofiles = [ @@ -3327,7 +3237,6 @@ click = [ {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, ] -clip = [] colorama = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -3521,10 +3430,6 @@ fqdn = [ {file = "fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014"}, {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] -ftfy = [ - {file = "ftfy-6.1.1-py3-none-any.whl", hash = "sha256:0ffd33fce16b54cccaec78d6ec73d95ad370e5df5a25255c8966a6147bd667ca"}, - {file = "ftfy-6.1.1.tar.gz", hash = "sha256:bfc2019f84fcd851419152320a6375604a0f1459c281b5b199b2cd0d2e727f8f"}, -] gitdb = [ {file = "gitdb-4.0.10-py3-none-any.whl", hash = "sha256:c286cf298426064079ed96a9e4a9d39e7f3e9bf15ba60701e95f5492f28415c7"}, {file = "gitdb-4.0.10.tar.gz", hash = "sha256:6eb990b69df4e15bad899ea868dc46572c3f75339735663b81de79b06f17eb9a"}, @@ -3970,10 +3875,6 @@ more-itertools = [ {file = "more-itertools-9.1.0.tar.gz", hash = "sha256:cabaa341ad0389ea83c17a94566a53ae4c9d07349861ecb14dc6d0345cf9ac5d"}, {file = "more_itertools-9.1.0-py3-none-any.whl", hash = "sha256:d2bc7f02446e86a68911e58ded76d6561eea00cddfb2a91e7019bbb586c799f3"}, ] -mpmath = [ - {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, - {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, -] mypy = [ {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"}, {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"}, @@ -4028,10 +3929,6 @@ nest-asyncio = [ {file = "nest_asyncio-1.5.6-py3-none-any.whl", hash = "sha256:b9a953fb40dceaa587d109609098db21900182b16440652454a146cffb06e8b8"}, {file = "nest_asyncio-1.5.6.tar.gz", hash = "sha256:d267cc1ff794403f7df692964d1d2a3fa9418ffea2a3f6859a439ff482fef290"}, ] -networkx = [ - {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, - {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, -] nodeenv = [ {file = "nodeenv-1.7.0-py2.py3-none-any.whl", hash = "sha256:27083a7b96a25f2f5e1d8cb4b6317ee8aeda3bdd121394e5ac54e498028a042e"}, {file = "nodeenv-1.7.0.tar.gz", hash = "sha256:e0e7f7dfb85fc5394c6fe1e8fa98131a2473e04311a45afb6508f7cf1836fa2b"}, @@ -4648,96 +4545,6 @@ qtpy = [ {file = "QtPy-2.3.1-py3-none-any.whl", hash = "sha256:5193d20e0b16e4d9d3bc2c642d04d9f4e2c892590bd1b9c92bfe38a95d5a2e12"}, {file = "QtPy-2.3.1.tar.gz", hash = "sha256:a8c74982d6d172ce124d80cafd39653df78989683f760f2281ba91a6e7b9de8b"}, ] -regex = [ - {file = "regex-2023.5.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:48c9ec56579d4ba1c88f42302194b8ae2350265cb60c64b7b9a88dcb7fbde309"}, - {file = "regex-2023.5.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02f4541550459c08fdd6f97aa4e24c6f1932eec780d58a2faa2068253df7d6ff"}, - {file = "regex-2023.5.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53e22e4460f0245b468ee645156a4f84d0fc35a12d9ba79bd7d79bdcd2f9629d"}, - {file = "regex-2023.5.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b870b6f632fc74941cadc2a0f3064ed8409e6f8ee226cdfd2a85ae50473aa94"}, - {file = "regex-2023.5.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:171c52e320fe29260da550d81c6b99f6f8402450dc7777ef5ced2e848f3b6f8f"}, - {file = "regex-2023.5.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aad5524c2aedaf9aa14ef1bc9327f8abd915699dea457d339bebbe2f0d218f86"}, - {file = "regex-2023.5.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a0f874ee8c0bc820e649c900243c6d1e6dc435b81da1492046716f14f1a2a96"}, - {file = "regex-2023.5.5-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e645c757183ee0e13f0bbe56508598e2d9cd42b8abc6c0599d53b0d0b8dd1479"}, - {file = "regex-2023.5.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a4c5da39bca4f7979eefcbb36efea04471cd68db2d38fcbb4ee2c6d440699833"}, - {file = "regex-2023.5.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5e3f4468b8c6fd2fd33c218bbd0a1559e6a6fcf185af8bb0cc43f3b5bfb7d636"}, - {file = "regex-2023.5.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:59e4b729eae1a0919f9e4c0fc635fbcc9db59c74ad98d684f4877be3d2607dd6"}, - {file = "regex-2023.5.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ba73a14e9c8f9ac409863543cde3290dba39098fc261f717dc337ea72d3ebad2"}, - {file = "regex-2023.5.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0bbd5dcb19603ab8d2781fac60114fb89aee8494f4505ae7ad141a3314abb1f9"}, - {file = "regex-2023.5.5-cp310-cp310-win32.whl", hash = "sha256:40005cbd383438aecf715a7b47fe1e3dcbc889a36461ed416bdec07e0ef1db66"}, - {file = "regex-2023.5.5-cp310-cp310-win_amd64.whl", hash = "sha256:59597cd6315d3439ed4b074febe84a439c33928dd34396941b4d377692eca810"}, - {file = "regex-2023.5.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8f08276466fedb9e36e5193a96cb944928301152879ec20c2d723d1031cd4ddd"}, - {file = "regex-2023.5.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cd46f30e758629c3ee91713529cfbe107ac50d27110fdcc326a42ce2acf4dafc"}, - {file = "regex-2023.5.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2910502f718828cecc8beff004917dcf577fc5f8f5dd40ffb1ea7612124547b"}, - {file = "regex-2023.5.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:445d6f4fc3bd9fc2bf0416164454f90acab8858cd5a041403d7a11e3356980e8"}, - {file = "regex-2023.5.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18196c16a584619c7c1d843497c069955d7629ad4a3fdee240eb347f4a2c9dbe"}, - {file = "regex-2023.5.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33d430a23b661629661f1fe8395be2004006bc792bb9fc7c53911d661b69dd7e"}, - {file = "regex-2023.5.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:72a28979cc667e5f82ef433db009184e7ac277844eea0f7f4d254b789517941d"}, - {file = "regex-2023.5.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f764e4dfafa288e2eba21231f455d209f4709436baeebb05bdecfb5d8ddc3d35"}, - {file = "regex-2023.5.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:23d86ad2121b3c4fc78c58f95e19173790e22ac05996df69b84e12da5816cb17"}, - {file = "regex-2023.5.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:690a17db524ee6ac4a27efc5406530dd90e7a7a69d8360235323d0e5dafb8f5b"}, - {file = "regex-2023.5.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:1ecf3dcff71f0c0fe3e555201cbe749fa66aae8d18f80d2cc4de8e66df37390a"}, - {file = "regex-2023.5.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:811040d7f3dd9c55eb0d8b00b5dcb7fd9ae1761c454f444fd9f37fe5ec57143a"}, - {file = "regex-2023.5.5-cp311-cp311-win32.whl", hash = "sha256:c8c143a65ce3ca42e54d8e6fcaf465b6b672ed1c6c90022794a802fb93105d22"}, - {file = "regex-2023.5.5-cp311-cp311-win_amd64.whl", hash = "sha256:586a011f77f8a2da4b888774174cd266e69e917a67ba072c7fc0e91878178a80"}, - {file = "regex-2023.5.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b6365703e8cf1644b82104cdd05270d1a9f043119a168d66c55684b1b557d008"}, - {file = "regex-2023.5.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a56c18f21ac98209da9c54ae3ebb3b6f6e772038681d6cb43b8d53da3b09ee81"}, - {file = "regex-2023.5.5-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8b942d8b3ce765dbc3b1dad0a944712a89b5de290ce8f72681e22b3c55f3cc8"}, - {file = "regex-2023.5.5-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:844671c9c1150fcdac46d43198364034b961bd520f2c4fdaabfc7c7d7138a2dd"}, - {file = "regex-2023.5.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2ce65bdeaf0a386bb3b533a28de3994e8e13b464ac15e1e67e4603dd88787fa"}, - {file = "regex-2023.5.5-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fee0016cc35a8a91e8cc9312ab26a6fe638d484131a7afa79e1ce6165328a135"}, - {file = "regex-2023.5.5-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:18f05d14f14a812fe9723f13afafefe6b74ca042d99f8884e62dbd34dcccf3e2"}, - {file = "regex-2023.5.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:941b3f1b2392f0bcd6abf1bc7a322787d6db4e7457be6d1ffd3a693426a755f2"}, - {file = "regex-2023.5.5-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:921473a93bcea4d00295799ab929522fc650e85c6b9f27ae1e6bb32a790ea7d3"}, - {file = "regex-2023.5.5-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:e2205a81f815b5bb17e46e74cc946c575b484e5f0acfcb805fb252d67e22938d"}, - {file = "regex-2023.5.5-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:385992d5ecf1a93cb85adff2f73e0402dd9ac29b71b7006d342cc920816e6f32"}, - {file = "regex-2023.5.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:890a09cb0a62198bff92eda98b2b507305dd3abf974778bae3287f98b48907d3"}, - {file = "regex-2023.5.5-cp36-cp36m-win32.whl", hash = "sha256:821a88b878b6589c5068f4cc2cfeb2c64e343a196bc9d7ac68ea8c2a776acd46"}, - {file = "regex-2023.5.5-cp36-cp36m-win_amd64.whl", hash = "sha256:7918a1b83dd70dc04ab5ed24c78ae833ae8ea228cef84e08597c408286edc926"}, - {file = "regex-2023.5.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:338994d3d4ca4cf12f09822e025731a5bdd3a37aaa571fa52659e85ca793fb67"}, - {file = "regex-2023.5.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a69cf0c00c4d4a929c6c7717fd918414cab0d6132a49a6d8fc3ded1988ed2ea"}, - {file = "regex-2023.5.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f5e06df94fff8c4c85f98c6487f6636848e1dc85ce17ab7d1931df4a081f657"}, - {file = "regex-2023.5.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8906669b03c63266b6a7693d1f487b02647beb12adea20f8840c1a087e2dfb5"}, - {file = "regex-2023.5.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fda3e50abad8d0f48df621cf75adc73c63f7243cbe0e3b2171392b445401550"}, - {file = "regex-2023.5.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5ac2b7d341dc1bd102be849d6dd33b09701223a851105b2754339e390be0627a"}, - {file = "regex-2023.5.5-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fb2b495dd94b02de8215625948132cc2ea360ae84fe6634cd19b6567709c8ae2"}, - {file = "regex-2023.5.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:aa7d032c1d84726aa9edeb6accf079b4caa87151ca9fabacef31fa028186c66d"}, - {file = "regex-2023.5.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:3d45864693351c15531f7e76f545ec35000d50848daa833cead96edae1665559"}, - {file = "regex-2023.5.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21e90a288e6ba4bf44c25c6a946cb9b0f00b73044d74308b5e0afd190338297c"}, - {file = "regex-2023.5.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:10250a093741ec7bf74bcd2039e697f519b028518f605ff2aa7ac1e9c9f97423"}, - {file = "regex-2023.5.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6b8d0c153f07a953636b9cdb3011b733cadd4178123ef728ccc4d5969e67f3c2"}, - {file = "regex-2023.5.5-cp37-cp37m-win32.whl", hash = "sha256:10374c84ee58c44575b667310d5bbfa89fb2e64e52349720a0182c0017512f6c"}, - {file = "regex-2023.5.5-cp37-cp37m-win_amd64.whl", hash = "sha256:9b320677521aabf666cdd6e99baee4fb5ac3996349c3b7f8e7c4eee1c00dfe3a"}, - {file = "regex-2023.5.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:afb1c70ec1e594a547f38ad6bf5e3d60304ce7539e677c1429eebab115bce56e"}, - {file = "regex-2023.5.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cf123225945aa58b3057d0fba67e8061c62d14cc8a4202630f8057df70189051"}, - {file = "regex-2023.5.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a99757ad7fe5c8a2bb44829fc57ced11253e10f462233c1255fe03888e06bc19"}, - {file = "regex-2023.5.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a623564d810e7a953ff1357f7799c14bc9beeab699aacc8b7ab7822da1e952b8"}, - {file = "regex-2023.5.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ced02e3bd55e16e89c08bbc8128cff0884d96e7f7a5633d3dc366b6d95fcd1d6"}, - {file = "regex-2023.5.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1cbe6b5be3b9b698d8cc4ee4dee7e017ad655e83361cd0ea8e653d65e469468"}, - {file = "regex-2023.5.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a6e4b0e0531223f53bad07ddf733af490ba2b8367f62342b92b39b29f72735a"}, - {file = "regex-2023.5.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2e9c4f778514a560a9c9aa8e5538bee759b55f6c1dcd35613ad72523fd9175b8"}, - {file = "regex-2023.5.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:256f7f4c6ba145f62f7a441a003c94b8b1af78cee2cccacfc1e835f93bc09426"}, - {file = "regex-2023.5.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:bd7b68fd2e79d59d86dcbc1ccd6e2ca09c505343445daaa4e07f43c8a9cc34da"}, - {file = "regex-2023.5.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4a5059bd585e9e9504ef9c07e4bc15b0a621ba20504388875d66b8b30a5c4d18"}, - {file = "regex-2023.5.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:6893544e06bae009916a5658ce7207e26ed17385149f35a3125f5259951f1bbe"}, - {file = "regex-2023.5.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c64d5abe91a3dfe5ff250c6bb267ef00dbc01501518225b45a5f9def458f31fb"}, - {file = "regex-2023.5.5-cp38-cp38-win32.whl", hash = "sha256:7923470d6056a9590247ff729c05e8e0f06bbd4efa6569c916943cb2d9b68b91"}, - {file = "regex-2023.5.5-cp38-cp38-win_amd64.whl", hash = "sha256:4035d6945cb961c90c3e1c1ca2feb526175bcfed44dfb1cc77db4fdced060d3e"}, - {file = "regex-2023.5.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:50fd2d9b36938d4dcecbd684777dd12a407add4f9f934f235c66372e630772b0"}, - {file = "regex-2023.5.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d19e57f888b00cd04fc38f5e18d0efbd91ccba2d45039453ab2236e6eec48d4d"}, - {file = "regex-2023.5.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd966475e963122ee0a7118ec9024388c602d12ac72860f6eea119a3928be053"}, - {file = "regex-2023.5.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db09e6c18977a33fea26fe67b7a842f706c67cf8bda1450974d0ae0dd63570df"}, - {file = "regex-2023.5.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6164d4e2a82f9ebd7752a06bd6c504791bedc6418c0196cd0a23afb7f3e12b2d"}, - {file = "regex-2023.5.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84397d3f750d153ebd7f958efaa92b45fea170200e2df5e0e1fd4d85b7e3f58a"}, - {file = "regex-2023.5.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c3efee9bb53cbe7b285760c81f28ac80dc15fa48b5fe7e58b52752e642553f1"}, - {file = "regex-2023.5.5-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:144b5b017646b5a9392a5554a1e5db0000ae637be4971c9747566775fc96e1b2"}, - {file = "regex-2023.5.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:1189fbbb21e2c117fda5303653b61905aeeeea23de4a94d400b0487eb16d2d60"}, - {file = "regex-2023.5.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f83fe9e10f9d0b6cf580564d4d23845b9d692e4c91bd8be57733958e4c602956"}, - {file = "regex-2023.5.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:72aa4746993a28c841e05889f3f1b1e5d14df8d3daa157d6001a34c98102b393"}, - {file = "regex-2023.5.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:de2f780c3242ea114dd01f84848655356af4dd561501896c751d7b885ea6d3a1"}, - {file = "regex-2023.5.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:290fd35219486dfbc00b0de72f455ecdd63e59b528991a6aec9fdfc0ce85672e"}, - {file = "regex-2023.5.5-cp39-cp39-win32.whl", hash = "sha256:732176f5427e72fa2325b05c58ad0b45af341c459910d766f814b0584ac1f9ac"}, - {file = "regex-2023.5.5-cp39-cp39-win_amd64.whl", hash = "sha256:1307aa4daa1cbb23823d8238e1f61292fd07e4e5d8d38a6efff00b67a7cdb764"}, - {file = "regex-2023.5.5.tar.gz", hash = "sha256:7d76a8a1fc9da08296462a18f16620ba73bcbf5909e42383b253ef34d9d5141e"}, -] requests = [ {file = "requests-2.29.0-py3-none-any.whl", hash = "sha256:e8f3c9be120d3333921d213eef078af392fba3933ab7ed2d1cba3b56f2568c3b"}, {file = "requests-2.29.0.tar.gz", hash = "sha256:f2e34a75f4749019bb0e3effb66683630e4ffeaf75819fb51bebef1bf5aef059"}, @@ -4920,10 +4727,6 @@ streamlit-plotly-events = [ {file = "streamlit-plotly-events-0.0.6.tar.gz", hash = "sha256:1fe25dbf0e5d803aeb90253be04d7b395f5bcfdf3c654f96ff3c19424e7f9582"}, {file = "streamlit_plotly_events-0.0.6-py3-none-any.whl", hash = "sha256:e63fbe3c6a0746fdfce20060fc45ba5cd97805505c332b27372dcbd02c2ede29"}, ] -sympy = [ - {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, - {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, -] tabulate = [ {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, @@ -4965,48 +4768,47 @@ toolz = [ {file = "toolz-0.12.0.tar.gz", hash = "sha256:88c570861c440ee3f2f6037c4654613228ff40c93a6c25e0eba70d17282c6194"}, ] torch = [ - {file = "torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8ced00b3ba471856b993822508f77c98f48a458623596a4c43136158781e306a"}, - {file = "torch-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:359bfaad94d1cda02ab775dc1cc386d585712329bb47b8741607ef6ef4950747"}, - {file = "torch-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:7c84e44d9002182edd859f3400deaa7410f5ec948a519cc7ef512c2f9b34d2c4"}, - {file = "torch-2.0.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:567f84d657edc5582d716900543e6e62353dbe275e61cdc36eda4929e46df9e7"}, - {file = "torch-2.0.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:787b5a78aa7917465e9b96399b883920c88a08f4eb63b5a5d2d1a16e27d2f89b"}, - {file = "torch-2.0.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e617b1d0abaf6ced02dbb9486803abfef0d581609b09641b34fa315c9c40766d"}, - {file = "torch-2.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b6019b1de4978e96daa21d6a3ebb41e88a0b474898fe251fd96189587408873e"}, - {file = "torch-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:dbd68cbd1cd9da32fe5d294dd3411509b3d841baecb780b38b3b7b06c7754434"}, - {file = "torch-2.0.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:ef654427d91600129864644e35deea761fb1fe131710180b952a6f2e2207075e"}, - {file = "torch-2.0.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:25aa43ca80dcdf32f13da04c503ec7afdf8e77e3a0183dd85cd3e53b2842e527"}, - {file = "torch-2.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5ef3ea3d25441d3957348f7e99c7824d33798258a2bf5f0f0277cbcadad2e20d"}, - {file = "torch-2.0.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:0882243755ff28895e8e6dc6bc26ebcf5aa0911ed81b2a12f241fc4b09075b13"}, - {file = "torch-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:f66aa6b9580a22b04d0af54fcd042f52406a8479e2b6a550e3d9f95963e168c8"}, - {file = "torch-2.0.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:1adb60d369f2650cac8e9a95b1d5758e25d526a34808f7448d0bd599e4ae9072"}, - {file = "torch-2.0.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:1bcffc16b89e296826b33b98db5166f990e3b72654a2b90673e817b16c50e32b"}, - {file = "torch-2.0.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:e10e1597f2175365285db1b24019eb6f04d53dcd626c735fc502f1e8b6be9875"}, - {file = "torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:423e0ae257b756bb45a4b49072046772d1ad0c592265c5080070e0767da4e490"}, - {file = "torch-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:8742bdc62946c93f75ff92da00e3803216c6cce9b132fbca69664ca38cfb3e18"}, - {file = "torch-2.0.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:c62df99352bd6ee5a5a8d1832452110435d178b5164de450831a3a8cc14dc680"}, - {file = "torch-2.0.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:671a2565e3f63b8fe8e42ae3e36ad249fe5e567435ea27b94edaa672a7d0c416"}, + {file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"}, + {file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"}, + {file = "torch-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:e9c8f4a311ac29fc7e8e955cfb7733deb5dbe1bdaabf5d4af2765695824b7e0d"}, + {file = "torch-1.12.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:976c3f997cea38ee91a0dd3c3a42322785414748d1761ef926b789dfa97c6134"}, + {file = "torch-1.12.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:68104e4715a55c4bb29a85c6a8d57d820e0757da363be1ba680fa8cc5be17b52"}, + {file = "torch-1.12.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:743784ccea0dc8f2a3fe6a536bec8c4763bd82c1352f314937cb4008d4805de1"}, + {file = "torch-1.12.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b5dbcca369800ce99ba7ae6dee3466607a66958afca3b740690d88168752abcf"}, + {file = "torch-1.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f3b52a634e62821e747e872084ab32fbcb01b7fa7dbb7471b6218279f02a178a"}, + {file = "torch-1.12.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:8a34a2fbbaa07c921e1b203f59d3d6e00ed379f2b384445773bd14e328a5b6c8"}, + {file = "torch-1.12.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:42f639501928caabb9d1d55ddd17f07cd694de146686c24489ab8c615c2871f2"}, + {file = "torch-1.12.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0b44601ec56f7dd44ad8afc00846051162ef9c26a8579dda0a02194327f2d55e"}, + {file = "torch-1.12.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cd26d8c5640c3a28c526d41ccdca14cf1cbca0d0f2e14e8263a7ac17194ab1d2"}, + {file = "torch-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:42e115dab26f60c29e298559dbec88444175528b729ae994ec4c65d56fe267dd"}, + {file = "torch-1.12.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:a8320ba9ad87e80ca5a6a016e46ada4d1ba0c54626e135d99b2129a4541c509d"}, + {file = "torch-1.12.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:03e31c37711db2cd201e02de5826de875529e45a55631d317aadce2f1ed45aa8"}, + {file = "torch-1.12.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9b356aea223772cd754edb4d9ecf2a025909b8615a7668ac7d5130f86e7ec421"}, + {file = "torch-1.12.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:6cf6f54b43c0c30335428195589bd00e764a6d27f3b9ba637aaa8c11aaf93073"}, + {file = "torch-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:f00c721f489089dc6364a01fd84906348fe02243d0af737f944fddb36003400d"}, + {file = "torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:bfec2843daa654f04fda23ba823af03e7b6f7650a873cdb726752d0e3718dada"}, + {file = "torch-1.12.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:69fe2cae7c39ccadd65a123793d30e0db881f1c1927945519c5c17323131437e"}, ] torchvision = [ - {file = "torchvision-0.15.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7754088774e810c5672b142a45dcf20b1bd986a5a7da90f8660c43dc43fb850c"}, - {file = "torchvision-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37eb138e13f6212537a3009ac218695483a635c404b6cc1d8e0d0d978026a86d"}, - {file = "torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:54143f7cc0797d199b98a53b7d21c3f97615762d4dd17ad45a41c7e80d880e73"}, - {file = "torchvision-0.15.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:1eefebf5fbd01a95fe8f003d623d941601c94b5cec547b420da89cb369d9cf96"}, - {file = "torchvision-0.15.2-cp310-cp310-win_amd64.whl", hash = "sha256:96fae30c5ca8423f4b9790df0f0d929748e32718d88709b7b567d2f630c042e3"}, - {file = "torchvision-0.15.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5f35f6bd5bcc4568e6522e4137fa60fcc72f4fa3e615321c26cd87e855acd398"}, - {file = "torchvision-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:757505a0ab2be7096cb9d2bf4723202c971cceddb72c7952a7e877f773de0f8a"}, - {file = "torchvision-0.15.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:012ad25cfd9019ff9b0714a168727e3845029be1af82296ff1e1482931fa4b80"}, - {file = "torchvision-0.15.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b02a7ffeaa61448737f39a4210b8ee60234bda0515a0c0d8562f884454105b0f"}, - {file = "torchvision-0.15.2-cp311-cp311-win_amd64.whl", hash = "sha256:10be76ceded48329d0a0355ac33da131ee3993ff6c125e4a02ab34b5baa2472c"}, - {file = "torchvision-0.15.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8f12415b686dba884fb086f53ac803f692be5a5cdd8a758f50812b30fffea2e4"}, - {file = "torchvision-0.15.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:31211c01f8b8ec33b8a638327b5463212e79a03e43c895f88049f97af1bd12fd"}, - {file = "torchvision-0.15.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c55f9889e436f14b4f84a9c00ebad0d31f5b4626f10cf8018e6c676f92a6d199"}, - {file = "torchvision-0.15.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9a192f2aa979438f23c20e883980b23d13268ab9f819498774a6d2eb021802c2"}, - {file = "torchvision-0.15.2-cp38-cp38-win_amd64.whl", hash = "sha256:c07071bc8d02aa8fcdfe139ab6a1ef57d3b64c9e30e84d12d45c9f4d89fb6536"}, - {file = "torchvision-0.15.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4790260fcf478a41c7ecc60a6d5200a88159fdd8d756e9f29f0f8c59c4a67a68"}, - {file = "torchvision-0.15.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:987ab62225b4151a11e53fd06150c5258ced24ac9d7c547e0e4ab6fbca92a5ce"}, - {file = "torchvision-0.15.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:63df26673e66cba3f17e07c327a8cafa3cce98265dbc3da329f1951d45966838"}, - {file = "torchvision-0.15.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b85f98d4cc2f72452f6792ab4463a3541bc5678a8cdd3da0e139ba2fe8b56d42"}, - {file = "torchvision-0.15.2-cp39-cp39-win_amd64.whl", hash = "sha256:07c462524cc1bba5190c16a9d47eac1fca024d60595a310f23c00b4ffff18b30"}, + {file = "torchvision-0.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:19286a733c69dcbd417b86793df807bd227db5786ed787c17297741a9b0d0fc7"}, + {file = "torchvision-0.13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:08f592ea61836ebeceb5c97f4d7a813b9d7dc651bbf7ce4401563ccfae6a21fc"}, + {file = "torchvision-0.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:ef5fe3ec1848123cd0ec74c07658192b3147dcd38e507308c790d5943e87b88c"}, + {file = "torchvision-0.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:099874088df104d54d8008f2a28539ca0117b512daed8bf3c2bbfa2b7ccb187a"}, + {file = "torchvision-0.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:8e4d02e4d8a203e0c09c10dfb478214c224d080d31efc0dbf36d9c4051f7f3c6"}, + {file = "torchvision-0.13.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5e631241bee3661de64f83616656224af2e3512eb2580da7c08e08b8c965a8ac"}, + {file = "torchvision-0.13.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:899eec0b9f3b99b96d6f85b9aa58c002db41c672437677b553015b9135b3be7e"}, + {file = "torchvision-0.13.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:83e9e2457f23110fd53b0177e1bc621518d6ea2108f570e853b768ce36b7c679"}, + {file = "torchvision-0.13.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7552e80fa222252b8b217a951c85e172a710ea4cad0ae0c06fbb67addece7871"}, + {file = "torchvision-0.13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f230a1a40ed70d51e463ce43df243ec520902f8725de2502e485efc5eea9d864"}, + {file = "torchvision-0.13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e9a563894f9fa40692e24d1aa58c3ef040450017cfed3598ff9637f404f3fe3b"}, + {file = "torchvision-0.13.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7cb789ceefe6dcd0dc8eeda37bfc45efb7cf34770eac9533861d51ca508eb5b3"}, + {file = "torchvision-0.13.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:87c137f343197769a51333076e66bfcd576301d2cd8614b06657187c71b06c4f"}, + {file = "torchvision-0.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:4d8bf321c4380854ef04613935fdd415dce29d1088a7ff99e06e113f0efe9203"}, + {file = "torchvision-0.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0298bae3b09ac361866088434008d82b99d6458fe8888c8df90720ef4b347d44"}, + {file = "torchvision-0.13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c5ed609c8bc88c575226400b2232e0309094477c82af38952e0373edef0003fd"}, + {file = "torchvision-0.13.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3567fb3def829229ec217c1e38f08c5128ff7fb65854cac17ebac358ff7aa309"}, + {file = "torchvision-0.13.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b167934a5943242da7b1e59318f911d2d253feeca0d13ad5d832b58eed943401"}, + {file = "torchvision-0.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:0e77706cc90462653620e336bb90daf03d7bf1b88c3a9a3037df8d111823a56e"}, ] tornado = [ {file = "tornado-6.3.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:db181eb3df8738613ff0a26f49e1b394aade05034b01200a63e9662f347d4415"}, diff --git a/pyproject.toml b/pyproject.toml index 8637c9c2b..65c142164 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,39 +3,31 @@ name = "encord-active" version = "v0.1.58" description = "Enable users to improve machine learning models in an active learning fashion via data, label, and model quality." authors = ["Cord Technologies Limited "] -classifiers = [ - "Environment :: Console", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Information Technology", - "Intended Audience :: Science/Research", - "Operating System :: OS Independent", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Scientific/Engineering :: Information Analysis", - "Topic :: Software Development", - "Topic :: Software Development :: Quality Assurance", +classifiers=[ + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Information Analysis", + "Topic :: Software Development", + "Topic :: Software Development :: Quality Assurance" ] documentation = "https://docs.encord.com/active/docs" homepage = "https://encord.com/encord-active/" -keywords = [ - "encord", - "active", - "machine", - "learning", - "data", - "label", - "model", - "quality", - "test", -] +keywords = ["encord", "active", "machine", "learning", "data", "label", "model", "quality", "test"] readme = "README.md" repository = "https://github.com/encord-team/encord-active" include = ['.env'] -packages = [{ include = "encord_active", from = "src" }] +packages = [ + { include = "encord_active", from = "src" }, +] license = "Apache-2.0" @@ -52,6 +44,8 @@ natsort = "^8.1.0" pandas = "^1.4.3" shapely = "^1.7.0" watchdog = "^2.1.9" +torch = "^1.12.1" +torchvision = "^0.13.1" faiss-cpu = "^1.7.2" matplotlib = "^3.5.3" scikit-learn = "^1.0.1" @@ -69,28 +63,23 @@ rich = "^12.6.0" PyYAML = "^6.0" toml = "^0.10.2" pydantic = "^1.10.2" -pycocotools = { version = "^2.0.6", optional = true } +pycocotools = {version = "^2.0.6", optional = true} psutil = "^5.9.4" pandera = "^0.13.4" -jupyterlab = { version = "^3.5.2", optional = true } -ipywidgets = { version = "^8.0.4", optional = true } +jupyterlab = {version = "^3.5.2", optional = true} +ipywidgets = {version = "^8.0.4", optional = true} inquirerpy = "^0.3.4" statsmodels = "^0.13.5" umap-learn = "^0.5.3" streamlit-plotly-events = "^0.0.6" encord-active-components = "^0.0.12" -llvmlite = "^0.39.1" # Pinning, as lower versions conflict with other libs +llvmlite = "^0.39.1" # Pinning, as lower versions conflict with other libs gitpython = "^3.1.31" prisma = "^0.8.2" fastapi = "^0.95.0" -uvicorn = { extras = ["standard"], version = "^0.21.1" } +uvicorn = {extras = ["standard"], version = "^0.21.1"} nodejs-bin = "^18.4.0a4" pyjwt = "^2.7.0" -torch = "^2.0.0" -clip = { git = "https://github.com/openai/CLIP.git" } -torchvision = "^0.15.2" -ftfy = "^6.1.1" -regex = "^2023.5.5" [tool.poetry.extras] coco = ["pycocotools"] diff --git a/src/encord_active/app/common/state.py b/src/encord_active/app/common/state.py index eb91d60fb..a21a86930 100644 --- a/src/encord_active/app/common/state.py +++ b/src/encord_active/app/common/state.py @@ -13,7 +13,7 @@ from encord_active.lib.db.connection import DBConnection from encord_active.lib.db.merged_metrics import MergedMetrics, initialize_merged_metrics from encord_active.lib.embeddings.utils import Embedding2DSchema -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import EmbeddingType from encord_active.lib.metrics.utils import MetricData, MetricSchema from encord_active.lib.model_predictions.reader import LabelSchema, OntologyObjectJSON from encord_active.lib.model_predictions.writer import OntologyClassificationJSON diff --git a/src/encord_active/app/label_onboarding/label_onboarding.py b/src/encord_active/app/label_onboarding/label_onboarding.py index 90e4329a4..a49b29cce 100644 --- a/src/encord_active/app/label_onboarding/label_onboarding.py +++ b/src/encord_active/app/label_onboarding/label_onboarding.py @@ -14,7 +14,7 @@ execute_metrics, get_metrics_by_embedding_type, ) -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import EmbeddingType from encord_active.lib.project.project import Project diff --git a/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py b/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py index 8be71c2e5..cd66770cd 100644 --- a/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py +++ b/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py @@ -28,7 +28,7 @@ from encord_active.lib.charts.scopes import PredictionMatchScope from encord_active.lib.embeddings.dimensionality_reduction import get_2d_embedding_data from encord_active.lib.embeddings.utils import Embedding2DSchema -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import EmbeddingType from encord_active.lib.metrics.utils import MetricSchema from encord_active.lib.model_predictions.classification_metrics import ( match_predictions_and_labels, @@ -253,7 +253,7 @@ def render_explorer(self): if EmbeddingType.IMAGE not in get_state().reduced_embeddings: get_state().reduced_embeddings[EmbeddingType.IMAGE] = get_2d_embedding_data( - get_state().project_paths, EmbeddingType.IMAGE + get_state().project_paths.embeddings, EmbeddingType.IMAGE ) metric_name = get_state().predictions.metric_datas_classification.selected_prediction diff --git a/src/encord_active/app/projects_page.py b/src/encord_active/app/projects_page.py index 9e75ec817..013b437f1 100644 --- a/src/encord_active/app/projects_page.py +++ b/src/encord_active/app/projects_page.py @@ -21,7 +21,7 @@ try_find_parent_project, ) from encord_active.lib.common.image_utils import show_image_and_draw_polygons -from encord_active.lib.metrics.types import AnnotationType +from encord_active.lib.metrics.metric import AnnotationType from encord_active.lib.metrics.utils import load_metric_metadata from encord_active.lib.model_predictions.writer import ( iterate_classification_attribute_options, diff --git a/src/encord_active/cli/main.py b/src/encord_active/cli/main.py index 9e80c36d5..b27df9376 100644 --- a/src/encord_active/cli/main.py +++ b/src/encord_active/cli/main.py @@ -205,7 +205,7 @@ def import_local_project( run_metrics_by_embedding_type, ) from encord_active.lib.metrics.heuristic.img_features import AreaMetric - from encord_active.lib.metrics.types import EmbeddingType + from encord_active.lib.metrics.metric import EmbeddingType from encord_active.lib.project.local import ( NoFilesFoundError, ProjectExistsError, diff --git a/src/encord_active/lib/db/merged_metrics.py b/src/encord_active/lib/db/merged_metrics.py index 9a325575d..341b7dea9 100644 --- a/src/encord_active/lib/db/merged_metrics.py +++ b/src/encord_active/lib/db/merged_metrics.py @@ -8,7 +8,7 @@ from encord_active.lib.db.connection import DBConnection from encord_active.lib.db.tags import Tag, TagScope from encord_active.lib.labels.classification import ClassificationType -from encord_active.lib.metrics.types import DataType, EmbeddingType +from encord_active.lib.metrics.metric import DataType, EmbeddingType from encord_active.lib.metrics.utils import load_metric_metadata from encord_active.lib.project.project_file_structure import ProjectFileStructure diff --git a/src/encord_active/lib/embeddings/embeddings.py b/src/encord_active/lib/embeddings/cnn.py similarity index 56% rename from src/encord_active/lib/embeddings/embeddings.py rename to src/encord_active/lib/embeddings/cnn.py index b999e0b00..bcd04858f 100644 --- a/src/encord_active/lib/embeddings/embeddings.py +++ b/src/encord_active/lib/embeddings/cnn.py @@ -1,37 +1,71 @@ +import logging +import os import pickle import time from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Tuple import numpy as np import torch +import torchvision.transforms as torch_transforms from encord.objects.common import PropertyType from encord.project_ontology.object_type import ObjectShape -from loguru import logger from PIL import Image +from torch import nn +from torchvision.models import EfficientNet_V2_S_Weights, efficientnet_v2_s +from torchvision.models.feature_extraction import create_feature_extractor from encord_active.lib.common.iterator import Iterator from encord_active.lib.common.utils import get_bbox_from_encord_label_object -from encord_active.lib.embeddings.models.clip_embedder import CLIPEmbedder -from encord_active.lib.embeddings.models.embedder_model import ImageEmbedder -from encord_active.lib.embeddings.utils import ClassificationAnswer, LabelEmbedding -from encord_active.lib.metrics.types import EmbeddingType -from encord_active.lib.project.project_file_structure import ProjectFileStructure +from encord_active.lib.embeddings.utils import ( + EMBEDDING_TYPE_TO_FILENAME, + ClassificationAnswer, + LabelEmbedding, +) +from encord_active.lib.metrics.metric import EmbeddingType +logger = logging.getLogger(__name__) +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -def get_default_embedder() -> ImageEmbedder: - return CLIPEmbedder() +def get_model_and_transforms() -> Tuple[nn.Module, nn.Module]: + weights = EfficientNet_V2_S_Weights.DEFAULT + model = efficientnet_v2_s(weights=weights).to(DEVICE) + embedding_extractor = create_feature_extractor(model, return_nodes={"avgpool": "my_avgpool"}) + for p in embedding_extractor.parameters(): + p.requires_grad = False + embedding_extractor.eval() + return embedding_extractor, weights.transforms() -def assemble_object_batch(data_unit: dict, img_path: Path) -> List[Image.Image]: - try: - image = np.asarray(Image.open(img_path).convert("RGB")) - except OSError: - logger.warning(f"Image with path {img_path} seems to be broken. Skipping.") - return [] - img_h, img_w, *_ = image.shape - img_batch: List[Image.Image] = [] +def adjust_image_channels(image: torch.Tensor) -> torch.Tensor: + if image.shape[0] == 4: + image = image[:3] + elif image.shape[0] < 3: + image = image.repeat(3, 1, 1) + + return image + + +def image_path_to_tensor(image_path: Path) -> torch.Tensor: + image = Image.open(image_path.as_posix()) + transform = torch_transforms.ToTensor() + image = transform(image) + + image = adjust_image_channels(image) + + return image + + +def assemble_object_batch(data_unit: dict, img_path: Path, transforms: Optional[nn.Module]): + if transforms is None: + transforms = torch.nn.Sequential() + + try: + image = image_path_to_tensor(img_path) + except Exception: + return None + img_batch: List[torch.Tensor] = [] for obj in data_unit["labels"].get("objects", []): if obj["shape"] in [ @@ -42,8 +76,8 @@ def assemble_object_batch(data_unit: dict, img_path: Path) -> List[Image.Image]: try: out = get_bbox_from_encord_label_object( obj, - w=img_w, - h=img_h, + image.shape[2], + image.shape[1], ) if out is None: @@ -51,55 +85,25 @@ def assemble_object_batch(data_unit: dict, img_path: Path) -> List[Image.Image]: x, y, w, h = out img_patch = image[:, y : y + h, x : x + w] - img_batch.append(Image.fromarray(img_patch)) + img_batch.append(transforms(img_patch)) except Exception as e: logger.warning(f"Error with object {obj['objectHash']}: {e}") continue - return img_batch + return torch.stack(img_batch).to(DEVICE) if len(img_batch) > 0 else None @torch.inference_mode() -def generate_image_embeddings( - iterator: Iterator, feature_extractor: Optional[ImageEmbedder] = None, batch_size=100 -) -> List[LabelEmbedding]: +def generate_cnn_image_embeddings(iterator: Iterator) -> List[LabelEmbedding]: start = time.perf_counter() - if feature_extractor is None: - feature_extractor = get_default_embedder() - - raw_embeddings: list[np.ndarray] = [] - batch = [] - skip: set[int] = set() - for i, (data_unit, img_pth) in enumerate(iterator.iterate(desc="Embedding image data.")): - if img_pth is None: - skip.add(i) - continue - try: - batch.append(Image.open(img_pth).convert("RGB")) - except OSError: - logger.warning(f"Image with path {img_pth} seems to be broken. Skipping.") - skip.add(i) - continue - - if len(batch) >= batch_size: - raw_embeddings.append(feature_extractor.embed_images(batch)) - batch = [] - - if batch: - raw_embeddings.append(feature_extractor.embed_images(batch)) - - if len(raw_embeddings) > 1: - raw_np_embeddings = np.concatenate(raw_embeddings) - else: - raw_np_embeddings = raw_embeddings[0] + feature_extractor, transforms = get_model_and_transforms() collections: List[LabelEmbedding] = [] - offset = 0 - for i, (data_unit, img_pth) in enumerate(iterator.iterate(desc="Storing embeddings.")): - if i in skip: - offset += 1 + for data_unit, img_pth in iterator.iterate(desc="Embedding image data."): + embedding = get_embdding_for_image(feature_extractor, transforms, img_pth) + + if embedding is None: continue - embedding = raw_np_embeddings[i - offset] entry = LabelEmbedding( url=data_unit["data_link"], label_row=iterator.label_hash, @@ -123,48 +127,45 @@ def generate_image_embeddings( @torch.inference_mode() -def generate_object_embeddings( - iterator: Iterator, feature_extractor: Optional[ImageEmbedder] = None -) -> List[LabelEmbedding]: +def generate_cnn_object_embeddings(iterator: Iterator) -> List[LabelEmbedding]: start = time.perf_counter() - if feature_extractor is None: - feature_extractor = get_default_embedder() + feature_extractor, transforms = get_model_and_transforms() collections: List[LabelEmbedding] = [] for data_unit, img_pth in iterator.iterate(desc="Embedding object data."): if img_pth is None: continue - batch = assemble_object_batch(data_unit, img_pth) - if not batch: + batches = assemble_object_batch(data_unit, img_pth, transforms=transforms) + if batches is None: continue - embeddings = feature_extractor.embed_images(batch) - for obj, emb in zip(data_unit["labels"].get("objects", []), embeddings): - if obj["shape"] not in [ + embeddings = feature_extractor(batches)["my_avgpool"] + embeddings_torch = torch.flatten(embeddings, start_dim=1).cpu().detach().numpy() + + for obj, emb in zip(data_unit["labels"].get("objects", []), embeddings_torch): + if obj["shape"] in [ ObjectShape.POLYGON.value, ObjectShape.BOUNDING_BOX.value, ObjectShape.ROTATABLE_BOUNDING_BOX.value, ]: - continue - - last_edited_by = obj["lastEditedBy"] if "lastEditedBy" in obj.keys() else obj["createdBy"] - - entry = LabelEmbedding( - url=data_unit["data_link"], - label_row=iterator.label_hash, - data_unit=data_unit["data_hash"], - frame=iterator.frame, - labelHash=obj["objectHash"], - lastEditedBy=last_edited_by, - featureHash=obj["featureHash"], - name=obj["name"], - dataset_title=iterator.dataset_title, - embedding=emb, - classification_answers=None, - ) + last_edited_by = obj["lastEditedBy"] if "lastEditedBy" in obj.keys() else obj["createdBy"] + + entry = LabelEmbedding( + url=data_unit["data_link"], + label_row=iterator.label_hash, + data_unit=data_unit["data_hash"], + frame=iterator.frame, + labelHash=obj["objectHash"], + lastEditedBy=last_edited_by, + featureHash=obj["featureHash"], + name=obj["name"], + dataset_title=iterator.dataset_title, + embedding=emb, + classification_answers=None, + ) - collections.append(entry) + collections.append(entry) logger.info( f"Generating {len(iterator)} embeddings took {str(time.perf_counter() - start)} seconds", @@ -174,10 +175,8 @@ def generate_object_embeddings( @torch.inference_mode() -def generate_classification_embeddings( - iterator: Iterator, feature_extractor: Optional[ImageEmbedder] -) -> List[LabelEmbedding]: - image_collections = get_embeddings(iterator, embedding_type=EmbeddingType.IMAGE) +def generate_cnn_classification_embeddings(iterator: Iterator) -> List[LabelEmbedding]: + image_collections = get_cnn_embeddings(iterator, embedding_type=EmbeddingType.IMAGE) ontology_class_hash_to_index: dict[str, dict] = {} ontology_class_hash_to_question_hash: dict[str, str] = {} @@ -190,18 +189,14 @@ def generate_classification_embeddings( ontology_class_hash = class_label.feature_node_hash ontology_class_hash_to_index[ontology_class_hash] = {} ontology_class_hash_to_question_hash[ontology_class_hash] = class_question.feature_node_hash - for index, option in enumerate(class_question.options): # type: ignore + for index, option in enumerate(class_question.options): ontology_class_hash_to_index[ontology_class_hash][option.feature_node_hash] = index start = time.perf_counter() - if feature_extractor is None: - feature_extractor = get_default_embedder() + feature_extractor, transforms = get_model_and_transforms() collections = [] for data_unit, img_pth in iterator.iterate(desc="Embedding classification data."): - if img_pth is None: - continue - matching_image_collections = [ collection for collection in image_collections @@ -211,13 +206,7 @@ def generate_classification_embeddings( ] if not matching_image_collections: - try: - image = Image.open(img_pth).convert("RGB") - except OSError: - logger.warning(f"Image with path {img_pth} seems to be broken. Skipping.") - continue - - embedding = feature_extractor.embed_image(image) + embedding = get_embdding_for_image(feature_extractor, transforms, img_pth) else: embedding = matching_image_collections[0]["embedding"] @@ -276,43 +265,59 @@ def generate_classification_embeddings( return collections -def get_embeddings(iterator: Iterator, embedding_type: EmbeddingType, *, force: bool = False) -> List[LabelEmbedding]: +def get_cnn_embeddings( + iterator: Iterator, embedding_type: EmbeddingType, *, force: bool = False +) -> List[LabelEmbedding]: if embedding_type not in [EmbeddingType.CLASSIFICATION, EmbeddingType.IMAGE, EmbeddingType.OBJECT]: - raise Exception(f"Undefined embedding type '{embedding_type}' for get_embeddings method") + raise Exception(f"Undefined embedding type '{embedding_type}' for get_cnn_embeddings method") - pfs = ProjectFileStructure(iterator.cache_dir) - embedding_path = pfs.get_embeddings_file(embedding_type) + target_folder = os.path.join(iterator.cache_dir, "embeddings") + embedding_path = os.path.join(target_folder, f"{EMBEDDING_TYPE_TO_FILENAME[embedding_type]}") if force: logger.info("Regenerating CNN embeddings...") - embeddings = generate_embeddings(iterator, embedding_type, embedding_path) + cnn_embeddings = generate_cnn_embeddings(iterator, embedding_type, embedding_path) else: try: with open(embedding_path, "rb") as f: - embeddings = pickle.load(f) + cnn_embeddings = pickle.load(f) except FileNotFoundError: logger.info(f"{embedding_path} not found. Generating embeddings...") - embeddings = generate_embeddings(iterator, embedding_type, embedding_path) - return embeddings + cnn_embeddings = generate_cnn_embeddings(iterator, embedding_type, embedding_path) + + return cnn_embeddings -def generate_embeddings( - iterator: Iterator, embedding_type: EmbeddingType, target: Path, feature_extractor: Optional[ImageEmbedder] = None -): +def generate_cnn_embeddings(iterator: Iterator, embedding_type: EmbeddingType, target: str): if embedding_type == EmbeddingType.IMAGE: - embeddings = generate_image_embeddings(iterator, feature_extractor=feature_extractor) + cnn_embeddings = generate_cnn_image_embeddings(iterator) elif embedding_type == EmbeddingType.OBJECT: - embeddings = generate_object_embeddings(iterator, feature_extractor=feature_extractor) + cnn_embeddings = generate_cnn_object_embeddings(iterator) elif embedding_type == EmbeddingType.CLASSIFICATION: - embeddings = generate_classification_embeddings(iterator, feature_extractor=feature_extractor) + cnn_embeddings = generate_cnn_classification_embeddings(iterator) else: raise ValueError(f"Unsupported embedding type {embedding_type}") - if embeddings: - target.parent.mkdir(parents=True, exist_ok=True) - target.write_bytes(pickle.dumps(embeddings)) + if cnn_embeddings: + target_path = Path(target) + target_path.parent.mkdir(parents=True, exist_ok=True) + target_path.write_bytes(pickle.dumps(cnn_embeddings)) logger.info("Done!") - return embeddings + return cnn_embeddings + + +def get_embdding_for_image(feature_extractor, transforms, img_pth: Optional[Path] = None) -> Optional[np.ndarray]: + if img_pth is None: + return None + + try: + image = image_path_to_tensor(img_pth) + transformed_image = transforms(image).unsqueeze(0) + embedding = feature_extractor(transformed_image.to(DEVICE))["my_avgpool"] + return torch.flatten(embedding).cpu().detach().numpy() + except: + logger.error(f"Falied generating embedding for file: {img_pth}") + return None diff --git a/src/encord_active/lib/embeddings/dimensionality_reduction.py b/src/encord_active/lib/embeddings/dimensionality_reduction.py index e2c037471..fd3217baf 100644 --- a/src/encord_active/lib/embeddings/dimensionality_reduction.py +++ b/src/encord_active/lib/embeddings/dimensionality_reduction.py @@ -9,11 +9,11 @@ from pandera.typing import DataFrame from encord_active.lib.embeddings.utils import ( + EMBEDDING_REDUCED_TO_FILENAME, Embedding2DSchema, EmbeddingType, load_collections, ) -from encord_active.lib.project.project_file_structure import ProjectFileStructure warnings.filterwarnings("ignore", "n_neighbors is larger than the dataset size", category=UserWarning) MIN_SAMPLES = 4 # The number 4 is experimentally determined, less than this creates error for UMAP calculation @@ -23,9 +23,8 @@ def generate_2d_embedding_data(embedding_type: EmbeddingType, project_dir: Path) """ This function transforms high dimensional embedding data to 2D and saves it to a file """ - pfs = ProjectFileStructure(project_dir) - collections = load_collections(embedding_type, pfs.embeddings) + collections = load_collections(embedding_type, project_dir / "embeddings") if not collections: return @@ -63,14 +62,15 @@ def generate_2d_embedding_data(embedding_type: EmbeddingType, project_dir: Path) embeddings_2d_collection["x"].append(embeddings_2d[counter, 0]) embeddings_2d_collection["y"].append(embeddings_2d[counter, 1]) - target_path = pfs.get_embeddings_file(embedding_type, reduced=True) + target_path = Path(project_dir / "embeddings" / EMBEDDING_REDUCED_TO_FILENAME[embedding_type]) target_path.write_bytes(pickle.dumps(embeddings_2d_collection)) def get_2d_embedding_data( - project_file_structure: ProjectFileStructure, embedding_type: EmbeddingType + embeddings_path: Path, embedding_type: EmbeddingType ) -> Optional[DataFrame[Embedding2DSchema]]: - embedding_file_path = project_file_structure.get_embeddings_file(embedding_type, reduced=True) + + embedding_file_path = embeddings_path / EMBEDDING_REDUCED_TO_FILENAME[embedding_type] if not embedding_file_path.exists(): return None diff --git a/src/encord_active/lib/embeddings/models/__init__.py b/src/encord_active/lib/embeddings/models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/encord_active/lib/embeddings/models/clip_embedder.py b/src/encord_active/lib/embeddings/models/clip_embedder.py deleted file mode 100644 index 4171f4229..000000000 --- a/src/encord_active/lib/embeddings/models/clip_embedder.py +++ /dev/null @@ -1,20 +0,0 @@ -import clip -import numpy as np -import torch -from PIL import Image - -from encord_active.lib.embeddings.models.embedder_model import ImageEmbedder - - -class CLIPEmbedder(ImageEmbedder): - def __init__(self): - super().__init__(supports_text_embeddings=True) - self.model, self.preprocess = clip.load("ViT-B/32", device=self.device) - - def embed_texts(self, texts: list[str]) -> np.ndarray: - tokens = torch.stack([clip.tokenize(t) for t in texts]) - return self.execute_with_largest_batch_size(self.model.encode_text, tokens) - - def embed_images(self, images: list[Image.Image]) -> np.ndarray: - tensors = torch.stack([self.preprocess(i) for i in images]) # type: ignore - return self.execute_with_largest_batch_size(self.model.encode_image, tensors) diff --git a/src/encord_active/lib/embeddings/models/embedder_model.py b/src/encord_active/lib/embeddings/models/embedder_model.py deleted file mode 100644 index f4d6573e4..000000000 --- a/src/encord_active/lib/embeddings/models/embedder_model.py +++ /dev/null @@ -1,114 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Callable - -import numpy as np -import torch -from PIL import Image - - -class ImageEmbedder(ABC): - def __init__(self, supports_text_embeddings: bool = False): - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.max_batch_size: dict[Callable[[torch.Tensor], torch.Tensor], int] = {} - self._supports_text_embeddings = supports_text_embeddings - - @torch.inference_mode() - def execute_with_largest_batch_size( - self, function: Callable[[torch.Tensor], torch.Tensor], input_: torch.Tensor - ) -> np.ndarray: - """ - Tries to utillize the GPU as much as possible by starting with the entire input. - If it doesn't fit on the GPU, half the input is tried. - If that doesn't fit halving again, and so on. - The largest successful batch size will be remembered for successive calls to the function. - - Args: - function: The function to be executed. - input_: The tensor input. - - Returns: - A numpy array with the result after concatenating the results of each batch. - - """ - if self.device == "cpu": - return function(input_.to(self.device)).numpy() - - n, *_ = input_.shape - if function in self.max_batch_size: - bs = self.max_batch_size[function] - else: - bs = input_.shape[0] - - bs_updated = False - while bs > 0: - try: - out = [] - for i in range(n // bs + 1): - start = i * bs - stop = (i + 1) * bs - - if start >= n: - break - - batch_out = function(input_[start:stop].to(self.device)).cpu() - out.append(batch_out) - - if bs_updated: - self.max_batch_size[function] = bs - - if len(out) == 1: - return out[0].numpy() - - return torch.concat(out, dim=0).numpy() - - except torch.cuda.OutOfMemoryError: # type: ignore - torch.cuda.empty_cache() - bs = bs // 2 - bs_updated = True - - raise RuntimeError( - 'Not enough GPU memory to compute embeddings. Consider disabling GPU with the `CUDA_VISIBLE_DEVICES=""' - ) - - def _ensure_text_embeddings_enabled(self): - if not self._supports_text_embeddings: - raise RuntimeError("Embedder does not support text embeddings") - - def embed_text(self, text: str) -> np.ndarray: - self._ensure_text_embeddings_enabled() - return self.embed_texts([text]).squeeze() - - def embed_image(self, image: Image.Image): - return self.embed_images([image]).squeeze() - - def embed_texts(self, texts: list[str]) -> np.ndarray: - """ - If the `supports_text_embeddings` flag is enabled in the `__init__` - command, provide a function for embedding text. - - Args: - texts: The list of length N of text that should be embedded. - - Returns: - A numpy array of shape [N, -1] with the result after concatenating - the results of each text. - - """ - self._ensure_text_embeddings_enabled() - return np.empty((0, 512)) - - @abstractmethod - def embed_images(self, images: list[Image.Image]) -> np.ndarray: - """ - The function for embeddings images. Note that you can utililze the - `self.execute_with_largest_batch_size` function to utilize the GPU - optimally. - - Args: - images: The list (of length N) of images that should be embedded. - - Returns: - A numpy array of shape [N, -1] with the resulting embeddings after - concatenating the results of each image. - """ - ... diff --git a/src/encord_active/lib/embeddings/utils.py b/src/encord_active/lib/embeddings/utils.py index ba0026869..c37087327 100644 --- a/src/encord_active/lib/embeddings/utils.py +++ b/src/encord_active/lib/embeddings/utils.py @@ -9,7 +9,7 @@ from faiss import IndexFlatL2 from pandera.typing import Series -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import EmbeddingType from encord_active.lib.metrics.utils import IdentifierSchema diff --git a/src/encord_active/lib/encord/project_sync.py b/src/encord_active/lib/encord/project_sync.py index a32a8f1b6..574284fbb 100644 --- a/src/encord_active/lib/encord/project_sync.py +++ b/src/encord_active/lib/encord/project_sync.py @@ -12,12 +12,12 @@ from encord_active.lib.db.connection import DBConnection from encord_active.lib.db.merged_metrics import MergedMetrics from encord_active.lib.embeddings.utils import ( + EMBEDDING_REDUCED_TO_FILENAME, LabelEmbedding, load_collections, save_collections, ) -from encord_active.lib.metrics.metric import MetricMetadata -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import EmbeddingType, MetricMetadata from encord_active.lib.project import ProjectFileStructure from encord_active.lib.project.metadata import fetch_project_meta @@ -51,7 +51,7 @@ def _update_identifiers(identifier: str): new_lr, new_du = renaming_map.get(old_lr, old_lr), renaming_map.get(old_du, old_du) return identifier.replace(old_du, new_du).replace(old_lr, new_lr) - embedding_file = project_file_structure.get_embeddings_file(embedding_type, reduced=True) + embedding_file = project_file_structure.embeddings / EMBEDDING_REDUCED_TO_FILENAME[embedding_type] if not embedding_file.is_file(): return @@ -133,14 +133,14 @@ def create_filtered_embeddings( save_collections(embedding_type, target_project_structure.embeddings, collection) for embedding_type in [EmbeddingType.IMAGE, EmbeddingType.CLASSIFICATION, EmbeddingType.OBJECT]: - curr_embedding_file = curr_project_structure.get_embeddings_file(embedding_type) - if not curr_embedding_file.exists(): + embedding_file_name = EMBEDDING_REDUCED_TO_FILENAME[embedding_type] + if not (curr_project_structure.embeddings / embedding_file_name).exists(): continue - embeddings = pickle.loads(curr_embedding_file.read_bytes()) + embeddings = pickle.loads(Path(curr_project_structure.embeddings / embedding_file_name).read_bytes()) embeddings_df = pd.DataFrame.from_dict(embeddings) embeddings_df = embeddings_df[embeddings_df["identifier"].isin(filtered_df.identifier)] filtered_embeddings = embeddings_df.to_dict(orient="list") - target_project_structure.get_embeddings_file(embedding_type).write_bytes(pickle.dumps(filtered_embeddings)) + (target_project_structure.embeddings / embedding_file_name).write_bytes(pickle.dumps(filtered_embeddings)) def get_filtered_objects(filtered_labels, label_row_hash, data_unit_hash, objects): diff --git a/src/encord_active/lib/metrics/acquisition_functions.py b/src/encord_active/lib/metrics/acquisition_functions.py index 09c926efb..a42747a4a 100644 --- a/src/encord_active/lib/metrics/acquisition_functions.py +++ b/src/encord_active/lib/metrics/acquisition_functions.py @@ -8,11 +8,11 @@ from encord_active.lib.common.iterator import Iterator from encord_active.lib.labels.classification import ClassificationType from encord_active.lib.labels.object import ObjectShape -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import ( +from encord_active.lib.metrics.metric import ( AnnotationType, DataType, EmbeddingType, + Metric, MetricType, ) from encord_active.lib.metrics.writer import CSVMetricWriter diff --git a/src/encord_active/lib/metrics/example.py b/src/encord_active/lib/metrics/example.py index 88858e15c..6f601920e 100644 --- a/src/encord_active/lib/metrics/example.py +++ b/src/encord_active/lib/metrics/example.py @@ -1,8 +1,12 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/execute.py b/src/encord_active/lib/metrics/execute.py index b42b7d427..70b20c9f3 100644 --- a/src/encord_active/lib/metrics/execute.py +++ b/src/encord_active/lib/metrics/execute.py @@ -16,8 +16,15 @@ from encord_active.lib.common.writer import StatisticsObserver from encord_active.lib.labels.classification import ClassificationType from encord_active.lib.labels.object import ObjectShape -from encord_active.lib.metrics.metric import Metric, SimpleMetric, StatsMetadata -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + EmbeddingType, + Metric, + MetricType, + SimpleMetric, + StatsMetadata, +) from encord_active.lib.metrics.utils import get_embedding_type from encord_active.lib.metrics.writer import CSVMetricWriter from encord_active.lib.model_predictions.writer import MainPredictionType @@ -26,9 +33,7 @@ logger = logger.opt(colors=True) -def get_metrics( - module: Optional[Union[str, list[str]]] = None, filter_func: Callable[[Type[Metric]], bool] = lambda x: True -): +def get_metrics(module: Optional[Union[str, list[str]]] = None, filter_func=lambda x: True): if module is None: module = ["geometric", "heuristic", "semantic"] elif isinstance(module, str): @@ -78,6 +83,18 @@ def run_metrics_by_embedding_type(embedding_type: EmbeddingType, **kwargs): execute_metrics(metrics, **kwargs) +def run_all_heuristic_metrics(): + run_metrics(filter_func=lambda x: x.metadata.metric_type == MetricType.HEURISTIC) + + +def run_all_image_metrics(): + run_metrics(filter_func=lambda x: x.metadata.data_type == DataType.IMAGE) + + +def run_all_polygon_metrics(): + run_metrics(filter_func=lambda x: x.metadata.annotation_type in [AnnotationType.OBJECT.POLYGON, AnnotationType.ALL]) + + def run_all_prediction_metrics(**kwargs): # Return all metrics that apply according to the prediction type. def filter_objects(m: Type[Metric]): @@ -92,7 +109,7 @@ def filter_objects(m: Type[Metric]): else: return isinstance(at, ObjectShape) - def filter_classifications(m: Type[Metric]) -> bool: + def filter_classifications(m: Type[Metric]): at = m().metadata.annotation_type # type: ignore if isinstance(at, list): @@ -111,7 +128,7 @@ def filter_classifications(m: Type[Metric]) -> bool: raise ValueError(f"Undefined prediction type {kwargs['prediction_type']}") -def run_metrics(filter_func: Callable[[Type[Metric]], bool] = lambda x: True, **kwargs): +def run_metrics(filter_func: Callable[[Metric], bool] = lambda x: True, **kwargs): metrics = list(map(load_metric, get_metrics(filter_func=filter_func))) execute_metrics(metrics, **kwargs) diff --git a/src/encord_active/lib/metrics/geometric/annotation_duplicates.py b/src/encord_active/lib/metrics/geometric/annotation_duplicates.py index da54d0fa4..fb4c068f3 100644 --- a/src/encord_active/lib/metrics/geometric/annotation_duplicates.py +++ b/src/encord_active/lib/metrics/geometric/annotation_duplicates.py @@ -2,8 +2,12 @@ from encord_active.lib.common.iterator import Iterator from encord_active.lib.common.utils import get_iou, get_polygon -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/geometric/hu_static.py b/src/encord_active/lib/metrics/geometric/hu_static.py index 19a7b17d0..cc688c5ca 100644 --- a/src/encord_active/lib/metrics/geometric/hu_static.py +++ b/src/encord_active/lib/metrics/geometric/hu_static.py @@ -8,8 +8,12 @@ from encord_active.lib.common.utils import get_object_coordinates, patch_sklearn_linalg from encord_active.lib.embeddings.hu_moments import get_hu_embeddings from encord_active.lib.embeddings.writer import CSVEmbeddingWriter -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/geometric/hu_temporal.py b/src/encord_active/lib/metrics/geometric/hu_temporal.py index eedfce229..30503d535 100644 --- a/src/encord_active/lib/metrics/geometric/hu_temporal.py +++ b/src/encord_active/lib/metrics/geometric/hu_temporal.py @@ -5,8 +5,12 @@ from encord_active.lib.common.iterator import Iterator from encord_active.lib.embeddings.hu_moments import get_hu_embeddings -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/geometric/image_border_closeness.py b/src/encord_active/lib/metrics/geometric/image_border_closeness.py index b11aebe17..4eb9b0be5 100644 --- a/src/encord_active/lib/metrics/geometric/image_border_closeness.py +++ b/src/encord_active/lib/metrics/geometric/image_border_closeness.py @@ -2,8 +2,12 @@ from encord_active.lib.common.iterator import Iterator from encord_active.lib.common.utils import get_object_coordinates -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/geometric/object_size.py b/src/encord_active/lib/metrics/geometric/object_size.py index 51f585c9e..76e55d005 100644 --- a/src/encord_active/lib/metrics/geometric/object_size.py +++ b/src/encord_active/lib/metrics/geometric/object_size.py @@ -8,8 +8,12 @@ get_polygon, ) from encord_active.lib.labels.object import BoxShapes, ObjectShape -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/geometric/occlusion_detection_video.py b/src/encord_active/lib/metrics/geometric/occlusion_detection_video.py index a09403e27..90c43a197 100644 --- a/src/encord_active/lib/metrics/geometric/occlusion_detection_video.py +++ b/src/encord_active/lib/metrics/geometric/occlusion_detection_video.py @@ -4,8 +4,12 @@ from tqdm.auto import tqdm from encord_active.lib.common.iterator import Iterator -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/heuristic/_annotation_time.py b/src/encord_active/lib/metrics/heuristic/_annotation_time.py index b99ac64dd..d6481704d 100644 --- a/src/encord_active/lib/metrics/heuristic/_annotation_time.py +++ b/src/encord_active/lib/metrics/heuristic/_annotation_time.py @@ -1,8 +1,12 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/heuristic/high_iou_changing_classes.py b/src/encord_active/lib/metrics/heuristic/high_iou_changing_classes.py index 5d240a03c..4f587b1de 100644 --- a/src/encord_active/lib/metrics/heuristic/high_iou_changing_classes.py +++ b/src/encord_active/lib/metrics/heuristic/high_iou_changing_classes.py @@ -2,8 +2,12 @@ from encord_active.lib.common.iterator import Iterator from encord_active.lib.common.utils import get_iou, get_polygon -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/heuristic/img_features.py b/src/encord_active/lib/metrics/heuristic/img_features.py index 2e1872c3d..1cd9b4689 100644 --- a/src/encord_active/lib/metrics/heuristic/img_features.py +++ b/src/encord_active/lib/metrics/heuristic/img_features.py @@ -5,8 +5,13 @@ from encord_active.lib.common.iterator import Iterator from encord_active.lib.common.utils import get_du_size -from encord_active.lib.metrics.metric import Metric, SimpleMetric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, + SimpleMetric, +) from encord_active.lib.metrics.writer import CSVMetricWriter @@ -39,6 +44,7 @@ def __init__( saturation_filters=[50, 255], value_filters=[20, 255], ): + super().__init__( title=f"{color_name} Values".title(), short_description=f"Ranks images by how {color_name.lower()} the average value of the image is.", diff --git a/src/encord_active/lib/metrics/heuristic/missing_objects_and_wrong_tracks.py b/src/encord_active/lib/metrics/heuristic/missing_objects_and_wrong_tracks.py index ccfec91d9..0ae3045eb 100644 --- a/src/encord_active/lib/metrics/heuristic/missing_objects_and_wrong_tracks.py +++ b/src/encord_active/lib/metrics/heuristic/missing_objects_and_wrong_tracks.py @@ -9,8 +9,12 @@ from encord_active.lib.common.iterator import Iterator from encord_active.lib.common.utils import get_iou, get_polygon -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) diff --git a/src/encord_active/lib/metrics/heuristic/object_counting.py b/src/encord_active/lib/metrics/heuristic/object_counting.py index f054cf8a7..9d69801de 100644 --- a/src/encord_active/lib/metrics/heuristic/object_counting.py +++ b/src/encord_active/lib/metrics/heuristic/object_counting.py @@ -1,6 +1,10 @@ from encord_active.lib.common.iterator import Iterator -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter diff --git a/src/encord_active/lib/metrics/heuristic/random.py b/src/encord_active/lib/metrics/heuristic/random.py index 4c655a877..9aea2d2bc 100644 --- a/src/encord_active/lib/metrics/heuristic/random.py +++ b/src/encord_active/lib/metrics/heuristic/random.py @@ -1,8 +1,12 @@ import numpy as np from encord_active.lib.common.iterator import Iterator -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.writer import CSVMetricWriter diff --git a/src/encord_active/lib/metrics/metric.py b/src/encord_active/lib/metrics/metric.py index 5b7101687..ee7f74fc4 100644 --- a/src/encord_active/lib/metrics/metric.py +++ b/src/encord_active/lib/metrics/metric.py @@ -1,24 +1,46 @@ from abc import ABC, abstractmethod +from enum import Enum from hashlib import md5 -from typing import List, Optional +from typing import List, Optional, Union import numpy as np from pydantic import BaseModel from encord_active.lib.common.iterator import Iterator from encord_active.lib.common.writer import StatisticsObserver - -# TODO: delete AnnotationType import on major version bump 👇 -from encord_active.lib.metrics.types import ( # pylint: disable=unused-import - AnnotationType, - AnnotationTypeUnion, - DataType, - EmbeddingType, - MetricType, -) +from encord_active.lib.labels.classification import ClassificationType +from encord_active.lib.labels.object import ObjectShape from encord_active.lib.metrics.writer import CSVMetricWriter +class MetricType(str, Enum): + SEMANTIC = "semantic" + GEOMETRIC = "geometric" + HEURISTIC = "heuristic" + + +class DataType(str, Enum): + IMAGE = "image" + SEQUENCE = "sequence" + + +class EmbeddingType(str, Enum): + CLASSIFICATION = "classification" + OBJECT = "object" + HU_MOMENTS = "hu_moments" + IMAGE = "image" + + +AnnotationTypeUnion = Union[ObjectShape, ClassificationType] + + +class AnnotationType: + NONE: List[AnnotationTypeUnion] = [] + OBJECT = ObjectShape + CLASSIFICATION = ClassificationType + ALL = [*OBJECT, *CLASSIFICATION] + + class StatsMetadata(BaseModel): min_value: float = 0.0 max_value: float = 0.0 @@ -59,7 +81,7 @@ def __init__( long_description: str, metric_type: MetricType, data_type: DataType, - annotation_type: List[AnnotationTypeUnion], + annotation_type: List[Union[ObjectShape, ClassificationType]], embedding_type: Optional[EmbeddingType] = None, doc_url: Optional[str] = None, ): @@ -88,7 +110,7 @@ def __init__( long_description: str, metric_type: MetricType, data_type: DataType, - annotation_type: List[AnnotationTypeUnion] = [], + annotation_type: List[Union[ObjectShape, ClassificationType]] = [], embedding_type: Optional[EmbeddingType] = None, doc_url: Optional[str] = None, ): diff --git a/src/encord_active/lib/metrics/semantic/_class_uncertainty.py b/src/encord_active/lib/metrics/semantic/_class_uncertainty.py index 02cdf4c79..dc201b6af 100644 --- a/src/encord_active/lib/metrics/semantic/_class_uncertainty.py +++ b/src/encord_active/lib/metrics/semantic/_class_uncertainty.py @@ -10,12 +10,12 @@ from torch.nn import LeakyReLU from encord_active.lib.common.iterator import Iterator -from encord_active.lib.embeddings.embeddings import get_embeddings -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import ( +from encord_active.lib.embeddings.cnn import get_cnn_embeddings +from encord_active.lib.metrics.metric import ( AnnotationType, DataType, EmbeddingType, + Metric, MetricType, ) from encord_active.lib.metrics.writer import CSVMetricWriter @@ -181,7 +181,7 @@ def get_batches_and_model(resnet_embeddings_df): def preliminaries(iterator): model_path = os.path.join(iterator.cache_dir, "models", f"{Path(__file__).stem}_classifier.pt") os.makedirs(os.path.dirname(model_path), exist_ok=True) - resnet_embeddings_df = get_embeddings(iterator, embedding_type=EmbeddingType.OBJECT, force=False) + resnet_embeddings_df = get_cnn_embeddings(iterator, embedding_type=EmbeddingType.OBJECT, force=False) batches, classifier, idx_to_counts, name_to_idx = get_batches_and_model(resnet_embeddings_df) if not os.path.isfile(model_path): train_model(classifier, model_path, batches, idx_to_counts, name_to_idx) diff --git a/src/encord_active/lib/metrics/semantic/_heatmap_uncertainty.py b/src/encord_active/lib/metrics/semantic/_heatmap_uncertainty.py index b796bfc56..4e9c24a99 100644 --- a/src/encord_active/lib/metrics/semantic/_heatmap_uncertainty.py +++ b/src/encord_active/lib/metrics/semantic/_heatmap_uncertainty.py @@ -12,9 +12,13 @@ from torchvision.models.segmentation import DeepLabV3_MobileNet_V3_Large_Weights from encord_active.lib.common.iterator import Iterator -from encord_active.lib.metrics.metric import Metric +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.semantic._class_uncertainty import train_test_split -from encord_active.lib.metrics.types import AnnotationType, DataType, MetricType from encord_active.lib.metrics.writer import CSVMetricWriter DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/src/encord_active/lib/metrics/semantic/image_difficulty.py b/src/encord_active/lib/metrics/semantic/image_difficulty.py index 535aabbda..b0861e241 100644 --- a/src/encord_active/lib/metrics/semantic/image_difficulty.py +++ b/src/encord_active/lib/metrics/semantic/image_difficulty.py @@ -6,11 +6,10 @@ from sklearn.cluster import KMeans from encord_active.lib.common.iterator import Iterator -from encord_active.lib.embeddings.embeddings import get_embeddings +from encord_active.lib.embeddings.cnn import get_cnn_embeddings from encord_active.lib.embeddings.utils import LabelEmbedding from encord_active.lib.labels.classification import ClassificationType -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import DataType, EmbeddingType, MetricType +from encord_active.lib.metrics.metric import DataType, EmbeddingType, Metric, MetricType from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -91,7 +90,7 @@ def _get_difficulty_ranking(self, cluster_size: int) -> Dict[str, int]: def execute(self, iterator: Iterator, writer: CSVMetricWriter): if self.metadata.embedding_type: - self.collections = get_embeddings(iterator, embedding_type=self.metadata.embedding_type) + self.collections = get_cnn_embeddings(iterator, embedding_type=self.metadata.embedding_type) else: logger.error( f"[Skipping] No `embedding_type` provided for the {self.metadata.title} metric!" @@ -106,9 +105,7 @@ def execute(self, iterator: Iterator, writer: CSVMetricWriter): data_hash_to_score = self._get_difficulty_ranking(cluster_size) - for data_unit, _ in iterator.iterate(desc="Writing scores to a file"): - score = data_hash_to_score.get(data_unit["data_hash"]) - if score is not None: - writer.write(score=score) + for data_unit, img_pth in iterator.iterate(desc="Writing scores to a file"): + writer.write(score=data_hash_to_score[data_unit["data_hash"]]) else: logger.info("[Skipping] The embedding file is empty.") diff --git a/src/encord_active/lib/metrics/semantic/image_singularity.py b/src/encord_active/lib/metrics/semantic/image_singularity.py index 5456a5e9c..81bc4d646 100644 --- a/src/encord_active/lib/metrics/semantic/image_singularity.py +++ b/src/encord_active/lib/metrics/semantic/image_singularity.py @@ -10,10 +10,9 @@ from encord_active.lib.common.utils import ( fix_duplicate_image_orders_in_knn_graph_all_rows, ) -from encord_active.lib.embeddings.embeddings import get_embeddings +from encord_active.lib.embeddings.cnn import get_cnn_embeddings from encord_active.lib.embeddings.utils import LabelEmbedding -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import DataType, EmbeddingType, MetricType +from encord_active.lib.metrics.metric import DataType, EmbeddingType, Metric, MetricType from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -88,7 +87,7 @@ def score_images(self, project_hash: str, nearest_distances: np.ndarray, nearest def execute(self, iterator: Iterator, writer: CSVMetricWriter): if self.metadata.embedding_type: - self.collections = get_embeddings(iterator, embedding_type=self.metadata.embedding_type) + self.collections = get_cnn_embeddings(iterator, embedding_type=self.metadata.embedding_type) else: logger.error( f"[Skipping] No `embedding_type` provided for the {self.metadata.title} metric!" @@ -96,6 +95,7 @@ def execute(self, iterator: Iterator, writer: CSVMetricWriter): return if len(self.collections) > 0: + embeddings, db_index = self.convert_to_index() # For more information why we set the below threshold # see here: https://github.com/facebookresearch/faiss/wiki/Implementation-notes#matrix-multiplication-to-do-many-l2-distance-computations @@ -112,10 +112,10 @@ def execute(self, iterator: Iterator, writer: CSVMetricWriter): else: logger.info("[Skipping] The embedding file is empty.") - for data_unit, _ in iterator.iterate(desc="Writing scores to a file"): - data_unit_info = self.scores.get(data_unit["data_hash"]) - if data_unit_info is not None: - writer.write( - score=float(data_unit_info.score), - description=data_unit_info.description, - ) + for data_unit, img_pth in iterator.iterate(desc="Writing scores to a file"): + + data_unit_info = self.scores[data_unit["data_hash"]] + writer.write( + score=float(data_unit_info.score), + description=data_unit_info.description, + ) diff --git a/src/encord_active/lib/metrics/semantic/img_classification_quality.py b/src/encord_active/lib/metrics/semantic/img_classification_quality.py index 47ebe666f..23fa555be 100644 --- a/src/encord_active/lib/metrics/semantic/img_classification_quality.py +++ b/src/encord_active/lib/metrics/semantic/img_classification_quality.py @@ -9,16 +9,16 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator +from encord_active.lib.embeddings.cnn import get_cnn_embeddings from encord_active.lib.embeddings.dimensionality_reduction import ( generate_2d_embedding_data, ) -from encord_active.lib.embeddings.embeddings import get_embeddings from encord_active.lib.embeddings.utils import LabelEmbedding -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import ( +from encord_active.lib.metrics.metric import ( AnnotationType, DataType, EmbeddingType, + Metric, MetricType, ) from encord_active.lib.metrics.utils import is_multiclass_ontology @@ -181,6 +181,7 @@ def convert_nearest_labels_to_scores_for_all_questions( return collections_scores_all_questions def create_key_score_pairs(self, nearest_indexes: dict[str, np.ndarray]): + nearest_labels_all_questions = self.transform_neighbors_to_labels_for_all_questions(nearest_indexes) collections_scores_all_questions = self.convert_nearest_labels_to_scores_for_all_questions( nearest_labels_all_questions @@ -266,11 +267,11 @@ def execute(self, iterator: Iterator, writer: CSVMetricWriter): logger.info("[Skipping] No frame level classifications in the project ontology.") # TODO: move me somewhere else, this is here to ensure the generation of image embeddings - get_embeddings(iterator, embedding_type=EmbeddingType.IMAGE) + get_cnn_embeddings(iterator, embedding_type=EmbeddingType.IMAGE) generate_2d_embedding_data(EmbeddingType.IMAGE, iterator.cache_dir) if self.metadata.embedding_type: - self.collections = get_embeddings(iterator, embedding_type=self.metadata.embedding_type) + self.collections = get_cnn_embeddings(iterator, embedding_type=self.metadata.embedding_type) else: logger.error( f"[Skipping] No `embedding_type` provided for the {self.metadata.title} metric!" diff --git a/src/encord_active/lib/metrics/semantic/img_object_quality.py b/src/encord_active/lib/metrics/semantic/img_object_quality.py index 1fde9a6ae..717812e20 100644 --- a/src/encord_active/lib/metrics/semantic/img_object_quality.py +++ b/src/encord_active/lib/metrics/semantic/img_object_quality.py @@ -10,16 +10,16 @@ from encord_active.lib.common.utils import ( fix_duplicate_image_orders_in_knn_graph_all_rows, ) +from encord_active.lib.embeddings.cnn import get_cnn_embeddings from encord_active.lib.embeddings.dimensionality_reduction import ( generate_2d_embedding_data, ) -from encord_active.lib.embeddings.embeddings import get_embeddings from encord_active.lib.embeddings.utils import LabelEmbedding -from encord_active.lib.metrics.metric import Metric -from encord_active.lib.metrics.types import ( +from encord_active.lib.metrics.metric import ( AnnotationType, DataType, EmbeddingType, + Metric, MetricType, ) from encord_active.lib.metrics.writer import CSVMetricWriter @@ -117,7 +117,7 @@ def execute(self, iterator: Iterator, writer: CSVMetricWriter): return if self.metadata.embedding_type: - collections = get_embeddings(iterator, embedding_type=self.metadata.embedding_type) + collections = get_cnn_embeddings(iterator, embedding_type=self.metadata.embedding_type) else: logger.error( f"[Skipping] No `embedding_type` provided for the {self.metadata.title} metric!" diff --git a/src/encord_active/lib/metrics/types.py b/src/encord_active/lib/metrics/types.py deleted file mode 100644 index a4ed43211..000000000 --- a/src/encord_active/lib/metrics/types.py +++ /dev/null @@ -1,32 +0,0 @@ -from enum import Enum -from typing import Union - -from encord_active.lib.labels.classification import ClassificationType -from encord_active.lib.labels.object import ObjectShape - -AnnotationTypeUnion = Union[ObjectShape, ClassificationType] - - -class AnnotationType: - NONE: list[AnnotationTypeUnion] = [] - OBJECT = ObjectShape - CLASSIFICATION = ClassificationType - ALL = [*OBJECT, *CLASSIFICATION] - - -class MetricType(str, Enum): - SEMANTIC = "semantic" - GEOMETRIC = "geometric" - HEURISTIC = "heuristic" - - -class DataType(str, Enum): - IMAGE = "image" - SEQUENCE = "sequence" - - -class EmbeddingType(str, Enum): - CLASSIFICATION = "classification" - OBJECT = "object" - HU_MOMENTS = "hu_moments" - IMAGE = "image" diff --git a/src/encord_active/lib/metrics/utils.py b/src/encord_active/lib/metrics/utils.py index 6c1481bbc..d0286d248 100644 --- a/src/encord_active/lib/metrics/utils.py +++ b/src/encord_active/lib/metrics/utils.py @@ -13,11 +13,12 @@ from pandera.typing import DataFrame, Series from pydantic import ValidationError -from encord_active.lib.metrics.metric import MetricMetadata, StatsMetadata -from encord_active.lib.metrics.types import ( +from encord_active.lib.metrics.metric import ( AnnotationType, AnnotationTypeUnion, EmbeddingType, + MetricMetadata, + StatsMetadata, ) diff --git a/src/encord_active/lib/project/project_file_structure.py b/src/encord_active/lib/project/project_file_structure.py index c219da261..8aa575def 100644 --- a/src/encord_active/lib/project/project_file_structure.py +++ b/src/encord_active/lib/project/project_file_structure.py @@ -16,19 +16,6 @@ from encord_active.lib.db.connection import PrismaConnection from encord_active.lib.file_structure.base import BaseProjectFileStructure -from encord_active.lib.metrics.types import EmbeddingType - -EMBEDDING_TYPE_TO_FILENAME = { - EmbeddingType.IMAGE: "cnn_images.pkl", - EmbeddingType.CLASSIFICATION: "cnn_classifications.pkl", - EmbeddingType.OBJECT: "cnn_objects.pkl", -} - -EMBEDDING_REDUCED_TO_FILENAME = { - EmbeddingType.IMAGE: "cnn_images_reduced.pkl", - EmbeddingType.CLASSIFICATION: "cnn_classifications_reduced.pkl", - EmbeddingType.OBJECT: "cnn_objects_reduced.pkl", -} # To be deprecated when Encord Active version is >= 0.1.60. @@ -153,10 +140,6 @@ def metrics_meta(self) -> Path: def embeddings(self) -> Path: return self.project_dir / "embeddings" - def get_embeddings_file(self, type_: EmbeddingType, reduced: bool = False) -> Path: - lookup = EMBEDDING_REDUCED_TO_FILENAME if reduced else EMBEDDING_TYPE_TO_FILENAME - return self.embeddings / lookup[type_] - @property def predictions(self) -> Path: return self.project_dir / "predictions" diff --git a/src/encord_active/server/main.py b/src/encord_active/server/main.py index 6e2df2509..28e220b3c 100644 --- a/src/encord_active/server/main.py +++ b/src/encord_active/server/main.py @@ -122,7 +122,7 @@ def get_available_metrics(project: ProjectFileStructureDep, scope: Optional[Metr @app.get("/projects/{project}/2d_embeddings/{current_metric}") def get_2d_embeddings(project: ProjectFileStructureDep, current_metric: str): embedding_type = get_metric_embedding_type(project, current_metric) - embeddings_df = get_2d_embedding_data(project, embedding_type) + embeddings_df = get_2d_embedding_data(project.embeddings, embedding_type) if embeddings_df is None: raise HTTPException( diff --git a/src/encord_active/server/utils.py b/src/encord_active/server/utils.py index f6b1541f8..fd98abaaf 100644 --- a/src/encord_active/server/utils.py +++ b/src/encord_active/server/utils.py @@ -10,7 +10,7 @@ from encord_active.lib.db.helpers.tags import to_grouped_tags from encord_active.lib.embeddings.utils import SimilaritiesFinder from encord_active.lib.labels.object import ObjectShape -from encord_active.lib.metrics.types import EmbeddingType +from encord_active.lib.metrics.metric import EmbeddingType from encord_active.lib.metrics.utils import ( MetricScope, get_embedding_type,