Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA not used when device=torch.device("cuda") ? #83

Open
sjfleming opened this issue Feb 10, 2025 · 0 comments
Open

CUDA not used when device=torch.device("cuda") ? #83

sjfleming opened this issue Feb 10, 2025 · 0 comments

Comments

@sjfleming
Copy link

I have here a little demo that appears to show that pymde.preserve_neighbors uses cuda only if the user specifies device="cuda" and not if the user specifies device=torch.device("cuda"). While I see that the type hint is str for that input, I still think this behavior is unexpected. Example:

import pymde
import torch

device = "cuda"

mnist = pymde.datasets.MNIST()
embedding = pymde.preserve_neighbors(mnist.data, embedding_dim=2, verbose=True, device=device).embed()
pymde.plot(embedding, color_by=mnist.attributes['digits'])

This works fine and seems to run on CUDA.

But this

import pymde
import torch

device = torch.device("cuda")

mnist = pymde.datasets.MNIST()
embedding = pymde.preserve_neighbors(mnist.data, embedding_dim=2, verbose=True, device=device).embed()
pymde.plot(embedding, color_by=mnist.attributes['digits'])

leads to ArpackError: ARPACK error -9: Starting vector is zero. (see #82 ... this error appears when arpack is used on cpu), so this seems not to be running on CUDA.

I think the potential fix would be to change this line

cg = device == "cuda"

to allow for torch.device("cuda") as well.

If you think I'm on the right track, I'd be happy to write a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant