@@ -16,24 +16,54 @@ def make(cfg):
16
16
return make_rccar_envs (cfg )
17
17
18
18
19
+ def prepare_randomization_fn (key , num_envs , cfg , task_name ):
20
+ randomize_fn = lambda sys , rng : randomization_fns [task_name ](sys , rng , cfg )
21
+ v_randomization_fn = functools .partial (
22
+ randomize_fn , rng = jax .random .split (key , num_envs )
23
+ )
24
+ vf_randomization_fn = lambda sys : v_randomization_fn (sys )[:- 1 ] # type: ignore
25
+ params_fn = lambda sys : v_randomization_fn (sys )[- 1 ]
26
+ return vf_randomization_fn , params_fn
27
+
28
+
19
29
def make_rccar_envs (cfg ):
20
30
task_cfg = dict (get_task_config (cfg ))
21
31
task_cfg .pop ("domain_name" )
32
+ task_cfg .pop ("task_name" )
22
33
train_car_params = task_cfg .pop ("train_car_params" )
23
34
eval_car_params = task_cfg .pop ("eval_car_params" )
24
- train_env = rccar .RCCar (train_car_params , ** task_cfg )
35
+ train_key , eval_key = jax .random .split (jax .random .PRNGKey (cfg .training .seed ))
36
+ train_env = rccar .RCCar (train_car_params ["nominal" ], ** task_cfg )
37
+ train_randomization_fn , params_fn = (
38
+ prepare_randomization_fn (
39
+ train_key ,
40
+ cfg .training .num_envs ,
41
+ train_car_params ["bounds" ],
42
+ cfg .environment .task_name ,
43
+ )
44
+ if cfg .training .train_domain_randomization
45
+ else (None , None )
46
+ )
25
47
train_env = envs .training .wrap (
26
48
train_env ,
27
49
episode_length = cfg .training .episode_length ,
28
50
action_repeat = cfg .training .action_repeat ,
51
+ randomization_fn = train_randomization_fn ,
52
+ )
53
+ eval_env = rccar .RCCar (eval_car_params ["nominal" ], ** task_cfg )
54
+ eval_randomization_fn , _ = prepare_randomization_fn (
55
+ eval_key ,
56
+ cfg .training .num_eval_envs ,
57
+ eval_car_params ["bounds" ],
58
+ cfg .environment .task_name ,
29
59
)
30
- eval_env = rccar .RCCar (eval_car_params , ** task_cfg )
31
60
eval_env = envs .training .wrap (
32
61
eval_env ,
33
62
episode_length = cfg .training .episode_length ,
34
63
action_repeat = cfg .training .action_repeat ,
64
+ randomization_fn = eval_randomization_fn ,
35
65
)
36
- return train_env , eval_env , None
66
+ return train_env , eval_env , params_fn
37
67
38
68
39
69
def make_brax_envs (cfg ):
@@ -43,20 +73,10 @@ def make_brax_envs(cfg):
43
73
)
44
74
eval_env = envs .get_environment (task_cfg .task_name , backend = cfg .environment .backend )
45
75
train_key , eval_key = jax .random .split (jax .random .PRNGKey (cfg .training .seed ))
46
-
47
- def prepare_randomization_fn (key , num_envs ):
48
- randomize_fn = lambda sys , rng : randomization_fns [task_cfg .task_name ](
49
- sys , rng , task_cfg
50
- )
51
- v_randomization_fn = functools .partial (
52
- randomize_fn , rng = jax .random .split (key , num_envs )
53
- )
54
- vf_randomization_fn = lambda sys : v_randomization_fn (sys )[:- 1 ] # type: ignore
55
- params_fn = lambda sys : v_randomization_fn (sys )[- 1 ]
56
- return vf_randomization_fn , params_fn
57
-
58
76
train_randomization_fn , params_fn = (
59
- prepare_randomization_fn (train_key , cfg .training .num_envs )
77
+ prepare_randomization_fn (
78
+ train_key , cfg .training .num_envs , task_cfg , task_cfg .task_name
79
+ )
60
80
if cfg .training .train_domain_randomization
61
81
else (None , None )
62
82
)
@@ -67,7 +87,7 @@ def prepare_randomization_fn(key, num_envs):
67
87
randomization_fn = train_randomization_fn ,
68
88
)
69
89
eval_randomization_fn , _ = prepare_randomization_fn (
70
- eval_key , cfg .training .num_eval_envs
90
+ eval_key , cfg .training .num_eval_envs , task_cfg , task_cfg . task_name
71
91
)
72
92
eval_env = envs .training .wrap (
73
93
eval_env ,
0 commit comments