Skip to content

Commit 97a69cf

Browse files
authored
Faster AMG (#69)
1 parent 2ed7d1a commit 97a69cf

File tree

8 files changed

+155
-39
lines changed

8 files changed

+155
-39
lines changed

amg_example/amg_example.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22
import torch
33
import matplotlib.pyplot as plt
44
import cv2
5+
import torch.utils.benchmark as benchmark
6+
7+
def profiler_runner(path, fn, *args, **kwargs):
8+
with torch.profiler.profile(
9+
activities=[torch.profiler.ProfilerActivity.CPU,
10+
torch.profiler.ProfilerActivity.CUDA],
11+
record_shapes=True) as prof:
12+
result = fn(*args, **kwargs)
13+
print(f"Saving trace under {path}")
14+
prof.export_chrome_trace(path)
15+
return result
516

617
def show_anns(anns):
718
if len(anns) == 0:
@@ -22,25 +33,46 @@ def show_anns(anns):
2233
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
2334

2435

25-
from segment_anything_fast import sam_model_registry, SamAutomaticMaskGenerator
26-
from segment_anything_fast.tools import apply_eval_dtype_predictor
36+
from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator
2737

2838
sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
2939
model_type = "vit_h"
30-
3140
device = "cuda"
3241

33-
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
42+
sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
3443
sam.to(device=device)
44+
mask_generator = SamAutomaticMaskGenerator(sam, process_batch_size=8)
3545

36-
mask_generator = SamAutomaticMaskGenerator(sam)
37-
mask_generator.predictor = apply_eval_dtype_predictor(mask_generator.predictor, torch.bfloat16)
38-
46+
# Run thrice for warmup
47+
masks = mask_generator.generate(image)
48+
masks = mask_generator.generate(image)
3949
masks = mask_generator.generate(image)
4050

51+
# Save an example
4152
plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100)
4253
plt.imshow(image)
4354
show_anns(masks)
4455
plt.axis('off')
4556
plt.tight_layout()
4657
plt.savefig('dog_mask_fast.png', format='png')
58+
59+
# Benchmark
60+
torch.cuda.synchronize()
61+
start_event = torch.cuda.Event(enable_timing=True)
62+
end_event = torch.cuda.Event(enable_timing=True)
63+
start_event.record()
64+
for _ in range(10):
65+
masks = mask_generator.generate(image)
66+
end_event.record()
67+
torch.cuda.synchronize()
68+
print(start_event.elapsed_time(end_event) / 10.)
69+
70+
# Save a GPU trace
71+
profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image)
72+
73+
# Write out memory usage
74+
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
75+
_, total_memory = torch.cuda.mem_get_info()
76+
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
77+
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
78+
print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}")

amg_example/amg_example_trace.json.gz

331 KB
Binary file not shown.

amg_example/dog_mask_fast.png

629 Bytes
Loading

segment_anything_fast/automatic_mask_generator.py

+52-14
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
generate_crop_boxes,
2525
is_box_near_crop_edge,
2626
mask_to_rle_pytorch,
27+
mask_to_rle_pytorch_2,
2728
remove_small_regions,
2829
rle_to_mask,
2930
uncrop_boxes_xyxy,
@@ -49,6 +50,7 @@ def __init__(
4950
point_grids: Optional[List[np.ndarray]] = None,
5051
min_mask_region_area: int = 0,
5152
output_mode: str = "binary_mask",
53+
process_batch_size: Optional[int] = None,
5254
) -> None:
5355
"""
5456
Using a SAM model, generates masks for the entire image.
@@ -93,6 +95,10 @@ def __init__(
9395
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
9496
For large resolutions, 'binary_mask' may consume large amounts of
9597
memory.
98+
process_batch_size (int or None): Set a batch size for the decoding step.
99+
If None, all points will be batched up at once. Set a small number here
100+
to decrease memory footprint. A smaller number will likely decrease
101+
latency, but also decrease memory usage.
96102
"""
97103

