Skip to content

Commit 1659466

Browse files
committed
fix: requirements & formatting
1 parent 720d808 commit 1659466

File tree

4 files changed

+21
-33
lines changed

4 files changed

+21
-33
lines changed

Diff for: dmlcloud/util/data.py

+18-27
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
from typing import Iterable
2+
23
import numpy as np
3-
import xarray as xr
44
import torch.distributed as dist
5+
import xarray as xr
56

67

78
def shard_indices(
8-
n: int,
9-
rank: int,
10-
world_size: int,
11-
shuffle: bool=False,
12-
drop_remainder: bool=True,
13-
seed: int=0
14-
) -> list[int]:
9+
n: int, rank: int, world_size: int, shuffle: bool = False, drop_remainder: bool = True, seed: int = 0
10+
) -> list[int]:
1511
indices = np.arange(n)
1612

1713
if shuffle:
@@ -24,32 +20,27 @@ def shard_indices(
2420

2521

2622
def chunked_xr_dataset(
27-
ds: xr.Dataset | xr.DataArray,
28-
chunk_size: int,
29-
dim: str,
30-
shuffle: bool=False,
31-
drop_remainder: bool=True,
32-
seed: int=0,
33-
rank: int|None=None,
34-
world_size: int|None=None,
35-
process_group: dist.ProcessGroup|None=None,
36-
load: bool = True,
37-
) -> Iterable[xr.Dataset | xr.DataArray]:
23+
ds: xr.Dataset | xr.DataArray,
24+
chunk_size: int,
25+
dim: str,
26+
shuffle: bool = False,
27+
drop_remainder: bool = True,
28+
seed: int = 0,
29+
rank: int | None = None,
30+
world_size: int | None = None,
31+
process_group: dist.ProcessGroup | None = None,
32+
load: bool = True,
33+
) -> Iterable[xr.Dataset | xr.DataArray]:
3834
num_total_elements = len(ds[dim])
3935
num_chunks = num_total_elements // chunk_size
40-
36+
4137
if rank is None:
4238
rank = dist.get_rank(process_group)
4339
if world_size is None:
4440
world_size = dist.get_world_size(process_group)
4541

4642
chunk_indices = shard_indices(
47-
num_chunks,
48-
rank,
49-
world_size,
50-
shuffle=shuffle,
51-
drop_remainder=drop_remainder,
52-
seed=seed
43+
num_chunks, rank, world_size, shuffle=shuffle, drop_remainder=drop_remainder, seed=seed
5344
)
5445

5546
for chunk_idx in chunk_indices:
@@ -58,4 +49,4 @@ def chunked_xr_dataset(
5849
chunk = ds.isel({dim: slice(start, end)})
5950
if load:
6051
chunk.load()
61-
yield chunk
52+
yield chunk

Diff for: dmlcloud/util/distributed.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
from contextlib import contextmanager
33

4-
import numpy as np
54
import torch.distributed as dist
65

76
from .tcp import find_free_port, get_local_ips

Diff for: requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
torch
22
numpy
3+
xarray
34
progress_table>=0.1.20,<1.0.0
45
omegaconf

Diff for: test/test_data.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import sys
22

3-
import xarray as xr
43
import numpy as np
54
import pytest
6-
from dmlcloud.util.data import shard_indices, chunked_xr_dataset
5+
import xarray as xr
6+
from dmlcloud.util.data import chunked_xr_dataset, shard_indices
77
from numpy.testing import assert_array_equal
88

99

1010
class TestSharding:
11-
1211
def test_types(self):
1312
indices = shard_indices(10, 0, 2, shuffle=False, drop_remainder=False)
1413
assert isinstance(indices, list)
@@ -43,7 +42,6 @@ def test_shuffling(self):
4342

4443

4544
class TestChunking:
46-
4745
def test_basic(self):
4846
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
4947
world_size = 3
@@ -73,7 +71,6 @@ def test_basic(self):
7371
assert_array_equal(chunks_2[1]['var'], np.arange(60, 75))
7472
assert_array_equal(chunks_3[1]['var'], np.arange(75, 90))
7573

76-
7774
def test_shuffled(self):
7875
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
7976
world_size = 3

0 commit comments

Comments
 (0)