1
1
import dataclasses
2
2
from abc import ABCMeta , abstractmethod
3
- from typing import Dict , Union
3
+ from typing import Callable , Dict , Union
4
4
5
5
import torch
6
6
from torch .optim import Optimizer
18
18
compute_stochastic_imitation_loss ,
19
19
)
20
20
from ....optimizers import OptimizerWrapper
21
- from ....torch_utility import Modules , TorchMiniBatch
21
+ from ....torch_utility import CudaGraphWrapper , Modules , TorchMiniBatch
22
22
from ....types import Shape , TorchObservation
23
23
from ..base import QLearningAlgoImplBase
24
24
@@ -32,12 +32,14 @@ class BCBaseModules(Modules):
32
32
33
33
class BCBaseImpl (QLearningAlgoImplBase , metaclass = ABCMeta ):
34
34
_modules : BCBaseModules
35
+ _compute_imitator_grad : Callable [[TorchMiniBatch ], ImitationLoss ]
35
36
36
37
def __init__ (
37
38
self ,
38
39
observation_shape : Shape ,
39
40
action_size : int ,
40
41
modules : BCBaseModules ,
42
+ compiled : bool ,
41
43
device : str ,
42
44
):
43
45
super ().__init__ (
@@ -46,15 +48,21 @@ def __init__(
46
48
modules = modules ,
47
49
device = device ,
48
50
)
51
+ self ._compute_imitator_grad = (
52
+ CudaGraphWrapper (self .compute_imitator_grad )
53
+ if compiled
54
+ else self .compute_imitator_grad
55
+ )
49
56
50
- def update_imitator (self , batch : TorchMiniBatch ) -> Dict [ str , float ] :
57
+ def compute_imitator_grad (self , batch : TorchMiniBatch ) -> ImitationLoss :
51
58
self ._modules .optim .zero_grad ()
52
-
53
59
loss = self .compute_loss (batch .observations , batch .actions )
54
-
55
60
loss .loss .backward ()
56
- self . _modules . optim . step ()
61
+ return loss
57
62
63
+ def update_imitator (self , batch : TorchMiniBatch ) -> Dict [str , float ]:
64
+ loss = self ._compute_imitator_grad (batch )
65
+ self ._modules .optim .step ()
58
66
return asdict_as_float (loss )
59
67
60
68
@abstractmethod
@@ -92,12 +100,14 @@ def __init__(
92
100
action_size : int ,
93
101
modules : BCModules ,
94
102
policy_type : str ,
103
+ compiled : bool ,
95
104
device : str ,
96
105
):
97
106
super ().__init__ (
98
107
observation_shape = observation_shape ,
99
108
action_size = action_size ,
100
109
modules = modules ,
110
+ compiled = compiled ,
101
111
device = device ,
102
112
)
103
113
self ._policy_type = policy_type
@@ -145,12 +155,14 @@ def __init__(
145
155
action_size : int ,
146
156
modules : DiscreteBCModules ,
147
157
beta : float ,
158
+ compiled : bool ,
148
159
device : str ,
149
160
):
150
161
super ().__init__ (
151
162
observation_shape = observation_shape ,
152
163
action_size = action_size ,
153
164
modules = modules ,
165
+ compiled = compiled ,
154
166
device = device ,
155
167
)
156
168
self ._beta = beta
0 commit comments