@@ -53,8 +53,8 @@ def __init__(
5353 cropped_shape : Sequence [int ] | None = None ,
5454 device : torch .device | str | None = None ,
5555 ) -> None :
56- self .merged_shape = merged_shape
57- self .cropped_shape = self .merged_shape if cropped_shape is None else cropped_shape
56+ self .merged_shape : tuple [ int , ...] = tuple ( merged_shape )
57+ self .cropped_shape : tuple [ int , ...] = tuple ( self .merged_shape if cropped_shape is None else cropped_shape )
5858 self .device = device
5959 self .is_finalized = False
6060
@@ -231,9 +231,9 @@ def __init__(
231231 dtype : np .dtype | str = "float32" ,
232232 value_dtype : np .dtype | str = "float32" ,
233233 count_dtype : np .dtype | str = "uint8" ,
234- store : zarr .storage .Store | str = "merged.zarr" ,
235- value_store : zarr .storage .Store | str | None = None ,
236- count_store : zarr .storage .Store | str | None = None ,
234+ store : zarr .storage .Store | str = "merged.zarr" , # type: ignore
235+ value_store : zarr .storage .Store | str | None = None , # type: ignore
236+ count_store : zarr .storage .Store | str | None = None , # type: ignore
237237 compressor : str | None = None ,
238238 value_compressor : str | None = None ,
239239 count_compressor : str | None = None ,
@@ -251,18 +251,18 @@ def __init__(
251251 if version_geq (get_package_version ("zarr" ), "3.0.0" ):
252252 if value_store is None :
253253 self .tmpdir = TemporaryDirectory ()
254- self .value_store = zarr .storage .LocalStore (self .tmpdir .name )
254+ self .value_store = zarr .storage .LocalStore (self .tmpdir .name ) # type: ignore
255255 else :
256- self .value_store = value_store
256+ self .value_store = value_store # type: ignore
257257 if count_store is None :
258258 self .tmpdir = TemporaryDirectory ()
259- self .count_store = zarr .storage .LocalStore (self .tmpdir .name )
259+ self .count_store = zarr .storage .LocalStore (self .tmpdir .name ) # type: ignore
260260 else :
261- self .count_store = count_store
261+ self .count_store = count_store # type: ignore
262262 else :
263263 self .tmpdir = None
264- self .value_store = zarr .storage .TempStore () if value_store is None else value_store
265- self .count_store = zarr .storage .TempStore () if count_store is None else count_store
264+ self .value_store = zarr .storage .TempStore () if value_store is None else value_store # type: ignore
265+ self .count_store = zarr .storage .TempStore () if count_store is None else count_store # type: ignore
266266 self .chunks = chunks
267267 self .compressor = compressor
268268 self .value_compressor = value_compressor
@@ -314,7 +314,7 @@ def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:
314314 map_slice = ensure_tuple_size (map_slice , values .ndim , pad_val = slice (None ), pad_from_start = True )
315315 with self .lock :
316316 self .values [map_slice ] += values .numpy ()
317- self .counts [map_slice ] += 1
317+ self .counts [map_slice ] += 1 # type: ignore[operator]
318318
319319 def finalize (self ) -> zarr .Array :
320320 """
@@ -332,7 +332,7 @@ def finalize(self) -> zarr.Array:
332332 if not self .is_finalized :
333333 # use chunks for division to fit into memory
334334 for chunk in iterate_over_chunks (self .values .chunks , self .values .cdata_shape ):
335- self .output [chunk ] = self .values [chunk ] / self .counts [chunk ]
335+ self .output [chunk ] = self .values [chunk ] / self .counts [chunk ] # type: ignore[operator]
336336 # finalize the shape
337337 self .output .resize (self .cropped_shape )
338338 # set finalize flag to protect performing in-place division again
0 commit comments