Skip to content

Commit 5397ab9

Browse files
Implement classes Navigation, Pick and Place as a new skills (#131)
1 parent d095a0c commit 5397ab9

File tree

16 files changed

+1208
-883
lines changed

16 files changed

+1208
-883
lines changed

spot_rl_experiments/configs/config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
WEIGHTS:
22
# Regular Nav
3-
NAV: "weights/torchscript/CUTOUT_WT_True_SD_200_ckpt.99.pvp_combined_net.torchscript" #"weights/CUTOUT_WT_True_SD_200_ckpt.99.pvp.pth"
3+
NAV: "weights/torchscript/CUTOUT_WT_True_SD_200_ckpt.99.pvp_combined_net.torchscript"
44

55
# Static gaze
6-
GAZE: "weights/torchscript/gaze_normal_32_seed100_1649708902_ckpt.38_combined_net.torchscript" #"weights/final_paper/gaze_normal_32_seed100_1649708902_ckpt.38.pth"
6+
GAZE: "weights/torchscript/gaze_normal_32_seed100_1649708902_ckpt.38_combined_net.torchscript"
77

88
# Mobile Gaze torchscript module files path
99
MOBILE_GAZE: "weights/torchscript/mg97_2_latest_combined_net.torchscript"
1010

1111
# Static place
12-
PLACE: "weights/torchscript/place_10deg_32_seed300_1649709235_ckpt.75_combined_net.torchscript" #"weights/final_paper/place_10deg_32_seed300_1649709235_ckpt.75.pth"
12+
PLACE: "weights/torchscript/place_10deg_32_seed300_1649709235_ckpt.75_combined_net.torchscript"
1313

1414
# ASC
1515
MIXER: "weights/final_paper/final_moe_rnn_60_1.0_SD_100_1652120928_ckpt.16_copy.pth"

spot_rl_experiments/experiments/skill_test/test_heurisitic_gaze.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import sys
77

88
import numpy as np
9-
from spot_rl.envs.nav_env import construct_config_for_nav
109
from spot_rl.envs.skill_manager import SpotSkillManager
10+
from spot_rl.utils.construct_configs import construct_config_for_nav
1111
from spot_rl.utils.heuristic_nav import heurisitic_object_search_and_navigation
1212
from spot_rl.utils.utils import get_default_parser, map_user_input_to_boolean
1313

spot_rl_experiments/spot_rl/envs/gaze_env.py

+1-209
Original file line numberDiff line numberDiff line change
@@ -3,194 +3,16 @@
33
# LICENSE file in the root directory of this source tree.
44

55

6-
import os
76
import sys
8-
import time
97
from typing import Dict, List
108

11-
import numpy as np
129
import rospy
1310
from spot_rl.envs.base_env import SpotBaseEnv
14-
from spot_rl.real_policy import GazePolicy, MobileGazePolicy
15-
from spot_rl.utils.utils import (
16-
construct_config,
17-
get_default_parser,
18-
map_user_input_to_boolean,
19-
)
2011
from spot_wrapper.spot import Spot
2112

2213

23-
def parse_arguments(args=sys.argv[1:]):
24-
parser = get_default_parser()
25-
parser.add_argument(
26-
"-t", "--target-object", type=str, help="name of the target object"
27-
)
28-
parser.add_argument(
29-
"-dp",
30-
"--dont_pick_up",
31-
action="store_true",
32-
help="robot should attempt pick but not actually pick",
33-
)
34-
parser.add_argument(
35-
"-ms", "--max_episode_steps", type=int, help="max episode steps"
36-
)
37-
args = parser.parse_args(args=args)
38-
39-
if args.max_episode_steps is not None:
40-
args.max_episode_steps = int(args.max_episode_steps)
41-
return args
42-
43-
44-
def construct_config_for_gaze(
45-
file_path=None, opts=[], dont_pick_up=False, max_episode_steps=None
46-
):
47-
"""
48-
Constructs and updates the config for gaze
49-
50-
Args:
51-
file_path (str): Path to the config file
52-
opts (list): List of options to update the config
53-
54-
Returns:
55-
config (Config): Updated config object
56-
"""
57-
config = None
58-
if file_path is None:
59-
config = construct_config(opts=opts)
60-
else:
61-
config = construct_config(file_path=file_path, opts=opts)
62-
63-
# Don't need head cameras for Gaze
64-
config.USE_HEAD_CAMERA = False
65-
66-
# Update the config based on the input argument
67-
if dont_pick_up != config.DONT_PICK_UP:
68-
print(
69-
f"WARNING: Overriding dont_pick_up in config from {config.DONT_PICK_UP} to {dont_pick_up}"
70-
)
71-
config.DONT_PICK_UP = dont_pick_up
72-
73-
# Update max episode steps based on the input argument
74-
if max_episode_steps is not None:
75-
print(
76-
f"WARNING: Overriding max_espisode_steps in config from {config.MAX_EPISODE_STEPS} to {max_episode_steps}"
77-
)
78-
config.MAX_EPISODE_STEPS = max_episode_steps
79-
return config
80-
81-
82-
class GazeController:
83-
"""
84-
GazeController is used to gaze at, and pick given objects.
85-
86-
Args:
87-
config (Config): Config object
88-
spot (Spot): Spot object
89-
90-
How to use:
91-
1. Create a GazeController object
92-
2. Call execute() method with the target object list
93-
94-
Example:
95-
config = construct_config_for_gaze(opts=[])
96-
spot = Spot("spot_client_name")
97-
with spot.get_lease(hijack=True):
98-
spot.power_robot()
99-
100-
gaze_target_list = ["apple", "banana"]
101-
gaze_controller = GazeController(config, spot)
102-
gaze_results = gaze_controller.execute(gaze_target_list)
103-
104-
spot.shutdown(should_dock=True)
105-
"""
106-
107-
def __init__(self, config, spot, use_mobile_pick=False):
108-
self.config = config
109-
self.spot = spot
110-
self._use_mobile_pick = use_mobile_pick
111-
112-
if use_mobile_pick:
113-
self.policy = MobileGazePolicy(
114-
config.WEIGHTS.MOBILE_GAZE, device=config.DEVICE, config=config
115-
)
116-
else:
117-
self.policy = GazePolicy(config.WEIGHTS.GAZE, device=config.DEVICE)
118-
self.policy.reset()
119-
120-
self.gaze_env = SpotGazeEnv(config, spot, use_mobile_pick)
121-
122-
def reset_env_and_policy(self, target_obj_name):
123-
"""
124-
Resets the gaze_env and policy
125-
126-
Args:
127-
target_obj_name (str): Name of the target object
128-
129-
Returns:
130-
observations: observations from the gaze_env
131-
132-
"""
133-
observations = self.gaze_env.reset(target_obj_name=target_obj_name)
134-
self.policy.reset()
135-
136-
return observations
137-
138-
def execute(self, target_object_list, take_user_input=False):
139-
"""
140-
Gaze at the target object list and pick up the objects if specified in the config
141-
142-
CAUTION: The robot will drop the object after picking it, please use objects that are not fragile
143-
144-
Args:
145-
target_object_list (list): List of target objects to gaze at
146-
take_user_input (bool): Whether to take user input for the success of the gaze
147-
148-
Returns:
149-
gaze_success_list (list): List of dictionaries containing the target object name, time taken and success
150-
"""
151-
gaze_success_list = []
152-
print(f"Target object list : {target_object_list}")
153-
for target_object in target_object_list:
154-
observations = self.reset_env_and_policy(target_obj_name=target_object)
155-
done = False
156-
start_time = time.time()
157-
self.gaze_env.say(f"Gaze at target object - {target_object}")
158-
159-
while not done:
160-
action = self.policy.act(observations)
161-
if self._use_mobile_pick:
162-
arm_action, base_action = None, None
163-
# first 4 are arm actions, then 2 are base actions & last bit is unused
164-
arm_action = action[0:4]
165-
base_action = action[4:6]
166-
167-
observations, _, done, _ = self.gaze_env.step(
168-
arm_action=arm_action, base_action=base_action
169-
)
170-
else:
171-
observations, _, done, _ = self.gaze_env.step(arm_action=action)
172-
self.gaze_env.say("Gaze finished")
173-
# Ask user for feedback about the success of the gaze and update the "success" flag accordingly
174-
success_status_from_user_feedback = True
175-
if take_user_input:
176-
user_prompt = f"Did the robot successfully pick the right object - {target_object}?"
177-
success_status_from_user_feedback = map_user_input_to_boolean(
178-
user_prompt
179-
)
180-
181-
gaze_success_list.append(
182-
{
183-
"target_object": target_object,
184-
"time_taken": time.time() - start_time,
185-
"success": self.gaze_env.grasp_attempted
186-
and success_status_from_user_feedback,
187-
}
188-
)
189-
return gaze_success_list
190-
191-
19214
class SpotGazeEnv(SpotBaseEnv):
193-
def __init__(self, config, spot, use_mobile_pick=False):
15+
def __init__(self, config, spot: Spot, use_mobile_pick: bool = False):
19416
# Select suitable keys
19517
max_joint_movement_key = (
19618
"MAX_JOINT_MOVEMENT_MOBILE_GAZE"
@@ -273,33 +95,3 @@ def get_observations(self):
27395

27496
def get_success(self, observations):
27597
return self.grasp_attempted
276-
277-
278-
if __name__ == "__main__":
279-
spot = Spot("RealGazeEnv")
280-
args = parse_arguments()
281-
config = construct_config_for_gaze(
282-
opts=args.opts,
283-
dont_pick_up=args.dont_pick_up,
284-
max_episode_steps=args.max_episode_steps,
285-
)
286-
287-
target_objects_list = []
288-
if args.target_object is not None:
289-
target_objects_list = [
290-
target
291-
for target in args.target_object.replace(" ,", ",")
292-
.replace(", ", ",")
293-
.split(",")
294-
if target.strip() is not None
295-
]
296-
297-
print(f"Target_objects list - {target_objects_list}")
298-
with spot.get_lease(hijack=True):
299-
spot.power_robot()
300-
gaze_controller = GazeController(config, spot)
301-
try:
302-
gaze_result = gaze_controller.execute(target_objects_list)
303-
print(gaze_result)
304-
finally:
305-
spot.shutdown(should_dock=True)

0 commit comments

Comments
 (0)