Skip to content

Commit

Permalink
Fix mypy issues for Python 3.11
Browse files Browse the repository at this point in the history
  • Loading branch information
nunofachada committed Jun 20, 2023
1 parent a3c7974 commit e763366
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
4 changes: 3 additions & 1 deletion pyclugen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,9 @@ def pt_from_proj_fn(projs, lat_disp, length, clu_dir, clu_ctr, rng=rng):
cumsum_points = concatenate((asarray([0]), cumsum(cluster_sizes)))

# Pre-allocate data structures for holding cluster info and points
point_clusters = empty(num_points, dtype=int32) # Cluster indices of each point
point_clusters: NDArray = empty(
num_points, dtype=int32
) # Cluster indices of each point
point_projections = empty((num_points, num_dims)) # Point projections on
# # cluster-supporting lines
points = empty((num_points, num_dims)) # Final points to be generated
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _ptoff_ones(

def _csz_equi_size(nclu: int, tpts: int, ae: bool, rng: Generator) -> NDArray:
"""Alternative cluster sizing function for testing purposes."""
cs = zeros(nclu, dtype=int)
cs: NDArray = zeros(nclu, dtype=int)
for i in range(tpts):
cs[i % nclu] += 1
return cs
Expand Down
10 changes: 6 additions & 4 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,12 +986,14 @@ def test_clumerge_fields(
with warnings.catch_warnings():
# Check that the function runs without warnings
warnings.simplefilter("error")
mds = clumerge(*datasets, fields=("sizes", "centers", "directions", "angles", "lengths"),
clusters_field=None,
)
mds = clumerge(
*datasets,
fields=("sizes", "centers", "directions", "angles", "lengths"),
clusters_field=None,
)

# Check that the cluster-related fields have the correct sizes
expect_shape = (tclu_i,) if ndims == 1 else (tclu_i, ndims)
expect_shape = (tclu_i,) if ndims == 1 else (tclu_i, ndims)
assert len(mds["sizes"]) == tclu_i
assert can_cast(mds["sizes"], int64)
assert mds["centers"].shape == expect_shape
Expand Down

0 comments on commit e763366

Please sign in to comment.