Skip to content

Commit 4c4428b

Browse files
committed
Added support for Cuda GPU
1 parent 747e92c commit 4c4428b

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

Code/training.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# Define Training Class
1515
class Trainer():
1616
def __init__(self, model, loss_function, model_save_path):
17-
# Define the model
18-
self.model = model
17+
# Define the device
18+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19+
# Define the model and move it to the device
20+
self.model = model.to(self.device)
1921
# Define the loss function
2022
self.loss_function = loss_function
2123
# Define the optimizer

0 commit comments

Comments
 (0)