Skip to content

Commit

Permalink
chore: initial pyright fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
qh681248 committed Jan 30, 2025
1 parent 4b4aed5 commit 15fabb6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion benchmark/blobs_benchmark_visualiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def plot_benchmarking_results(data):

# Adjust layout to avoid overlap
plt.subplots_adjust(hspace=15.0, wspace=1.0)
plt.tight_layout(pad=3.0, rect=[0, 0, 1, 0.96])
plt.tight_layout(pad=3.0, rect=(0.0, 0.0, 1.0, 0.96))
plt.show()


Expand Down
20 changes: 11 additions & 9 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
Solver,
SteinThinning,
)
from coreax.util import KeyArrayLike


# Convert PyTorch dataset to JAX arrays
Expand All @@ -77,7 +78,8 @@ def convert_to_jax_arrays(pytorch_data: Dataset) -> tuple[jnp.ndarray, jnp.ndarr
:return: Tuple of JAX arrays (data, targets).
"""
# Load all data in one batch
data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data))
# pyright is wrong here, a Dataset object does have __len__ method
data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data)) # type: ignore
# Grab the first batch, which is all data
_data, _targets = next(iter(data_loader))
# Convert to NumPy first, then JAX array
Expand Down Expand Up @@ -149,8 +151,8 @@ def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
class TrainState(train_state.TrainState):
"""Custom train state with batch statistics and dropout RNG."""

batch_stats: Optional[dict[str, jnp.ndarray]] = None
dropout_rng: Optional[jnp.ndarray] = None
batch_stats: Optional[dict[str, jnp.ndarray]]
dropout_rng: KeyArrayLike


class Metrics(NamedTuple):
Expand All @@ -161,7 +163,7 @@ class Metrics(NamedTuple):


def create_train_state(
rng: jnp.ndarray, _model: nn.Module, learning_rate: float, weight_decay: float
rng: KeyArrayLike, _model: nn.Module, learning_rate: float, weight_decay: float
) -> TrainState:
"""
Create and initialise the train state.
Expand Down Expand Up @@ -323,7 +325,7 @@ def train_and_evaluate(
train_set: DataSet,
test_set: DataSet,
_model: nn.Module,
rng: jnp.ndarray,
rng: KeyArrayLike,
config: dict[str, Any],
) -> dict[str, float]:
"""
Expand Down Expand Up @@ -427,7 +429,7 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr


def initialise_solvers(
train_data_umap: Data, key: jax.random.PRNGKey
train_data_umap: Data, key: KeyArrayLike
) -> list[Callable[[int], Solver]]:
"""
Initialise and return a list of solvers for various coreset algorithms.
Expand All @@ -449,7 +451,7 @@ def initialise_solvers(
random_seed = 45
generator = np.random.default_rng(random_seed)
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
length_scale = median_heuristic(train_data_umap[idx])
length_scale = median_heuristic(jnp.asarray(train_data_umap[idx]))
kernel = SquaredExponentialKernel(length_scale=length_scale)

def _get_herding_solver(_size: int) -> MapReduce:
Expand Down Expand Up @@ -479,7 +481,7 @@ def _get_stein_solver(_size: int) -> MapReduce:
"""
# Generate small dataset for ScoreMatching for Stein Kernel

score_function = KernelDensityMatching(length_scale=length_scale).match(
score_function = KernelDensityMatching(length_scale=length_scale.item()).match(
train_data_umap[idx]
)
stein_kernel = SteinKernel(kernel, score_function)
Expand Down Expand Up @@ -513,7 +515,7 @@ def _get_rp_solver(_size: int) -> RPCholesky:

def train_model(
data_bundle: dict[str, jnp.ndarray],
key: jax.random.PRNGKey,
key: KeyArrayLike,
config: dict[str, Union[int, float]],
) -> dict[str, float]:
"""
Expand Down
5 changes: 4 additions & 1 deletion benchmark/mnist_benchmark_visualiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def plot_performance(
if log_scale:
plt.yscale("log")
plt.title(title)
plt.xticks(index + bar_width * (n_algorithms - 1) / 2, coreset_sizes)
plt.xticks(
index + bar_width * (n_algorithms - 1) / 2,
[str(size) for size in coreset_sizes],
)
plt.legend()
plt.grid(True, linestyle="--", alpha=0.7)
plt.tight_layout()
Expand Down
6 changes: 3 additions & 3 deletions benchmark/pounce_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def benchmark_coreset_algorithms(
reshaped_data = raw_data.reshape(raw_data.shape[0], -1)

umap_model = umap.UMAP(densmap=True, n_components=25)
umap_data = umap_model.fit_transform(reshaped_data)
umap_data = jnp.asarray(umap_model.fit_transform(reshaped_data))

solver_factories = initialise_solvers(umap_data, random.PRNGKey(45))
solver_factories = initialise_solvers(Data(umap_data), random.PRNGKey(45))
for solver_creator in solver_factories:
solver = solver_creator(coreset_size)

Expand All @@ -83,7 +83,7 @@ def benchmark_coreset_algorithms(
# Extract corresponding frames from original data and save GIF
coreset_frames = raw_data[selected_indices]
output_gif_path = out_dir / f"{solver_name}_coreset.gif"
imageio.mimsave(output_gif_path, coreset_frames, loop=0)
imageio.mimsave(output_gif_path, list(coreset_frames), loop=0)
print(f"Saved {solver_name} coreset GIF to {output_gif_path}")
print(f"time taken: {solver_name:<25} {duration:<30.4f}")

Expand Down

0 comments on commit 15fabb6

Please sign in to comment.