Skip to content

Commit c90e152

Browse files
committed
Implement class Pick as a new skill
1 parent 7f85583 commit c90e152

File tree

8 files changed

+385
-270
lines changed

8 files changed

+385
-270
lines changed

spot_rl_experiments/spot_rl/envs/gaze_env.py

+1-210
Original file line numberDiff line numberDiff line change
@@ -3,194 +3,15 @@
33
# LICENSE file in the root directory of this source tree.
44

55

6-
import os
76
import sys
8-
import time
97
from typing import Dict, List
108

11-
import numpy as np
129
import rospy
1310
from spot_rl.envs.base_env import SpotBaseEnv
14-
from spot_rl.real_policy import GazePolicy, MobileGazePolicy
15-
from spot_rl.utils.utils import (
16-
construct_config,
17-
get_default_parser,
18-
map_user_input_to_boolean,
19-
)
2011
from spot_wrapper.spot import Spot
2112

22-
23-
def parse_arguments(args=sys.argv[1:]):
24-
parser = get_default_parser()
25-
parser.add_argument(
26-
"-t", "--target-object", type=str, help="name of the target object"
27-
)
28-
parser.add_argument(
29-
"-dp",
30-
"--dont_pick_up",
31-
action="store_true",
32-
help="robot should attempt pick but not actually pick",
33-
)
34-
parser.add_argument(
35-
"-ms", "--max_episode_steps", type=int, help="max episode steps"
36-
)
37-
args = parser.parse_args(args=args)
38-
39-
if args.max_episode_steps is not None:
40-
args.max_episode_steps = int(args.max_episode_steps)
41-
return args
42-
43-
44-
def construct_config_for_gaze(
45-
file_path=None, opts=[], dont_pick_up=False, max_episode_steps=None
46-
):
47-
"""
48-
Constructs and updates the config for gaze
49-
50-
Args:
51-
file_path (str): Path to the config file
52-
opts (list): List of options to update the config
53-
54-
Returns:
55-
config (Config): Updated config object
56-
"""
57-
config = None
58-
if file_path is None:
59-
config = construct_config(opts=opts)
60-
else:
61-
config = construct_config(file_path=file_path, opts=opts)
62-
63-
# Don't need head cameras for Gaze
64-
config.USE_HEAD_CAMERA = False
65-
66-
# Update the config based on the input argument
67-
if dont_pick_up != config.DONT_PICK_UP:
68-
print(
69-
f"WARNING: Overriding dont_pick_up in config from {config.DONT_PICK_UP} to {dont_pick_up}"
70-
)
71-
config.DONT_PICK_UP = dont_pick_up
72-
73-
# Update max episode steps based on the input argument
74-
if max_episode_steps is not None:
75-
print(
76-
f"WARNING: Overriding max_espisode_steps in config from {config.MAX_EPISODE_STEPS} to {max_episode_steps}"
77-
)
78-
config.MAX_EPISODE_STEPS = max_episode_steps
79-
return config
80-
81-
82-
class GazeController:
83-
"""
84-
GazeController is used to gaze at, and pick given objects.
85-
86-
Args:
87-
config (Config): Config object
88-
spot (Spot): Spot object
89-
90-
How to use:
91-
1. Create a GazeController object
92-
2. Call execute() method with the target object list
93-
94-
Example:
95-
config = construct_config_for_gaze(opts=[])
96-
spot = Spot("spot_client_name")
97-
with spot.get_lease(hijack=True):
98-
spot.power_robot()
99-
100-
gaze_target_list = ["apple", "banana"]
101-
gaze_controller = GazeController(config, spot)
102-
gaze_results = gaze_controller.execute(gaze_target_list)
103-
104-
spot.shutdown(should_dock=True)
105-
"""
106-
107-
def __init__(self, config, spot, use_mobile_pick=False):
108-
self.config = config
109-
self.spot = spot
110-
self._use_mobile_pick = use_mobile_pick
111-
112-
if use_mobile_pick:
113-
self.policy = MobileGazePolicy(
114-
config.WEIGHTS.MOBILE_GAZE, device=config.DEVICE, config=config
115-
)
116-
else:
117-
self.policy = GazePolicy(config.WEIGHTS.GAZE, device=config.DEVICE)
118-
self.policy.reset()
119-
120-
self.gaze_env = SpotGazeEnv(config, spot, use_mobile_pick)
121-
122-
def reset_env_and_policy(self, target_obj_name):
123-
"""
124-
Resets the gaze_env and policy
125-
126-
Args:
127-
target_obj_name (str): Name of the target object
128-
129-
Returns:
130-
observations: observations from the gaze_env
131-
132-
"""
133-
observations = self.gaze_env.reset(target_obj_name=target_obj_name)
134-
self.policy.reset()
135-
136-
return observations
137-
138-
def execute(self, target_object_list, take_user_input=False):
139-
"""
140-
Gaze at the target object list and pick up the objects if specified in the config
141-
142-
CAUTION: The robot will drop the object after picking it, please use objects that are not fragile
143-
144-
Args:
145-
target_object_list (list): List of target objects to gaze at
146-
take_user_input (bool): Whether to take user input for the success of the gaze
147-
148-
Returns:
149-
gaze_success_list (list): List of dictionaries containing the target object name, time taken and success
150-
"""
151-
gaze_success_list = []
152-
print(f"Target object list : {target_object_list}")
153-
for target_object in target_object_list:
154-
observations = self.reset_env_and_policy(target_obj_name=target_object)
155-
done = False
156-
start_time = time.time()
157-
self.gaze_env.say(f"Gaze at target object - {target_object}")
158-
159-
while not done:
160-
action = self.policy.act(observations)
161-
if self._use_mobile_pick:
162-
arm_action, base_action = None, None
163-
# first 4 are arm actions, then 2 are base actions & last bit is unused
164-
arm_action = action[0:4]
165-
base_action = action[4:6]
166-
167-
observations, _, done, _ = self.gaze_env.step(
168-
arm_action=arm_action, base_action=base_action
169-
)
170-
else:
171-
observations, _, done, _ = self.gaze_env.step(arm_action=action)
172-
self.gaze_env.say("Gaze finished")
173-
# Ask user for feedback about the success of the gaze and update the "success" flag accordingly
174-
success_status_from_user_feedback = True
175-
if take_user_input:
176-
user_prompt = f"Did the robot successfully pick the right object - {target_object}?"
177-
success_status_from_user_feedback = map_user_input_to_boolean(
178-
user_prompt
179-
)
180-
181-
gaze_success_list.append(
182-
{
183-
"target_object": target_object,
184-
"time_taken": time.time() - start_time,
185-
"success": self.gaze_env.grasp_attempted
186-
and success_status_from_user_feedback,
187-
}
188-
)
189-
return gaze_success_list
190-
191-
19213
class SpotGazeEnv(SpotBaseEnv):
193-
def __init__(self, config, spot, use_mobile_pick=False):
14+
def __init__(self, config, spot: Spot, use_mobile_pick=False):
19415
# Select suitable keys
19516
max_joint_movement_key = (
19617
"MAX_JOINT_MOVEMENT_MOBILE_GAZE"
@@ -273,33 +94,3 @@ def get_observations(self):
27394

27495
def get_success(self, observations):
27596
return self.grasp_attempted
276-
277-
278-
if __name__ == "__main__":
279-
spot = Spot("RealGazeEnv")
280-
args = parse_arguments()
281-
config = construct_config_for_gaze(
282-
opts=args.opts,
283-
dont_pick_up=args.dont_pick_up,
284-
max_episode_steps=args.max_episode_steps,
285-
)
286-
287-
target_objects_list = []
288-
if args.target_object is not None:
289-
target_objects_list = [
290-
target
291-
for target in args.target_object.replace(" ,", ",")
292-
.replace(", ", ",")
293-
.split(",")
294-
if target.strip() is not None
295-
]
296-
297-
print(f"Target_objects list - {target_objects_list}")
298-
with spot.get_lease(hijack=True):
299-
spot.power_robot()
300-
gaze_controller = GazeController(config, spot)
301-
try:
302-
gaze_result = gaze_controller.execute(target_objects_list)
303-
print(gaze_result)
304-
finally:
305-
spot.shutdown(should_dock=True)

spot_rl_experiments/spot_rl/envs/skill_manager.py

+8-31
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
import numpy as np
99
from multimethod import multimethod
10-
from spot_rl.envs.gaze_env import GazeController, construct_config_for_gaze
1110
from spot_rl.envs.place_env import PlaceController, construct_config_for_place
12-
from spot_rl.skills.atomic_skills import Navigation
13-
from spot_rl.utils.construct_configs import construct_config_for_nav
11+
from spot_rl.skills.atomic_skills import Navigation, Pick
12+
from spot_rl.utils.construct_configs import (
13+
construct_config_for_gaze,
14+
construct_config_for_nav,
15+
)
1416
from spot_rl.utils.geometry_utils import is_position_within_bounds
1517
from spot_rl.utils.heuristic_nav import (
1618
ImageSearch,
@@ -168,9 +170,9 @@ def __initiate_controllers(self, use_policies: bool = True):
168170
config=self.nav_config,
169171
record_robot_trajectories=True,
170172
)
171-
self.gaze_controller = GazeController(
172-
config=self.pick_config,
173+
self.gaze_controller = Pick(
173174
spot=self.spot,
175+
config=self.pick_config,
174176
use_mobile_pick=self._use_mobile_pick,
175177
)
176178
self.place_controller = PlaceController(
@@ -293,32 +295,7 @@ def pick(self, pick_target: str = None) -> Tuple[bool, str]:
293295
bool: True if pick was successful, False otherwise
294296
str: Message indicating the status of the pick
295297
"""
296-
conditional_print(
297-
message=f"Received pick target request for - {pick_target}",
298-
verbose=self.verbose,
299-
)
300-
301-
if pick_target is None:
302-
message = "No pick target specified, skipping pick"
303-
conditional_print(message=message, verbose=self.verbose)
304-
return False, message
305-
306-
conditional_print(message=f"Picking {pick_target}", verbose=self.verbose)
307-
308-
result = None
309-
try:
310-
result = self.gaze_controller.execute([pick_target])
311-
except Exception:
312-
message = "Error encountered while picking"
313-
conditional_print(message=message, verbose=self.verbose)
314-
return False, message
315-
316-
# Check for success and return appropriately
317-
status = False
318-
message = "Pick failed to pick the target object"
319-
if result[0].get("success"):
320-
status = True
321-
message = "Successfully picked the target object"
298+
status, message = self.gaze_controller.execute(pick_target)
322299
conditional_print(message=message, verbose=self.verbose)
323300
return status, message
324301

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import sys
2+
3+
from spot_rl.skills.atomic_skills import Pick
4+
from spot_rl.utils.construct_configs import construct_config_for_gaze
5+
from spot_rl.utils.utils import get_default_parser
6+
from spot_wrapper.spot import Spot
7+
8+
9+
def parse_arguments(args=sys.argv[1:]):
10+
parser = get_default_parser()
11+
parser.add_argument(
12+
"-t", "--target-object", type=str, help="name of the target object"
13+
)
14+
parser.add_argument(
15+
"-dp",
16+
"--dont_pick_up",
17+
action="store_true",
18+
help="robot should attempt pick but not actually pick",
19+
)
20+
parser.add_argument(
21+
"-ms", "--max_episode_steps", type=int, help="max episode steps"
22+
)
23+
parser.add_argument(
24+
"-mg",
25+
"--mobile_gaze",
26+
action="store_true",
27+
help="whether to use mobile gaze or static",
28+
)
29+
args = parser.parse_args(args=args)
30+
31+
if args.max_episode_steps is not None:
32+
args.max_episode_steps = int(args.max_episode_steps)
33+
return args
34+
35+
36+
if __name__ == "__main__":
37+
spot = Spot("RealGazeEnv")
38+
args = parse_arguments()
39+
config = construct_config_for_gaze(
40+
opts=args.opts,
41+
dont_pick_up=args.dont_pick_up,
42+
max_episode_steps=args.max_episode_steps,
43+
)
44+
45+
target_objects_list = []
46+
if args.target_object is not None:
47+
print(args.target_object)
48+
target_objects_list = [
49+
target
50+
for target in args.target_object.replace(" ,", ",")
51+
.replace(", ", ",")
52+
.split(",")
53+
if target.strip() is not None
54+
]
55+
56+
print(f"Target_objects list - {target_objects_list}")
57+
with spot.get_lease(hijack=True):
58+
spot.power_robot()
59+
gaze_controller = Pick(spot=spot, config=config)
60+
try:
61+
gaze_result = gaze_controller.execute_pick(
62+
target_objects_list, take_user_input=True
63+
)
64+
print(gaze_result)
65+
finally:
66+
spot.shutdown(should_dock=False)

0 commit comments

Comments
 (0)