Skip to content

Commit

Permalink
More batching and improved furious accuracy/performance (#1253)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Nov 8, 2024
1 parent 17a0a96 commit e34c83f
Show file tree
Hide file tree
Showing 9 changed files with 408 additions and 156 deletions.
28 changes: 18 additions & 10 deletions examples/sam2_amg_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,29 @@ xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rl

Experiments run on H100 and with batch size 1

| mode | mIoU | mask count mismatch | avg. ms per request | batch size | points per batch |
| --- | --- | ------------------- | ------------------- | ---------- | ---------------- |
| baseline | 1.0 | 0 | 786 | 1 | 64 |
| baseline | N/A | N/A | N/A | 32 | 1024 |
| ao | 1.0 | 0 | 738 | 1 | 64 |
| ao | 0.9999994993636996 | 0 | 564 | 32 | 1024 |
| fast | 0.95 | 190 | 563 | 1 | 64 |
| fast | 0.9527849197435295 | 191 | 460 | 32 | 1024 |
| furious | 0 | 1000 | 204 | 1 | 64 |
| furious | 0 | 1000 | 210 | 32 | 1024 |
| mode | mIoU | mask count mismatch | avg. ms per request | max. memory (MiB (%)) | batch size | points per batch |
| -------------- | ----------------- | ------------------- | ------------------- | --------------------- | ---------- | ---------------- |
| baseline | 1.0 | 0 | 863 | 4013MiB (4%) | 1 | 64 |
| ao | 1.0 | 0 | 840 | 4350MiB (4%) | 1 | 64 |
| fast | 0.9897813200950623 | 191 | 661 | 3916MiB (4%) | 1 | 64 |
| fast | 0.9897371530532837 | 192 | 388 | 50787MiB (52%) | 16 | 1024 |
| fast + furious | 0.974319338798523 | 209 | 461 | 3453MiB (3%) | 1 | 64 |
| fast + furious | 0.9702069759368896 | 196 | 195 | 48298MiB (49%) | 16 | 1024 |

mask count mismatch counts the number of requests where the number of masks differ from the baseline.
For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19.
We exclude these examples from the mIoU calculation.

The 'ao' mode is a copy of the baseline with modifications to make the code compile-able and improve the performance of fast.

### 0. Download checkpoints and install requirements

```
pip install -r requirements.txt
```

Download `sam2.1_hiera_large.pt` from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints and put it into `~/checkpoints/sam2`

### 1. Create a random subset of 1000 images
```
find sav_val -type f > sav_val_image_paths
Expand Down
43 changes: 37 additions & 6 deletions examples/sam2_amg_server/compare_rle_lists.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import fire
import torch
import json
from sam2.utils.amg import rle_to_mask
from torchao._models.sam2.utils.amg import rle_to_mask

"""
Script to calculate mIoU given two lists of rles from upload_rle endpoint
Expand All @@ -16,6 +16,39 @@ def iou(mask1, mask2):
union = torch.logical_or(mask1, mask2)
return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))

def compare_masks(masks, ref_masks, order_by_area=False, verbose=False):
from torchao._models.sam2.utils.amg import rle_to_mask
v0_areas = []
v1_areas = []
v0_masks = []
v1_masks = []
for k0 in ref_masks:
assert k0 in masks, f"Expected {k0} to be in return data"
from torchao._models.sam2.utils.amg import area_from_rle
v0_area = area_from_rle(ref_masks[k0])
v1_area = area_from_rle(masks[k0])
v0_areas.append(v0_area)
v1_areas.append(v1_area)
if (v0_area != v1_area) and verbose:
print(f"v0 area {v0_area} doesn't match v1 area {v1_area}")
v0_mask = torch.from_numpy(rle_to_mask(ref_masks[k0]))
v1_mask = torch.from_numpy(rle_to_mask(masks[k0]))
v0_masks.append((v0_mask, v0_area))
v1_masks.append((v1_mask, v1_area))

if order_by_area:
v0_masks = sorted(v0_masks, key=(lambda x: x[1]), reverse=True)
v1_masks = sorted(v1_masks, key=(lambda x: x[1]), reverse=True)
miou_sum = 0.0
miou_count = 0
for ((v0_mask, _), (v1_mask, _)) in zip(v0_masks, v1_masks):
miou_sum += iou(v0_mask, v1_mask)
miou_count += 1
if verbose:
print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}")

return miou_sum, miou_count


def main(path0, path1):
fail_count = 0
Expand All @@ -28,11 +61,9 @@ def main(path0, path1):
if masks0.keys() != masks1.keys():
fail_count += 1
continue
for mask0, mask1 in zip(masks0.values(), masks1.values()):
mask0 = torch.from_numpy(rle_to_mask(mask0))
mask1 = torch.from_numpy(rle_to_mask(mask1))
miou_sum += iou(mask0, mask1).item()
miou_count += 1
s, c = compare_masks(masks0, masks1, order_by_area=True)
miou_sum += s
miou_count += c

print(f"fail_count: {fail_count} mIoU: {miou_sum / miou_count}")

Expand Down
Loading

0 comments on commit e34c83f

Please sign in to comment.