Skip to content

Commit

Permalink
format tests, remove print
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Sep 24, 2024
1 parent 034c647 commit fde215c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 19 deletions.
10 changes: 4 additions & 6 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,10 @@ def test_weighted_sampling(self) -> None:
def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomBatchGeoSampler(
ds, 1, 1, generator=torch.Generator().manual_seed(0)
)
sampler2 = RandomBatchGeoSampler(
ds, 1, 1, generator=torch.Generator().manual_seed(0)
)
generator1 = torch.Generator().manual_seed(0)
generator2 = torch.Generator().manual_seed(0)
sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=generator1)
sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=generator2)
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2
Expand Down
20 changes: 8 additions & 12 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,10 @@ def test_weighted_sampling(self) -> None:
def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler1 = RandomGeoSampler(
ds, 1, 1, generator=torch.Generator().manual_seed(0)
)
sampler2 = RandomGeoSampler(
ds, 1, 1, generator=torch.Generator().manual_seed(0)
)
generator1 = torch.Generator().manual_seed(0)
generator2 = torch.Generator().manual_seed(0)
sampler1 = RandomGeoSampler(ds, 1, 1, generator=generator1)
sampler2 = RandomGeoSampler(ds, 1, 1, generator=generator2)
sample1 = next(iter(sampler1))
sample2 = next(iter(sampler2))
assert sample1 == sample2
Expand Down Expand Up @@ -306,13 +304,11 @@ def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
sampler1 = PreChippedGeoSampler(
ds, shuffle=True, generator=torch.Generator().manual_seed(0)
)
generator1 = torch.Generator().manual_seed(0)
generator2 = torch.Generator().manual_seed(0)
sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator1)
sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator2)
sample1 = next(iter(sampler1))
sampler2 = PreChippedGeoSampler(
ds, shuffle=True, generator=torch.Generator().manual_seed(0)
)
sample2 = next(iter(sampler2))
assert sample1 == sample2

Expand Down
1 change: 0 additions & 1 deletion torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def __iter__(self) -> Iterator[BoundingBox]:
generator = partial(torch.randperm, generator=self.generator)

for idx in generator(len(self)):
print(idx)
yield BoundingBox(*self.hits[idx].bounds)

def __len__(self) -> int:
Expand Down

0 comments on commit fde215c

Please sign in to comment.