Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for MPS on apple silicon devices for faster inference. #38

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

JunkyByte
Copy link

Hello! Thanks for your work, I have modified your code to support device switching with --device flag.
I changed the code to move tensors to correct device accordingly.

These changes also fix a bug in the current version: When using vit_t cpu can be used for model inference but other tensors are still loaded on gpu causing an error if torch is not compiled with cuda making cpu inference not possible.

$ python persam_f.py --outdir ./outputs/ --sam_type vit_t
Traceback (most recent call last):
  File "/Users/junkybyte/Desktop/Personalize-SAM/persam_f.py", line 74, in persam_f
    gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
AssertionError: Torch not compiled with CUDA enabled

As SAM works out of the box with MPS on apple silicon devices I choose the default device to be cuda or mps when available fallbacking to cpu otherwise.
I tested MPS on M2 Macbook air with torch==2.0.1 installed.

I changed the README to mirror these changes, let me know if you are interested in a merge. Thanks!

@JunkyByte
Copy link
Author

Other changes you see are just autopep8 style corrections that were applied by default by my editor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant