diff --git a/benchmark/david_benchmark.py b/benchmark/david_benchmark.py index fbbdaf62..0606480d 100644 --- a/benchmark/david_benchmark.py +++ b/benchmark/david_benchmark.py @@ -29,6 +29,7 @@ """ import os +import sys import time from pathlib import Path from typing import Optional @@ -38,12 +39,15 @@ import matplotlib.pyplot as plt import numpy as np from jax import random +from mnist_benchmark import get_solver_name, initialise_solvers -from benchmark.mnist_benchmark import get_solver_name, initialise_solvers from coreax import Data from coreax.solvers import MapReduce from examples.david_map_reduce_weighted import downsample_opencv +sys.path.append(str(Path(__file__).parent.parent)) + + MAX_8BIT = 255 @@ -51,7 +55,7 @@ def benchmark_coreset_algorithms( in_path: Path = Path("../examples/data/david_orig.png"), out_path: Optional[Path] = Path("david_benchmark_results.png"), - downsampling_factor: int = 1, + downsampling_factor: int = 6, ): """ Benchmark the performance of coreset algorithms on a downsampled greyscale image. @@ -66,7 +70,6 @@ def benchmark_coreset_algorithms( """ # Base directory of the current script base_dir = os.path.dirname(os.path.abspath(__file__)) - # Convert to absolute paths using os.path.join if not in_path.is_absolute(): in_path = Path(os.path.join(base_dir, in_path))