Skip to content

Commit

Permalink
Fix Dont combine option for SEGS Detector for AnimateDiff (#656)
Browse files Browse the repository at this point in the history
Co-authored-by: maratz <[email protected]>
  • Loading branch information
aganoob and maratz authored Jun 28, 2024
1 parent f7df6e4 commit cbc8384
Showing 1 changed file with 5 additions and 22 deletions.
27 changes: 5 additions & 22 deletions modules/impact/detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,12 @@ def get_pivot_segs():
merged_mask = get_whole_merged_mask()
return segs_nodes.MaskToSEGS.doit(merged_mask, False, crop_factor, False, drop_size, contour_fill=True)[0]

def get_merged_neighboring_segs():
def get_segs(merged_neighboring=False):
pivot_segs = get_pivot_segs()

masks_by_frame = get_masked_frames()
masks_by_frame = get_merged_neighboring_mask(masks_by_frame)
if merged_neighboring:
masks_by_frame = get_merged_neighboring_mask(masks_by_frame)

new_segs = []
for seg in pivot_segs[1]:
Expand All @@ -435,33 +436,15 @@ def get_merged_neighboring_segs():

return pivot_segs[0], new_segs

def get_separated_segs():
pivot_segs = get_pivot_segs()

masks_by_frame = get_masked_frames()

new_segs = []
for seg in pivot_segs[1]:
cropped_mask = torch.zeros(seg.cropped_mask.shape, dtype=torch.float32, device="cpu").unsqueeze(0)
x1, y1, x2, y2 = seg.crop_region
for mask in masks_by_frame:
cropped_mask_at_frame = mask[y1:y2, x1:x2]
cropped_mask = torch.cat((cropped_mask, cropped_mask_at_frame), dim=0)

new_seg = SEG(seg.cropped_image, cropped_mask, seg.confidence, seg.crop_region, seg.bbox, seg.label, seg.control_net_wrapper)
new_segs.append(new_seg)

return pivot_segs[0], new_segs

# create result mask
if masking_mode == "Pivot SEGS":
return (get_pivot_segs(), )

elif masking_mode == "Combine neighboring frames":
return (get_merged_neighboring_segs(), )
return (get_segs(merged_neighboring=True), )

else: # elif masking_mode == "Don't combine":
return (get_separated_segs(), )
return (get_segs(merged_neighboring=False), )

def doit(self, bbox_detector, image_frames, bbox_threshold, bbox_dilation, crop_factor, drop_size,
sub_threshold, sub_dilation, sub_bbox_expansion, sam_mask_hint_threshold,
Expand Down

0 comments on commit cbc8384

Please sign in to comment.