From 52e96ea1c9f9db636b4fd88dad4fb006b43dcfca Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Mon, 15 May 2023 19:07:17 +0200 Subject: [PATCH 1/8] edit installation instructions in readme --- README.rst | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 754dc795..ffd201b7 100644 --- a/README.rst +++ b/README.rst @@ -150,9 +150,12 @@ If you choose to pursue this way, first install Poetry and add it to your PATH poetry install -All the dependencies will be installed at their required versions. -If you also want to install the optional Sphinx dependencies to build the documentation, -add the flag :code:`-E docs` to the command above. +All the dependencies will be installed at their required versions. Consider adding the following flags to the command above: + +- :code:`-E transformers` if you want to use models and datasets from `Hugging Face `_. +- :code:`-E docs` if you want to install Sphinx dependencies to build the documentation. +- :code:`-E notebooks` if you want to work with Jupyter notebooks. + Finally, you can either access the virtualenv that Poetry created by typing :code:`poetry shell`, or execute commands within the virtualenv using the :code:`run` command, e.g. :code:`poetry run python`. From 6cb6581b9574306dc4acfdabc4fe1b92cfcb1241 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Mon, 15 May 2023 21:25:25 +0200 Subject: [PATCH 2/8] bump up version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9adeaa0b..beedbe3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws-fortuna" -version = "0.1.15" +version = "0.1.16" description = "A Library for Uncertainty Quantification." authors = ["Gianluca Detommaso ", "Alberto Gasparin "] license = "Apache-2.0" From b2540c18ef0cfaaf23b49b87f8ba157187b537f0 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Tue, 18 Jul 2023 10:30:59 +0200 Subject: [PATCH 3/8] make small change in readme because of publish to pypi error --- README.rst | 6 +- poetry.lock | 274 +++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 266 insertions(+), 14 deletions(-) diff --git a/README.rst b/README.rst index 54018894..b1687322 100644 --- a/README.rst +++ b/README.rst @@ -183,12 +183,12 @@ We offer a simple pipeline that allows you to run Fortuna on Amazon SageMaker wi 3. Create an `S3 bucket `_. You will need this to dump the results from your training jobs on Amazon Sagemaker. -3. Write a configuration `yaml` file. This will include your AWS details, the path to the entrypoint script that you want +4. Write a configuration `yaml` file. This will include your AWS details, the path to the entrypoint script that you want to run on Amazon SageMaker, the arguments to pass to the script, the path to the S3 bucket where you want to dump the results, the metrics to monitor, and more. - See `here `_ for an example. + Check `this file `_ for an example. -4. Finally, given :code:`config_dir`, that is the absolute path to the main configuration directory, +5. Finally, given :code:`config_dir`, that is the absolute path to the main configuration directory, and :code:`config_filename`, that is the name of the main configuration file (without .yaml extension), enter Python and run the following: diff --git a/poetry.lock b/poetry.lock index 5450bd50..4f6aa650 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "absl-py" version = "1.4.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -15,6 +16,7 @@ files = [ name = "absolufy-imports" version = "0.3.1" description = "A tool to automatically replace relative imports with absolute ones." +category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -26,6 +28,7 @@ files = [ name = "aiohttp" version = "3.8.4" description = "Async http client/server framework (asyncio)" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -134,6 +137,7 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -148,6 +152,7 @@ frozenlist = ">=1.1.0" name = "alabaster" version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -159,6 +164,7 @@ files = [ name = "antlr4-python3-runtime" version = "4.9.3" description = "ANTLR 4.9.3 runtime for Python 3.7" +category = "main" optional = true python-versions = "*" files = [ @@ -169,6 +175,7 @@ files = [ name = "anyio" version = "3.6.2" description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" optional = true python-versions = ">=3.6.2" files = [ @@ -189,6 +196,7 @@ trio = ["trio (>=0.16,<0.22)"] name = "appdirs" version = "1.4.4" description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = "*" files = [ @@ -200,6 +208,7 @@ files = [ name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" +category = "main" optional = false python-versions = "*" files = [ @@ -211,6 +220,7 @@ files = [ name = "argon2-cffi" version = "21.3.0" description = "The secure Argon2 password hashing algorithm." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -230,6 +240,7 @@ tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"] name = "argon2-cffi-bindings" version = "21.2.0" description = "Low-level CFFI bindings for Argon2" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -267,6 +278,7 @@ tests = ["pytest"] name = "array-record" version = "0.2.0" description = "A file format that achieves a new frontier of IO efficiency" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -283,6 +295,7 @@ etils = {version = "*", extras = ["epath"]} name = "arrow" version = "1.2.3" description = "Better dates & times for Python" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -297,6 +310,7 @@ python-dateutil = ">=2.7.0" name = "asttokens" version = "2.2.1" description = "Annotate AST trees with source code positions" +category = "main" optional = false python-versions = "*" files = [ @@ -314,6 +328,7 @@ test = ["astroid", "pytest"] name = "astunparse" version = "1.6.3" description = "An AST unparser for Python" +category = "main" optional = false python-versions = "*" files = [ @@ -329,6 +344,7 @@ wheel = ">=0.23.0,<1.0" name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -340,6 +356,7 @@ files = [ name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -358,6 +375,7 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte name = "babel" version = "2.12.1" description = "Internationalization utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -372,6 +390,7 @@ pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""} name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" +category = "main" optional = false python-versions = "*" files = [ @@ -383,6 +402,7 @@ files = [ name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" +category = "main" optional = false python-versions = ">=3.6.0" files = [ @@ -401,6 +421,7 @@ lxml = ["lxml"] name = "bleach" version = "6.0.0" description = "An easy safelist-based HTML-sanitizing tool." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -419,6 +440,7 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] name = "boto3" version = "1.26.145" description = "The AWS SDK for Python" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -438,6 +460,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.29.145" description = "Low-level, data-driven core of boto 3." +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -457,6 +480,7 @@ crt = ["awscrt (==0.16.9)"] name = "cached-property" version = "1.5.2" description = "A decorator for caching properties in classes." +category = "main" optional = false python-versions = "*" files = [ @@ -468,6 +492,7 @@ files = [ name = "cachetools" version = "5.3.0" description = "Extensible memoizing collections and decorators" +category = "main" optional = false python-versions = "~=3.7" files = [ @@ -479,6 +504,7 @@ files = [ name = "certifi" version = "2023.5.7" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -490,6 +516,7 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." +category = "main" optional = false python-versions = "*" files = [ @@ -566,6 +593,7 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." +category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -577,6 +605,7 @@ files = [ name = "charset-normalizer" version = "3.1.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -661,6 +690,7 @@ files = [ name = "chex" version = "0.1.7" description = "Chex: Testing made fun, in JAX!" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -681,6 +711,7 @@ typing-extensions = {version = ">=4.2.0", markers = "python_version < \"3.11\""} name = "click" version = "8.1.3" description = "Composable command line interface toolkit" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -695,6 +726,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "cloudpickle" version = "2.2.1" description = "Extended pickling support for Python objects" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -706,6 +738,7 @@ files = [ name = "codespell" version = "2.2.4" description = "Codespell" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -723,6 +756,7 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -734,6 +768,7 @@ files = [ name = "comm" version = "0.1.3" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -753,6 +788,7 @@ typing = ["mypy (>=0.990)"] name = "contextlib2" version = "21.6.0" description = "Backports and enhancements for the contextlib module" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -764,6 +800,7 @@ files = [ name = "contourpy" version = "1.0.7" description = "Python library for calculating contours of 2D quadrilateral grids" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -838,6 +875,7 @@ test-no-images = ["pytest"] name = "coverage" version = "7.2.5" description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -904,6 +942,7 @@ toml = ["tomli"] name = "cryptography" version = "41.0.0" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -945,6 +984,7 @@ test-randomorder = ["pytest-randomly"] name = "cycler" version = "0.11.0" description = "Composable style cycles" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -956,6 +996,7 @@ files = [ name = "datasets" version = "2.12.0" description = "HuggingFace community-driven open-source library of datasets" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -999,6 +1040,7 @@ vision = ["Pillow (>=6.2.1)"] name = "debugpy" version = "1.6.7" description = "An implementation of the Debug Adapter Protocol for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1026,6 +1068,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1037,6 +1080,7 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -1048,6 +1092,7 @@ files = [ name = "dill" version = "0.3.6" description = "serialize all of python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1062,6 +1107,7 @@ graph = ["objgraph (>=1.7.2)"] name = "distlib" version = "0.3.6" description = "Distribution utilities" +category = "dev" optional = false python-versions = "*" files = [ @@ -1073,6 +1119,7 @@ files = [ name = "dm-tree" version = "0.1.8" description = "Tree is a library for working with nested data structures." +category = "main" optional = false python-versions = "*" files = [ @@ -1121,6 +1168,7 @@ files = [ name = "docutils" version = "0.19" description = "Docutils -- Python Documentation Utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1132,6 +1180,7 @@ files = [ name = "et-xmlfile" version = "1.1.0" description = "An implementation of lxml.xmlfile for the standard library" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1143,6 +1192,7 @@ files = [ name = "etils" version = "1.2.0" description = "Collection of common python utils" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1177,6 +1227,7 @@ lazy-imports = ["etils[ecolab]"] name = "exceptiongroup" version = "1.1.1" description = "Backport of PEP 654 (exception groups)" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1191,6 +1242,7 @@ test = ["pytest (>=6)"] name = "executing" version = "1.2.0" description = "Get the currently executing AST node of a frame, and other information" +category = "main" optional = false python-versions = "*" files = [ @@ -1205,6 +1257,7 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] name = "fastjsonschema" version = "2.16.3" description = "Fastest Python implementation of JSON schema" +category = "main" optional = false python-versions = "*" files = [ @@ -1219,6 +1272,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.12.0" description = "A platform independent file lock." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1234,6 +1288,7 @@ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "p name = "flatbuffers" version = "23.5.9" description = "The FlatBuffers serialization format for Python" +category = "main" optional = false python-versions = "*" files = [ @@ -1245,6 +1300,7 @@ files = [ name = "flax" version = "0.6.10" description = "Flax: A neural network library for JAX designed for flexibility" +category = "main" optional = false python-versions = "*" files = [ @@ -1271,6 +1327,7 @@ testing = ["atari-py (==0.2.5)", "clu", "einops", "gym (==0.18.3)", "jaxlib", "j name = "fonttools" version = "4.39.4" description = "Tools to manipulate font files" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1296,6 +1353,7 @@ woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] name = "fqdn" version = "1.5.1" description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" +category = "main" optional = true python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" files = [ @@ -1307,6 +1365,7 @@ files = [ name = "frozendict" version = "2.3.8" description = "A simple immutable dictionary" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1353,6 +1412,7 @@ files = [ name = "frozenlist" version = "1.3.3" description = "A list-like structure which implements collections.abc.MutableSequence" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1436,6 +1496,7 @@ files = [ name = "fsspec" version = "2023.5.0" description = "File-system specification" +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -1475,6 +1536,7 @@ tqdm = ["tqdm"] name = "gast" version = "0.4.0" description = "Python AST that abstracts the underlying Python version" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1486,6 +1548,7 @@ files = [ name = "google-auth" version = "2.18.0" description = "Google Authentication Library" +category = "main" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" files = [ @@ -1511,6 +1574,7 @@ requests = ["requests (>=2.20.0,<3.0.0dev)"] name = "google-auth-oauthlib" version = "1.0.0" description = "Google Authentication Library" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -1529,6 +1593,7 @@ tool = ["click (>=6.0.0)"] name = "google-pasta" version = "0.2.0" description = "pasta is an AST-based Python refactoring library" +category = "main" optional = false python-versions = "*" files = [ @@ -1544,6 +1609,7 @@ six = "*" name = "googleapis-common-protos" version = "1.59.0" description = "Common protobufs used in Google APIs" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1561,6 +1627,7 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"] name = "greenlet" version = "2.0.2" description = "Lightweight in-process concurrent programming" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*" files = [ @@ -1634,6 +1701,7 @@ test = ["objgraph", "psutil"] name = "grpcio" version = "1.54.0" description = "HTTP/2-based RPC framework" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1691,6 +1759,7 @@ protobuf = ["grpcio-tools (>=1.54.0)"] name = "h5py" version = "3.8.0" description = "Read and write HDF5 files from Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1728,6 +1797,7 @@ numpy = ">=1.14.5" name = "html5lib" version = "1.1" description = "HTML parser based on the WHATWG HTML specification" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -1749,6 +1819,7 @@ lxml = ["lxml"] name = "huggingface-hub" version = "0.14.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -1780,6 +1851,7 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t name = "hydra-core" version = "1.3.2" description = "A framework for elegantly configuring complex applications" +category = "main" optional = true python-versions = "*" files = [ @@ -1788,7 +1860,7 @@ files = [ ] [package.dependencies] -antlr4-python3-runtime = "==4.9.*" +antlr4-python3-runtime = ">=4.9.0,<4.10.0" importlib-resources = {version = "*", markers = "python_version < \"3.9\""} omegaconf = ">=2.2,<2.4" packaging = "*" @@ -1797,6 +1869,7 @@ packaging = "*" name = "identify" version = "2.5.24" description = "File identification library for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1811,6 +1884,7 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1822,6 +1896,7 @@ files = [ name = "imagesize" version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1833,6 +1908,7 @@ files = [ name = "importlib-metadata" version = "4.13.0" description = "Read metadata from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1852,6 +1928,7 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "5.12.0" description = "Read resources from Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1870,6 +1947,7 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1881,6 +1959,7 @@ files = [ name = "ipykernel" version = "6.23.0" description = "IPython Kernel for Jupyter" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1894,7 +1973,7 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" @@ -1914,6 +1993,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipython" version = "8.12.2" description = "IPython: Productive Interactive Computing" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1953,6 +2033,7 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pa name = "ipython-genutils" version = "0.2.0" description = "Vestigial utilities from IPython" +category = "main" optional = true python-versions = "*" files = [ @@ -1964,6 +2045,7 @@ files = [ name = "ipywidgets" version = "8.0.6" description = "Jupyter interactive widgets" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1985,6 +2067,7 @@ test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] name = "isoduration" version = "20.11.0" description = "Operations with ISO 8601 durations" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1999,6 +2082,7 @@ arrow = ">=0.15.0" name = "jax" version = "0.4.10" description = "Differentiate, compile, and transform Numpy code." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2029,6 +2113,7 @@ tpu = ["jaxlib (==0.4.10)", "libtpu-nightly (==0.1.dev20230511)", "requests"] name = "jaxlib" version = "0.4.10" description = "XLA library for JAX" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2055,6 +2140,7 @@ scipy = ">=1.7" name = "jedi" version = "0.18.2" description = "An autocompletion tool for Python that can be used for text editors." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2074,6 +2160,7 @@ testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2091,6 +2178,7 @@ i18n = ["Babel (>=2.7)"] name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2102,6 +2190,7 @@ files = [ name = "joblib" version = "1.2.0" description = "Lightweight pipelining with Python functions" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2113,6 +2202,7 @@ files = [ name = "jsonpointer" version = "2.3" description = "Identify specific nodes in a JSON document (RFC 6901)" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2124,6 +2214,7 @@ files = [ name = "jsonschema" version = "4.17.3" description = "An implementation of JSON Schema validation for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2153,6 +2244,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jupyter" version = "1.0.0" description = "Jupyter metapackage. Install all the Jupyter components in one go." +category = "main" optional = true python-versions = "*" files = [ @@ -2173,6 +2265,7 @@ qtconsole = "*" name = "jupyter-cache" version = "0.6.1" description = "A defined interface for working with a cache of jupyter notebooks." +category = "dev" optional = false python-versions = "~=3.8" files = [ @@ -2200,6 +2293,7 @@ testing = ["coverage", "ipykernel", "jupytext", "matplotlib", "nbdime", "nbforma name = "jupyter-client" version = "8.2.0" description = "Jupyter protocol implementation and client libraries" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2209,7 +2303,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -2223,6 +2317,7 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-console" version = "6.6.3" description = "Jupyter terminal console" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2234,7 +2329,7 @@ files = [ ipykernel = ">=6.14" ipython = "*" jupyter-client = ">=7.0.0" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" prompt-toolkit = ">=3.0.30" pygments = "*" pyzmq = ">=17" @@ -2247,6 +2342,7 @@ test = ["flaky", "pexpect", "pytest"] name = "jupyter-core" version = "5.3.0" description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2267,6 +2363,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyter-events" version = "0.6.3" description = "Jupyter Event System library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2291,6 +2388,7 @@ test = ["click", "coverage", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>= name = "jupyter-server" version = "2.5.0" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -2303,7 +2401,7 @@ anyio = ">=3.1.0" argon2-cffi = "*" jinja2 = "*" jupyter-client = ">=7.4.4" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" jupyter-events = ">=0.4.0" jupyter-server-terminals = "*" nbconvert = ">=6.4.4" @@ -2326,6 +2424,7 @@ test = ["ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", " name = "jupyter-server-terminals" version = "0.4.4" description = "A Jupyter Server Extension Providing Terminals." +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -2345,6 +2444,7 @@ test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2356,6 +2456,7 @@ files = [ name = "jupyterlab-widgets" version = "3.0.7" description = "Jupyter interactive widgets for JupyterLab" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2367,6 +2468,7 @@ files = [ name = "jupytext" version = "1.14.5" description = "Jupyter notebooks as Markdown documents, Julia, Python or R scripts" +category = "dev" optional = false python-versions = "~=3.6" files = [ @@ -2389,6 +2491,7 @@ toml = ["toml"] name = "keras" version = "2.12.0" description = "Deep learning for humans." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2399,6 +2502,7 @@ files = [ name = "kiwisolver" version = "1.4.4" description = "A fast implementation of the Cassowary constraint solver" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2476,6 +2580,7 @@ files = [ name = "libclang" version = "16.0.0" description = "Clang Python Bindings, mirrored from the official LLVM repo: https://github.com/llvm/llvm-project/tree/main/clang/bindings/python, to make the installation process easier." +category = "main" optional = false python-versions = "*" files = [ @@ -2493,6 +2598,7 @@ files = [ name = "lxml" version = "4.9.2" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" files = [ @@ -2585,6 +2691,7 @@ source = ["Cython (>=0.29.7)"] name = "markdown" version = "3.4.3" description = "Python implementation of John Gruber's Markdown." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2602,6 +2709,7 @@ testing = ["coverage", "pyyaml"] name = "markdown-it-py" version = "2.2.0" description = "Python port of markdown-it. Markdown parsing, done right!" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2626,6 +2734,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.2" description = "Safely add untrusted strings to HTML/XML markup." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2685,6 +2794,7 @@ files = [ name = "matplotlib" version = "3.7.1" description = "Python plotting package" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2747,6 +2857,7 @@ python-dateutil = ">=2.7" name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -2761,6 +2872,7 @@ traitlets = "*" name = "mdit-py-plugins" version = "0.3.5" description = "Collection of plugins for markdown-it-py" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2780,6 +2892,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2791,6 +2904,7 @@ files = [ name = "mistune" version = "2.0.5" description = "A sane Markdown parser with useful plugins and renderers" +category = "main" optional = true python-versions = "*" files = [ @@ -2802,6 +2916,7 @@ files = [ name = "ml-dtypes" version = "0.1.0" description = "" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2837,6 +2952,7 @@ dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] name = "msgpack" version = "1.0.5" description = "MessagePack serializer" +category = "main" optional = false python-versions = "*" files = [ @@ -2909,6 +3025,7 @@ files = [ name = "multidict" version = "6.0.4" description = "multidict implementation" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2992,6 +3109,7 @@ files = [ name = "multiprocess" version = "0.70.14" description = "better multiprocessing and multithreading in python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3018,6 +3136,7 @@ dill = ">=0.3.6" name = "multitasking" version = "0.0.11" description = "Non-blocking Python methods using decorators" +category = "dev" optional = false python-versions = "*" files = [ @@ -3029,6 +3148,7 @@ files = [ name = "myst-nb" version = "0.17.2" description = "A Jupyter Notebook Sphinx reader built on top of the MyST markdown parser." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3057,6 +3177,7 @@ testing = ["beautifulsoup4", "coverage (>=6.4,<8.0)", "ipykernel (>=5.5,<6.0)", name = "myst-parser" version = "0.18.1" description = "An extended commonmark compliant parser, with bridges to docutils & sphinx." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3083,6 +3204,7 @@ testing = ["beautifulsoup4", "coverage[toml]", "pytest (>=6,<7)", "pytest-cov", name = "nbclassic" version = "1.0.0" description = "Jupyter Notebook as a Jupyter Server extension." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3118,6 +3240,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "pytest-jupyter", "pytest-p name = "nbclient" version = "0.7.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3127,7 +3250,7 @@ files = [ [package.dependencies] jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" nbformat = ">=5.1" traitlets = ">=5.3" @@ -3140,6 +3263,7 @@ test = ["flaky", "ipykernel", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "p name = "nbconvert" version = "7.4.0" description = "Converting Jupyter Notebooks" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3178,6 +3302,7 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.8.0" description = "The Jupyter Notebook format" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3199,6 +3324,7 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] name = "nbsphinx" version = "0.8.12" description = "Jupyter Notebook Tools for Sphinx" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3218,6 +3344,7 @@ traitlets = ">=5" name = "nbsphinx-link" version = "1.3.0" description = "A sphinx extension for including notebook files outside sphinx source root" +category = "main" optional = true python-versions = "*" files = [ @@ -3233,6 +3360,7 @@ sphinx = ">=1.8" name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -3244,6 +3372,7 @@ files = [ name = "nodeenv" version = "1.7.0" description = "Node.js virtual environment builder" +category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -3258,6 +3387,7 @@ setuptools = "*" name = "notebook" version = "6.5.4" description = "A web-based notebook environment for interactive computing" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3292,6 +3422,7 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "requests", "requests-unixs name = "notebook-shim" version = "0.2.3" description = "A shim layer for notebook traits and config" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3309,6 +3440,7 @@ test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync" name = "numpy" version = "1.23.5" description = "NumPy is the fundamental package for array computing with Python." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3346,6 +3478,7 @@ files = [ name = "oauthlib" version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3362,6 +3495,7 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] name = "omegaconf" version = "2.3.0" description = "A flexible configuration library" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3370,13 +3504,14 @@ files = [ ] [package.dependencies] -antlr4-python3-runtime = "==4.9.*" +antlr4-python3-runtime = ">=4.9.0,<4.10.0" PyYAML = ">=5.1.0" [[package]] name = "openpyxl" version = "3.1.2" description = "A Python library to read/write Excel 2010 xlsx/xlsm files" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3391,6 +3526,7 @@ et-xmlfile = "*" name = "opt-einsum" version = "3.3.0" description = "Optimizing numpys einsum function" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -3409,6 +3545,7 @@ tests = ["pytest", "pytest-cov", "pytest-pep8"] name = "optax" version = "0.1.5" description = "A gradient processing and optimisation library in JAX." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3427,6 +3564,7 @@ numpy = ">=1.18.0" name = "orbax-checkpoint" version = "0.2.2" description = "Orbax Checkpoint" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3455,6 +3593,7 @@ dev = ["flax", "pytest", "pytest-xdist"] name = "packaging" version = "23.1" description = "Core utilities for Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3466,6 +3605,7 @@ files = [ name = "pandas" version = "2.0.1" description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3532,6 +3672,7 @@ xml = ["lxml (>=4.6.3)"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3543,6 +3684,7 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3558,6 +3700,7 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathos" version = "0.3.0" description = "parallel graph management and execution in heterogeneous computing" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3575,6 +3718,7 @@ ppft = ">=1.7.6.6" name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." +category = "main" optional = false python-versions = "*" files = [ @@ -3589,6 +3733,7 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" +category = "main" optional = false python-versions = "*" files = [ @@ -3600,6 +3745,7 @@ files = [ name = "pillow" version = "9.5.0" description = "Python Imaging Library (Fork)" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3679,6 +3825,7 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3690,6 +3837,7 @@ files = [ name = "platformdirs" version = "3.5.1" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3705,6 +3853,7 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3720,6 +3869,7 @@ testing = ["pytest", "pytest-benchmark"] name = "pox" version = "0.3.2" description = "utilities for filesystem exploration and automated builds" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3731,6 +3881,7 @@ files = [ name = "ppft" version = "1.7.6.6" description = "distributed and parallel python" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3745,6 +3896,7 @@ dill = ["dill (>=0.3.6)"] name = "pre-commit" version = "3.3.1" description = "A framework for managing and maintaining multi-language pre-commit hooks." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3763,6 +3915,7 @@ virtualenv = ">=20.10.0" name = "prometheus-client" version = "0.16.0" description = "Python client for the Prometheus monitoring system." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3777,6 +3930,7 @@ twisted = ["twisted"] name = "promise" version = "2.3" description = "Promises/A+ implementation for Python" +category = "dev" optional = false python-versions = "*" files = [ @@ -3793,6 +3947,7 @@ test = ["coveralls", "futures", "mock", "pytest (>=2.7.3)", "pytest-benchmark", name = "prompt-toolkit" version = "3.0.38" description = "Library for building powerful interactive command lines in Python" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3807,6 +3962,7 @@ wcwidth = "*" name = "protobuf" version = "3.20.3" description = "Protocol Buffers" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3838,6 +3994,7 @@ files = [ name = "protobuf3-to-dict" version = "0.1.5" description = "Ben Hodgson: A teeny Python library for creating Python dicts from protocol buffers and the reverse. Useful as an intermediate step before serialisation (e.g. to JSON). Kapor: upgrade it to PB3 and PY3, rename it to protobuf3-to-dict" +category = "main" optional = true python-versions = "*" files = [ @@ -3852,6 +4009,7 @@ six = "*" name = "psutil" version = "5.9.5" description = "Cross-platform lib for process and system monitoring in Python." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3878,6 +4036,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" +category = "main" optional = false python-versions = "*" files = [ @@ -3889,6 +4048,7 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" +category = "main" optional = false python-versions = "*" files = [ @@ -3903,6 +4063,7 @@ tests = ["pytest"] name = "pyarrow" version = "12.0.0" description = "Python library for Apache Arrow" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3940,6 +4101,7 @@ numpy = ">=1.16.6" name = "pyasn1" version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3951,6 +4113,7 @@ files = [ name = "pyasn1-modules" version = "0.3.0" description = "A collection of ASN.1-based protocols modules" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3965,6 +4128,7 @@ pyasn1 = ">=0.4.6,<0.6.0" name = "pycparser" version = "2.21" description = "C parser in Python" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3976,6 +4140,7 @@ files = [ name = "pydata-sphinx-theme" version = "0.12.0" description = "Bootstrap-based Sphinx theme from the PyData community" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4000,6 +4165,7 @@ test = ["pydata-sphinx-theme[doc]", "pytest"] name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4014,6 +4180,7 @@ plugins = ["importlib-metadata"] name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" +category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -4028,6 +4195,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyrsistent" version = "0.19.3" description = "Persistent/Functional/Immutable data structures" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4064,6 +4232,7 @@ files = [ name = "pytest" version = "7.3.1" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4086,6 +4255,7 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-cov" version = "4.0.0" description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4104,6 +4274,7 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -4118,6 +4289,7 @@ six = ">=1.5" name = "python-json-logger" version = "2.0.7" description = "A python library adding a json log formatter" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4129,6 +4301,7 @@ files = [ name = "pytz" version = "2023.3" description = "World timezone definitions, modern and historical" +category = "main" optional = false python-versions = "*" files = [ @@ -4140,6 +4313,7 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" +category = "main" optional = false python-versions = "*" files = [ @@ -4163,6 +4337,7 @@ files = [ name = "pywinpty" version = "2.0.10" description = "Pseudo terminal support for Windows from Python." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4178,6 +4353,7 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4227,6 +4403,7 @@ files = [ name = "pyzmq" version = "25.0.2" description = "Python bindings for 0MQ" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4316,6 +4493,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "qtconsole" version = "5.4.3" description = "Jupyter Qt console" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -4342,6 +4520,7 @@ test = ["flaky", "pytest", "pytest-qt"] name = "qtpy" version = "2.3.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4359,6 +4538,7 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] name = "regex" version = "2023.5.5" description = "Alternative regular expression module, to replace re." +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4456,6 +4636,7 @@ files = [ name = "requests" version = "2.31.0" description = "Python HTTP for Humans." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4477,6 +4658,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "requests-oauthlib" version = "1.3.1" description = "OAuthlib authentication support for Requests." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -4495,6 +4677,7 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"] name = "responses" version = "0.18.0" description = "A utility library for mocking out the `requests` Python library." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4513,6 +4696,7 @@ tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=4.6)", "pytest-cov", name = "rfc3339-validator" version = "0.1.4" description = "A pure python RFC3339 validator" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -4527,6 +4711,7 @@ six = "*" name = "rfc3986-validator" version = "0.1.1" description = "Pure python rfc3986 validator" +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -4538,6 +4723,7 @@ files = [ name = "rich" version = "13.3.5" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -4557,6 +4743,7 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" +category = "main" optional = false python-versions = ">=3.6,<4" files = [ @@ -4571,6 +4758,7 @@ pyasn1 = ">=0.1.3" name = "s3transfer" version = "0.6.1" description = "An Amazon S3 Transfer Manager" +category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -4588,6 +4776,7 @@ crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] name = "safetensors" version = "0.3.1" description = "Fast and Safe Tensor serialization" +category = "main" optional = true python-versions = "*" files = [ @@ -4648,6 +4837,7 @@ torch = ["torch (>=1.10)"] name = "sagemaker" version = "2.161.0" description = "Open source library for training and deploying models on Amazon SageMaker." +category = "main" optional = true python-versions = ">= 3.6" files = [ @@ -4683,6 +4873,7 @@ test = ["Jinja2 (==3.0.3)", "PyYAML (==6.0)", "apache-airflow (==2.6.0)", "apach name = "sagemaker-utils" version = "0.3.6" description = "Helper functions to work with SageMaker" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4700,6 +4891,7 @@ yaspin = "*" name = "schema" version = "0.7.5" description = "Simple data validation library" +category = "main" optional = true python-versions = "*" files = [ @@ -4714,6 +4906,7 @@ contextlib2 = ">=0.5.5" name = "scikit-learn" version = "1.2.2" description = "A set of python modules for machine learning and data mining" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4756,6 +4949,7 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy ( name = "scipy" version = "1.10.1" description = "Fundamental algorithms for scientific computing in Python" +category = "main" optional = false python-versions = "<3.12,>=3.8" files = [ @@ -4794,6 +4988,7 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo name = "send2trash" version = "1.8.2" description = "Send file to trash natively under Mac OS X, Windows and Linux" +category = "main" optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -4810,6 +5005,7 @@ win32 = ["pywin32"] name = "setuptools" version = "67.7.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4826,6 +5022,7 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -4837,6 +5034,7 @@ files = [ name = "smdebug-rulesconfig" version = "1.0.1" description = "SMDebug RulesConfig" +category = "main" optional = true python-versions = ">=2.7" files = [ @@ -4848,6 +5046,7 @@ files = [ name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4859,6 +5058,7 @@ files = [ name = "snowballstemmer" version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." +category = "main" optional = false python-versions = "*" files = [ @@ -4870,6 +5070,7 @@ files = [ name = "soupsieve" version = "2.4.1" description = "A modern CSS selector implementation for Beautiful Soup." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4881,6 +5082,7 @@ files = [ name = "sphinx" version = "5.3.0" description = "Python documentation generator" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4916,6 +5118,7 @@ test = ["cython", "html5lib", "pytest (>=4.6)", "typed_ast"] name = "sphinx-autodoc-typehints" version = "1.23.0" description = "Type hints (PEP 484) support for the Sphinx autodoc extension" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4935,6 +5138,7 @@ type-comment = ["typed-ast (>=1.5.4)"] name = "sphinx-gallery" version = "0.11.1" description = "A Sphinx extension that builds an HTML version of any Python script and puts it into an examples gallery." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4949,6 +5153,7 @@ sphinx = ">=3" name = "sphinxcontrib-applehelp" version = "1.0.4" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4964,6 +5169,7 @@ test = ["pytest"] name = "sphinxcontrib-devhelp" version = "1.0.2" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -4979,6 +5185,7 @@ test = ["pytest"] name = "sphinxcontrib-htmlhelp" version = "2.0.1" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4994,6 +5201,7 @@ test = ["html5lib", "pytest"] name = "sphinxcontrib-jsmath" version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5008,6 +5216,7 @@ test = ["flake8", "mypy", "pytest"] name = "sphinxcontrib-qthelp" version = "1.0.3" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5023,6 +5232,7 @@ test = ["pytest"] name = "sphinxcontrib-serializinghtml" version = "1.1.5" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5038,6 +5248,7 @@ test = ["pytest"] name = "sqlalchemy" version = "2.0.13" description = "Database Abstraction Library" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5085,7 +5296,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} typing-extensions = ">=4.2.0" [package.extras] @@ -5115,6 +5326,7 @@ sqlcipher = ["sqlcipher3-binary"] name = "stack-data" version = "0.6.2" description = "Extract data from python stack frames and tracebacks for informative displays" +category = "main" optional = false python-versions = "*" files = [ @@ -5134,6 +5346,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "tabulate" version = "0.9.0" description = "Pretty-print tabular data" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5148,6 +5361,7 @@ widechars = ["wcwidth"] name = "tblib" version = "1.7.0" description = "Traceback serialization library." +category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -5159,6 +5373,7 @@ files = [ name = "tensorboard" version = "2.12.3" description = "TensorBoard lets you watch Tensors Flow" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5183,6 +5398,7 @@ wheel = ">=0.26" name = "tensorboard-data-server" version = "0.7.0" description = "Fast data loading for TensorBoard" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5195,6 +5411,7 @@ files = [ name = "tensorflow-cpu" version = "2.12.0" description = "TensorFlow is an open source machine learning framework for everyone." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5240,6 +5457,7 @@ wrapt = ">=1.11.0,<1.15" name = "tensorflow-datasets" version = "4.9.2" description = "tensorflow/datasets is a library of datasets ready to use with TensorFlow." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -5306,6 +5524,7 @@ youtube-vis = ["pycocotools"] name = "tensorflow-estimator" version = "2.12.0" description = "TensorFlow Estimator." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5316,6 +5535,7 @@ files = [ name = "tensorflow-io-gcs-filesystem" version = "0.32.0" description = "TensorFlow IO" +category = "main" optional = false python-versions = ">=3.7, <3.12" files = [ @@ -5346,13 +5566,12 @@ tensorflow-rocm = ["tensorflow-rocm (>=2.12.0,<2.13.0)"] name = "tensorflow-macos" version = "2.12.0" description = "TensorFlow is an open source machine learning framework for everyone." +category = "main" optional = false python-versions = ">=3.8" files = [ {file = "tensorflow_macos-2.12.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:db464c88e10e927725997f9b872a21c9d057789d3b7e9a26e4ef1af41d0bcc8c"}, {file = "tensorflow_macos-2.12.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:172277c33cb1ae0da19f98c5bcd4946149cfa73c8ea05c6ba18365d58dd3c6f2"}, - {file = "tensorflow_macos-2.12.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:9c9b14fbb73ec4cb0f209722a1489020fd8614c92ae22589f2309c48cefdf21f"}, - {file = "tensorflow_macos-2.12.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:6a54539bd076746f69ae8bef7282f981674fe4dbf59c3a84c4af86ae6bae9d5c"}, {file = "tensorflow_macos-2.12.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:e3fa53e63672fd71998bbd71cc5478c74dbe5a2d9291d1801c575358c28403c2"}, {file = "tensorflow_macos-2.12.0-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:5499312c21ed3ed47cc6b4cf861896e9564c2c32d8d3c2ef1437c5ca31adfc73"}, {file = "tensorflow_macos-2.12.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:84cb873c90be63efabfecca53fdc48b734a037d0750532b55cb7ce7c343b5cac"}, @@ -5387,6 +5606,7 @@ wrapt = ">=1.11.0,<1.15" name = "tensorflow-metadata" version = "1.13.1" description = "Library and standards for schema and statistics." +category = "dev" optional = false python-versions = ">=3.8,<4" files = [ @@ -5402,6 +5622,7 @@ protobuf = ">=3.20.3,<5" name = "tensorstore" version = "0.1.36" description = "Read and write large, multi-dimensional arrays" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5431,6 +5652,7 @@ numpy = ">=1.16.0" name = "termcolor" version = "2.3.0" description = "ANSI color formatting for output in terminal" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5445,6 +5667,7 @@ tests = ["pytest", "pytest-cov"] name = "terminado" version = "0.17.1" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5465,6 +5688,7 @@ test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] name = "threadpoolctl" version = "3.1.0" description = "threadpoolctl" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -5476,6 +5700,7 @@ files = [ name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5494,6 +5719,7 @@ test = ["flake8", "isort", "pytest"] name = "tokenizers" version = "0.13.3" description = "Fast and Customizable Tokenizers" +category = "main" optional = true python-versions = "*" files = [ @@ -5548,6 +5774,7 @@ testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" +category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -5559,6 +5786,7 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5570,6 +5798,7 @@ files = [ name = "toolz" version = "0.12.0" description = "List processing tools and functional utilities" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -5581,6 +5810,7 @@ files = [ name = "tornado" version = "6.3.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "main" optional = false python-versions = ">= 3.8" files = [ @@ -5601,6 +5831,7 @@ files = [ name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5621,6 +5852,7 @@ telegram = ["requests"] name = "traitlets" version = "5.9.0" description = "Traitlets Python configuration system" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5636,6 +5868,7 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] name = "transformers" version = "4.30.0" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +category = "main" optional = true python-versions = ">=3.7.0" files = [ @@ -5705,6 +5938,7 @@ vision = ["Pillow"] name = "typing-extensions" version = "4.5.0" description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5716,6 +5950,7 @@ files = [ name = "tzdata" version = "2023.3" description = "Provider of IANA time zone data" +category = "main" optional = false python-versions = ">=2" files = [ @@ -5727,6 +5962,7 @@ files = [ name = "uri-template" version = "1.2.0" description = "RFC 6570 URI Template Processor" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -5741,6 +5977,7 @@ dev = ["flake8 (<4.0.0)", "flake8-annotations", "flake8-bugbear", "flake8-commas name = "urllib3" version = "1.26.15" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -5757,6 +5994,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "virtualenv" version = "20.23.0" description = "Virtual Python Environment builder" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5777,6 +6015,7 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "coverage-enable-subprocess name = "wcwidth" version = "0.2.6" description = "Measures the displayed width of unicode strings in a terminal" +category = "main" optional = false python-versions = "*" files = [ @@ -5788,6 +6027,7 @@ files = [ name = "webcolors" version = "1.13" description = "A library for working with the color formats defined by HTML and CSS." +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5803,6 +6043,7 @@ tests = ["pytest", "pytest-cov"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" +category = "main" optional = false python-versions = "*" files = [ @@ -5814,6 +6055,7 @@ files = [ name = "websocket-client" version = "1.5.1" description = "WebSocket client for Python with low level API options" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5830,6 +6072,7 @@ test = ["websockets"] name = "werkzeug" version = "2.3.4" description = "The comprehensive WSGI web application library." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5847,6 +6090,7 @@ watchdog = ["watchdog (>=2.3)"] name = "wheel" version = "0.40.0" description = "A built-package format for Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5861,6 +6105,7 @@ test = ["pytest (>=6.0.0)"] name = "widgetsnbextension" version = "4.0.7" description = "Jupyter interactive widgets for Jupyter Notebook" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -5872,6 +6117,7 @@ files = [ name = "wrapt" version = "1.14.1" description = "Module for decorators, wrappers and monkey patching." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -5945,6 +6191,7 @@ files = [ name = "xlrd" version = "2.0.1" description = "Library for developers to extract data from Microsoft Excel (tm) .xls spreadsheet files" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -5961,6 +6208,7 @@ test = ["pytest", "pytest-cov"] name = "xxhash" version = "3.2.0" description = "Python binding for xxHash" +category = "main" optional = true python-versions = ">=3.6" files = [ @@ -6068,6 +6316,7 @@ files = [ name = "yarl" version = "1.9.2" description = "Yet another URL library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -6155,6 +6404,7 @@ multidict = ">=4.0" name = "yaspin" version = "2.3.0" description = "Yet Another Terminal Spinner" +category = "main" optional = true python-versions = ">=3.7.2,<4.0.0" files = [ @@ -6169,6 +6419,7 @@ termcolor = ">=2.2,<3.0" name = "yfinance" version = "0.2.18" description = "Download market data from Yahoo! Finance API" +category = "dev" optional = false python-versions = "*" files = [ @@ -6193,6 +6444,7 @@ requests = ">=2.26" name = "zipp" version = "3.15.0" description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" optional = false python-versions = ">=3.7" files = [ From bc64a0179383a26a08ab7fa2a75ef6e966b474ef Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Sun, 30 Jul 2023 18:24:25 +0200 Subject: [PATCH 4/8] bump up version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1687732c..4915c415 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws-fortuna" -version = "0.1.22" +version = "0.1.23" description = "A Library for Uncertainty Quantification." authors = ["Gianluca Detommaso ", "Alberto Gasparin "] license = "Apache-2.0" From 591d8425ebbca4c039ac08f73637027f8e126431 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Wed, 13 Sep 2023 15:39:17 +0200 Subject: [PATCH 5/8] refactor tabular analysis of benchmarks --- benchmarks/tabular/analysis.py | 757 ++++++--------------------------- 1 file changed, 137 insertions(+), 620 deletions(-) diff --git a/benchmarks/tabular/analysis.py b/benchmarks/tabular/analysis.py index 2608da52..464f0652 100644 --- a/benchmarks/tabular/analysis.py +++ b/benchmarks/tabular/analysis.py @@ -6,21 +6,17 @@ with open("tabular_results.json", "r") as j: metrics = json.loads(j.read()) +TOL = 1e-4 + # ~~~REGRESSION~~~ # MAP map_nlls = [ metrics["regression"][k]["map"]["nll"] for k in metrics["regression"].keys() ] -map_quantiles_nlls = np.percentile(map_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_picp_errors = [ - np.abs(0.95 - metrics["regression"][k]["map"]["picp"]) - for k in metrics["regression"].keys() + np.abs(metrics["regression"][k]["map"]["picp"] - 0.95) for k in metrics["regression"].keys() ] -map_quantiles_picp_errors = np.percentile( - map_picp_errors, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) map_times = [ metrics["regression"][k]["map"]["time"] for k in metrics["regression"].keys() @@ -31,185 +27,83 @@ metrics["regression"][k]["temp_scaling"]["nll"] for k in metrics["regression"].keys() ] -temp_scaling_quantiles_nlls = np.percentile( - temp_scaling_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_nlls = np.sum(np.array(temp_scaling_nlls) / np.array(map_nlls) <= 1) -winlose_temp_scaling_nlls = ( - f"{win_temp_scaling_nlls} / {len(map_nlls) - win_temp_scaling_nlls}" -) -rel_improve_temp_scaling_nlls = ( - np.array(map_nlls) - np.array(temp_scaling_nlls) -) / np.array(map_nlls) -max_loss_temp_scaling_nlls = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_temp_scaling_nlls[rel_improve_temp_scaling_nlls < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_nlls = ( - f"{np.round(np.median(rel_improve_temp_scaling_nlls), 2)}" -) - +win_temp_scaling_nlls = np.array(temp_scaling_nlls) < np.array(map_nlls) - TOL +lose_temp_scaling_nlls = np.array(temp_scaling_nlls) > np.array(map_nlls) + TOL temp_scaling_picp_errors = [ - np.abs(0.95 - metrics["regression"][k]["temp_scaling"]["picp"]) + np.abs(metrics["regression"][k]["temp_scaling"]["picp"] - 0.95) for k in metrics["regression"].keys() ] -temp_scaling_quantiles_picp_errors = np.percentile( - temp_scaling_picp_errors, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_picp_errors = np.sum( - np.array(temp_scaling_picp_errors) / np.array(map_picp_errors) <= 1 -) -winlose_temp_scaling_picp_errors = f"{win_temp_scaling_picp_errors} / {len(map_picp_errors) - win_temp_scaling_picp_errors}" -rel_improve_temp_scaling_picp_errors = ( - np.array(map_picp_errors) - np.array(temp_scaling_picp_errors) -) / np.array(map_picp_errors) -max_loss_temp_scaling_picp_errors = ( - str( - np.round( - 100 - * np.abs( - np.max( - rel_improve_temp_scaling_picp_errors[ - rel_improve_temp_scaling_picp_errors < 0 - ] - ) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_picp_errors = ( - f"{np.round(np.median(rel_improve_temp_scaling_picp_errors), 2)}" -) +win_temp_scaling_picp_errors = np.array(temp_scaling_picp_errors) < np.array(map_picp_errors) - TOL +lose_temp_scaling_picp_errors = np.array(temp_scaling_picp_errors) > np.array(map_picp_errors) + TOL temp_scaling_times = [ metrics["regression"][k]["temp_scaling"]["time"] for k in metrics["regression"].keys() ] +temp_scaling_best_win = np.max(np.array(map_picp_errors) - np.array(temp_scaling_picp_errors)) +temp_scaling_worst_loss = np.min(np.array(map_picp_errors) - np.array(temp_scaling_picp_errors)) + # CQR cqr_picp_errors = [ - np.abs(0.95 - metrics["regression"][k]["cqr"]["picp"]) + np.abs(metrics["regression"][k]["cqr"]["picp"] - 0.95) for k in metrics["regression"].keys() ] -cqr_quantiles_picp_errors = np.percentile( - cqr_picp_errors, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_cqr_picp_errors = np.sum(np.array(cqr_picp_errors) / np.array(map_picp_errors) <= 1) -winlose_cqr_picp_errors = ( - f"{win_cqr_picp_errors} / {len(map_picp_errors) - win_cqr_picp_errors}" -) -rel_improve_cqr_picp_errors = ( - np.array(map_picp_errors) - np.array(cqr_picp_errors) -) / np.array(map_picp_errors) -max_loss_cqr_picp_errors = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_cqr_picp_errors[rel_improve_cqr_picp_errors < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_cqr_picp_errors = f"{np.round(np.median(rel_improve_cqr_picp_errors), 2)}" +win_cqr_picp_errors = np.array(cqr_picp_errors) < np.array(map_picp_errors) - TOL +lose_cqr_picp_errors = np.array(cqr_picp_errors) > np.array(map_picp_errors) + TOL cqr_times = [ - metrics["regression"][k]["cqr"]["time"] for k in metrics["regression"].keys() -] - -# # TEMPERED CQR -temp_cqr_picp_errors = [ - np.abs(0.95 - metrics["regression"][k]["temp_cqr"]["picp"]) + metrics["regression"][k]["cqr"]["time"] for k in metrics["regression"].keys() ] -temp_cqr_quantiles_picp_errors = np.percentile( - temp_cqr_picp_errors, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_cqr_picp_errors = f"{np.sum(np.array(temp_cqr_picp_errors) / np.array(map_picp_errors) <= 1)} / {len(map_picp_errors)}" -med_improv_temp_cqr_picp_errors = f"{np.round(np.median((np.array(map_picp_errors) - np.array(temp_cqr_picp_errors)) / np.array(map_picp_errors)), 2)}" - -temp_cqr_times = [ - metrics["regression"][k]["temp_cqr"]["time"] for k in metrics["regression"].keys() -] -plt.figure(figsize=(8, 6)) -plt.suptitle("Quantile-quantile plots of metrics on regression datasets") +cqr_best_win = np.max(np.array(map_picp_errors) - np.array(cqr_picp_errors)) +cqr_worst_loss = np.min(np.array(map_picp_errors) - np.array(cqr_picp_errors)) -plt.subplot(2, 2, 1) -plt.title("NLL") -plt.scatter(map_quantiles_nlls, temp_scaling_quantiles_nlls, s=3) -_min, _max = min(map_quantiles_nlls.min(), temp_scaling_quantiles_nlls.min()), max( - map_quantiles_nlls.max(), temp_scaling_quantiles_nlls.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") -plt.grid() +### Regression plots -plt.subplot(2, 2, 2) -plt.title("PICP absolute error") -plt.scatter(map_quantiles_picp_errors, temp_scaling_quantiles_picp_errors, s=3) -_min, _max = min( - map_quantiles_picp_errors.min(), temp_scaling_quantiles_picp_errors.min() -), max(map_quantiles_picp_errors.max(), temp_scaling_quantiles_picp_errors.max()) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") +plt.figure(figsize=(10, 3)) +plt.suptitle("Scatter plots for regression datasets") +plt.subplot(1, 2, 1) +plt.title("PICP errors") plt.grid() - -plt.subplot(2, 2, 4) -plt.title("PICP absolute error") -plt.scatter(map_quantiles_picp_errors, cqr_quantiles_picp_errors, s=3) -_min, _max = min(map_quantiles_picp_errors.min(), cqr_quantiles_picp_errors.min()), max( - map_quantiles_picp_errors.max(), cqr_quantiles_picp_errors.max() -) +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_picp_errors).min(), np.array(temp_scaling_picp_errors).min()), max(np.array(map_picp_errors).max(), np.array(temp_scaling_picp_errors).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("CQR quantiles") +plt.scatter(map_picp_errors, temp_scaling_picp_errors, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_picp_errors, lose_temp_scaling_picp_errors)]) +plt.xscale("log") +plt.yscale("log") + +plt.subplot(1, 2, 2) +plt.title("PICP errors") plt.grid() +plt.xlabel("MAP") +plt.ylabel("CQR") +_min, _max = min(np.array(map_picp_errors).min(), np.array(cqr_picp_errors).min()), max(np.array(map_picp_errors).max(), np.array(cqr_picp_errors).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_picp_errors, cqr_picp_errors, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_cqr_picp_errors, lose_cqr_picp_errors)]) +plt.xscale("log") +plt.yscale("log") plt.tight_layout() - plt.show() -print("~~~REGRESSION~~~\n") -print("## TEMPERATURE SCALING ##") -print( - f"Fraction of times temp_scaling is at least on a par w.r.t. the NLL: {winlose_temp_scaling_nlls}" -) -print( - f"Fraction of times temp_scaling is at least on a par w.r.t. the PICP error: {winlose_temp_scaling_picp_errors}" -) -print() -print( - f"Median of relative NLL improvement given by temp_scaling: {med_improv_temp_scaling_nlls}" -) -print( - f"Median of relative PICP error improvement given by temp_scaling: {med_improv_temp_scaling_picp_errors}" -) -print() -print() -print("## CQR ##") -print( - f"Fraction of times CQR is at least on a par w.r.t. the PICP error: {winlose_cqr_picp_errors}" -) -print() -print( - f"Median of relative PICP error improvement given by temp_scaling: {med_improv_cqr_picp_errors}" -) +plt.figure(figsize=(5, 3)) +plt.suptitle("Scatter plots for regression datasets on other metrics") +plt.title("NLL") +plt.grid() +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_nlls).min(), np.array(temp_scaling_nlls).min()), max(np.array(map_nlls).max(), np.array(temp_scaling_nlls).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_nlls, temp_scaling_nlls, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_nlls, lose_temp_scaling_nlls)]) +plt.xscale("log") +plt.yscale("log") + +plt.tight_layout() +plt.show() # ~~~CLASSIFICATION~~~ @@ -218,32 +112,23 @@ map_nlls = [ metrics["classification"][k]["map"]["nll"] for k in metrics["classification"].keys() ] -map_quantiles_nlls = np.percentile(map_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90]) map_mse = [ metrics["classification"][k]["map"]["mse"] for k in metrics["classification"].keys() ] -map_quantiles_mse = np.percentile(map_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_ece = [ metrics["classification"][k]["map"]["ece"] for k in metrics["classification"].keys() ] -map_quantiles_ece = np.percentile(map_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_rocauc = [ metrics["classification"][k]["map"]["rocauc"] for k in metrics["classification"].keys() if "rocauc" in metrics["classification"][k]["map"] ] -map_quantiles_rocauc = np.percentile(map_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_prauc = [ metrics["classification"][k]["map"]["prauc"] for k in metrics["classification"].keys() if "prauc" in metrics["classification"][k]["map"] ] -map_quantiles_prauc = np.percentile(map_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90]) - map_acc = [ metrics["classification"][k]["map"]["accuracy"] for k in metrics["classification"].keys() @@ -258,155 +143,31 @@ metrics["classification"][k]["temp_scaling"]["nll"] for k in metrics["classification"].keys() ] -temp_scaling_quantiles_nlls = np.percentile( - temp_scaling_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_nlls = np.sum(np.array(temp_scaling_nlls) / np.array(map_nlls) <= 1) -winlose_temp_scaling_nlls = ( - f"{win_temp_scaling_nlls} / {len(map_nlls) - win_temp_scaling_nlls}" -) -rel_improve_temp_scaling_nlls = ( - np.array(map_nlls) - np.array(temp_scaling_nlls) -) / np.array(map_nlls) -max_loss_temp_scaling_nlls = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_temp_scaling_nlls[rel_improve_temp_scaling_nlls < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_nlls = ( - f"{np.round(np.median(rel_improve_temp_scaling_nlls), 2)}" -) +win_temp_scaling_nlls = np.array(temp_scaling_nlls) < np.array(map_nlls) - TOL +lose_temp_scaling_nlls = np.array(temp_scaling_nlls) > np.array(map_nlls) + TOL temp_scaling_mse = [ metrics["classification"][k]["temp_scaling"]["mse"] for k in metrics["classification"].keys() ] -temp_scaling_quantiles_mse = np.percentile( - temp_scaling_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_mse = np.sum(np.array(temp_scaling_mse) / np.array(map_mse) <= 1) -winlose_temp_scaling_mse = ( - f"{win_temp_scaling_mse} / {len(map_mse) - win_temp_scaling_mse}" -) -rel_improve_temp_scaling_mse = ( - np.array(map_mse) - np.array(temp_scaling_mse) -) / np.array(map_mse) -max_loss_temp_scaling_mse = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_temp_scaling_mse[rel_improve_temp_scaling_mse < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_mse = f"{np.round(np.median(rel_improve_temp_scaling_mse), 2)}" - -temp_scaling_ece = [ - metrics["classification"][k]["temp_scaling"]["ece"] - for k in metrics["classification"].keys() -] -temp_scaling_quantiles_ece = np.percentile( - temp_scaling_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_ece = np.sum(np.array(temp_scaling_ece) / np.array(map_ece) <= 1) -winlose_temp_scaling_ece = ( - f"{win_temp_scaling_ece} / {len(map_ece) - win_temp_scaling_ece}" -) -rel_improve_temp_scaling_ece = ( - np.array(map_ece) - np.array(temp_scaling_ece) -) / np.array(map_ece) -max_loss_temp_scaling_ece = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_temp_scaling_ece[rel_improve_temp_scaling_ece < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_ece = f"{np.round(np.median(rel_improve_temp_scaling_ece), 2)}" +win_temp_scaling_mse = np.array(temp_scaling_mse) < np.array(map_mse) - TOL +lose_temp_scaling_mse = np.array(temp_scaling_mse) > np.array(map_mse) + TOL temp_scaling_rocauc = [ metrics["classification"][k]["temp_scaling"]["rocauc"] for k in metrics["classification"].keys() if "rocauc" in metrics["classification"][k]["temp_scaling"] ] -temp_scaling_quantiles_rocauc = np.percentile( - temp_scaling_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_rocauc = np.sum( - np.array(temp_scaling_rocauc) / np.array(map_rocauc) <= 1 -) -winlose_temp_scaling_rocauc = ( - f"{win_temp_scaling_rocauc} / {len(map_rocauc) - win_temp_scaling_rocauc}" -) -rel_improve_temp_scaling_rocauc = ( - np.array(map_rocauc) - np.array(temp_scaling_rocauc) -) / np.array(map_rocauc) -max_loss_temp_scaling_rocauc = ( - str( - np.round( - 100 - * np.abs( - np.max( - rel_improve_temp_scaling_rocauc[rel_improve_temp_scaling_rocauc < 0] - ) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_rocauc = ( - f"{np.round(np.median(rel_improve_temp_scaling_rocauc), 2)}" -) +win_temp_scaling_rocauc = np.array(temp_scaling_rocauc) < np.array(map_rocauc) - TOL +lose_temp_scaling_rocauc = np.array(temp_scaling_rocauc) > np.array(map_rocauc) + TOL temp_scaling_prauc = [ metrics["classification"][k]["temp_scaling"]["prauc"] for k in metrics["classification"].keys() if "prauc" in metrics["classification"][k]["temp_scaling"] ] -temp_scaling_quantiles_prauc = np.percentile( - temp_scaling_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_temp_scaling_prauc = np.sum(np.array(temp_scaling_prauc) / np.array(map_prauc) <= 1) -winlose_temp_scaling_prauc = ( - f"{win_temp_scaling_prauc} / {len(map_prauc) - win_temp_scaling_prauc}" -) -rel_improve_temp_scaling_prauc = ( - np.array(map_prauc) - np.array(temp_scaling_prauc) -) / np.array(map_prauc) -max_loss_temp_scaling_prauc = ( - str( - np.round( - 100 - * np.abs( - np.max( - rel_improve_temp_scaling_prauc[rel_improve_temp_scaling_prauc < 0] - ) - ), - 2, - ) - ) - + "%" -) -med_improv_temp_scaling_prauc = ( - f"{np.round(np.median(rel_improve_temp_scaling_prauc), 2)}" -) +win_temp_scaling_prauc = np.array(temp_scaling_prauc) < np.array(map_prauc) - TOL +lose_temp_scaling_prauc = np.array(temp_scaling_prauc) > np.array(map_prauc) + TOL temp_scaling_acc = [ metrics["classification"][k]["temp_scaling"]["accuracy"] @@ -418,132 +179,39 @@ for k in metrics["classification"].keys() ] +temp_scaling_best_win = np.max(np.array(map_mse) - np.array(temp_scaling_mse)) +temp_scaling_worst_loss = np.min(np.array(map_mse) - np.array(temp_scaling_mse)) + # MULTICALIBRATE CONF mc_conf_nlls = [ metrics["classification"][k]["mc_conf"]["nll"] for k in metrics["classification"].keys() ] -mc_conf_nlls = np.array(mc_conf_nlls) -mc_conf_quantiles_nlls = np.percentile( - mc_conf_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_mc_conf_nlls = np.sum(np.array(mc_conf_nlls) / np.array(np.array(map_nlls)) <= 1) -winlose_mc_conf_nlls = f"{win_mc_conf_nlls} / {len(mc_conf_nlls) - win_mc_conf_nlls}" -rel_improve_mc_conf_nlls = (np.array(map_nlls) - np.array(mc_conf_nlls)) / np.array( - map_nlls -) -max_loss_mc_conf_nlls = ( - str( - np.round( - 100 - * np.abs(np.max(rel_improve_mc_conf_nlls[rel_improve_mc_conf_nlls < 0])), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_nlls = f"{np.round(np.median(rel_improve_mc_conf_nlls), 2)}" +win_mc_conf_nlls = np.array(mc_conf_nlls) < np.array(map_nlls) - TOL +lose_mc_conf_nlls = np.array(mc_conf_nlls) > np.array(map_nlls) + TOL mc_conf_mse = [ metrics["classification"][k]["mc_conf"]["mse"] for k in metrics["classification"].keys() ] -mc_conf_mse = np.array(mc_conf_mse) -mc_conf_quantiles_mse = np.percentile(mc_conf_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90]) -win_mc_conf_mse = np.sum(np.array(mc_conf_mse) / np.array(np.array(map_mse)) <= 1) -winlose_mc_conf_mse = f"{win_mc_conf_mse} / {len(mc_conf_mse) - win_mc_conf_mse}" -rel_improve_mc_conf_mse = (np.array(map_mse) - np.array(mc_conf_mse)) / np.array( - map_mse -) -max_loss_mc_conf_mse = ( - str( - np.round( - 100 * np.abs(np.max(rel_improve_mc_conf_mse[rel_improve_mc_conf_mse < 0])), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_mse = f"{np.round(np.median(rel_improve_mc_conf_mse), 2)}" - -mc_conf_ece = [ - metrics["classification"][k]["mc_conf"]["ece"] - for k in metrics["classification"].keys() -] -mc_conf_quantiles_ece = np.percentile(mc_conf_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90]) -win_mc_conf_ece = np.sum(np.array(mc_conf_ece) / np.array(map_ece) <= 1) -winlose_mc_conf_ece = f"{win_mc_conf_ece} / {len(map_ece) - win_mc_conf_ece}" -rel_improve_mc_conf_ece = (np.array(map_ece) - np.array(mc_conf_ece)) / np.array( - map_ece -) -max_loss_mc_conf_ece = ( - str( - np.round( - 100 * np.abs(np.max(rel_improve_mc_conf_ece[rel_improve_mc_conf_ece < 0])), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_ece = f"{np.round(np.median(rel_improve_mc_conf_ece), 2)}" +win_mc_conf_mse = np.array(mc_conf_mse) < np.array(map_mse) - TOL +lose_mc_conf_mse = np.array(mc_conf_mse) > np.array(map_mse) + TOL mc_conf_rocauc = [ metrics["classification"][k]["mc_conf"]["rocauc"] for k in metrics["classification"].keys() + if "rocauc" in metrics["classification"][k]["mc_conf"] ] -mc_conf_rocauc = np.array(mc_conf_rocauc) -mc_conf_quantiles_rocauc = np.percentile( - mc_conf_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_mc_conf_rocauc = np.sum( - np.array(mc_conf_rocauc) / np.array(np.array(map_rocauc)) <= 1 -) -winlose_mc_conf_rocauc = ( - f"{win_mc_conf_rocauc} / {len(mc_conf_rocauc) - win_mc_conf_rocauc}" -) -rel_improve_mc_conf_rocauc = ( - np.array(map_rocauc) - np.array(mc_conf_rocauc) -) / np.array(map_rocauc) -max_loss_mc_conf_rocauc = ( - str( - np.round( - 100 - * np.abs( - np.max(rel_improve_mc_conf_rocauc[rel_improve_mc_conf_rocauc < 0]) - ), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_rocauc = f"{np.round(np.median(rel_improve_mc_conf_rocauc), 2)}" +win_mc_conf_rocauc = np.array(mc_conf_rocauc) < np.array(map_rocauc) - TOL +lose_mc_conf_rocauc = np.array(mc_conf_rocauc) > np.array(map_rocauc) + TOL mc_conf_prauc = [ metrics["classification"][k]["mc_conf"]["prauc"] for k in metrics["classification"].keys() + if "prauc" in metrics["classification"][k]["mc_conf"] ] -mc_conf_prauc = np.array(mc_conf_prauc) -mc_conf_quantiles_prauc = np.percentile( - mc_conf_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -win_mc_conf_prauc = np.sum(np.array(mc_conf_prauc) / np.array(np.array(map_prauc)) <= 1) -winlose_mc_conf_prauc = ( - f"{win_mc_conf_prauc} / {len(mc_conf_prauc) - win_mc_conf_prauc}" -) -rel_improve_mc_conf_prauc = (np.array(map_prauc) - np.array(mc_conf_prauc)) / np.array( - map_prauc -) -max_loss_mc_conf_prauc = ( - str( - np.round( - 100 - * np.abs(np.max(rel_improve_mc_conf_prauc[rel_improve_mc_conf_prauc < 0])), - 2, - ) - ) - + "%" -) -med_improv_mc_conf_prauc = f"{np.round(np.median(rel_improve_mc_conf_prauc), 2)}" +win_mc_conf_prauc = np.array(mc_conf_prauc) < np.array(map_prauc) - TOL +lose_mc_conf_prauc = np.array(mc_conf_prauc) > np.array(map_prauc) + TOL mc_conf_acc = [ metrics["classification"][k]["mc_conf"]["accuracy"] @@ -555,244 +223,93 @@ for k in metrics["classification"].keys() ] -# TEMPERED MULTICALIBRATE CONF -temp_mc_conf_nlls = [ - metrics["classification"][k]["temp_mc_conf"]["nll"] - for k in metrics["classification"].keys() -] -temp_mc_conf_nlls = np.array(temp_mc_conf_nlls) -temp_mc_conf_quantiles_nlls = np.percentile( - temp_mc_conf_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_nlls = f"{np.sum(np.array(temp_mc_conf_nlls) / np.array(np.array(map_nlls)) <= 1)} / {len(temp_mc_conf_nlls)}" -med_improv_temp_mc_conf_nlls = f"{np.round(np.median((np.array(map_nlls) - np.array(temp_mc_conf_nlls)) / np.array(map_nlls)), 2)}" - -temp_mc_conf_mse = [ - metrics["classification"][k]["temp_mc_conf"]["mse"] - for k in metrics["classification"].keys() -] -temp_mc_conf_mse = np.array(temp_mc_conf_mse) -temp_mc_conf_quantiles_mse = np.percentile( - temp_mc_conf_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_mse = f"{np.sum(np.array(temp_mc_conf_mse) / np.array(np.array(map_mse)) <= 1)} / {len(temp_mc_conf_mse)}" -med_improv_temp_mc_conf_mse = f"{np.round(np.median((np.array(map_mse) - np.array(temp_mc_conf_mse)) / np.array(map_mse)), 2)}" - -temp_mc_conf_ece = [ - metrics["classification"][k]["temp_mc_conf"]["ece"] - for k in metrics["classification"].keys() -] -temp_mc_conf_quantiles_ece = np.percentile( - temp_mc_conf_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_ece = ( - f"{np.sum(np.array(temp_mc_conf_ece) / np.array(map_ece) <= 1)} / {len(map_ece)}" -) -med_improv_temp_mc_conf_ece = f"{np.round(np.median((np.array(map_ece) - np.array(temp_mc_conf_ece)) / np.array(map_ece)), 2)}" - -temp_mc_conf_rocauc = [ - metrics["classification"][k]["temp_mc_conf"]["rocauc"] - for k in metrics["classification"].keys() -] -temp_mc_conf_rocauc = np.array(temp_mc_conf_rocauc) -temp_mc_conf_quantiles_rocauc = np.percentile( - temp_mc_conf_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_rocauc = f"{np.sum(np.array(temp_mc_conf_rocauc) / np.array(np.array(map_rocauc)) <= 1)} / {len(temp_mc_conf_rocauc)}" -med_improv_temp_mc_conf_rocauc = f"{np.round(np.median((np.array(map_rocauc) - np.array(temp_mc_conf_rocauc)) / np.array(map_rocauc)), 2)}" - -temp_mc_conf_prauc = [ - metrics["classification"][k]["temp_mc_conf"]["prauc"] - for k in metrics["classification"].keys() -] -temp_mc_conf_prauc = np.array(temp_mc_conf_prauc) -temp_mc_conf_quantiles_prauc = np.percentile( - temp_mc_conf_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_temp_mc_conf_prauc = f"{np.sum(np.array(temp_mc_conf_prauc) / np.array(np.array(map_prauc)) <= 1)} / {len(temp_mc_conf_prauc)}" -med_improv_temp_mc_conf_prauc = f"{np.round(np.median((np.array(map_prauc) - np.array(temp_mc_conf_prauc)) / np.array(map_prauc)), 2)}" - -temp_mc_conf_acc = [ - metrics["classification"][k]["temp_mc_conf"]["accuracy"] - for k in metrics["classification"].keys() -] - -temp_mc_conf_times = [ - metrics["classification"][k]["temp_mc_conf"]["time"] - for k in metrics["classification"].keys() -] - -# MULTICALIBRATE PROB -idx_overlap = [ - i - for i, k in enumerate(metrics["classification"]) - if len(metrics["classification"][k]["mc_prob"]) -] +mc_conf_best_win = np.max(np.array(map_mse) - np.array(mc_conf_mse)) +mc_conf_worst_loss = np.min(np.array(map_mse) - np.array(mc_conf_mse)) -mc_prob_nlls = [ - metrics["classification"][k]["mc_prob"]["nll"] - for k in metrics["classification"].keys() - if "nll" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_nlls = np.percentile( - mc_prob_nlls, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_mc_prob_nlls = f"{np.sum(np.array(mc_prob_nlls) / np.array(map_nlls)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_nlls = f"{np.round(np.median((np.array(map_nlls)[idx_overlap] - np.array(mc_prob_nlls)) / np.array(map_nlls)[idx_overlap]), 2)}" - -mc_prob_mse = [ - metrics["classification"][k]["mc_prob"]["mse"] - for k in metrics["classification"].keys() - if "mse" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_mse = np.percentile(mc_prob_mse, [10, 20, 30, 40, 50, 60, 70, 80, 90]) -winlose_mc_prob_mse = f"{np.sum(np.array(mc_prob_mse) / np.array(map_mse)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_mse = f"{np.round(np.median((np.array(map_mse)[idx_overlap] - np.array(mc_prob_mse)) / np.array(map_mse)[idx_overlap]), 2)}" +### Classification plots -mc_prob_ece = [ - metrics["classification"][k]["mc_prob"]["ece"] - for k in metrics["classification"].keys() - if "ece" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_ece = np.percentile(mc_prob_ece, [10, 20, 30, 40, 50, 60, 70, 80, 90]) -winlose_mc_prob_ece = f"{np.sum(np.array(mc_prob_ece) / np.array(map_ece)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_ece = f"{np.round(np.median((np.array(map_ece)[idx_overlap] - np.array(mc_prob_ece)) / np.array(map_ece)[idx_overlap]), 2)}" +plt.figure(figsize=(10, 3)) +plt.suptitle("Scatter plots for classification datasets") +plt.subplot(1, 2, 1) +plt.title("MSE") +plt.grid() +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_mse).min(), np.array(temp_scaling_mse).min()), max(np.array(map_mse).max(), np.array(temp_scaling_mse).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_mse, temp_scaling_mse, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_mse, lose_temp_scaling_mse)]) -mc_prob_rocauc = [ - metrics["classification"][k]["mc_prob"]["rocauc"] - for k in metrics["classification"].keys() - if "rocauc" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_rocauc = np.percentile( - mc_prob_rocauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_mc_prob_rocauc = f"{np.sum(np.array(mc_prob_rocauc) / np.array(map_rocauc)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_rocauc = f"{np.round(np.median((np.array(map_rocauc)[idx_overlap] - np.array(mc_prob_rocauc)) / np.array(map_rocauc)[idx_overlap]), 2)}" - -mc_prob_prauc = [ - metrics["classification"][k]["mc_prob"]["prauc"] - for k in metrics["classification"].keys() - if "prauc" in metrics["classification"][k]["mc_prob"] -] -mc_prob_quantiles_prauc = np.percentile( - mc_prob_prauc, [10, 20, 30, 40, 50, 60, 70, 80, 90] -) -winlose_mc_prob_prauc = f"{np.sum(np.array(mc_prob_prauc) / np.array(map_prauc)[idx_overlap] <= 1)} / {len(idx_overlap)}" -med_improv_mc_prob_prauc = f"{np.round(np.median((np.array(map_prauc)[idx_overlap] - np.array(mc_prob_prauc)) / np.array(map_prauc)[idx_overlap]), 2)}" - -mc_prob_acc = [ - metrics["classification"][k]["mc_prob"]["accuracy"] - for k in metrics["classification"].keys() - if "accuracy" in metrics["classification"][k]["mc_prob"] -] +plt.subplot(1, 2, 2) +plt.title("MSE") +plt.grid() +plt.xlabel("MAP") +plt.ylabel("TLMC") +_min, _max = min(np.array(map_mse).min(), np.array(mc_conf_mse).min()), max(np.array(map_mse).max(), np.array(mc_conf_mse).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_mse, mc_conf_mse, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_mc_conf_mse, lose_mc_conf_mse)]) -mc_prob_times = [ - metrics["classification"][k]["mc_prob"]["time"] - for k in metrics["classification"].keys() - if "time" in metrics["classification"][k]["mc_prob"] -] +plt.tight_layout() +plt.show() plt.figure(figsize=(10, 6)) -plt.suptitle("Quantile-quantile plots of metrics on classification datasets") - -plt.subplot(2, 4, 1) +plt.suptitle("Scatter plots for classification datasets on other metrics") +plt.subplot(3, 2, 1) plt.title("NLL") -plt.scatter(map_quantiles_nlls, temp_scaling_quantiles_nlls, s=3) -_min, _max = min(map_quantiles_nlls.min(), temp_scaling_quantiles_nlls.min()), max( - map_quantiles_nlls.max(), temp_scaling_quantiles_nlls.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") plt.grid() - -plt.subplot(2, 4, 2) -plt.title("MSE") -plt.scatter(map_quantiles_mse, temp_scaling_quantiles_mse, s=3) -_min, _max = min(map_quantiles_mse.min(), temp_scaling_quantiles_mse.min()), max( - map_quantiles_mse.max(), temp_scaling_quantiles_mse.max() -) +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_nlls).min(), np.array(temp_scaling_nlls).min()), max(np.array(map_nlls).max(), np.array(temp_scaling_nlls).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") +plt.scatter(map_nlls, temp_scaling_nlls, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_nlls, lose_temp_scaling_nlls)]) +plt.xscale("log") +plt.yscale("log") + +plt.subplot(3, 2, 2) +plt.title("NLL") plt.grid() +plt.xlabel("MAP") +plt.ylabel("TLMC") +_min, _max = min(np.array(map_nlls).min(), np.array(mc_conf_nlls).min()), max(np.array(map_nlls).max(), np.array(mc_conf_nlls).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_nlls, mc_conf_nlls, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_mc_conf_nlls, lose_mc_conf_nlls)]) +plt.xscale("log") +plt.yscale("log") -plt.subplot(2, 4, 3) +plt.subplot(3, 2, 3) plt.title("ROCAUC") -plt.scatter(map_quantiles_rocauc, temp_scaling_quantiles_rocauc, s=3) -_min, _max = min(map_quantiles_rocauc.min(), temp_scaling_quantiles_rocauc.min()), max( - map_quantiles_rocauc.max(), temp_scaling_quantiles_rocauc.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") plt.grid() - -plt.subplot(2, 4, 4) -plt.title("PRAUC") -plt.scatter(map_quantiles_prauc, temp_scaling_quantiles_prauc, s=3) -_min, _max = min(map_quantiles_prauc.min(), temp_scaling_quantiles_prauc.min()), max( - map_quantiles_prauc.max(), temp_scaling_quantiles_prauc.max() -) +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_rocauc).min(), np.array(temp_scaling_rocauc).min()), max(np.array(map_rocauc).max(), np.array(temp_scaling_rocauc).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("temp scaling quantiles") -plt.grid() +plt.scatter(map_rocauc, temp_scaling_rocauc, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_rocauc, lose_temp_scaling_rocauc)]) -plt.subplot(2, 4, 5) -plt.title("NLL") -plt.scatter(map_quantiles_nlls, mc_conf_quantiles_nlls, s=3) -_min, _max = min(map_quantiles_nlls.min(), mc_conf_quantiles_nlls.min()), max( - map_quantiles_nlls.max(), mc_conf_quantiles_nlls.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") +plt.subplot(3, 2, 4) +plt.title("ROCAUC") plt.grid() - -plt.subplot(2, 4, 6) -plt.title("ECE") -plt.scatter(map_quantiles_ece, mc_conf_quantiles_ece, s=3) -_min, _max = min(map_quantiles_ece.min(), mc_conf_quantiles_ece.min()), max( - map_quantiles_ece.max(), mc_conf_quantiles_ece.max() -) +plt.xlabel("MAP") +plt.ylabel("TLMC") +_min, _max = min(np.array(map_rocauc).min(), np.array(mc_conf_rocauc).min()), max(np.array(map_rocauc).max(), np.array(mc_conf_rocauc).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") -plt.grid() +plt.scatter(map_rocauc, mc_conf_rocauc, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_mc_conf_rocauc, lose_mc_conf_rocauc)]) -plt.subplot(2, 4, 6) -plt.title("MSE") -plt.scatter(map_quantiles_mse, mc_conf_quantiles_mse, s=3) -_min, _max = min(map_quantiles_mse.min(), mc_conf_quantiles_mse.min()), max( - map_quantiles_mse.max(), mc_conf_quantiles_mse.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") +plt.subplot(3, 2, 5) +plt.title("PRAUC") plt.grid() - -plt.subplot(2, 4, 7) -plt.title("ROCAUC") -plt.scatter(map_quantiles_rocauc, mc_conf_quantiles_rocauc, s=3) -_min, _max = min(map_quantiles_rocauc.min(), mc_conf_quantiles_rocauc.min()), max( - map_quantiles_rocauc.max(), mc_conf_quantiles_rocauc.max() -) +plt.xlabel("MAP") +plt.ylabel("temp scaling") +_min, _max = min(np.array(map_prauc).min(), np.array(temp_scaling_prauc).min()), max(np.array(map_prauc).max(), np.array(temp_scaling_prauc).max()) plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") -plt.grid() +plt.scatter(map_prauc, temp_scaling_prauc, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_temp_scaling_prauc, lose_temp_scaling_prauc)]) -plt.subplot(2, 4, 8) +plt.subplot(3, 2, 6) plt.title("PRAUC") -plt.scatter(map_quantiles_prauc, mc_conf_quantiles_prauc, s=3) -_min, _max = min(map_quantiles_prauc.min(), mc_conf_quantiles_prauc.min()), max( - map_quantiles_prauc.max(), mc_conf_quantiles_prauc.max() -) -plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) -plt.xlabel("MAP quantiles") -plt.ylabel("TLMC quantiles") plt.grid() +plt.xlabel("MAP") +plt.ylabel("TLMC") +_min, _max = min(np.array(map_prauc).min(), np.array(mc_conf_prauc).min()), max(np.array(map_prauc).max(), np.array(mc_conf_prauc).max()) +plt.plot([_min, _max], [_min, _max], color="gray", linestyle="--", alpha=0.2) +plt.scatter(map_prauc, mc_conf_prauc, s=3, color=["C2" if w else "C3" if l else "grey" for w, l in zip(win_mc_conf_prauc, lose_mc_conf_prauc)]) plt.tight_layout() plt.show() From 452a8821e6fc5a737f2d3ff92b6d37bf04162600 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Wed, 22 Nov 2023 19:44:30 +0100 Subject: [PATCH 6/8] add option to balance the calibration data --- benchmarks/hallucination/mmlu/run.py | 23 +++--- fortuna/hallucination/base.py | 111 ++++++++++++++++----------- 2 files changed, 76 insertions(+), 58 deletions(-) diff --git a/benchmarks/hallucination/mmlu/run.py b/benchmarks/hallucination/mmlu/run.py index e5db717e..bf7fa06b 100644 --- a/benchmarks/hallucination/mmlu/run.py +++ b/benchmarks/hallucination/mmlu/run.py @@ -61,18 +61,17 @@ samples = [samples[i] for i in perm] calib_size = int(np.ceil(CALIB_FRAC * tot_size)) - calib_choices, calib_questions, calib_targets = [], [], [] - test_choices, test_questions, test_targets = [], [], [] + calib_answers, calib_questions, calib_targets = [], [], [] + test_answers, test_questions, test_targets = [], [], [] for i, sample in enumerate(samples): if i < calib_size: calib_questions.append(sample["question"]) - calib_choices.append(sample["choices"]) - calib_targets.append(sample["targets"]) + calib_answers.append(sample["choices"][0]) + calib_targets.append(int(sample["targets"] == 0)) else: test_questions.append(sample["question"]) - # test the first answer for each question - test_choices.append(sample["choices"][0]) - test_targets.append(sample["targets"] == 0) + test_answers.append(sample["choices"][0]) + test_targets.append(int(sample["targets"] == 0)) test_targets = np.array(test_targets) # calibrate @@ -81,7 +80,7 @@ ) status = calibrator.fit( - texts=calib_choices, + texts=calib_answers, contexts=calib_questions, targets=calib_targets, ) @@ -90,17 +89,17 @@ # test test_probs = calibrator.predict_proba( - texts=test_choices, contexts=test_questions, calibrate=False + texts=test_answers, contexts=test_questions, calibrate=False ) test_preds = calibrator.predict( - texts=test_choices, contexts=test_questions, probs=test_probs + texts=test_answers, contexts=test_questions, probs=test_probs ) calib_test_probs = calibrator.predict_proba( - texts=test_choices, contexts=test_questions + texts=test_answers, contexts=test_questions ) calib_test_preds = calibrator.predict( - texts=test_choices, contexts=test_questions, probs=calib_test_probs + texts=test_answers, contexts=test_questions, probs=calib_test_probs ) # measure diff --git a/fortuna/hallucination/base.py b/fortuna/hallucination/base.py index e0c6666d..97e7af40 100644 --- a/fortuna/hallucination/base.py +++ b/fortuna/hallucination/base.py @@ -32,6 +32,7 @@ def __init__( scoring_fn: Optional[ Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor] ] = None, + seed: int = 0, ): """ A hallucination multicalibrator class. @@ -72,13 +73,16 @@ def __init__( self.grouping_model = None self.multicalibrator = None self._quantiles = None + self.rng = np.random.default_rng(seed) def fit( self, - texts: Union[List[str], List[List[str]]], + texts: List[str], contexts: List[str], - targets: List[str], + targets: List[int], + batch_size: int = 16, quantile_group_scores_threshold: float = 0.8, + balance: bool = True, ) -> Dict: """ Fit the multicalibrator. @@ -99,31 +103,29 @@ def fit( A list of target variables to be used for calibration. If `texts` is a list of strings, `targets` should be binary variables indicating whether each of the strings in the `texts` list should be marked as positive given the corresponding `contexts`. - If `texts` is a list of lists of strings, - then `targets` should be a list of integers indicating the position of the strings in the inner lists that - should be marked as a positive class. + batch_size: int + The batch size. quantile_group_scores_threshold: float A threshold for which to compute the quantiles of the clustering scores. This will determine the groups. + balance: bool + Whether to balance the calibration data. Returns ------- Dict The status returned by fitting the multicalibrator. """ + if balance: + texts, contexts, targets = self._balance_data(texts, contexts, targets) + ( scores, embeddings, - which_choices, - ) = self._compute_scores_embeddings_which_choices( - texts=texts, contexts=contexts + ) = self._compute_scores_embeddings( + texts=texts, contexts=contexts, batch_size=batch_size ) - if len(which_choices): - targets = (which_choices == np.array(targets[: len(which_choices)])).astype( - int - ) - else: - targets = np.array(targets) + targets = np.array(targets, dtype="int32") embeddings = self.embedding_reduction_model.fit_transform(embeddings) embeddings = np.concatenate((embeddings, scores[:, None]), axis=1) @@ -150,6 +152,7 @@ def predict_proba( self, texts: List[str], contexts: List[str], + batch_size: int = 16, calibrate: bool = True, ) -> np.ndarray: """ @@ -163,6 +166,8 @@ def predict_proba( or a list of lists of strings (e.g. a list of multi-choice answers). contexts: List[str] A list of contexts (e.g. a list of questions). + batch_size: int + The batch size. calibrate: bool Whether to calibration the initial probability estimates. @@ -174,12 +179,8 @@ def predict_proba( if self.multicalibrator is None: raise ValueError("`fit` must be called before this method.") - ( - scores, - embeddings, - _, - ) = self._compute_scores_embeddings_which_choices( - texts=texts, contexts=contexts + (scores, embeddings) = self._compute_scores_embeddings( + texts=texts, contexts=contexts, batch_size=batch_size ) if not calibrate: return scores @@ -198,6 +199,7 @@ def predict( self, texts: List[str], contexts: List[str], + batch_size: int = 16, calibrate: bool = True, probs: Optional[np.ndarray] = None, ) -> np.ndarray: @@ -213,6 +215,8 @@ def predict( or a list of lists of strings (e.g. a list of multi-choice answers). contexts: List[str] A list of contexts (e.g. a list of questions). + batch_size: int + The batch size. calibrate: bool Whether to calibration the initial probability estimates. probs: Optional[np.ndarray] @@ -225,7 +229,10 @@ def predict( """ if probs is None: probs = self.predict_proba( - texts=texts, contexts=contexts, calibrate=calibrate + texts=texts, + contexts=contexts, + batch_size=batch_size, + calibrate=calibrate, ) return (probs >= 0.5).astype(int) @@ -235,45 +242,42 @@ def _get_groups(self, group_scores: np.ndarray): bool ) - def _compute_scores_embeddings_which_choices( - self, - texts: Union[List[str], List[List[str]]], - contexts: List[str], - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + def _compute_scores_embeddings( + self, texts: List[str], contexts: List[str], batch_size: int + ) -> Tuple[np.ndarray, np.ndarray]: scores = [] embeddings = [] - which_choices = [] - - for text, context in tqdm(zip(texts, contexts)): - _logits, _scores = self._get_logits_scores(text, context) - _embeddings = _logits.mean(1) - if isinstance(text, list): - which_choice = np.argmax(_scores) - which_choices.append(which_choice) - scores.append(_scores[which_choice]) - embeddings.append(_embeddings[which_choice, None]) - elif isinstance(text, str): - embeddings.append(_embeddings) - scores.append(_scores[0]) + + gen = self._batch(texts, contexts, batch_size) + + for batch_texts, batch_contexts in tqdm( + gen, total=int(np.ceil(len(texts) / batch_size)) + ): + logits, _scores = self._get_logits_scores(batch_texts, batch_contexts) + embeddings.append(logits.mean(1)) + scores.append(_scores) return ( - np.array(scores), + np.concatenate(scores, axis=0).astype("float32"), np.concatenate(embeddings, axis=0), - np.array(which_choices), ) + @staticmethod + def _batch(texts: List[str], contexts: List[str], batch_size: int): + for i in range(0, len(texts), batch_size): + yield texts[i : i + batch_size], contexts[i : i + batch_size] + def _get_logits_scores( - self, text: str, context: str + self, texts: str, contexts: str ) -> Tuple[np.ndarray, np.ndarray]: - context_inputs = self.tokenizer(context, return_tensors="pt", padding=True).to( + context_inputs = self.tokenizer(contexts, return_tensors="pt", padding=True).to( self.generative_model.device ) - text_inputs = self.tokenizer(text, return_tensors="pt", padding=True).to( + text_inputs = self.tokenizer(texts, return_tensors="pt", padding=True).to( self.generative_model.device ) inputs = { - k: torch.cat((context_inputs[k].repeat((v.shape[0], 1)), v), dim=1) - for k, v in text_inputs.items() + k: torch.cat((context_inputs[k], v), dim=1) for k, v in text_inputs.items() } with torch.no_grad(): @@ -287,6 +291,21 @@ def _get_logits_scores( return _logits.cpu().numpy(), _scores.cpu().numpy() + def _balance_data( + self, texts: List[str], contexts: List[str], targets: List[int] + ) -> Tuple[List[str], List[str], List[int]]: + idx0 = [i for i, y in enumerate(targets) if y == 0] + idx1 = [i for i, y in enumerate(targets) if y == 1] + len_diff = len(idx1) - len(idx0) + idx = self.rng.choice( + idx0 if len_diff > 0 else idx1, np.abs(len_diff), replace=True + ) + for i in idx: + texts.append(texts[i]) + contexts.append(contexts[i]) + targets.append(targets[i]) + return texts, contexts, targets + def save(self, path): state = dict( embedding_reduction_model=self.embedding_reduction_model, From 4d9dca99dcf7c857cc3784a7e6487f9446a79395 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Tue, 19 Dec 2023 12:06:26 +0100 Subject: [PATCH 7/8] fix bug in output sampling on cuda --- fortuna/prob_model/predictive/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fortuna/prob_model/predictive/base.py b/fortuna/prob_model/predictive/base.py index ba7cee6e..3d87387c 100644 --- a/fortuna/prob_model/predictive/base.py +++ b/fortuna/prob_model/predictive/base.py @@ -563,7 +563,8 @@ def _sample(key, _inputs): ) if distribute: outputs = jnp.stack( - list(map(lambda key: _sample(shard_prng_key(key), inputs), keys)) + list(map(lambda key: _sample(shard_prng_key(key), inputs), keys)), + axis=1 ) outputs = self._unshard_ensemble_arrays(outputs) else: From d2115911136a5db2917c2d3c247c19961ae19447 Mon Sep 17 00:00:00 2001 From: Gianluca Detommaso Date: Tue, 19 Dec 2023 18:01:32 +0100 Subject: [PATCH 8/8] refactor advi --- .../normalizing_flow/advi/advi_posterior.py | 100 +++++++++++------- 1 file changed, 59 insertions(+), 41 deletions(-) diff --git a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py index 8f06738e..5cd83647 100755 --- a/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py +++ b/fortuna/prob_model/posterior/normalizing_flow/advi/advi_posterior.py @@ -48,6 +48,7 @@ Array, OptaxOptimizer, Params, + Path, Status, ) from fortuna.utils.builtins import get_dynamic_scale_instance_from_model_dtype @@ -209,6 +210,61 @@ def _get_base_and_architecture( ) return base, architecture + def load_state(self, checkpoint_path: Path) -> None: + """ + Load the state of the posterior distribution from a checkpoint path. The checkpoint must be + compatible with the current probabilistic model. + + Parameters + ---------- + checkpoint_path: Path + Path to checkpoint file or directory to restore. + """ + try: + self.restore_checkpoint(checkpoint_path) + except ValueError: + raise ValueError( + f"No checkpoint was found in `checkpoint_path={checkpoint_path}`." + ) + self.state = PosteriorStateRepository(checkpoint_dir=checkpoint_path) + + state = self.state.get() + if state._encoded_which_params is None: + n_params = len(ravel_pytree(state.params)[0]) // 2 + which_params = None + else: + which_params = decode_encoded_tuple_of_lists_of_strings_to_array( + state._encoded_which_params + ) + n_params = len( + ravel_pytree( + nested_unpair( + d=state.params.unfreeze(), + key_paths=which_params, + labels=("mean", "log_std"), + )[1] + )[0] + ) + _base, _architecture = self._get_base_and_architecture(n_params) + _unravel = self._get_unravel( + FrozenDict( + nested_unpair( + d=state.params.unfreeze(), + key_paths=which_params, + labels=("mean", "log_std"), + )[0] + if which_params + else { + k: dict(params=v["params"]["mean"]) for k, v in state.params.items() + } + ), + which_params=which_params, + )[1] + + self._base = _base + self._architecture = _architecture + self._unravel = _unravel + def sample( self, rng: Optional[PRNGKeyArray] = None, @@ -238,47 +294,9 @@ def sample( rng = self.rng.get() state = self.state.get() - if self._base is None or self._unravel is None: - if state._encoded_which_params is None: - n_params = len(ravel_pytree(state.params)[0]) // 2 - which_params = None - else: - which_params = decode_encoded_tuple_of_lists_of_strings_to_array( - state._encoded_which_params - ) - n_params = len( - ravel_pytree( - nested_unpair( - d=state.params.unfreeze(), - key_paths=which_params, - labels=("mean", "log_std"), - )[1] - )[0] - ) - _base, _architecture = self._get_base_and_architecture(n_params) - _unravel = self._get_unravel( - FrozenDict( - nested_unpair( - d=state.params.unfreeze(), - key_paths=which_params, - labels=("mean", "log_std"), - )[0] - if which_params - else { - k: dict(params=v["params"]["mean"]) - for k, v in state.params.items() - } - ), - which_params=which_params, - )[1] - - self._base = _base - self._architecture = _architecture - self._unravel = _unravel - else: - _base = self._base - _architecture = self._architecture - _unravel = self._unravel + _base = self._base + _architecture = self._architecture + _unravel = self._unravel if state._encoded_which_params is None: means = _unravel(