-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathego_gmm_clsft.yaml
43 lines (37 loc) · 1.05 KB
/
ego_gmm_clsft.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# @package _global_
defaults:
# - override /trainer: ddp
- override /model: ego_gmm
model:
model_config:
lr: 1e-4
lr_min_ratio: 0.05
token_processor:
map_token_sampling: # open-loop
num_k: 1 # for k nearest neighbors
temp: 1.0 # uniform sampling
agent_token_sampling: # closed-loop
num_k: 1 # for k nearest neighbors
temp: 1.0
validation_rollout_sampling:
criterium: topk_prob # {topk_prob, topk_prob_sampled_with_dist}
num_k: 3 # for k most likely
temp_mode: 1e-3
temp_cov: 1e-3
training_rollout_sampling:
criterium: topk_prob_sampled_with_dist # {topk_prob, topk_prob_sampled_with_dist}
num_k: 3 # for k nearest neighbors, set to -1 to turn-off closed-loop training
temp_mode: 1e-3
temp_cov: 1e-3
finetune: true
ckpt_path: BC_PRETRAINED_MODEL.ckpt
trainer:
limit_train_batches: 1.0
limit_val_batches: 0.1
check_val_every_n_epoch: 1
data:
train_batch_size: 10
val_batch_size: 10
test_batch_size: 10
num_workers: 10
action: finetune