24
24
generate_crop_boxes ,
25
25
is_box_near_crop_edge ,
26
26
mask_to_rle_pytorch ,
27
+ mask_to_rle_pytorch_2 ,
27
28
remove_small_regions ,
28
29
rle_to_mask ,
29
30
uncrop_boxes_xyxy ,
@@ -49,6 +50,7 @@ def __init__(
49
50
point_grids : Optional [List [np .ndarray ]] = None ,
50
51
min_mask_region_area : int = 0 ,
51
52
output_mode : str = "binary_mask" ,
53
+ process_batch_size : Optional [int ] = None ,
52
54
) -> None :
53
55
"""
54
56
Using a SAM model, generates masks for the entire image.
@@ -93,6 +95,10 @@ def __init__(
93
95
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94
96
For large resolutions, 'binary_mask' may consume large amounts of
95
97
memory.
98
+ process_batch_size (int or None): Set a batch size for the decoding step.
99
+ If None, all points will be batched up at once. Set a small number here
100
+ to decrease memory footprint. A smaller number will likely decrease
101
+ latency, but also decrease memory usage.
96
102
"""
97
103
98
104
assert (points_per_side is None ) != (
@@ -132,6 +138,7 @@ def __init__(
132
138
self .crop_n_points_downscale_factor = crop_n_points_downscale_factor
133
139
self .min_mask_region_area = min_mask_region_area
134
140
self .output_mode = output_mode
141
+ self .process_batch_size = process_batch_size
135
142
136
143
@torch .no_grad ()
137
144
def generate (self , image : np .ndarray ) -> List [Dict [str , Any ]]:
@@ -241,10 +248,13 @@ def _process_crop(
241
248
242
249
# Generate masks for this crop in batches
243
250
data = MaskData ()
244
- for (points ,) in batch_iterator (self .points_per_batch , points_for_image ):
245
- batch_data = self ._process_batch (points , cropped_im_size , crop_box , orig_size )
251
+ all_points = [points for (points ,) in batch_iterator (self .points_per_batch , points_for_image )]
252
+ process_batch_size = len (all_points ) if self .process_batch_size is None else self .process_batch_size
253
+ for i in range (0 , len (all_points ), process_batch_size ):
254
+ some_points = all_points [i :i + process_batch_size ]
255
+ batch_data = self ._process_batch (some_points , cropped_im_size , crop_box , orig_size )
246
256
data .cat (batch_data )
247
- del batch_data
257
+ data [ "rles" ] = mask_to_rle_pytorch_2 ( data [ "masks" ])
248
258
self .predictor .reset_image ()
249
259
250
260
# Remove duplicates within this crop.
@@ -265,24 +275,50 @@ def _process_crop(
265
275
266
276
def _process_batch (
267
277
self ,
268
- points : np .ndarray ,
278
+ all_points : List [ np .ndarray ] ,
269
279
im_size : Tuple [int , ...],
270
280
crop_box : List [int ],
271
281
orig_size : Tuple [int , ...],
272
282
) -> MaskData :
273
283
orig_h , orig_w = orig_size
274
-
275
- # Run model on this batch
276
- transformed_points = self .predictor .transform .apply_coords (points , im_size )
277
- in_points = torch .as_tensor (transformed_points , device = self .predictor .device )
278
- in_labels = torch .ones (in_points .shape [0 ], dtype = torch .int , device = in_points .device )
279
- masks , iou_preds , _ = self .predictor .predict_torch (
280
- in_points [:, None , :],
281
- in_labels [:, None ],
284
+ nt_in_points = []
285
+ for points in all_points :
286
+ # Run model on this batch
287
+ transformed_points = self .predictor .transform .apply_coords (points , im_size )
288
+ in_points = torch .as_tensor (transformed_points ) #, device=self.predictor.device)
289
+ nt_in_points .append (in_points )
290
+
291
+ nt_in_points = torch .nested .nested_tensor (nt_in_points , layout = torch .jagged , pin_memory = True ).to (device = self .predictor .device , non_blocking = True )
292
+ # The call to prod is a workaround to share jagged sizes between two NestedTensors.
293
+ nt_in_labels = torch .ones_like (nt_in_points , dtype = torch .int ).prod (dim = - 1 , keepdim = True )
294
+ nt_in_points = nt_in_points .unsqueeze (2 )
295
+
296
+ self .predictor .input_sizes = [self .predictor .input_size for _ in range (len (nt_in_points ))]
297
+ self .predictor .original_sizes = [self .predictor .original_size for _ in range (len (nt_in_points ))]
298
+ nt_masks , nt_iou_preds , _ = self .predictor .predict_torch (
299
+ point_coords = nt_in_points ,
300
+ point_labels = nt_in_labels ,
282
301
multimask_output = True ,
283
302
return_logits = True ,
284
303
)
285
304
305
+ data = MaskData ()
306
+ for masks , iou_preds , points in zip (nt_masks .unbind (), nt_iou_preds .unbind (), all_points ):
307
+ batch_data = self ._process_batch_2 (masks , iou_preds , points , im_size , crop_box , orig_size )
308
+ data .cat (batch_data )
309
+ return data
310
+
311
+ # TODO: Batch this up
312
+ def _process_batch_2 (
313
+ self ,
314
+ masks : torch .Tensor ,
315
+ iou_preds : torch .Tensor ,
316
+ points : torch .Tensor ,
317
+ im_size : Tuple [int , ...],
318
+ crop_box : List [int ],
319
+ orig_size : Tuple [int , ...],
320
+ ) -> MaskData :
321
+ orig_h , orig_w = orig_size
286
322
# Serialize predictions and store in MaskData
287
323
data = MaskData (
288
324
masks = masks .flatten (0 , 1 ),
@@ -315,8 +351,10 @@ def _process_batch(
315
351
316
352
# Compress to RLE
317
353
data ["masks" ] = uncrop_masks (data ["masks" ], crop_box , orig_h , orig_w )
318
- data ["rles" ] = mask_to_rle_pytorch (data ["masks" ])
319
- del data ["masks" ]
354
+ # Doing this once at the end across all masks.
355
+ # data["rles"] = mask_to_rle_pytorch(data["masks"].cpu())
356
+ # Keeping the masks around is faster, even though it uses more memory.
357
+ # del data["masks"]
320
358
321
359
return data
322
360
0 commit comments