|
3 | 3 | # LICENSE file in the root directory of this source tree.
|
4 | 4 |
|
5 | 5 |
|
6 |
| -import os |
7 | 6 | import sys
|
8 |
| -import time |
9 | 7 | from typing import Dict, List
|
10 | 8 |
|
11 |
| -import numpy as np |
12 | 9 | import rospy
|
13 | 10 | from spot_rl.envs.base_env import SpotBaseEnv
|
14 |
| -from spot_rl.real_policy import GazePolicy, MobileGazePolicy |
15 |
| -from spot_rl.utils.utils import ( |
16 |
| - construct_config, |
17 |
| - get_default_parser, |
18 |
| - map_user_input_to_boolean, |
19 |
| -) |
20 | 11 | from spot_wrapper.spot import Spot
|
21 | 12 |
|
22 |
| - |
23 |
| -def parse_arguments(args=sys.argv[1:]): |
24 |
| - parser = get_default_parser() |
25 |
| - parser.add_argument( |
26 |
| - "-t", "--target-object", type=str, help="name of the target object" |
27 |
| - ) |
28 |
| - parser.add_argument( |
29 |
| - "-dp", |
30 |
| - "--dont_pick_up", |
31 |
| - action="store_true", |
32 |
| - help="robot should attempt pick but not actually pick", |
33 |
| - ) |
34 |
| - parser.add_argument( |
35 |
| - "-ms", "--max_episode_steps", type=int, help="max episode steps" |
36 |
| - ) |
37 |
| - args = parser.parse_args(args=args) |
38 |
| - |
39 |
| - if args.max_episode_steps is not None: |
40 |
| - args.max_episode_steps = int(args.max_episode_steps) |
41 |
| - return args |
42 |
| - |
43 |
| - |
44 |
| -def construct_config_for_gaze( |
45 |
| - file_path=None, opts=[], dont_pick_up=False, max_episode_steps=None |
46 |
| -): |
47 |
| - """ |
48 |
| - Constructs and updates the config for gaze |
49 |
| -
|
50 |
| - Args: |
51 |
| - file_path (str): Path to the config file |
52 |
| - opts (list): List of options to update the config |
53 |
| -
|
54 |
| - Returns: |
55 |
| - config (Config): Updated config object |
56 |
| - """ |
57 |
| - config = None |
58 |
| - if file_path is None: |
59 |
| - config = construct_config(opts=opts) |
60 |
| - else: |
61 |
| - config = construct_config(file_path=file_path, opts=opts) |
62 |
| - |
63 |
| - # Don't need head cameras for Gaze |
64 |
| - config.USE_HEAD_CAMERA = False |
65 |
| - |
66 |
| - # Update the config based on the input argument |
67 |
| - if dont_pick_up != config.DONT_PICK_UP: |
68 |
| - print( |
69 |
| - f"WARNING: Overriding dont_pick_up in config from {config.DONT_PICK_UP} to {dont_pick_up}" |
70 |
| - ) |
71 |
| - config.DONT_PICK_UP = dont_pick_up |
72 |
| - |
73 |
| - # Update max episode steps based on the input argument |
74 |
| - if max_episode_steps is not None: |
75 |
| - print( |
76 |
| - f"WARNING: Overriding max_espisode_steps in config from {config.MAX_EPISODE_STEPS} to {max_episode_steps}" |
77 |
| - ) |
78 |
| - config.MAX_EPISODE_STEPS = max_episode_steps |
79 |
| - return config |
80 |
| - |
81 |
| - |
82 |
| -class GazeController: |
83 |
| - """ |
84 |
| - GazeController is used to gaze at, and pick given objects. |
85 |
| -
|
86 |
| - Args: |
87 |
| - config (Config): Config object |
88 |
| - spot (Spot): Spot object |
89 |
| -
|
90 |
| - How to use: |
91 |
| - 1. Create a GazeController object |
92 |
| - 2. Call execute() method with the target object list |
93 |
| -
|
94 |
| - Example: |
95 |
| - config = construct_config_for_gaze(opts=[]) |
96 |
| - spot = Spot("spot_client_name") |
97 |
| - with spot.get_lease(hijack=True): |
98 |
| - spot.power_robot() |
99 |
| -
|
100 |
| - gaze_target_list = ["apple", "banana"] |
101 |
| - gaze_controller = GazeController(config, spot) |
102 |
| - gaze_results = gaze_controller.execute(gaze_target_list) |
103 |
| -
|
104 |
| - spot.shutdown(should_dock=True) |
105 |
| - """ |
106 |
| - |
107 |
| - def __init__(self, config, spot, use_mobile_pick=False): |
108 |
| - self.config = config |
109 |
| - self.spot = spot |
110 |
| - self._use_mobile_pick = use_mobile_pick |
111 |
| - |
112 |
| - if use_mobile_pick: |
113 |
| - self.policy = MobileGazePolicy( |
114 |
| - config.WEIGHTS.MOBILE_GAZE, device=config.DEVICE, config=config |
115 |
| - ) |
116 |
| - else: |
117 |
| - self.policy = GazePolicy(config.WEIGHTS.GAZE, device=config.DEVICE) |
118 |
| - self.policy.reset() |
119 |
| - |
120 |
| - self.gaze_env = SpotGazeEnv(config, spot, use_mobile_pick) |
121 |
| - |
122 |
| - def reset_env_and_policy(self, target_obj_name): |
123 |
| - """ |
124 |
| - Resets the gaze_env and policy |
125 |
| -
|
126 |
| - Args: |
127 |
| - target_obj_name (str): Name of the target object |
128 |
| -
|
129 |
| - Returns: |
130 |
| - observations: observations from the gaze_env |
131 |
| -
|
132 |
| - """ |
133 |
| - observations = self.gaze_env.reset(target_obj_name=target_obj_name) |
134 |
| - self.policy.reset() |
135 |
| - |
136 |
| - return observations |
137 |
| - |
138 |
| - def execute(self, target_object_list, take_user_input=False): |
139 |
| - """ |
140 |
| - Gaze at the target object list and pick up the objects if specified in the config |
141 |
| -
|
142 |
| - CAUTION: The robot will drop the object after picking it, please use objects that are not fragile |
143 |
| -
|
144 |
| - Args: |
145 |
| - target_object_list (list): List of target objects to gaze at |
146 |
| - take_user_input (bool): Whether to take user input for the success of the gaze |
147 |
| -
|
148 |
| - Returns: |
149 |
| - gaze_success_list (list): List of dictionaries containing the target object name, time taken and success |
150 |
| - """ |
151 |
| - gaze_success_list = [] |
152 |
| - print(f"Target object list : {target_object_list}") |
153 |
| - for target_object in target_object_list: |
154 |
| - observations = self.reset_env_and_policy(target_obj_name=target_object) |
155 |
| - done = False |
156 |
| - start_time = time.time() |
157 |
| - self.gaze_env.say(f"Gaze at target object - {target_object}") |
158 |
| - |
159 |
| - while not done: |
160 |
| - action = self.policy.act(observations) |
161 |
| - if self._use_mobile_pick: |
162 |
| - arm_action, base_action = None, None |
163 |
| - # first 4 are arm actions, then 2 are base actions & last bit is unused |
164 |
| - arm_action = action[0:4] |
165 |
| - base_action = action[4:6] |
166 |
| - |
167 |
| - observations, _, done, _ = self.gaze_env.step( |
168 |
| - arm_action=arm_action, base_action=base_action |
169 |
| - ) |
170 |
| - else: |
171 |
| - observations, _, done, _ = self.gaze_env.step(arm_action=action) |
172 |
| - self.gaze_env.say("Gaze finished") |
173 |
| - # Ask user for feedback about the success of the gaze and update the "success" flag accordingly |
174 |
| - success_status_from_user_feedback = True |
175 |
| - if take_user_input: |
176 |
| - user_prompt = f"Did the robot successfully pick the right object - {target_object}?" |
177 |
| - success_status_from_user_feedback = map_user_input_to_boolean( |
178 |
| - user_prompt |
179 |
| - ) |
180 |
| - |
181 |
| - gaze_success_list.append( |
182 |
| - { |
183 |
| - "target_object": target_object, |
184 |
| - "time_taken": time.time() - start_time, |
185 |
| - "success": self.gaze_env.grasp_attempted |
186 |
| - and success_status_from_user_feedback, |
187 |
| - } |
188 |
| - ) |
189 |
| - return gaze_success_list |
190 |
| - |
191 |
| - |
192 | 13 | class SpotGazeEnv(SpotBaseEnv):
|
193 |
| - def __init__(self, config, spot, use_mobile_pick=False): |
| 14 | + def __init__(self, config, spot: Spot, use_mobile_pick=False): |
194 | 15 | # Select suitable keys
|
195 | 16 | max_joint_movement_key = (
|
196 | 17 | "MAX_JOINT_MOVEMENT_MOBILE_GAZE"
|
@@ -273,33 +94,3 @@ def get_observations(self):
|
273 | 94 |
|
274 | 95 | def get_success(self, observations):
|
275 | 96 | return self.grasp_attempted
|
276 |
| - |
277 |
| - |
278 |
| -if __name__ == "__main__": |
279 |
| - spot = Spot("RealGazeEnv") |
280 |
| - args = parse_arguments() |
281 |
| - config = construct_config_for_gaze( |
282 |
| - opts=args.opts, |
283 |
| - dont_pick_up=args.dont_pick_up, |
284 |
| - max_episode_steps=args.max_episode_steps, |
285 |
| - ) |
286 |
| - |
287 |
| - target_objects_list = [] |
288 |
| - if args.target_object is not None: |
289 |
| - target_objects_list = [ |
290 |
| - target |
291 |
| - for target in args.target_object.replace(" ,", ",") |
292 |
| - .replace(", ", ",") |
293 |
| - .split(",") |
294 |
| - if target.strip() is not None |
295 |
| - ] |
296 |
| - |
297 |
| - print(f"Target_objects list - {target_objects_list}") |
298 |
| - with spot.get_lease(hijack=True): |
299 |
| - spot.power_robot() |
300 |
| - gaze_controller = GazeController(config, spot) |
301 |
| - try: |
302 |
| - gaze_result = gaze_controller.execute(target_objects_list) |
303 |
| - print(gaze_result) |
304 |
| - finally: |
305 |
| - spot.shutdown(should_dock=True) |
0 commit comments