-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathdemo.py
More file actions
50 lines (43 loc) · 1.78 KB
/
Copy pathdemo.py
File metadata and controls
50 lines (43 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import time
from tqdm import tqdm
import pickle
from experiments.robot.libero.run_libero_eval import GenerateConfig
from experiments.robot.openvla_utils import get_action_head, get_processor, get_proprio_projector, get_vla, get_vla_action
from prismatic.vla.constants import NUM_ACTIONS_CHUNK, PROPRIO_DIM
# Instantiate config (see class GenerateConfig in experiments/robot/libero/run_libero_eval.py for definitions)
model_path = "/path/to/ckpt"
cfg = GenerateConfig(
pretrained_checkpoint=model_path,
use_l1_regression=True,
use_diffusion=False,
num_images_in_input=2,
use_proprio=True,
load_in_8bit=False,
load_in_4bit=False,
center_crop=True,
num_open_loop_steps=NUM_ACTIONS_CHUNK,
unnorm_key="libero_spatial_no_noops",
)
# Load policy and inputs processor
vla = get_vla(cfg)
processor = get_processor(cfg)
# Load MLP action head to generate continuous actions (via L1 regression)
action_head = get_action_head(cfg, llm_dim=vla.llm_dim)
# Load proprio projector to map proprio to language embedding space
proprio_projector = get_proprio_projector(cfg, llm_dim=vla.llm_dim, proprio_dim=PROPRIO_DIM)
# Load sample observation:
# observation (dict): {
# "full_image": primary third-person image,
# "wrist_image": wrist-mounted camera image,
# "state": robot proprioceptive state,
# "task_description": task description,
# }
with open("experiments/robot/libero/sample_libero_spatial_observation.pkl", "rb") as file:
observation = pickle.load(file)
# Generate robot action chunk (sequence of future actions)
actions = get_vla_action(
cfg, vla, processor, observation, observation["task_description"], action_head, proprio_projector
)
print("Generated action chunk:")
for act in actions:
print(act)