From c1f582fa32f256f41cc78977ab2054316af439bf Mon Sep 17 00:00:00 2001 From: Sander Vandenhaute Date: Fri, 19 Jul 2024 17:00:02 -0400 Subject: [PATCH] make sure model evaluation runs on GPU --- psiflow/functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/psiflow/functions.py b/psiflow/functions.py index 5173061..d3d382b 100644 --- a/psiflow/functions.py +++ b/psiflow/functions.py @@ -218,6 +218,8 @@ def __post_init__(self): from mace.tools import torch_tools, utils torch_tools.set_default_dtype(self.dtype) + if self.device == 'gpu': # when it's not a specific GPU, use any + self.device = 'cuda' self.device = torch_tools.init_device(self.device) torch.set_num_threads(self.ncores)