Skip to content

Commit 76407ec

Browse files
authored
added a warning when a PyTorch model is in train mode (#78)
* added a warning when a PyTorch model is in train mode * added missing import
1 parent fcbbc6f commit 76407ec

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

foolbox/models/pytorch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import warnings
23

34
from .base import DifferentiableModel
45

@@ -43,6 +44,12 @@ def __init__(
4344
self._model = model
4445
self.cuda = cuda
4546

47+
if model.training:
48+
warnings.warn(
49+
'The PyTorch model is in training mode and therefore might'
50+
' not be deterministic. Call the eval() method to set it in'
51+
' evaluation mode if this is not intended.')
52+
4653
def batch_predictions(self, images):
4754
# lazy import
4855
import torch

0 commit comments

Comments
 (0)