Skip to content

Commit 19ed6cf

Browse files
committed
fix grad scaler import
1 parent d20e810 commit 19ed6cf

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

hivemind/optim/grad_scaler.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
11
import contextlib
22
import threading
33
from copy import deepcopy
4-
from typing import Dict, Optional
4+
from typing import Any, Dict, Optional
55

66
import torch
77
from torch.cuda.amp import GradScaler as TorchGradScaler
8-
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
98
from torch.optim import Optimizer as TorchOptimizer
109

1110
import hivemind
1211
from hivemind.utils.logging import get_logger
1312

13+
if torch.cuda.is_available():
14+
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
15+
else:
16+
# on cpu the import is not working, so just copy pasting the code here as it is simple
17+
# code taken from here : https://github.com/pytorch/pytorch/blob/afda6685ae87cce7ac2fe4bac3926572da2960f7/torch/amp/grad_scaler.py#L44
18+
19+
from enum import Enum
20+
21+
class OptState(Enum):
22+
READY = 0
23+
UNSCALED = 1
24+
STEPPED = 2
25+
26+
def _refresh_per_optimizer_state() -> Dict[str, Any]:
27+
return {"stage": OptState.READY, "found_inf_per_device": {}}
28+
29+
1430
logger = get_logger(__name__)
1531

1632

0 commit comments

Comments
 (0)