File tree 1 file changed +3
-2
lines changed
1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -128,7 +128,7 @@ import d3rlpy
128
128
dataset, env = d3rlpy.datasets.get_d4rl(' hopper-medium-v0' )
129
129
130
130
# prepare algorithm
131
- cql = d3rlpy.algos.CQLConfig().create(device = ' cuda:0' )
131
+ cql = d3rlpy.algos.CQLConfig(compile_graph = True ).create(device = ' cuda:0' )
132
132
133
133
# train
134
134
cql.fit(
@@ -157,6 +157,7 @@ dataset, env = d3rlpy.datasets.get_atari_transitions(
157
157
cql = d3rlpy.algos.DiscreteCQLConfig(
158
158
observation_scaler = d3rlpy.preprocessing.PixelObservationScaler(),
159
159
reward_scaler = d3rlpy.preprocessing.ClipRewardScaler(- 1.0 , 1.0 ),
160
+ compile_graph = True ,
160
161
).create(device = ' cuda:0' )
161
162
162
163
# start training
@@ -180,7 +181,7 @@ env = gym.make('Hopper-v3')
180
181
eval_env = gym.make(' Hopper-v3' )
181
182
182
183
# prepare algorithm
183
- sac = d3rlpy.algos.SACConfig().create(device = ' cuda:0' )
184
+ sac = d3rlpy.algos.SACConfig(compile_graph = True ).create(device = ' cuda:0' )
184
185
185
186
# prepare replay buffer
186
187
buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit = 1000000 , env = env)
You can’t perform that action at this time.
0 commit comments