Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for MPS on apple silicon devices for faster inference. #38

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Official implementation of ['Personalize Segment Anything Model with One Shot'](


## News
* MPS (Metal Performance Shader) support added 🔥 Faster performance on apple silicon devices.
* Support [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) 🔥 with significant efficiency improvement. Thanks for their wonderful work!
* **TODO**: Release the PerSAM-assisted [Dreambooth](https://arxiv.org/pdf/2208.12242.pdf) for better fine-tuning [Stable Diffusion](https://github.com/CompVis/stable-diffusion) 📌.
* We release the code of PerSAM and PerSAM-F 🔥. Check our [video](https://www.youtube.com/watch?v=QlunvXpYQXM) here!
Expand Down Expand Up @@ -91,6 +92,7 @@ For **Multi-Object** segmentation of the same category by PerSAM-F (Great thanks
python persam_f_multi_obj.py --sam_type <sam module type> --outdir <output filename>
```

Specify device to use with `--device` currently supports `cpu, cuda, mps (apple silicon)`. Will default to `cuda` and `mps` when available.
After running, the output masks and visualizations will be stored at `outputs/<output filename>`.

### Evaluation
Expand Down
53 changes: 27 additions & 26 deletions persam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,25 @@
import warnings
warnings.filterwarnings('ignore')

from show import *
from per_segment_anything import sam_model_registry, SamPredictor
from show import *

# Priority is cuda > mps > cpu
DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else
'mps' if torch.backends.mps.is_available() else
'cpu')


def get_arguments():

parser = argparse.ArgumentParser()

parser.add_argument('--data', type=str, default='./data')
parser.add_argument('--outdir', type=str, default='persam')
parser.add_argument('--ckpt', type=str, default='sam_vit_h_4b8939.pth')
parser.add_argument('--device', type=str, default=DEFAULT_DEVICE)
parser.add_argument('--ref_idx', type=str, default='00')
parser.add_argument('--sam_type', type=str, default='vit_h')

args = parser.parse_args()
return args

Expand All @@ -40,7 +44,7 @@ def main():

if not os.path.exists('./outputs/'):
os.mkdir('./outputs/')

for obj_name in os.listdir(images_path):
if ".DS" not in obj_name:
persam(args, obj_name, images_path, masks_path, output_path)
Expand All @@ -49,7 +53,7 @@ def main():
def persam(args, obj_name, images_path, masks_path, output_path):

print("\n------------> Segment " + obj_name)

# Path preparation
ref_image_path = os.path.join(images_path, obj_name, args.ref_idx + '.jpg')
ref_mask_path = os.path.join(masks_path, obj_name, args.ref_idx + '.png')
Expand All @@ -64,21 +68,19 @@ def persam(args, obj_name, images_path, masks_path, output_path):

ref_mask = cv2.imread(ref_mask_path)
ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB)


print("======> Load SAM" )
print("======> Load SAM")
if args.sam_type == 'vit_h':
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(args.device)
elif args.sam_type == 'vit_t':
sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt'
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device)
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=args.device)
sam.eval()

predictor = SamPredictor(sam)

print("======> Obtain Location Prior" )
print("======> Obtain Location Prior")
# Image features encoding
ref_mask = predictor.set_image(ref_image, ref_mask)
ref_feat = predictor.features.squeeze().permute(1, 2, 0)
Expand All @@ -92,10 +94,9 @@ def persam(args, obj_name, images_path, masks_path, output_path):
target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True)
target_embedding = target_embedding.unsqueeze(0)


print('======> Start Testing')
for test_idx in tqdm(range(len(os.listdir(test_images_path)))):

# Load test image
test_idx = '%02d' % test_idx
test_image_path = test_images_path + '/' + test_idx + '.jpg'
Expand All @@ -115,9 +116,9 @@ def persam(args, obj_name, images_path, masks_path, output_path):
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()

# Positive-negative location prior
topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1)
Expand All @@ -131,8 +132,8 @@ def persam(args, obj_name, images_path, masks_path, output_path):

# First-step prediction
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=False,
attn_sim=attn_sim, # Target-guided Attention
target_embedding=target_embedding # Target-semantic Prompting
Expand All @@ -141,10 +142,10 @@ def persam(args, obj_name, images_path, masks_path, output_path):

# Cascaded Post-refinement-1
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
point_coords=topk_xy,
point_labels=topk_label,
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)

# Cascaded Post-refinement-2
Expand All @@ -158,7 +159,7 @@ def persam(args, obj_name, images_path, masks_path, output_path):
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logits[best_idx: best_idx + 1, :, :],
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)

Expand Down Expand Up @@ -189,17 +190,17 @@ def point_selection(mask_sim, topk=1):
topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
topk_label = np.array([1] * topk)
topk_xy = topk_xy.cpu().numpy()

# Top-last point selection
last_xy = mask_sim.flatten(0).topk(topk, largest=False)[1]
last_x = (last_xy // h).unsqueeze(0)
last_y = (last_xy - last_x * h)
last_xy = torch.cat((last_y, last_x), dim=0).permute(1, 0)
last_label = np.array([0] * topk)
last_xy = last_xy.cpu().numpy()

return topk_xy, topk_label, last_xy, last_label


if __name__ == "__main__":
main()
70 changes: 34 additions & 36 deletions persam_f.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,29 @@
import warnings
warnings.filterwarnings('ignore')

from show import *
from per_segment_anything import sam_model_registry, SamPredictor
from show import *


# Priority is cuda > mps > cpu
DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else
'mps' if torch.backends.mps.is_available() else
'cpu')

def get_arguments():

parser = argparse.ArgumentParser()

parser.add_argument('--data', type=str, default='./data')
parser.add_argument('--outdir', type=str, default='persam_f')
parser.add_argument('--device', type=str, default=DEFAULT_DEVICE)
parser.add_argument('--ckpt', type=str, default='./sam_vit_h_4b8939.pth')
parser.add_argument('--sam_type', type=str, default='vit_h')

parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--train_epoch', type=int, default=1000)
parser.add_argument('--log_epoch', type=int, default=200)
parser.add_argument('--ref_idx', type=str, default='00')

args = parser.parse_args()
return args

Expand All @@ -45,16 +49,16 @@ def main():

if not os.path.exists('./outputs/'):
os.mkdir('./outputs/')

for obj_name in os.listdir(images_path):
if ".DS" not in obj_name:
persam_f(args, obj_name, images_path, masks_path, output_path)


def persam_f(args, obj_name, images_path, masks_path, output_path):

print("\n------------> Segment " + obj_name)

# Path preparation
ref_image_path = os.path.join(images_path, obj_name, args.ref_idx + '.jpg')
ref_mask_path = os.path.join(masks_path, obj_name, args.ref_idx + '.png')
Expand All @@ -70,27 +74,23 @@ def persam_f(args, obj_name, images_path, masks_path, output_path):
ref_mask = cv2.imread(ref_mask_path)
ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB)

gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to(args.device)


print("======> Load SAM" )
print("======> Load SAM")
if args.sam_type == 'vit_h':
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(args.device)
elif args.sam_type == 'vit_t':
sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt'
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device)
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=args.device)
sam.eval()



for name, param in sam.named_parameters():
param.requires_grad = False
predictor = SamPredictor(sam)


print("======> Obtain Self Location Prior" )
print("======> Obtain Self Location Prior")
# Image features encoding
ref_mask = predictor.set_image(ref_image, ref_mask)
ref_feat = predictor.features.squeeze().permute(1, 2, 0)
Expand All @@ -114,19 +114,18 @@ def persam_f(args, obj_name, images_path, masks_path, output_path):
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()

# Positive location prior
topk_xy, topk_label = point_selection(sim, topk=1)


print('======> Start Training')
# Learnable mask weights
mask_weights = Mask_Weights().cuda()
mask_weights = Mask_Weights().to(args.device)
mask_weights.train()

optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch)

Expand Down Expand Up @@ -158,7 +157,6 @@ def persam_f(args, obj_name, images_path, masks_path, output_path):
current_lr = scheduler.get_last_lr()[0]
print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item()))


mask_weights.eval()
weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
weights_np = weights.detach().cpu().numpy()
Expand Down Expand Up @@ -186,18 +184,18 @@ def persam_f(args, obj_name, images_path, masks_path, output_path):
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()

# Positive location prior
topk_xy, topk_label = point_selection(sim, topk=1)

# First-step prediction
masks, scores, logits, logits_high = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True)
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True)

# Weighted sum three-scale masks
logits_high = logits_high * weights.unsqueeze(-1)
Expand Down Expand Up @@ -236,7 +234,7 @@ def persam_f(args, obj_name, images_path, masks_path, output_path):
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)

# Save masks
plt.figure(figsize=(10, 10))
plt.imshow(test_image)
Expand Down Expand Up @@ -270,11 +268,11 @@ def point_selection(mask_sim, topk=1):
topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
topk_label = np.array([1] * topk)
topk_xy = topk_xy.cpu().numpy()

return topk_xy, topk_label


def calculate_dice_loss(inputs, targets, num_masks = 1):
def calculate_dice_loss(inputs, targets, num_masks=1):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
Expand All @@ -292,7 +290,7 @@ def calculate_dice_loss(inputs, targets, num_masks = 1):
return loss.sum() / num_masks


def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2):
def calculate_sigmoid_focal_loss(inputs, targets, num_masks=1, alpha: float = 0.25, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
Expand Down
Loading