File tree Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Original file line number Diff line number Diff line change 11import contextlib
22import threading
33from copy import deepcopy
4- from typing import Dict , Optional
4+ from typing import Any , Dict , Optional
55
66import torch
77from torch .cuda .amp import GradScaler as TorchGradScaler
8- from torch .cuda .amp .grad_scaler import OptState , _refresh_per_optimizer_state
98from torch .optim import Optimizer as TorchOptimizer
109
1110import hivemind
1211from hivemind .utils .logging import get_logger
1312
13+ if torch .cuda .is_available ():
14+ from torch .cuda .amp .grad_scaler import OptState , _refresh_per_optimizer_state
15+ else :
16+ # on cpu the import is not working, so just copy pasting the code here as it is simple
17+ # code taken from here : https://github.com/pytorch/pytorch/blob/afda6685ae87cce7ac2fe4bac3926572da2960f7/torch/amp/grad_scaler.py#L44
18+
19+ from enum import Enum
20+
21+ class OptState (Enum ):
22+ READY = 0
23+ UNSCALED = 1
24+ STEPPED = 2
25+
26+ def _refresh_per_optimizer_state () -> Dict [str , Any ]:
27+ return {"stage" : OptState .READY , "found_inf_per_device" : {}}
28+
29+
1430logger = get_logger (__name__ )
1531
1632
You can’t perform that action at this time.
0 commit comments