98104
assert (points_per_side is None) != (
@@ -132,6 +138,7 @@ def __init__(
132138
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133139
self.min_mask_region_area = min_mask_region_area
134140
self.output_mode = output_mode
141+
self.process_batch_size = process_batch_size
135142

136143
@torch.no_grad()
137144
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
@@ -241,10 +248,13 @@ def _process_crop(
241248

242249
# Generate masks for this crop in batches
243250
data = MaskData()
244-
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245-
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
251+
all_points = [points for (points,) in batch_iterator(self.points_per_batch, points_for_image)]
252+
process_batch_size = len(all_points) if self.process_batch_size is None else self.process_batch_size
253+
for i in range(0, len(all_points), process_batch_size):
254+
some_points = all_points[i:i+process_batch_size]
255+
batch_data = self._process_batch(some_points, cropped_im_size, crop_box, orig_size)
246256
data.cat(batch_data)
247-
del batch_data
257+
data["rles"] = mask_to_rle_pytorch_2(data["masks"])
248258
self.predictor.reset_image()
249259

250260
# Remove duplicates within this crop.
@@ -265,24 +275,50 @@ def _process_crop(
265275

266276
def _process_batch(
267277
self,
268-
points: np.ndarray,
278+
all_points: List[np.ndarray],
269279
im_size: Tuple[int, ...],
270280
crop_box: List[int],
271281
orig_size: Tuple[int, ...],
272282
) -> MaskData:
273283
orig_h, orig_w = orig_size
274-
275-
# Run model on this batch
276-
transformed_points = self.predictor.transform.apply_coords(points, im_size)
277-
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278-
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279-
masks, iou_preds, _ = self.predictor.predict_torch(
280-
in_points[:, None, :],
281-
in_labels[:, None],
284+
nt_in_points = []
285+
for points in all_points:
286+
# Run model on this batch
287+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
288+
in_points = torch.as_tensor(transformed_points) #, device=self.predictor.device)
289+
nt_in_points.append(in_points)
290+
291+
nt_in_points = torch.nested.nested_tensor(nt_in_points, layout=torch.jagged, pin_memory=True).to(device=self.predictor.device, non_blocking=True)
292+
# The call to prod is a workaround to share jagged sizes between two NestedTensors.
293+
nt_in_labels = torch.ones_like(nt_in_points, dtype=torch.int).prod(dim=-1, keepdim=True)
294+
nt_in_points = nt_in_points.unsqueeze(2)
295+
296+
self.predictor.input_sizes = [self.predictor.input_size for _ in range(len(nt_in_points))]
297+
self.predictor.original_sizes = [self.predictor.original_size for _ in range(len(nt_in_points))]
298+
nt_masks, nt_iou_preds, _ = self.predictor.predict_torch(
299+
point_coords=nt_in_points,
300+
point_labels=nt_in_labels,
282301
multimask_output=True,
283302
return_logits=True,
284303
)
285304

305+
data = MaskData()
306+
for masks, iou_preds, points in zip(nt_masks.unbind(), nt_iou_preds.unbind(), all_points):
307+
batch_data = self._process_batch_2(masks, iou_preds, points, im_size, crop_box, orig_size)
308+
data.cat(batch_data)
309+
return data
310+
311+
# TODO: Batch this up
312+
def _process_batch_2(
313+
self,
314+
masks: torch.Tensor,
315+
iou_preds: torch.Tensor,
316+
points: torch.Tensor,
317+
im_size: Tuple[int, ...],
318+
crop_box: List[int],
319+
orig_size: Tuple[int, ...],
320+
) -> MaskData:
321+
orig_h, orig_w = orig_size
286322
# Serialize predictions and store in MaskData
287323
data = MaskData(
288324
masks=masks.flatten(0, 1),
@@ -315,8 +351,10 @@ def _process_batch(
315351

316352
# Compress to RLE
317353
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318-
data["rles"] = mask_to_rle_pytorch(data["masks"])
319-
del data["masks"]
354+
# Doing this once at the end across all masks.
355+
# data["rles"] = mask_to_rle_pytorch(data["masks"].cpu())
356+
# Keeping the masks around is faster, even though it uses more memory.
357+
# del data["masks"]
320358

321359
return data
322360

segment_anything_fast/build_sam.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def build_sam_vit_b(checkpoint=None):
5151
"vit_b": build_sam_vit_b,
5252
}
5353

54-
def _apply_eval_dtype_sam(model, dtype=None):
54+
def _apply_eval_dtype_sam(model, dtype):
5555

5656
def prep_model(model, dtype):
5757
if dtype is not None:
@@ -64,24 +64,24 @@ def prep_model(model, dtype):
6464

6565
return model
6666

67-
def build_sam_fast_vit_h(checkpoint=None):
67+
def build_sam_fast_vit_h(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
6868
sam = build_sam_vit_h(checkpoint)
69-
sam = _apply_eval_dtype_sam(sam)
70-
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
69+
sam = _apply_eval_dtype_sam(sam, dtype)
70+
sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode)
7171
return sam
7272

7373
build_sam_fast = build_sam_fast_vit_h
7474

75-
def build_sam_fast_vit_l(checkpoint=None):
75+
def build_sam_fast_vit_l(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
7676
sam = build_sam_vit_l(checkpoint)
77-
sam = _apply_eval_dtype_sam(sam)
78-
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
77+
sam = _apply_eval_dtype_sam(sam, dtype)
78+
sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode)
7979
return sam
8080

81-
def build_sam_fast_vit_b(checkpoint=None):
81+
def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
8282
sam = build_sam_vit_b(checkpoint)
83-
sam = _apply_eval_dtype_sam(sam)
84-
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
83+
sam = _apply_eval_dtype_sam(sam, dtype)
84+
sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode)
8585
return sam
8686

8787
sam_model_fast_registry = {

segment_anything_fast/modeling/prompt_encoder.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,10 @@ def forward(
157157
torch.Tensor: dense embeddings for the masks, in the shape
158158
Bx(embed_dim)x(embed_H)x(embed_W)
159159
"""
160-
return_dtype = None
161160
bs = self._get_batch_size(points, boxes, masks)
162161
if points is not None:
163162
coords, labels = points
164163
sparse_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
165-
if sparse_embeddings.dtype != coords.dtype:
166-
return_dtype = coords.dtype
167164
if boxes is not None:
168165
sparse_embeddings = self._embed_boxes(boxes)
169166

@@ -183,10 +180,7 @@ def forward(
183180
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
184181
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])
185182

186-
r0, r1 = sparse_embeddings.to(dense_embeddings.dtype), dense_embeddings
187-
if return_dtype is None:
188-
return r0, r1
189-
return r0.to(return_dtype), r1.to(return_dtype)
183+
return sparse_embeddings.to(dense_embeddings.dtype), dense_embeddings
190184

191185

192186
class PositionEmbeddingRandom(nn.Module):

segment_anything_fast/utils/amg.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def cat(self, new_stats: "MaskData") -> None:
7272
def to_numpy(self) -> None:
7373
for k, v in self._stats.items():
7474
if isinstance(v, torch.Tensor):
75-
self._stats[k] = v.detach().cpu().numpy()
75+
self._stats[k] = v.detach().cpu().float().numpy()
7676

7777

7878
def is_box_near_crop_edge(
@@ -103,6 +103,40 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
103103
for b in range(n_batches):
104104
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105105

106+
def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]:
107+
"""
108+
Encodes masks to an uncompressed RLE, in the format expected by
109+
pycoco tools.
110+
"""
111+
# Put in fortran order and flatten h,w
112+
b, h, w = tensor.shape
113+
tensor = tensor.permute(0, 2, 1).flatten(1)
114+
115+
# Compute change indices
116+
diff = tensor[:, 1:] ^ tensor[:, :-1]
117+
a = torch.tensor([[True]]).pin_memory().cuda().expand_as(diff.narrow(1, 0, 1))
118+
diff = torch.cat([a, diff, a], dim=1)
119+
change_indices = diff.nonzero()
120+
121+
alt_lens = diff.sum(dim=1).tolist()
122+
123+
all_cur_idx = change_indices[:, 1]
124+
all_btw_idx = torch.cat([all_cur_idx[1:], all_cur_idx[:1]]) - all_cur_idx
125+
all_btw_idx = all_btw_idx.detach().cpu().tolist()
126+
127+
# Encode run length
128+
out = []
129+
counts_init = (tensor[:, 0] == 0).tolist()
130+
offset = 0
131+
for i, ci in zip(range(b), counts_init):
132+
btw_idxs = all_btw_idx[offset:offset + alt_lens[i]][:-1]
133+
offset += alt_lens[i]
134+
counts = [] if ci else [0]
135+
counts.extend(btw_idxs)
136+
out.append({"size": [h, w], "counts": counts})
137+
138+
return out
139+
106140

107141
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108142
"""

test/test_mask_to_rle.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
import itertools
3+
from segment_anything_fast.utils.amg import (
4+
mask_to_rle_pytorch,
5+
mask_to_rle_pytorch_2,
6+
)
7+
8+
def test_masks(masks):
9+
rles_0 = mask_to_rle_pytorch(masks)
10+
rles_2 = mask_to_rle_pytorch_2(masks)
11+
12+
for i in range(len(rles_0)):
13+
torch.testing.assert_close(torch.tensor(rles_0[i]['counts']), torch.tensor(rles_2[i]['counts']))
14+
15+
for b, w, h in itertools.product([1, 5], [50, 128], [50, 128]):
16+
test_masks(torch.randn(b, w, h).clamp(min=0).bool().cuda())
17+
test_masks(torch.randn(b, w, h).mul(0).bool().cuda())
18+
test_masks(torch.randn(b, w, h).mul(0).add(1).bool().cuda())

0 commit comments

Comments
 (0)