diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py index ae8fea9a..ccb51ee6 100644 --- a/benchmark/mnist_benchmark.py +++ b/benchmark/mnist_benchmark.py @@ -498,7 +498,7 @@ def _get_thinning_solver(_size: int) -> MapReduce: sqrt_kernel=sqrt_kernel, ) - return thinning_solver + return MapReduce(thinning_solver, leaf_size=15_000) def _get_herding_solver(_size: int) -> MapReduce: """ @@ -512,7 +512,7 @@ def _get_herding_solver(_size: int) -> MapReduce: :return: MapReduce solver with KernelHerding as the base solver. """ herding_solver = KernelHerding(_size, kernel) - return MapReduce(herding_solver, leaf_size=3 * _size) + return MapReduce(herding_solver, leaf_size=15_000) def _get_stein_solver(_size: int) -> MapReduce: """ @@ -534,7 +534,7 @@ def _get_stein_solver(_size: int) -> MapReduce: stein_solver = SteinThinning( coreset_size=_size, kernel=stein_kernel, regularise=False ) - return MapReduce(stein_solver, leaf_size=3 * _size) + return MapReduce(stein_solver, leaf_size=15_000) def _get_random_solver(_size: int) -> RandomSample: """