Skip to content

Commit 5d6e65c

Browse files
committed
Improve mechanism of saving waypoints
1 parent 8ee2b7f commit 5d6e65c

File tree

8 files changed

+265
-68
lines changed

8 files changed

+265
-68
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ git checkout main
5252
```
5353

5454
## :video_game: Instructions to record waypoints (use joystick to move robot around)
55-
- Before running scripts on the robot, waypoints should be recording. These waypoints exist inside file `spot-sim2real/spot_rl_experiments/configs/waypoints.yaml`
55+
- Before running scripts on the robot, waypoints should be recorded. These waypoints exist inside file `spot-sim2real/spot_rl_experiments/configs/waypoints.yaml`
5656

5757
- Before recording receptacles, make the robot sit at home position then run following command
5858
```bash

spot_rl_experiments/configs/waypoints_apartment.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ kitchen_island:
112112

113113

114114
nav_targets:
115+
dock: [1.5, 0.0, 0.0]
115116
hall_table:
116117
- 4.973838631994881
117118
- 3.1629789195672372

spot_rl_experiments/configs/waypoints_microkitchen.yaml

+28
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,31 @@ trash:
7272
- 6.184285849819101
7373
- -0.7815312691658676
7474
- 179.66070876534218
75+
dock: [1.5, 0.0, 0.0]
76+
77+
nav_targets:
78+
dock: [1.5, 0.0, 0.0]
79+
table1:
80+
- 3.299641079508338
81+
- -2.543085839668663
82+
- -166.60155928991725
83+
couch:
84+
- 2.23906958193421
85+
- -0.56702771730329
86+
- 88.29794273106572
87+
table2:
88+
- 0.8877704664933876
89+
- -5.601562538611034
90+
- -149.53198213686227
91+
sink:
92+
- 6.174181697206729
93+
- -3.563298461337695
94+
- -3.3523308120884794
95+
coffee_counter:
96+
- 5.967649279863249
97+
- -1.2784570676598292
98+
- 2.5817470348872544
99+
trash:
100+
- 6.184285849819101
101+
- -0.7815312691658676
102+
- 179.66070876534218

