diff --git a/CHANGELOG.md b/CHANGELOG.md index 70412291..d1fdf7e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 root kernel for the squared exponential kernel. (https://github.com/gchq/coreax/pull/883) - Added a new coreset algorithm Kernel Thinning. (https://github.com/gchq/coreax/pull/915) - Added (loose) lower bounds to all direct dependencies. (https://github.com/gchq/coreax/pull/920) +- Added Kernel Thinning to existing benchmarking tests. (https://github.com/gchq/coreax/pull/927) ### Fixed @@ -48,6 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Pylint pre-commit hook is now configured as the Pylint docs recommend. (https://github.com/gchq/coreax/pull/899) - Type annotations so that core coreax package passes Pyright. (https://github.com/gchq/coreax/pull/906) - Type annotations so that the example scripts pass Pyright. (https://github.com/gchq/coreax/pull/921) +- `KernelThinning` now computes swap probability correctly. (https://github.com/gchq/coreax/pull/932) - Incorrectly-implemented tests for the gradients of `PeriodicKernel`. (https://github.com/gchq/coreax/pull/936) ### Changed diff --git a/benchmark/blobs_benchmark.py b/benchmark/blobs_benchmark.py index 029adcae..07548813 100644 --- a/benchmark/blobs_benchmark.py +++ b/benchmark/blobs_benchmark.py @@ -47,6 +47,7 @@ from coreax.metrics import KSD, MMD from coreax.solvers import ( KernelHerding, + KernelThinning, RandomSample, RPCholesky, Solver, @@ -102,6 +103,7 @@ def setup_solvers( coreset_size: int, sq_exp_kernel: SquaredExponentialKernel, stein_kernel: SteinKernel, + delta: float, random_seed: int = 45, ) -> list[tuple[str, _Solver]]: """ @@ -109,13 +111,16 @@ def setup_solvers( :param coreset_size: The size of the coresets to be generated by the solvers. :param sq_exp_kernel: A Squared Exponential kernel for KernelHerding and RPCholesky. + The square root kernel for KernelThinning is also derived from this kernel. :param stein_kernel: A Stein kernel object used for the SteinThinning solver. + :param delta: The delta parameter for KernelThinning solver. :param random_seed: An integer seed for the random number generator. :return: A list of tuples, where each tuple contains the name of the solver and the corresponding solver object. """ random_key = jax.random.PRNGKey(random_seed) + sqrt_kernel = sq_exp_kernel.get_sqrt_kernel(dim=2) return [ ( "KernelHerding", @@ -141,6 +146,16 @@ def setup_solvers( regularise=False, ), ), + ( + "KernelThinning", + KernelThinning( + coreset_size=coreset_size, + kernel=sq_exp_kernel, + random_key=random_key, + delta=delta, + sqrt_kernel=sqrt_kernel, + ), + ), ] diff --git a/benchmark/david_benchmark.py b/benchmark/david_benchmark.py index 2734d54d..5c79d0c3 100644 --- a/benchmark/david_benchmark.py +++ b/benchmark/david_benchmark.py @@ -39,8 +39,8 @@ import numpy as np from jax import random -from benchmark.mnist_benchmark import get_solver_name, initialise_solvers from coreax import Data +from coreax.benchmark_util import get_solver_name, initialise_solvers from examples.david_map_reduce_weighted import downsample_opencv MAX_8BIT = 255 @@ -65,7 +65,6 @@ def benchmark_coreset_algorithms( """ # Base directory of the current script base_dir = os.path.dirname(os.path.abspath(__file__)) - # Convert to absolute paths using os.path.join if not in_path.is_absolute(): in_path = Path(os.path.join(base_dir, in_path)) diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py index 16c95467..6973fe19 100644 --- a/benchmark/mnist_benchmark.py +++ b/benchmark/mnist_benchmark.py @@ -40,7 +40,6 @@ import json import os import time -from collections.abc import Callable from typing import Any, NamedTuple, Optional, Union import equinox as eqx @@ -56,16 +55,7 @@ from torchvision import transforms from coreax import Data -from coreax.kernels import SquaredExponentialKernel, SteinKernel, median_heuristic -from coreax.score_matching import KernelDensityMatching -from coreax.solvers import ( - KernelHerding, - MapReduce, - RandomSample, - RPCholesky, - Solver, - SteinThinning, -) +from coreax.benchmark_util import get_solver_name, initialise_solvers # Convert PyTorch dataset to JAX arrays @@ -426,91 +416,6 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr return train_data_jax, train_targets_jax, test_data_jax, test_targets_jax -def initialise_solvers( - train_data_umap: Data, key: jax.random.PRNGKey -) -> list[Callable[[int], Solver]]: - """ - Initialise and return a list of solvers for various coreset algorithms. - - Set up solvers for Kernel Herding, Stein Thinning, Random Sampling, and Randomised - Cholesky methods. Each solver has different parameter requirements. Some solvers - can utilise MapReduce, while others cannot,and some require specific kernels. - This setup allows them to be called by passing only the coreset size, - enabling easy integration in a loop for benchmarking. - - :param train_data_umap: The UMAP-transformed training data used for - length scale estimation for ``SquareExponentialKernel``. - :param key: The random key for initialising random solvers. - :return: A list of solvers functions for different coreset algorithms. - """ - # Set up kernel using median heuristic - num_data_points = len(train_data_umap) - num_samples_length_scale = min(num_data_points, 300) - random_seed = 45 - generator = np.random.default_rng(random_seed) - idx = generator.choice(num_data_points, num_samples_length_scale, replace=False) - length_scale = median_heuristic(train_data_umap[idx]) - kernel = SquaredExponentialKernel(length_scale=length_scale) - - def _get_herding_solver(_size: int) -> MapReduce: - """ - Set up KernelHerding to use ``MapReduce``. - - Create a KernelHerding solver with the specified size and return - it along with a MapReduce object for reducing a large dataset like - MNIST dataset. - - :param _size: The size of the coreset to be generated. - :return: A tuple containing the solver name and the MapReduce solver. - """ - herding_solver = KernelHerding(_size, kernel) - return MapReduce(herding_solver, leaf_size=3 * _size) - - def _get_stein_solver(_size: int) -> MapReduce: - """ - Set up Stein Thinning to use ``MapReduce``. - - Create a SteinThinning solver with the specified coreset size, - using ``KernelDensityMatching`` score function for matching on - a subset of the dataset. - - :param _size: The size of the coreset to be generated. - :return: A tuple containing the solver name and the MapReduce solver. - """ - # Generate small dataset for ScoreMatching for Stein Kernel - - score_function = KernelDensityMatching(length_scale=length_scale).match( - train_data_umap[idx] - ) - stein_kernel = SteinKernel(kernel, score_function) - stein_solver = SteinThinning( - coreset_size=_size, kernel=stein_kernel, regularise=False - ) - return MapReduce(stein_solver, leaf_size=3 * _size) - - def _get_random_solver(_size: int) -> RandomSample: - """ - Set up Random Sampling to generate a coreset. - - :param _size: The size of the coreset to be generated. - :return: A tuple containing the solver name and the RandomSample solver. - """ - random_solver = RandomSample(_size, key) - return random_solver - - def _get_rp_solver(_size: int) -> RPCholesky: - """ - Set up Randomised Cholesky solver. - - :param _size: The size of the coreset to be generated. - :return: A tuple containing the solver name and the RPCholesky solver. - """ - rp_solver = RPCholesky(coreset_size=_size, kernel=kernel, random_key=key) - return rp_solver - - return [_get_random_solver, _get_rp_solver, _get_herding_solver, _get_stein_solver] - - def train_model( data_bundle: dict[str, jnp.ndarray], key: jax.random.PRNGKey, @@ -593,25 +498,6 @@ def save_results(results: dict) -> None: print(f"Data has been saved to {file_name}") -def get_solver_name(solver: Callable[[int], Solver]) -> str: - """ - Get the name of the solver. - - This function extracts and returns the name of the solver class. - If ``_solver`` is an instance of :class:`~coreax.solvers.MapReduce`, it retrieves - the name of the :class:`~coreax.solvers.MapReduce.base_solver` class instead. - - :param solver: An instance of a solver, such as `MapReduce` or `RandomSample`. - :return: The name of the solver class. - """ - # Evaluate solver function to get an instance to interrogate - # Don't just inspect type annotations, as they may be incorrect - not robust - solver_instance = solver(1) - if isinstance(solver_instance, MapReduce): - return type(solver_instance.base_solver).__name__ - return type(solver_instance).__name__ - - # pylint: disable=too-many-locals def main() -> None: """ diff --git a/benchmark/mnist_benchmark_results.json b/benchmark/mnist_benchmark_results.json index 63d18543..feadc718 100644 --- a/benchmark/mnist_benchmark_results.json +++ b/benchmark/mnist_benchmark_results.json @@ -3,535 +3,669 @@ "25": { "0": { "accuracy": 0.47499004006385803, - "time_taken": 21.129530332 + "time_taken": 23.26633542699983 }, "1": { "accuracy": 0.5094035863876343, - "time_taken": 21.620098648999942 + "time_taken": 24.154528033000133 }, "2": { - "accuracy": 0.4429771900177002, - "time_taken": 23.587054420000072 + "accuracy": 0.4431772530078888, + "time_taken": 26.308422624999366 }, "3": { - "accuracy": 0.5266107320785522, - "time_taken": 20.613561648000086 + "accuracy": 0.5268108248710632, + "time_taken": 22.50174796200008 }, "4": { "accuracy": 0.4708881676197052, - "time_taken": 23.3405117719999 + "time_taken": 25.894998888999908 } }, "50": { "0": { "accuracy": 0.6092996001243591, - "time_taken": 7.480583503999981 + "time_taken": 7.9610507360002885 }, "1": { "accuracy": 0.6098998785018921, - "time_taken": 9.492431753999995 + "time_taken": 10.142288449999796 }, "2": { "accuracy": 0.5316997766494751, - "time_taken": 7.016472408000027 + "time_taken": 7.382097613000042 }, "3": { "accuracy": 0.6038997769355774, - "time_taken": 7.433440193000024 + "time_taken": 7.2605028240004685 }, "4": { "accuracy": 0.6276999711990356, - "time_taken": 15.224367368999992 + "time_taken": 15.606078997999248 } }, "100": { "0": { "accuracy": 0.704299807548523, - "time_taken": 5.593184323999992 + "time_taken": 6.080028390000734 }, "1": { "accuracy": 0.7321001887321472, - "time_taken": 5.385782182999947 + "time_taken": 5.851846229999865 }, "2": { "accuracy": 0.7243001461029053, - "time_taken": 9.138968591000094 + "time_taken": 9.87153262100037 }, "3": { "accuracy": 0.7279003262519836, - "time_taken": 3.660555053000053 + "time_taken": 3.8054748680006014 }, "4": { "accuracy": 0.6853998899459839, - "time_taken": 6.721555722999938 + "time_taken": 7.10781191300066 } }, "500": { "0": { "accuracy": 0.849459171295166, - "time_taken": 4.1964316430000395 + "time_taken": 4.452629182000237 }, "1": { "accuracy": 0.8390424847602844, - "time_taken": 2.6543577679999544 + "time_taken": 2.864639194999654 }, "2": { "accuracy": 0.8586738705635071, - "time_taken": 3.7907679610000287 + "time_taken": 4.056492333999813 }, "3": { "accuracy": 0.8433493971824646, - "time_taken": 3.211584036999966 + "time_taken": 3.1211573710006633 }, "4": { "accuracy": 0.8508613705635071, - "time_taken": 4.339650910000046 + "time_taken": 4.570619415000692 } }, "1000": { "0": { "accuracy": 0.8806089758872986, - "time_taken": 3.810672871999998 + "time_taken": 4.099615649000043 }, "1": { "accuracy": 0.8755007982254028, - "time_taken": 2.6573232700000062 + "time_taken": 2.84869260799951 }, "2": { "accuracy": 0.8828125, - "time_taken": 2.6886053809998884 + "time_taken": 2.8027847920002387 }, "3": { "accuracy": 0.8731971383094788, - "time_taken": 2.8564365010001893 + "time_taken": 2.9280455670004812 }, "4": { "accuracy": 0.8818109035491943, - "time_taken": 3.027272230000108 + "time_taken": 3.1732710930000394 } }, "5000": { "0": { "accuracy": 0.9258814454078674, - "time_taken": 4.428039572999978 + "time_taken": 4.808443182000701 }, "1": { "accuracy": 0.9238781929016113, - "time_taken": 3.322857870000007 + "time_taken": 3.5261693180000293 }, "2": { "accuracy": 0.9277844429016113, - "time_taken": 3.3129011699998046 + "time_taken": 3.4832477119998657 }, "3": { "accuracy": 0.9291867017745972, - "time_taken": 4.800562907000085 + "time_taken": 5.071235248999983 }, "4": { "accuracy": 0.9294871687889099, - "time_taken": 3.492984100000058 + "time_taken": 3.7022924619996047 } } }, "RPCholesky": { "25": { "0": { - "accuracy": 0.47358962893486023, - "time_taken": 19.895207870000036 + "accuracy": 0.4749898314476013, + "time_taken": 17.966321985999457 }, "1": { - "accuracy": 0.5369148850440979, - "time_taken": 14.838139795000075 + "accuracy": 0.5499197840690613, + "time_taken": 16.858777587000077 }, "2": { - "accuracy": 0.5467190146446228, - "time_taken": 17.321421823000037 + "accuracy": 0.5490198731422424, + "time_taken": 21.207069888999285 }, "3": { - "accuracy": 0.4283711016178131, - "time_taken": 20.84139633099994 + "accuracy": 0.4454781413078308, + "time_taken": 19.69648443999995 }, "4": { - "accuracy": 0.4912963807582855, - "time_taken": 17.36893893699994 + "accuracy": 0.5133053064346313, + "time_taken": 24.770281666999836 } }, "50": { "0": { - "accuracy": 0.625399649143219, - "time_taken": 13.748352533000002 + "accuracy": 0.5790995359420776, + "time_taken": 7.9108207769995715 }, "1": { - "accuracy": 0.6291998624801636, - "time_taken": 8.75706090899996 + "accuracy": 0.6649996638298035, + "time_taken": 8.449880023000333 }, "2": { - "accuracy": 0.5662998557090759, - "time_taken": 6.946477493999964 + "accuracy": 0.6247993111610413, + "time_taken": 9.283834525000202 }, "3": { - "accuracy": 0.6297994256019592, - "time_taken": 11.587993342000118 + "accuracy": 0.6294994950294495, + "time_taken": 14.32392578100007 }, "4": { - "accuracy": 0.5483998656272888, - "time_taken": 14.694229403999998 + "accuracy": 0.5775997638702393, + "time_taken": 13.211208218999673 } }, "100": { "0": { - "accuracy": 0.725000262260437, - "time_taken": 8.696777773000008 + "accuracy": 0.7272999882698059, + "time_taken": 9.433186399000078 }, "1": { - "accuracy": 0.7001000046730042, - "time_taken": 5.543955067999946 + "accuracy": 0.6856998801231384, + "time_taken": 7.170746434999273 }, "2": { - "accuracy": 0.6311002373695374, - "time_taken": 4.008666183000059 + "accuracy": 0.6285001635551453, + "time_taken": 6.053287822000129 }, "3": { - "accuracy": 0.6850997805595398, - "time_taken": 7.576385713000036 + "accuracy": 0.6865997910499573, + "time_taken": 5.618571209000038 }, "4": { - "accuracy": 0.646899938583374, - "time_taken": 9.464091529000143 + "accuracy": 0.6816001534461975, + "time_taken": 8.63684380199993 } }, "500": { "0": { - "accuracy": 0.8492588400840759, - "time_taken": 5.850811007999994 + "accuracy": 0.8495593070983887, + "time_taken": 5.951769608000177 }, "1": { - "accuracy": 0.8279246687889099, - "time_taken": 3.400919842999997 + "accuracy": 0.8297275900840759, + "time_taken": 3.830718762000288 }, "2": { - "accuracy": 0.8026843070983887, - "time_taken": 4.008895577999965 + "accuracy": 0.8172075152397156, + "time_taken": 4.784693309000431 }, "3": { - "accuracy": 0.8110977411270142, - "time_taken": 4.032388095999977 + "accuracy": 0.805588960647583, + "time_taken": 4.134177151000586 }, "4": { - "accuracy": 0.8451522588729858, - "time_taken": 3.68542980899997 + "accuracy": 0.8328325152397156, + "time_taken": 4.010456953999892 } }, "1000": { "0": { - "accuracy": 0.8804086446762085, - "time_taken": 5.837933116000045 + "accuracy": 0.8742988705635071, + "time_taken": 5.9872287300004245 }, "1": { - "accuracy": 0.8630809187889099, - "time_taken": 4.17050370100003 + "accuracy": 0.8658854365348816, + "time_taken": 4.616141986999537 }, "2": { - "accuracy": 0.8666867017745972, - "time_taken": 4.835822229000087 + "accuracy": 0.8723958730697632, + "time_taken": 4.796533407000425 }, "3": { "accuracy": 0.8480569124221802, - "time_taken": 4.71791837700016 + "time_taken": 4.392059116000382 }, "4": { - "accuracy": 0.8788061141967773, - "time_taken": 4.380343813000081 + "accuracy": 0.8732972741127014, + "time_taken": 4.212848541999847 } }, "5000": { "0": { - "accuracy": 0.9239783883094788, - "time_taken": 29.51819806499998 + "accuracy": 0.9250801205635071, + "time_taken": 29.899051755000073 }, "1": { - "accuracy": 0.9287860989570618, - "time_taken": 27.981147562000046 + "accuracy": 0.9318910241127014, + "time_taken": 28.640717678999863 }, "2": { - "accuracy": 0.9247796535491943, - "time_taken": 27.539323438999872 + "accuracy": 0.9252804517745972, + "time_taken": 27.780483479000395 }, "3": { - "accuracy": 0.9285857677459717, - "time_taken": 27.830323820999865 + "accuracy": 0.9284855723381042, + "time_taken": 28.07316877300036 }, "4": { - "accuracy": 0.9238781929016113, - "time_taken": 28.524467420000065 + "accuracy": 0.9221754670143127, + "time_taken": 28.88361871699999 } } }, "KernelHerding": { "25": { "0": { - "accuracy": 0.5501202940940857, - "time_taken": 29.94118039799997 + "accuracy": 0.4633856415748596, + "time_taken": 21.417966922000232 }, "1": { - "accuracy": 0.5024006366729736, - "time_taken": 19.38333930700003 + "accuracy": 0.41996797919273376, + "time_taken": 21.351254336999773 }, "2": { - "accuracy": 0.5388156771659851, - "time_taken": 16.329524966000008 + "accuracy": 0.43827515840530396, + "time_taken": 18.99325110800055 }, "3": { - "accuracy": 0.5127052068710327, - "time_taken": 16.15010204400005 + "accuracy": 0.43787533044815063, + "time_taken": 15.95066889699956 }, "4": { - "accuracy": 0.539915919303894, - "time_taken": 17.088837638000086 + "accuracy": 0.4330735206604004, + "time_taken": 19.170973098000104 } }, "50": { "0": { - "accuracy": 0.6299993395805359, - "time_taken": 17.985974345999978 + "accuracy": 0.5270997285842896, + "time_taken": 8.63385688299968 }, "1": { - "accuracy": 0.6076997518539429, - "time_taken": 10.080095013000005 + "accuracy": 0.523399829864502, + "time_taken": 5.941639144999499 }, "2": { - "accuracy": 0.6290996670722961, - "time_taken": 10.87603085700016 + "accuracy": 0.5230996608734131, + "time_taken": 8.082218694000403 }, "3": { - "accuracy": 0.6168995499610901, - "time_taken": 12.538436354000169 + "accuracy": 0.5010999441146851, + "time_taken": 7.190972131000308 }, "4": { - "accuracy": 0.5985994935035706, - "time_taken": 9.817184417000135 + "accuracy": 0.5044997930526733, + "time_taken": 7.761757675999434 } }, "100": { "0": { - "accuracy": 0.7302002310752869, - "time_taken": 13.26493101899996 + "accuracy": 0.6518000364303589, + "time_taken": 6.5709225259997766 }, "1": { - "accuracy": 0.7350001931190491, - "time_taken": 7.117937374000007 + "accuracy": 0.5968999862670898, + "time_taken": 4.8999388629999885 }, "2": { - "accuracy": 0.723000168800354, - "time_taken": 6.991755536000028 + "accuracy": 0.5978002548217773, + "time_taken": 5.410862422999344 }, "3": { - "accuracy": 0.7188998460769653, - "time_taken": 6.8836636179999005 + "accuracy": 0.6083003878593445, + "time_taken": 4.623381014000188 }, "4": { - "accuracy": 0.716400146484375, - "time_taken": 5.2314696280000135 + "accuracy": 0.6084000468254089, + "time_taken": 5.331589719000476 } }, "500": { "0": { - "accuracy": 0.8127003312110901, - "time_taken": 8.29185179000001 + "accuracy": 0.7951722741127014, + "time_taken": 5.391921381999964 }, "1": { - "accuracy": 0.8148037195205688, - "time_taken": 3.6316947710000704 + "accuracy": 0.8112980723381042, + "time_taken": 3.77097700899958 }, "2": { - "accuracy": 0.8196113705635071, - "time_taken": 3.252079743999957 + "accuracy": 0.8093950152397156, + "time_taken": 3.864875149999534 }, "3": { - "accuracy": 0.8093950152397156, - "time_taken": 3.455513597000163 + "accuracy": 0.8002804517745972, + "time_taken": 3.4463949179998963 }, "4": { - "accuracy": 0.8156049847602844, - "time_taken": 3.3011748870001156 + "accuracy": 0.7914663553237915, + "time_taken": 3.883257230999334 } }, "1000": { "0": { - "accuracy": 0.8576722741127014, - "time_taken": 8.381767883000009 + "accuracy": 0.846754789352417, + "time_taken": 5.342803989999993 }, "1": { - "accuracy": 0.8568710088729858, - "time_taken": 3.56256258399992 + "accuracy": 0.859375, + "time_taken": 3.8688295120000475 }, "2": { - "accuracy": 0.8465545177459717, - "time_taken": 3.135869393999883 + "accuracy": 0.8543670177459717, + "time_taken": 3.8199609699995563 }, "3": { - "accuracy": 0.860276460647583, - "time_taken": 3.4447939570000017 + "accuracy": 0.8598757982254028, + "time_taken": 3.771100228999785 }, "4": { - "accuracy": 0.8583734035491943, - "time_taken": 3.61459772000012 + "accuracy": 0.8578726053237915, + "time_taken": 3.5506639280001764 } }, "5000": { "0": { - "accuracy": 0.9268830418586731, - "time_taken": 9.059691469000086 + "accuracy": 0.9211738705635071, + "time_taken": 8.47471116900033 }, "1": { - "accuracy": 0.9315905570983887, - "time_taken": 5.148555465000072 + "accuracy": 0.922776460647583, + "time_taken": 4.937656685000547 }, "2": { - "accuracy": 0.930588960647583, - "time_taken": 5.6641410270001415 + "accuracy": 0.9282852411270142, + "time_taken": 6.115159175999906 }, "3": { - "accuracy": 0.9274839758872986, - "time_taken": 5.935124415000018 + "accuracy": 0.9300881624221802, + "time_taken": 5.731699566000316 }, "4": { - "accuracy": 0.9286859035491943, - "time_taken": 4.589270231 + "accuracy": 0.9243789911270142, + "time_taken": 4.7779110970004695 } } }, "SteinThinning": { "25": { "0": { - "accuracy": 0.36124444007873535, - "time_taken": 32.886642288000075 + "accuracy": 0.402760773897171, + "time_taken": 19.745292868000433 }, "1": { - "accuracy": 0.3338334560394287, - "time_taken": 41.821553507999965 + "accuracy": 0.36034372448921204, + "time_taken": 19.371978351000507 }, "2": { - "accuracy": 0.31372541189193726, - "time_taken": 34.356144773000096 + "accuracy": 0.3401360809803009, + "time_taken": 14.285965998000393 }, "3": { - "accuracy": 0.27040842175483704, - "time_taken": 33.386861657000054 + "accuracy": 0.3896558880805969, + "time_taken": 16.36114198499945 }, "4": { - "accuracy": 0.35444176197052, - "time_taken": 36.694613785 + "accuracy": 0.3565424978733063, + "time_taken": 20.385888099999647 } }, "50": { "0": { - "accuracy": 0.46370017528533936, - "time_taken": 37.51419123599999 + "accuracy": 0.4324001669883728, + "time_taken": 13.217264007999802 }, "1": { - "accuracy": 0.43250009417533875, - "time_taken": 33.460962724999945 + "accuracy": 0.42600005865097046, + "time_taken": 9.966807724000319 }, "2": { - "accuracy": 0.4118999242782593, - "time_taken": 29.25043843399999 + "accuracy": 0.4104001522064209, + "time_taken": 10.416085636999924 }, "3": { - "accuracy": 0.40100017189979553, - "time_taken": 27.707170819999874 + "accuracy": 0.41020023822784424, + "time_taken": 10.325945335999677 }, "4": { - "accuracy": 0.45050016045570374, - "time_taken": 28.452544991999957 + "accuracy": 0.3909001350402832, + "time_taken": 13.138384072000008 } }, "100": { "0": { - "accuracy": 0.4580000340938568, - "time_taken": 23.420479623999995 + "accuracy": 0.4612000286579132, + "time_taken": 9.312922237000748 }, "1": { - "accuracy": 0.4966999590396881, - "time_taken": 26.137647441000013 + "accuracy": 0.469699889421463, + "time_taken": 8.583355914000094 }, "2": { - "accuracy": 0.500999927520752, - "time_taken": 24.4042449760002 + "accuracy": 0.4544999897480011, + "time_taken": 8.941393133999554 }, "3": { - "accuracy": 0.4515998661518097, - "time_taken": 22.571522415000118 + "accuracy": 0.4674999415874481, + "time_taken": 8.999052006000056 }, "4": { - "accuracy": 0.4860999286174774, - "time_taken": 27.187762750000047 + "accuracy": 0.451499879360199, + "time_taken": 9.512783743 } }, "500": { "0": { - "accuracy": 0.5588942170143127, - "time_taken": 20.343176127999982 + "accuracy": 0.5759214758872986, + "time_taken": 8.666155873000207 }, "1": { - "accuracy": 0.5828325152397156, - "time_taken": 19.380917395999973 + "accuracy": 0.5520833134651184, + "time_taken": 8.722324781999305 }, "2": { - "accuracy": 0.5852363705635071, - "time_taken": 19.87485118199993 + "accuracy": 0.5831330418586731, + "time_taken": 9.471579476999977 }, "3": { - "accuracy": 0.5499799847602844, - "time_taken": 22.195402222999974 + "accuracy": 0.5356570482254028, + "time_taken": 9.300596547000168 }, "4": { - "accuracy": 0.5618990659713745, - "time_taken": 19.422155341999996 + "accuracy": 0.5519831776618958, + "time_taken": 9.42204293100076 } }, "1000": { "0": { - "accuracy": 0.5892428159713745, - "time_taken": 21.69759744099997 + "accuracy": 0.5844351053237915, + "time_taken": 10.180833551000433 }, "1": { - "accuracy": 0.5992588400840759, - "time_taken": 22.000713124999947 + "accuracy": 0.5734174847602844, + "time_taken": 9.615570485999342 }, "2": { - "accuracy": 0.5953525900840759, - "time_taken": 23.080259337999905 + "accuracy": 0.5765224695205688, + "time_taken": 9.886270799000158 }, "3": { - "accuracy": 0.5842347741127014, - "time_taken": 20.16238621299999 + "accuracy": 0.5474759936332703, + "time_taken": 9.45628960400063 }, "4": { - "accuracy": 0.5821314454078674, - "time_taken": 20.503732982999963 + "accuracy": 0.5767227411270142, + "time_taken": 13.535986196000522 } }, "5000": { "0": { - "accuracy": 0.6593549847602844, - "time_taken": 29.339991162000047 + "accuracy": 0.6347155570983887, + "time_taken": 29.896120647000316 + }, + "1": { + "accuracy": 0.6311097741127014, + "time_taken": 29.06243791899942 + }, + "2": { + "accuracy": 0.650240421295166, + "time_taken": 33.30571946100008 + }, + "3": { + "accuracy": 0.6276041865348816, + "time_taken": 30.51324089100035 + }, + "4": { + "accuracy": 0.626802921295166, + "time_taken": 29.449209794000126 + } + } + }, + "KernelThinning": { + "25": { + "0": { + "accuracy": 0.4526808559894562, + "time_taken": 79.91361540300022 + }, + "1": { + "accuracy": 0.40636226534843445, + "time_taken": 11.441792544000236 + }, + "2": { + "accuracy": 0.48499375581741333, + "time_taken": 28.38083182599985 + }, + "3": { + "accuracy": 0.418967604637146, + "time_taken": 31.02636636399984 + }, + "4": { + "accuracy": 0.4386753439903259, + "time_taken": 22.148064952000823 + } + }, + "50": { + "0": { + "accuracy": 0.6550991535186768, + "time_taken": 40.07802116200037 }, "1": { - "accuracy": 0.6642628312110901, - "time_taken": 29.7060228900001 + "accuracy": 0.6222995519638062, + "time_taken": 7.831840456000464 }, "2": { - "accuracy": 0.6545472741127014, - "time_taken": 30.334544582000035 + "accuracy": 0.638399600982666, + "time_taken": 10.783056123999813 }, "3": { - "accuracy": 0.6499398946762085, - "time_taken": 29.475541733 + "accuracy": 0.6058999300003052, + "time_taken": 9.712099656000646 }, "4": { - "accuracy": 0.6581530570983887, - "time_taken": 30.22798960299997 + "accuracy": 0.5935994386672974, + "time_taken": 11.725060616000519 + } + }, + "100": { + "0": { + "accuracy": 0.7074002027511597, + "time_taken": 23.023189279999315 + }, + "1": { + "accuracy": 0.6904999613761902, + "time_taken": 5.541293954999674 + }, + "2": { + "accuracy": 0.7206001281738281, + "time_taken": 5.422012566999911 + }, + "3": { + "accuracy": 0.6836000680923462, + "time_taken": 3.8105385219996606 + }, + "4": { + "accuracy": 0.7140999436378479, + "time_taken": 5.527283147000162 + } + }, + "500": { + "0": { + "accuracy": 0.8555689454078674, + "time_taken": 12.997218095999415 + }, + "1": { + "accuracy": 0.8521634936332703, + "time_taken": 3.805464752000262 + }, + "2": { + "accuracy": 0.8616787195205688, + "time_taken": 4.052976539000156 + }, + "3": { + "accuracy": 0.859375, + "time_taken": 3.713409908000358 + }, + "4": { + "accuracy": 0.8414463400840759, + "time_taken": 3.7632819319996997 + } + }, + "1000": { + "0": { + "accuracy": 0.8792067170143127, + "time_taken": 9.468509635999908 + }, + "1": { + "accuracy": 0.8872195482254028, + "time_taken": 3.8311882420002803 + }, + "2": { + "accuracy": 0.8777043223381042, + "time_taken": 3.494454018000397 + }, + "3": { + "accuracy": 0.8869190812110901, + "time_taken": 3.8445811410001625 + }, + "4": { + "accuracy": 0.8827123641967773, + "time_taken": 3.8723056460003136 + } + }, + "5000": { + "0": { + "accuracy": 0.9207732677459717, + "time_taken": 14.704426218000663 + }, + "1": { + "accuracy": 0.9328926205635071, + "time_taken": 6.878256106000663 + }, + "2": { + "accuracy": 0.9287860989570618, + "time_taken": 6.707938498000658 + }, + "3": { + "accuracy": 0.9293870329856873, + "time_taken": 4.895929524000167 + }, + "4": { + "accuracy": 0.9291867017745972, + "time_taken": 5.907913719000135 } } } diff --git a/benchmark/mnist_benchmark_visualiser.py b/benchmark/mnist_benchmark_visualiser.py index 76287a45..3684e679 100644 --- a/benchmark/mnist_benchmark_visualiser.py +++ b/benchmark/mnist_benchmark_visualiser.py @@ -42,6 +42,10 @@ def compute_statistics( """ Compute statistical summary (mean, min, max). + The parameter data_by_solver contains time and accuracy data for different solvers + for different seeds. The data for the first run is skipped because it is much slower + due to JIT compilation. + :param data_by_solver: A dictionary where each key is an algorithm name, and each value is a dictionary mapping coreset size to benchmark results. Benchmark results include multiple @@ -77,7 +81,9 @@ def compute_statistics( size_str = str(size) accuracies, times = [], [] if size_str in sizes: - for run_data in sizes[size_str].values(): + # Skip the first run as it is much slower due to JIT compilation + run_list = list(sizes[size_str].values())[1:] + for run_data in run_list: accuracies.append(run_data["accuracy"]) times.append(run_data["time_taken"]) accuracy_stats[algo]["points"][size].append(run_data["accuracy"]) @@ -98,6 +104,9 @@ def compute_time_statistics(data: dict, coreset_sizes: list[int]) -> dict: """ Compute statistical summary (mean, min, max) for standalone time data. + The data for the first run is skipped because it is much slower due to JIT + compilation. + :param data: A dictionary containing time data for different algorithms and coreset sizes. :param coreset_sizes: A list of integer coreset sizes to evaluate. @@ -119,7 +128,9 @@ def compute_time_statistics(data: dict, coreset_sizes: list[int]) -> dict: size_str = str(size) times = [] if size_str in sizes: - for time in sizes[size_str].values(): + # Skip the first run as it is much slower due to JIT compilation + run_list = list(sizes[size_str].values())[1:] + for time in run_list: times.append(time) stats[algo]["points"][size].append(time) diff --git a/benchmark/mnist_time_results.json b/benchmark/mnist_time_results.json index 4feb60ac..4406fc12 100644 --- a/benchmark/mnist_time_results.json +++ b/benchmark/mnist_time_results.json @@ -1,150 +1,187 @@ { "RandomSample": { "25": { - "0": 1.5215764229997149, - "1": 0.001615036000202963, - "2": 0.0016977929999484331, - "3": 0.0017370190007568453, - "4": 0.0016495150002810988 + "0": 0.6588367329986795, + "1": 0.0016759669997554738, + "2": 0.001854901998740388, + "3": 0.0016810239994811127, + "4": 0.0016000750001694541 }, "50": { - "0": 1.1865117110000938, - "1": 0.0015167020001172205, - "2": 0.001471213000058924, - "3": 0.00159685000016907, - "4": 0.0014185949994498515 + "0": 0.6135869649988308, + "1": 0.0016016179997677682, + "2": 0.001721464001093409, + "3": 0.00153965900062758, + "4": 0.0016626109991193516 }, "100": { - "0": 1.190416921999713, - "1": 0.001415693000126339, - "2": 0.001445034999960626, - "3": 0.001560416000756959, - "4": 0.0014196720003383234 + "0": 0.6118492949990468, + "1": 0.0015750410002510762, + "2": 0.0018720730004133657, + "3": 0.0015939309996610973, + "4": 0.0015289340008166619 }, "500": { - "0": 1.1969940730000417, - "1": 0.0014159980000840733, - "2": 0.0014652069999101514, - "3": 0.0014559060000465252, - "4": 0.0013854159997208626 + "0": 0.6197238580007252, + "1": 0.0018075179996230872, + "2": 0.001730617999783135, + "3": 0.0015197210013866425, + "4": 0.0016761769984441344 }, "1000": { - "0": 1.20574632700027, - "1": 0.0014430999999603955, - "2": 0.0013995819999763626, - "3": 0.0015107440003703232, - "4": 0.0014337950005938183 + "0": 0.6255511849994946, + "1": 0.0015996750007616356, + "2": 0.001622727999347262, + "3": 0.0015904940009932034, + "4": 0.0015960620003170334 } }, "RPCholesky": { "25": { - "0": 1.3980586309999126, - "1": 0.001971067999875231, - "2": 0.0020426159999260562, - "3": 0.002098918999763555, - "4": 0.0019817629990939167 + "0": 1.519872093000231, + "1": 0.009014817998831859, + "2": 0.009139535000940668, + "3": 0.008957120000559371, + "4": 0.008898485999452532 }, "50": { - "0": 1.329528728999776, - "1": 0.002450966000196786, - "2": 0.002407561999916652, - "3": 0.0025540699998600758, - "4": 0.0024392029999944498 + "0": 1.458239367000715, + "1": 0.017449393999413587, + "2": 0.017646361000515753, + "3": 0.017497655000624945, + "4": 0.01741153100010706 }, "100": { - "0": 1.3276212349996968, - "1": 0.0057891330002348695, - "2": 0.00594377200013696, - "3": 0.005744803000197862, - "4": 0.005901627000639564 + "0": 1.48360984700048, + "1": 0.03753473100005067, + "2": 0.037234362000162946, + "3": 0.037164757999562426, + "4": 0.037121964000107255 }, "500": { - "0": 1.9051166680001188, - "1": 0.34004028900017147, - "2": 0.33999469499985935, - "3": 0.34010550800030614, - "4": 0.3408227230002012 + "0": 1.9741088339997077, + "1": 0.3905957330007368, + "2": 0.3912712890014518, + "3": 0.3908946850006032, + "4": 0.39045989899932465 }, "1000": { - "0": 2.522816674000296, - "1": 1.2190807510000923, - "2": 1.216899468000065, - "3": 1.215633935000369, - "4": 1.2168376090003221 + "0": 2.7680352509996737, + "1": 1.2777003790015442, + "2": 1.278581608999957, + "3": 1.27806329699888, + "4": 1.2789005979993817 } }, "KernelHerding": { "25": { - "0": 4.611737683000229, - "1": 0.37138685099989743, - "2": 0.37333088700006556, - "3": 0.37363806999928784, - "4": 0.37325262000013026 + "0": 2.8221302519996243, + "1": 0.2740416950000508, + "2": 0.2751288489998842, + "3": 0.27401620499949786, + "4": 0.2768986720002431 }, "50": { - "0": 3.9535803169997052, - "1": 0.19524956099985502, - "2": 0.19619354599990402, - "3": 0.19783811200068158, - "4": 0.1954313500000353 + "0": 1.577164325000922, + "1": 0.27728562999982387, + "2": 0.2791873059995851, + "3": 0.27608360599879234, + "4": 0.2759777829996892 }, "100": { - "0": 3.9510098840000865, - "1": 0.18832070500002374, - "2": 0.18669702299985147, - "3": 0.18900510200001008, - "4": 0.1869027909997385 + "0": 1.5832591270009289, + "1": 0.2820007230002375, + "2": 0.28537204100030067, + "3": 0.2851871380007651, + "4": 0.2826830610010802 }, "500": { - "0": 3.754979469999853, - "1": 0.23399434600014501, - "2": 0.2340401080000447, - "3": 0.23858422699959192, - "4": 0.2348375669998859 + "0": 1.7189907629999652, + "1": 0.3274141089987097, + "2": 0.3350357380004425, + "3": 0.3309756380003819, + "4": 0.3338158980004664 }, "1000": { - "0": 3.2994891039998038, - "1": 0.31192064400011077, - "2": 0.30946771800017814, - "3": 0.3112291079996794, - "4": 0.3077905899999678 + "0": 1.8001215259992023, + "1": 0.3800924019997183, + "2": 0.3808283049984311, + "3": 0.3794397999990906, + "4": 0.3793001079993701 } }, "SteinThinning": { "25": { - "0": 32.84552058700001, - "1": 23.36532163999982, - "2": 23.79629538800009, - "3": 24.90528073400037, - "4": 24.946374798999386 + "0": 4.062421253998764, + "1": 3.7067944970003737, + "2": 3.592493141999512, + "3": 3.622419968000031, + "4": 3.645126594999965 }, "50": { - "0": 21.58282834199963, - "1": 21.647645435999948, - "2": 22.063991701000305, - "3": 20.429498262999914, - "4": 20.436032122000142 + "0": 3.842947594999714, + "1": 3.841391066000142, + "2": 3.75092935199973, + "3": 3.7667867290001595, + "4": 3.7816334729996015 }, "100": { - "0": 24.78626846000043, - "1": 20.212115823000204, - "2": 19.087178812000275, - "3": 21.241850555000383, - "4": 21.746380099999442 + "0": 4.318612201999713, + "1": 3.9970832929993776, + "2": 3.9448298860006616, + "3": 3.9489083890002803, + "4": 3.923231038999802 }, "500": { - "0": 28.629314329999943, - "1": 18.76713966499983, - "2": 20.139166873999784, - "3": 18.26489885399951, - "4": 18.16690275899964 + "0": 5.439173460999882, + "1": 5.453762462000668, + "2": 5.412860956999793, + "3": 5.417045615000461, + "4": 5.401182251000137 }, "1000": { - "0": 24.427439505000166, - "1": 19.43167951400028, - "2": 19.28958952099947, - "3": 19.266749262999838, - "4": 19.204127647999485 + "0": 7.2683715949988255, + "1": 9.998894977999953, + "2": 6.865873652999653, + "3": 6.937165109000489, + "4": 6.851976205998653 + } + }, + "KernelThinning": { + "25": { + "0": 57.40046906799944, + "1": 0.4691025020001689, + "2": 0.4655119300005026, + "3": 0.4667137439992075, + "4": 0.4651397059988085 + }, + "50": { + "0": 31.640804945000127, + "1": 0.4824182379998092, + "2": 0.4758363860000827, + "3": 0.4767543710004247, + "4": 0.4776206539991108 + }, + "100": { + "0": 19.31581526300033, + "1": 0.4937812940006552, + "2": 0.4880024880003475, + "3": 0.49104342400096357, + "4": 0.48902785900099843 + }, + "500": { + "0": 6.845564752999053, + "1": 0.47625808600059827, + "2": 0.47526046200073324, + "3": 0.47945132499989995, + "4": 0.4767048789999535 + }, + "1000": { + "0": 6.7132718520006165, + "1": 0.5458353630001511, + "2": 0.5415765449997707, + "3": 0.5430132919991593, + "4": 0.5429429139985587 } } } diff --git a/benchmark/pounce_benchmark.py b/benchmark/pounce_benchmark.py index ce01ea49..41f119a7 100644 --- a/benchmark/pounce_benchmark.py +++ b/benchmark/pounce_benchmark.py @@ -30,7 +30,7 @@ import umap from jax import random -from benchmark.mnist_benchmark import get_solver_name, initialise_solvers +from coreax.benchmark_util import get_solver_name, initialise_solvers from coreax.data import Data from coreax.solvers import MapReduce diff --git a/coreax/benchmark_util.py b/coreax/benchmark_util.py new file mode 100644 index 00000000..e30e036c --- /dev/null +++ b/coreax/benchmark_util.py @@ -0,0 +1,202 @@ +# © Crown Copyright GCHQ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Core functionality for setting up and managing coreset solvers. + +This module provides functions to initialise solvers Kernel Thinning, Kernel Herding, +Stein Thinning, Random Sampling, and Randomised Cholesky. It also defines helper +functions for computing solver parameters and retrieving solver names. +""" + +from collections.abc import Callable + +import jax.numpy as jnp +import numpy as np +from jaxtyping import Array, Float + +from coreax import Data +from coreax.kernels import SquaredExponentialKernel, SteinKernel, median_heuristic +from coreax.score_matching import KernelDensityMatching +from coreax.solvers import ( + KernelHerding, + KernelThinning, + MapReduce, + RandomSample, + RPCholesky, + Solver, + SteinThinning, +) +from coreax.util import KeyArrayLike + + +def calculate_delta(n: int) -> Float[Array, "1"]: + r""" + Calculate the delta parameter for kernel thinning. + + This function evaluates the following cases: + + 1. If :math:`\\log n` is positive: + - Further evaluates :math:`\\log (\\log n)`. + * If this is also positive, returns :math:`\frac{1}{n \\log (\\log n)}`. + * Otherwise, returns :math:`\frac{1}{n \\log n}`. + 2. If :math:`\\log n` is negative: + - Returns :math:`\frac{1}{n}`. + + The recommended value is :math:`\frac{1}{n \\log (\\log n)}`, but for small + values of :math:`n`, this may be negative or even undefined. Therefore, + alternative values are used in such cases. + + :param n: The size of the dataset we wish to reduce. + :return: The calculated delta value based on the described conditions. + """ + log_n = jnp.log(n) + if log_n > 0: + log_log_n = jnp.log(log_n) + if log_log_n > 0: + return 1 / (n * log_log_n) + return 1 / (n * log_n) + return jnp.array(1 / n) + + +def initialise_solvers( + train_data_umap: Data, key: KeyArrayLike +) -> list[Callable[[int], Solver]]: + """ + Initialise and return a list of solvers for various coreset algorithms. + + Set up solvers for Kernel Herding, Stein Thinning, Random Sampling, and Randomised + Cholesky methods. Each solver has different parameter requirements. Some solvers + can utilise MapReduce, while others cannot,and some require specific kernels. + This setup allows them to be called by passing only the coreset size, + enabling easy integration in a loop for benchmarking. + + :param train_data_umap: The UMAP-transformed training data used for + length scale estimation for ``SquareExponentialKernel``. + :param key: The random key for initialising random solvers. + :return: A list of solvers functions for different coreset algorithms. + """ + # Set up kernel using median heuristic + num_data_points = len(train_data_umap) + num_samples_length_scale = min(num_data_points, 300) + random_seed = 45 + generator = np.random.default_rng(random_seed) + idx = generator.choice(num_data_points, num_samples_length_scale, replace=False) + length_scale = median_heuristic(jnp.asarray(train_data_umap[idx])) + kernel = SquaredExponentialKernel(length_scale=length_scale) + sqrt_kernel = kernel.get_sqrt_kernel(16) + + def _get_thinning_solver(_size: int) -> MapReduce: + """ + Set up KernelThinning to use ``MapReduce``. + + Create a KernelThinning solver with the specified size and return + it along with a MapReduce object for reducing a large dataset like + MNIST dataset. + + :param _size: The size of the coreset to be generated. + :return: MapReduce solver with KernelThinning as the base solver. + """ + thinning_solver = KernelThinning( + coreset_size=_size, + kernel=kernel, + random_key=key, + delta=calculate_delta(num_data_points).item(), + sqrt_kernel=sqrt_kernel, + ) + + return MapReduce(thinning_solver, leaf_size=3 * _size) + + def _get_herding_solver(_size: int) -> MapReduce: + """ + Set up KernelHerding to use ``MapReduce``. + + Create a KernelHerding solver with the specified size and return + it along with a MapReduce object for reducing a large dataset like + MNIST dataset. + + :param _size: The size of the coreset to be generated. + :return: MapReduce solver with KernelHerding as the base solver. + """ + herding_solver = KernelHerding(_size, kernel) + return MapReduce(herding_solver, leaf_size=3 * _size) + + def _get_stein_solver(_size: int) -> MapReduce: + """ + Set up Stein Thinning to use ``MapReduce``. + + Create a SteinThinning solver with the specified coreset size, + using ``KernelDensityMatching`` score function for matching on + a subset of the dataset. + + :param _size: The size of the coreset to be generated. + :return: MapReduce solver with SteinThinning as the base solver. + """ + # Generate small dataset for ScoreMatching for Stein Kernel + + score_function = KernelDensityMatching(length_scale=length_scale.item()).match( + train_data_umap[idx] + ) + stein_kernel = SteinKernel(kernel, score_function) + stein_solver = SteinThinning( + coreset_size=_size, kernel=stein_kernel, regularise=False + ) + return MapReduce(stein_solver, leaf_size=3 * _size) + + def _get_random_solver(_size: int) -> RandomSample: + """ + Set up Random Sampling to generate a coreset. + + :param _size: The size of the coreset to be generated. + :return: A RandomSample solver. + """ + random_solver = RandomSample(_size, key) + return random_solver + + def _get_rp_solver(_size: int) -> RPCholesky: + """ + Set up Randomised Cholesky solver. + + :param _size: The size of the coreset to be generated. + :return: An RPCholesky solver. + """ + rp_solver = RPCholesky(coreset_size=_size, kernel=kernel, random_key=key) + return rp_solver + + return [ + _get_random_solver, + _get_rp_solver, + _get_herding_solver, + _get_stein_solver, + _get_thinning_solver, + ] + + +def get_solver_name(solver: Callable[[int], Solver]) -> str: + """ + Get the name of the solver. + + This function extracts and returns the name of the solver class. + If ``_solver`` is an instance of :class:`~coreax.solvers.MapReduce`, it retrieves + the name of the :class:`~coreax.solvers.MapReduce.base_solver` class instead. + + :param solver: An instance of a solver, such as `MapReduce` or `RandomSample`. + :return: The name of the solver class. + """ + # Evaluate solver function to get an instance to interrogate + # Don't just inspect type annotations, as they may be incorrect - not robust + solver_instance = solver(1) + if isinstance(solver_instance, MapReduce): + return type(solver_instance.base_solver).__name__ + return type(solver_instance).__name__ diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index a8838a51..b7433a90 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -882,8 +882,9 @@ class KernelThinning(CoresubsetSolver[_Data, None], ExplicitSizeSolver): :param random_key: Key for random number generation, enabling reproducibility of probabilistic components in the algorithm. :param delta: A float between 0 and 1 used to compute the swapping probability - during the splitting process. A recommended value is :math:`1 / \log(\log(n))`, - where :math:`n` is the length of the original dataset. + during the splitting process. A recommended value is + :math:`\frac{1}{n \log (\log n)}`, where :math:`n` is the length of the original + dataset. :param sqrt_kernel: A `~coreax.kernels.ScalarValuedKernel` instance representing the square root kernel used for splitting the original dataset. """ @@ -1099,11 +1100,12 @@ def probabilistic_swap( """ key1, key2 = jax.random.split(random_key) - prob = jax.random.uniform(key1) + swap_probability = 1 / 2 * (1 - alpha / a) + should_swap = jax.random.uniform(key1) <= swap_probability return lax.cond( - prob > 1 / 2 * (1 - alpha / a), - lambda _: (2 * i, 2 * i + 1), # first case: val1 = x1, val2 = x2 - lambda _: (2 * i + 1, 2 * i), # second case: val1 = x2, val2 = x1 + should_swap, + lambda _: (2 * i + 1, 2 * i), # do swap: val1 = x2, val2 = x1 + lambda _: (2 * i, 2 * i + 1), # don't swap: val1 = x1, val2 = x2 None, ), key2 diff --git a/documentation/source/benchmark.rst b/documentation/source/benchmark.rst index 06a051f6..bf7303ae 100644 --- a/documentation/source/benchmark.rst +++ b/documentation/source/benchmark.rst @@ -3,9 +3,10 @@ Benchmarking Coreset Algorithms In this benchmark, we assess the performance of four different coreset algorithms: :class:`~coreax.solvers.KernelHerding`, :class:`~coreax.solvers.SteinThinning`, -:class:`~coreax.solvers.RandomSample`, and :class:`~coreax.solvers.RPCholesky`. -Each of these algorithms is evaluated across four different tests, providing a -comparison of their performance and applicability to various datasets. +:class:`~coreax.solvers.RandomSample`, :class:`~coreax.solvers.RPCholesky` and +:class:`~coreax.solvers.KernelThinning`. Each of these algorithms is evaluated across +four different tests, providing a comparison of their performance and applicability to +various datasets. Test 1: Benchmarking Coreset Algorithms on the MNIST Dataset ------------------------------------------------------------ @@ -34,7 +35,7 @@ these steps: measured on the test set of 10,000 images. 6. **Evaluation**: Due to randomness in the coreset algorithms and training process, - the experiment is repeated 5 times with different random seeds. The benchmark is run + the experiment is repeated 4 times with different random seeds. The benchmark is run on an **Amazon Web Services EC2 g4dn.12xlarge instance** with 4 NVIDIA T4 Tensor Core GPUs, 48 vCPUs, and 192 GiB memory. @@ -84,7 +85,7 @@ The tables below show the performance metrics (Unweighted MMD, Unweighted KSD, Weighted MMD, Weighted KSD, and Time) for each coreset algorithm and each coreset size. For each metric and coreset size, the best performance score is highlighted in bold. -.. list-table:: Coreset Size 10 (Original Sample Size 1,000) +.. list-table:: Coreset Size 25 (Original Sample Size 1,000) :header-rows: 1 :widths: 20 15 15 15 15 15 @@ -95,29 +96,35 @@ For each metric and coreset size, the best performance score is highlighted in b - Weighted_KSD - Time * - KernelHerding - - **0.071504** - - 0.087505 - - 0.037931 - - 0.082903 - - 5.884511 + - 0.026319 + - 0.071420 + - 0.008461 + - 0.072526 + - 1.836664 * - RandomSample - - 0.275138 - - 0.106468 - - 0.080327 - - **0.082597** - - **2.705248** + - 0.105940 + - 0.081013 + - 0.038174 + - *0.077431* + - *1.281091* * - RPCholesky - - 0.182342 - - 0.079254 - - **0.032423** - - 0.085621 - - 3.177700 + - 0.121869 + - *0.059722* + - *0.003283* + - 0.072288 + - 1.576841 * - SteinThinning - - 0.186064 - - **0.078773** - - 0.087347 - - 0.085194 - - 4.450125 + - 0.161923 + - 0.077394 + - 0.030987 + - 0.074365 + - 1.821020 + * - KernelThinning + - *0.014111* + - 0.072134 + - 0.006634 + - 0.072531 + - 9.144707 .. list-table:: Coreset Size 50 (Original Sample Size 1,000) :header-rows: 1 @@ -130,29 +137,35 @@ For each metric and coreset size, the best performance score is highlighted in b - Weighted_KSD - Time * - KernelHerding - - **0.016602** - - 0.080800 - - 0.003821 - - **0.079875** - - 5.309067 + - 0.012574 + - 0.072600 + - 0.003843 + - *0.072351* + - 1.863356 * - RandomSample - - 0.083658 - - 0.084844 - - 0.005009 - - 0.079948 - - **2.636160** + - 0.083379 + - 0.079031 + - 0.008653 + - 0.072867 + - *1.329118* * - RPCholesky - - 0.133182 - - **0.061976** - - **0.001859** - - 0.079935 - - 3.201798 + - 0.154799 + - *0.056437* + - *0.001347* + - 0.072359 + - 1.564009 * - SteinThinning - - 0.079028 - - 0.074763 - - 0.009652 - - 0.080119 - - 3.735810 + - 0.122605 + - 0.079683 + - 0.012048 + - 0.072424 + - 1.849748 + * - KernelThinning + - *0.005397* + - 0.072051 + - 0.002191 + - 0.072453 + - 5.524234 .. list-table:: Coreset Size 100 (Original Sample Size 1,000) :header-rows: 1 @@ -165,29 +178,35 @@ For each metric and coreset size, the best performance score is highlighted in b - Weighted_KSD - Time * - KernelHerding - - **0.007747** - - 0.080280 - - 0.001582 - - 0.080024 - - 5.425807 + - 0.007651 + - *0.071999* + - 0.001814 + - 0.072364 + - 2.185324 * - RandomSample - - 0.032532 - - 0.077081 - - 0.001638 - - 0.080073 - - **3.009871** + - 0.052402 + - 0.077454 + - 0.001630 + - 0.072480 + - *1.359826* * - RPCholesky - - 0.069909 - - **0.072023** - - **0.000977** - - 0.079995 - - 3.497632 + - 0.087236 + - 0.063822 + - *0.000910* + - 0.072433 + - 1.721290 * - SteinThinning - - 0.118452 - - 0.081853 - - 0.002652 - - **0.079836** - - 3.766622 + - 0.128295 + - 0.082733 + - 0.006041 + - *0.072182* + - 1.893099 + * - KernelThinning + - *0.002591* + - 0.072293 + - 0.001207 + - 0.072394 + - 3.519274 .. list-table:: Coreset Size 200 (Original Sample Size 1,000) :header-rows: 1 @@ -200,29 +219,35 @@ For each metric and coreset size, the best performance score is highlighted in b - Weighted_KSD - Time * - KernelHerding - - **0.003937** - - 0.079932 - - 0.001064 - - 0.080012 - - 5.786333 + - 0.004310 + - 0.072341 + - 0.000777 + - 0.072422 + - 1.837929 * - RandomSample - - 0.048701 - - 0.077522 - - 0.000913 - - 0.080059 - - **2.964436** + - 0.036624 + - 0.072870 + - *0.000584* + - 0.072441 + - *1.367439* * - RPCholesky - - 0.052085 - - **0.075708** - - **0.000772** - - 0.080050 - - 3.722556 + - 0.041140 + - *0.068655* + - 0.000751 + - 0.072430 + - 2.106838 * - SteinThinning - - 0.129073 - - 0.084883 - - 0.002329 - - **0.079847** - - 4.004353 + - 0.148525 + - 0.087512 + - 0.003799 + - *0.072164* + - 1.910560 + * - KernelThinning + - *0.001330* + - 0.072348 + - 0.001014 + - 0.072428 + - 2.565189 **Visualisation**: The results in this table can be visualised as follows: @@ -311,7 +336,7 @@ Conclusion In this benchmark, we evaluated four coreset algorithms across various datasets and tasks, including image classification, synthetic datasets, and pixel/frame data -processing. Based on the results, **kernel herding** emerges as the preferred choice +processing. Based on the results, **kernel thinning** emerges as the preferred choice for most tasks due to its consistent performance. For larger datasets, combining kernel herding with distributed frameworks like **map reduce** is recommended to ensure scalability and efficiency. diff --git a/examples/benchmarking_images/KernelThinning_coreset.gif b/examples/benchmarking_images/KernelThinning_coreset.gif new file mode 100644 index 00000000..eca296c5 --- /dev/null +++ b/examples/benchmarking_images/KernelThinning_coreset.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:093a0f1c5309307cd35352d388f153fe0e27718da4eaf11d394c4dc754eae10a +size 336313 diff --git a/examples/benchmarking_images/blobs_benchmark_results.png b/examples/benchmarking_images/blobs_benchmark_results.png index 52bfb64e..8878c186 100644 --- a/examples/benchmarking_images/blobs_benchmark_results.png +++ b/examples/benchmarking_images/blobs_benchmark_results.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:61d4b5bc3f4e2d1244b84b594b0dc6214f08228240d76bc764d32231353032e0 -size 209002 +oid sha256:d144b777b2cddc86191bfdef932c8bf55f9e08f22e165650c6941a4733174b59 +size 152391 diff --git a/examples/benchmarking_images/david_benchmark_results.png b/examples/benchmarking_images/david_benchmark_results.png index 07e2bb44..44b265ec 100644 --- a/examples/benchmarking_images/david_benchmark_results.png +++ b/examples/benchmarking_images/david_benchmark_results.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6f77bbf822a5d20d661e62c9b8c7626aa44823a674035e418c1e7899374d448f -size 626273 +oid sha256:44b968f02e3ab99502592227f452baf8c870dd73039f4b45df68ae9e9c475417 +size 757481 diff --git a/examples/benchmarking_images/mnist_benchmark_accuracy.png b/examples/benchmarking_images/mnist_benchmark_accuracy.png index f2366e97..ca35c12c 100644 --- a/examples/benchmarking_images/mnist_benchmark_accuracy.png +++ b/examples/benchmarking_images/mnist_benchmark_accuracy.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4bcc4aee35615600d8d04c8300571cb53dceab23e77b3a6bead5566ae47ed455 -size 67261 +oid sha256:2c705227205915da97a4c80cabee1048caba54995928c0b623271724a00d41c7 +size 65234 diff --git a/examples/benchmarking_images/mnist_benchmark_time_taken.png b/examples/benchmarking_images/mnist_benchmark_time_taken.png index cef48e39..1df5c846 100644 --- a/examples/benchmarking_images/mnist_benchmark_time_taken.png +++ b/examples/benchmarking_images/mnist_benchmark_time_taken.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c6bf579a51b8c2a934e286d3293bafc5d620119e14980cb15eb3d93ed4ce9e60 -size 61307 +oid sha256:5dfe92ce3124097afb476f2827620c05109f4c8352a0a8e96f73c79522bbc4bd +size 52228 diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index aa8a1660..c617f2d9 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -32,6 +32,15 @@ convert_to_jax_arrays, train_and_evaluate, ) +from coreax import Data +from coreax.benchmark_util import calculate_delta, get_solver_name, initialise_solvers +from coreax.kernel import SquaredExponentialKernel +from coreax.solvers import ( + KernelHerding, + MapReduce, + RandomSample, + RPCholesky, +) class MockDataset(Dataset): @@ -121,5 +130,54 @@ def test_train_and_evaluate() -> None: assert 0.0 <= result["final_test_accuracy"] <= 1.0 +def test_initialise_solvers() -> None: + """ + Test the :func:`initialise_solvers`. + + Verify that the returned list contains callable functions that produce + valid solver instances. + """ + # Create a mock dataset (UMAP-transformed) with arbitrary values + mock_data = Data(jnp.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]])) + key = random.PRNGKey(42) + + solvers = initialise_solvers(mock_data, key) + for solver in solvers: + solver_instance = solver(1) # Instantiate with a coreset size of 1 + assert isinstance(solver_instance, (MapReduce, RandomSample, RPCholesky)), ( + f"Unexpected solver type: {type(solver_instance)}" + ) + + +def test_get_solver_name(): + """ + Test `get_solver_name` function to ensure it returns correct solver names. + """ + # Create a KernelHerding solver + herding_solver = KernelHerding(coreset_size=5, kernel=SquaredExponentialKernel()) + + # Wrap it in MapReduce + map_reduce_solver = MapReduce(base_solver=herding_solver, leaf_size=15) + + assert get_solver_name(lambda _: herding_solver) == "KernelHerding", ( + "Expected 'KernelHerding' but got something else." + ) + + assert get_solver_name(lambda _: map_reduce_solver) == "KernelHerding", ( + "Expected 'KernelHerding' from MapReduce solver but got something else." + ) + + +@pytest.mark.parametrize("n", [10, 100, 1000]) +def test_calculate_delta(n): + """ + Test the `calculate_delta` function. + + Ensure that the function produces a positive delta value for different values of n. + """ + delta = calculate_delta(n) + assert delta > 0, f"Delta should be positive but got {delta} for n={n}" + + if __name__ == "__main__": pytest.main() diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index ed33906f..c3466010 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -2193,8 +2193,10 @@ def test_kt_half_analytic(self) -> None: that determines whether :math:`x` goes to S1 and :math:`y` to S2, or vice versa. In either case, both :math:`x` and :math:`y` are added to S. - If swap probability is less than 0.5, we add the x and y to S1 and S2 - respectively, otherwise we swap x and y and then add x to S1 and y to S2. + We swap x and y with probability equal to swap probability and then add x and y + to S1 and S2 respectively. For the purpose of analytic test, if swap probability + is less than 0.5, we do not swap and add x and y to S1 and S2 respectively, + otherwise we swap x and y and then add x to S1 and y to S2. The process is as follows: @@ -2249,14 +2251,21 @@ def test_kt_half_analytic(self) -> None: - Inputs: S=[], S1=[], S2=[], sigma=0, delta=1/8. - Compute b: - - b(0.7, 0.55) = 0.2109442800283432. - - Compute alpha: alpha = 0 (as S and S1 are empty). + - .. math:: + b(0.7, 0.55) = \sqrt{k(0.7, 0.7) + k(0.55, 0.55) - 2k(0.7, 0.55)} + = 0.2109442800283432. + - Compute alpha: + - Since S and S1 are empty, alpha = 0. - Compute a: - - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.04449748992919922. + - .. math:: + a = \max(b \cdot \sigma \cdot \sqrt{2 \ln(2 / \delta)}, b^2) + = 0.04449748992919922. - Update sigma: - new_sigma = 0.2109442800283432. - Compute probability: - - p = 0.5 * (1 - alpha / a) = 0.5. + - .. math:: + p = 0.5 \cdot \left(1 - \frac{\alpha}{a}\right) + = 0.5. - Assign: - Since p >= 0.5, assign x=0.7 to S2, y=0.55 to S1, and add both to S. - S1 = [0.55], S2 = [0.7], S = [0.7, 0.55]. @@ -2267,11 +2276,13 @@ def test_kt_half_analytic(self) -> None: - Inputs: S=[0.7, 0.55], S1=[0.55], S2=[0.7], sigma=0.2109442800283432. - Compute b: - - b(0.6, 0.65) = 0.07066679745912552. + - .. math:: + b(0.6, 0.65) = \sqrt{k(0.6, 0.6) + k(0.65, 0.65) - 2k(0.6, 0.65)} + = 0.07066679745912552. - Compute alpha: - alpha = -0.014906525611877441. - Compute a: - - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.035102729200688874. + - a = 0.035102729200688874. - Update sigma: - new_sigma = 0.2109442800283432. - Compute probability: @@ -2287,11 +2298,13 @@ def test_kt_half_analytic(self) -> None: - Inputs: S=[0.7, 0.55, 0.6, 0.65], S1=[0.55, 0.65], S2=[0.7, 0.6], sigma=0.2109442800283432. - Compute b: - - b(0.9, 0.1) = 0.9723246097564697. + - .. math:: + b(0.9, 0.1) = \sqrt{k(0.9, 0.9) + k(0.1, 0.1) - 2k(0.9, 0.1)} + = 0.9723246097564697. - Compute alpha: - alpha = 0.12977957725524902. - Compute a: - - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.9454151391983032. + - a = 0.9454151391983032. - Update sigma: - new_sigma = 0.9723246097564697. - Compute probability: @@ -2308,11 +2321,13 @@ def test_kt_half_analytic(self) -> None: - Inputs: S=[0.7, 0.55, 0.6, 0.65, 0.9, 0.1], S1=[0.55, 0.65, 0.9], S2=[0.7, 0.6, 0.1], sigma=0.9723246097564697. - Compute b: - - b(0.11, 0.12) = 0.014143308624625206. + - .. math:: + b(0.11, 0.12) = \sqrt{k(0.11, 0.11) + k(0.12, 0.12) - 2k(0.11, 0.12)} + = 0.014143308624625206. - Compute alpha: - alpha = 0.008038222789764404. - Compute a: - - a = max(b * sigma * sqrt(2 * log(2/delta)), b^2) = 0.03238321865838572. + - a = 0.03238321865838572. - Update sigma: - new_sigma = 0.9723246097564697. - Compute probability: @@ -2320,12 +2335,11 @@ def test_kt_half_analytic(self) -> None: - Assign: - Since p < 0.5, assign x=0.11 to S1 and y=0.12 to S2, and add both to S. - S1 = [0.55, 0.65, 0.9, 0.11], S2 = [0.7, 0.6, 0.1, 0.12], - S = [0.7, 0.55, 0.6, 0.65, 0.9, 0.1, 0.11, 0.12]. + S = [0.7, 0.55, 0.6, 0.65, 0.9, 0.1, 0.11, 0.12]. --- **Final result:** - S1 = [0.55, 0.65, 0.9, 0.11], S2 = [0.7, 0.6, 0.1, 0.12]. """ # noqa: E501 # pylint: enable=line-too-long