Skip to content

Commit

Permalink
make sure model evaluation runs on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Jul 19, 2024
1 parent 496caee commit c1f582f
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions psiflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def __post_init__(self):
from mace.tools import torch_tools, utils

torch_tools.set_default_dtype(self.dtype)
if self.device == 'gpu': # when it's not a specific GPU, use any
self.device = 'cuda'
self.device = torch_tools.init_device(self.device)

torch.set_num_threads(self.ncores)
Expand Down

0 comments on commit c1f582f

Please sign in to comment.