Skip to content

Commit

Permalink
chore: more pyright fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
qh681248 committed Jan 30, 2025
1 parent 654ac24 commit b512f38
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
13 changes: 5 additions & 8 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import json
import os
import time
from typing import TypeVar

import jax
import jax.numpy as jnp
Expand All @@ -55,10 +54,8 @@
)
from coreax.weights import MMDWeightsOptimiser

_Solver = TypeVar("_Solver", bound=Solver)


def setup_kernel(x: jnp.array, random_seed: int = 45) -> SquaredExponentialKernel:
def setup_kernel(x: jax.Array, random_seed: int = 45) -> SquaredExponentialKernel:
"""
Set up a squared exponential kernel using the median heuristic.
Expand Down Expand Up @@ -105,7 +102,7 @@ def setup_solvers(
stein_kernel: SteinKernel,
delta: float,
random_seed: int = 45,
) -> list[tuple[str, _Solver]]:
) -> list[tuple[str, Solver]]:
"""
Set up and return a list of solver configurations for reducing a dataset.
Expand Down Expand Up @@ -160,7 +157,7 @@ def setup_solvers(


def compute_solver_metrics(
solver: _Solver,
solver: Solver,
dataset: Data,
mmd_metric: MMD,
ksd_metric: KSD,
Expand Down Expand Up @@ -203,7 +200,7 @@ def compute_solver_metrics(


def compute_metrics(
solvers: list[tuple[str, _Solver]],
solvers: list[tuple[str, Solver]],
dataset: Data,
mmd_metric: MMD,
ksd_metric: KSD,
Expand Down Expand Up @@ -279,7 +276,7 @@ def main() -> None: # pylint: disable=too-many-locals
aggregated_results[size][solver_name][metric].append(value)

# Average results across seeds
final_results = {"n_samples": n_samples}
final_results: dict = {"n_samples": n_samples}
for size, solvers in aggregated_results.items():
final_results[size] = {}
for solver_name, metrics in solvers.items():
Expand Down
10 changes: 6 additions & 4 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

from coreax import Data
from coreax.benchmark_util import get_solver_name, initialise_solvers
from coreax.util import KeyArrayLike


# Convert PyTorch dataset to JAX arrays
Expand All @@ -67,7 +68,8 @@ def convert_to_jax_arrays(pytorch_data: Dataset) -> tuple[jnp.ndarray, jnp.ndarr
:return: Tuple of JAX arrays (data, targets).
"""
# Load all data in one batch
data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data))
# pyright is wrong here, a Dataset object does have __len__ method
data_loader = DataLoader(pytorch_data, batch_size=len(pytorch_data)) # type: ignore
# Grab the first batch, which is all data
_data, _targets = next(iter(data_loader))
# Convert to NumPy first, then JAX array
Expand Down Expand Up @@ -139,8 +141,8 @@ def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
class TrainState(train_state.TrainState):
"""Custom train state with batch statistics and dropout RNG."""

batch_stats: Optional[dict[str, jnp.ndarray]] = None
dropout_rng: Optional[jnp.ndarray] = None
batch_stats: Optional[dict[str, jnp.ndarray]]
dropout_rng: KeyArrayLike


class Metrics(NamedTuple):
Expand Down Expand Up @@ -418,7 +420,7 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr

def train_model(
data_bundle: dict[str, jnp.ndarray],
key: jax.random.PRNGKey,
key: KeyArrayLike,
config: dict[str, Union[int, float]],
) -> dict[str, float]:
"""
Expand Down

0 comments on commit b512f38

Please sign in to comment.