spot_rl_experiments/spot_rl/envs/lang_env.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
from spot_rl.real_policy import GazePolicy, MixerPolicy, NavPolicy, PlacePolicy
2121
from spot_rl.utils.remote_spot import RemoteSpot
2222
from spot_rl.utils.utils import (
23-
WAYPOINTS,
2423
closest_clutter,
2524
construct_config,
2625
get_clutter_amounts,
2726
get_default_parser,
27+
get_waypoint_yaml,
2828
nav_target_from_waypoints,
2929
object_id_to_nav_waypoint,
3030
place_target_from_waypoints,
@@ -65,6 +65,9 @@ def main(spot, use_mixer, config, out_path=None):
6565
# Check if robot should return to base
6666
return_to_base = config.RETURN_TO_BASE
6767

68+
# Get the waypoints from waypoints.yaml
69+
waypoints = get_waypoint_yaml()
70+
6871
audio_to_text = WhisperTranslator()
6972
sentence_similarity = SentenceSimilarity()
7073
with initialize(config_path="../llm/src/conf"):
@@ -87,10 +90,10 @@ def main(spot, use_mixer, config, out_path=None):
8790

8891
# Find closest nav_targets to the ones robot knows locations of
8992
nav_1 = sentence_similarity.get_most_similar_in_list(
90-
nav_1, list(WAYPOINTS["nav_targets"].keys())
93+
nav_1, list(waypoints["nav_targets"].keys())
9194
)
9295
nav_2 = sentence_similarity.get_most_similar_in_list(
93-
nav_2, list(WAYPOINTS["nav_targets"].keys())
96+
nav_2, list(waypoints["nav_targets"].keys())
9497
)
9598
print("MOST SIMILAR: ", nav_1, pick, nav_2)
9699

spot_rl_experiments/spot_rl/envs/mobile_manipulation_env.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
from spot_rl.real_policy import GazePolicy, MixerPolicy, NavPolicy, PlacePolicy
1717
from spot_rl.utils.remote_spot import RemoteSpot
1818
from spot_rl.utils.utils import (
19-
WAYPOINTS,
2019
closest_clutter,
2120
construct_config,
2221
get_clutter_amounts,
2322
get_default_parser,
23+
get_waypoint_yaml,
2424
nav_target_from_waypoints,
2525
object_id_to_nav_waypoint,
2626
place_target_from_waypoints,
@@ -64,9 +64,12 @@ def main(spot, use_mixer, config, out_path=None):
6464
# Check if robot should return to base
6565
return_to_base = config.RETURN_TO_BASE
6666

67+
# Get the waypoints from waypoints.yaml
68+
waypoints = get_waypoint_yaml()
69+
6770
objects_to_look = []
68-
for waypoint in WAYPOINTS["object_targets"]:
69-
objects_to_look.append(WAYPOINTS["object_targets"][waypoint][0])
71+
for waypoint in waypoints["object_targets"]:
72+
objects_to_look.append(waypoints["object_targets"][waypoint][0])
7073
rospy.set_param("object_target", ",".join(objects_to_look))
7174

7275
env.power_robot()
@@ -78,7 +81,7 @@ def main(spot, use_mixer, config, out_path=None):
7881
if trip_idx < NUM_OBJECTS:
7982
# 2 objects per receptacle
8083
clutter_blacklist = [
81-
i for i in WAYPOINTS["clutter"] if count[i] >= CLUTTER_AMOUNTS[i]
84+
i for i in waypoints["clutter"] if count[i] >= CLUTTER_AMOUNTS[i]
8285
]
8386
waypoint_name, waypoint = closest_clutter(
8487
env.x, env.y, clutter_blacklist=clutter_blacklist

spot_rl_experiments/spot_rl/utils/generate_place_goal.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def get_global_place_target(spot: Spot):
1616
position, rotation = spot.get_base_transform_to("link_wr1")
1717
position = [position.x, position.y, position.z]
1818
rotation = [rotation.x, rotation.y, rotation.z, rotation.w]
19+
20+
# Spot2Habitat transform SHOULD NOT BE A PART OF SpotBaseEnv Class. It is an unnecessary dependency.
1921
wrist_T_base = SpotBaseEnv.spot2habitat_transform(position, rotation)
2022
gripper_T_base = wrist_T_base @ mn.Matrix4.translation(
2123
mn.Vector3(EE_GRIPPER_OFFSET)

spot_rl_experiments/spot_rl/utils/utils.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@
1818
configs_dir = osp.join(spot_rl_experiments_dir, "configs")
1919
DEFAULT_CONFIG = osp.join(configs_dir, "config.yaml")
2020
WAYPOINTS_YAML = osp.join(configs_dir, "waypoints.yaml")
21-
with open(WAYPOINTS_YAML) as f:
22-
WAYPOINTS = yaml.safe_load(f)
2321

2422
ROS_TOPICS = osp.join(configs_dir, "ros_topic_names.yaml")
2523
ros_topics = CN()
2624
ros_topics.set_new_allowed(True)
2725
ros_topics.merge_from_file(ROS_TOPICS)
2826

2927

28+
def get_waypoint_yaml(waypoint_file=WAYPOINTS_YAML):
29+
with open(waypoint_file) as f:
30+
return yaml.safe_load(f)
31+
32+
3033
def get_default_parser():
3134
parser = argparse.ArgumentParser()
3235
parser.add_argument("-o", "--opts", nargs="*", default=[])
@@ -54,20 +57,23 @@ def construct_config(opts=None):
5457

5558

5659
def nav_target_from_waypoints(waypoint):
57-
goal_x, goal_y, goal_heading = WAYPOINTS[waypoint]
60+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
61+
goal_x, goal_y, goal_heading = waypoints_yaml["nav_targets"][waypoint]
5862
return goal_x, goal_y, np.deg2rad(goal_heading)
5963

6064

6165
def place_target_from_waypoints(waypoint):
62-
return np.array(WAYPOINTS["place_targets"][waypoint])
66+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
67+
return np.array(waypoints_yaml["place_targets"][waypoint])
6368

6469

6570
def closest_clutter(x, y, clutter_blacklist=None):
71+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
6672
if clutter_blacklist is None:
6773
clutter_blacklist = []
6874
clutter_locations = [
6975
(np.array(nav_target_from_waypoints(w)[:2]), w)
70-
for w in WAYPOINTS["clutter"]
76+
for w in waypoints_yaml["clutter"]
7177
if w not in clutter_blacklist
7278
]
7379
xy = np.array([x, y])
@@ -77,23 +83,26 @@ def closest_clutter(x, y, clutter_blacklist=None):
7783

7884

7985
def object_id_to_nav_waypoint(object_id):
86+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
8087
if isinstance(object_id, str):
81-
for k, v in WAYPOINTS["object_targets"].items():
88+
for k, v in waypoints_yaml["object_targets"].items():
8289
if v[0] == object_id:
8390
object_id = int(k)
8491
break
8592
if isinstance(object_id, str):
8693
KeyError(f"{object_id} not a valid class name!")
87-
place_nav_target_name = WAYPOINTS["object_targets"][object_id][1]
94+
place_nav_target_name = waypoints_yaml["object_targets"][object_id][1]
8895
return place_nav_target_name, nav_target_from_waypoints(place_nav_target_name)
8996

9097

9198
def object_id_to_object_name(object_id):
92-
return WAYPOINTS["object_targets"][object_id][0]
99+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
100+
return waypoints_yaml["object_targets"][object_id][0]
93101

94102

95103
def get_clutter_amounts():
96-
return WAYPOINTS["clutter_amounts"]
104+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
105+
return waypoints_yaml["clutter_amounts"]
97106

98107

99108
def arr2str(arr):

0 commit comments

Comments
 (0)