Skip to content

Commit 7d4dd03

Browse files
committed
Final changes for waypoints
1 parent a9ef0f1 commit 7d4dd03

File tree

7 files changed

+144
-46
lines changed

7 files changed

+144
-46
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ git checkout main
5252
```
5353

5454
## :video_game: Instructions to record waypoints (use joystick to move robot around)
55-
- Before running scripts on the robot, waypoints should be recording. These waypoints exist inside file `spot-sim2real/spot_rl_experiments/configs/waypoints.yaml`
55+
- Before running scripts on the robot, waypoints should be recorded. These waypoints exist inside file `spot-sim2real/spot_rl_experiments/configs/waypoints.yaml`
5656

5757
- Before recording receptacles, make the robot sit at home position then run following command
5858
```bash

spot_rl_experiments/configs/waypoints_apartment.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ kitchen_island:
112112

113113

114114
nav_targets:
115+
dock: [1.5, 0.0, 0.0]
115116
hall_table:
116117
- 4.973838631994881
117118
- 3.1629789195672372

spot_rl_experiments/configs/waypoints_microkitchen.yaml

+28
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,31 @@ trash:
7272
- 6.184285849819101
7373
- -0.7815312691658676
7474
- 179.66070876534218
75+
dock: [1.5, 0.0, 0.0]
76+
77+
nav_targets:
78+
dock: [1.5, 0.0, 0.0]
79+
table1:
80+
- 3.299641079508338
81+
- -2.543085839668663
82+
- -166.60155928991725
83+
couch:
84+
- 2.23906958193421
85+
- -0.56702771730329
86+
- 88.29794273106572
87+
table2:
88+
- 0.8877704664933876
89+
- -5.601562538611034
90+
- -149.53198213686227
91+
sink:
92+
- 6.174181697206729
93+
- -3.563298461337695
94+
- -3.3523308120884794
95+
coffee_counter:
96+
- 5.967649279863249
97+
- -1.2784570676598292
98+
- 2.5817470348872544
99+
trash:
100+
- 6.184285849819101
101+
- -0.7815312691658676
102+
- 179.66070876534218

spot_rl_experiments/spot_rl/envs/lang_env.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
from spot_rl.real_policy import GazePolicy, MixerPolicy, NavPolicy, PlacePolicy
2121
from spot_rl.utils.remote_spot import RemoteSpot
2222
from spot_rl.utils.utils import (
23-
WAYPOINTS,
2423
closest_clutter,
2524
construct_config,
2625
get_clutter_amounts,
2726
get_default_parser,
27+
get_waypoint_yaml,
2828
nav_target_from_waypoints,
2929
object_id_to_nav_waypoint,
3030
place_target_from_waypoints,
@@ -65,6 +65,9 @@ def main(spot, use_mixer, config, out_path=None):
6565
# Check if robot should return to base
6666
return_to_base = config.RETURN_TO_BASE
6767

68+
# Get the waypoints from waypoints.yaml
69+
waypoints = get_waypoint_yaml()
70+
6871
audio_to_text = WhisperTranslator()
6972
sentence_similarity = SentenceSimilarity()
7073
with initialize(config_path="../llm/src/conf"):
@@ -87,10 +90,10 @@ def main(spot, use_mixer, config, out_path=None):
8790

8891
# Find closest nav_targets to the ones robot knows locations of
8992
nav_1 = sentence_similarity.get_most_similar_in_list(
90-
nav_1, list(WAYPOINTS["nav_targets"].keys())
93+
nav_1, list(waypoints["nav_targets"].keys())
9194
)
9295
nav_2 = sentence_similarity.get_most_similar_in_list(
93-
nav_2, list(WAYPOINTS["nav_targets"].keys())
96+
nav_2, list(waypoints["nav_targets"].keys())
9497
)
9598
print("MOST SIMILAR: ", nav_1, pick, nav_2)
9699

spot_rl_experiments/spot_rl/envs/mobile_manipulation_env.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
from spot_rl.real_policy import GazePolicy, MixerPolicy, NavPolicy, PlacePolicy
1717
from spot_rl.utils.remote_spot import RemoteSpot
1818
from spot_rl.utils.utils import (
19-
WAYPOINTS,
2019
closest_clutter,
2120
construct_config,
2221
get_clutter_amounts,
2322
get_default_parser,
23+
get_waypoint_yaml,
2424
nav_target_from_waypoints,
2525
object_id_to_nav_waypoint,
2626
place_target_from_waypoints,
@@ -64,9 +64,12 @@ def main(spot, use_mixer, config, out_path=None):
6464
# Check if robot should return to base
6565
return_to_base = config.RETURN_TO_BASE
6666

67+
# Get the waypoints from waypoints.yaml
68+
waypoints = get_waypoint_yaml()
69+
6770
objects_to_look = []
68-
for waypoint in WAYPOINTS["object_targets"]:
69-
objects_to_look.append(WAYPOINTS["object_targets"][waypoint][0])
71+
for waypoint in waypoints["object_targets"]:
72+
objects_to_look.append(waypoints["object_targets"][waypoint][0])
7073
rospy.set_param("object_target", ",".join(objects_to_look))
7174

7275
env.power_robot()
@@ -78,7 +81,7 @@ def main(spot, use_mixer, config, out_path=None):
7881
if trip_idx < NUM_OBJECTS:
7982
# 2 objects per receptacle
8083
clutter_blacklist = [
81-
i for i in WAYPOINTS["clutter"] if count[i] >= CLUTTER_AMOUNTS[i]
84+
i for i in waypoints["clutter"] if count[i] >= CLUTTER_AMOUNTS[i]
8285
]
8386
waypoint_name, waypoint = closest_clutter(
8487
env.x, env.y, clutter_blacklist=clutter_blacklist

spot_rl_experiments/spot_rl/utils/utils.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@
1818
configs_dir = osp.join(spot_rl_experiments_dir, "configs")
1919
DEFAULT_CONFIG = osp.join(configs_dir, "config.yaml")
2020
WAYPOINTS_YAML = osp.join(configs_dir, "waypoints.yaml")
21-
with open(WAYPOINTS_YAML) as f:
22-
WAYPOINTS = yaml.safe_load(f)
2321

2422
ROS_TOPICS = osp.join(configs_dir, "ros_topic_names.yaml")
2523
ros_topics = CN()
2624
ros_topics.set_new_allowed(True)
2725
ros_topics.merge_from_file(ROS_TOPICS)
2826

2927

28+
def get_waypoint_yaml(waypoint_file=WAYPOINTS_YAML):
29+
with open(waypoint_file) as f:
30+
return yaml.safe_load(f)
31+
32+
3033
def get_default_parser():
3134
parser = argparse.ArgumentParser()
3235
parser.add_argument("-o", "--opts", nargs="*", default=[])
@@ -54,20 +57,23 @@ def construct_config(opts=None):
5457

5558

5659
def nav_target_from_waypoints(waypoint):
57-
goal_x, goal_y, goal_heading = WAYPOINTS[waypoint]
60+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
61+
goal_x, goal_y, goal_heading = waypoints_yaml["nav_targets"][waypoint]
5862
return goal_x, goal_y, np.deg2rad(goal_heading)
5963

6064

6165
def place_target_from_waypoints(waypoint):
62-
return np.array(WAYPOINTS["place_targets"][waypoint])
66+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
67+
return np.array(waypoints_yaml["place_targets"][waypoint])
6368

6469

6570
def closest_clutter(x, y, clutter_blacklist=None):
71+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
6672
if clutter_blacklist is None:
6773
clutter_blacklist = []
6874
clutter_locations = [
6975
(np.array(nav_target_from_waypoints(w)[:2]), w)
70-
for w in WAYPOINTS["clutter"]
76+
for w in waypoints_yaml["clutter"]
7177
if w not in clutter_blacklist
7278
]
7379
xy = np.array([x, y])
@@ -77,23 +83,26 @@ def closest_clutter(x, y, clutter_blacklist=None):
7783

7884

7985
def object_id_to_nav_waypoint(object_id):
86+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
8087
if isinstance(object_id, str):
81-
for k, v in WAYPOINTS["object_targets"].items():
88+
for k, v in waypoints_yaml["object_targets"].items():
8289
if v[0] == object_id:
8390
object_id = int(k)
8491
break
8592
if isinstance(object_id, str):
8693
KeyError(f"{object_id} not a valid class name!")
87-
place_nav_target_name = WAYPOINTS["object_targets"][object_id][1]
94+
place_nav_target_name = waypoints_yaml["object_targets"][object_id][1]
8895
return place_nav_target_name, nav_target_from_waypoints(place_nav_target_name)
8996

9097

9198
def object_id_to_object_name(object_id):
92-
return WAYPOINTS["object_targets"][object_id][0]
99+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
100+
return waypoints_yaml["object_targets"][object_id][0]
93101

94102

95103
def get_clutter_amounts():
96-
return WAYPOINTS["clutter_amounts"]
104+
waypoints_yaml = get_waypoint_yaml(WAYPOINTS_YAML)
105+
return waypoints_yaml["clutter_amounts"]
97106

98107

99108
def arr2str(arr):

spot_rl_experiments/spot_rl/utils/waypoint_recorder.py

+83-29
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,73 @@ def parse_arguments(args):
3030
return args
3131

3232

33-
class WaypointRecorder:
34-
def __init__(self, spot: Spot, waypoint_file_path: str = WAYPOINT_YAML):
35-
self.spot = spot
36-
37-
self.waypoint_file = waypoint_file_path
38-
39-
# Local copy of waypoints.yaml which keeps getting updated as new waypoints are added
40-
self.yaml_dict = self.read_yaml(self.waypoint_file)
41-
print("Yaml: ", self.yaml_dict)
42-
self.initialize_yaml(self.yaml_dict)
43-
return
33+
class YamlHandler:
34+
"""
35+
Class to handle reading and writing to yaml files
36+
37+
How to use:
38+
1. Create a yaml file with the following format:
39+
40+
place_targets:
41+
test_place_target:
42+
- 3.0
43+
- 0.0
44+
- 0.8
45+
clutter:
46+
- test_place_target
47+
clutter_amounts:
48+
test_place_target: 1
49+
object_targets:
50+
0: [penguin, test_place_target]
51+
nav_targets:
52+
dock:
53+
- 1.5
54+
- 0.0
55+
- 0.0
56+
test_place_target:
57+
- 2.5
58+
- 0.0
59+
- 0.0
60+
61+
2. Create an instance of this class
62+
3. Read the yaml file using the read_yaml method
63+
4. Modify the yaml_dict as needed
64+
5. Write the yaml file using the write_yaml method
65+
"""
66+
67+
def __init__(self):
68+
pass
4469

4570
def create_yaml(self, waypoint_file: str):
71+
init_yaml_dict = """
72+
place_targets: # i.e., where an object needs to be placed (x,y,z)
73+
test_place_target:
74+
- 3.0
75+
- 0.0
76+
- 0.8
77+
clutter: # i.e., where an object is currently placed
78+
# <receptacle where clutter exists>
79+
- test_place_target
80+
clutter_amounts: # i.e., how much clutter exists in each receptacle
81+
# <receptacle where clutter exists>: <number of objects in that receptacle>
82+
test_place_target: 1
83+
object_targets: # i.e., where an object belongs / needs to be placed
84+
# <Class_id>: [<object's name>, <which place_target it belongs to>]
85+
0: [penguin, test_place_target]
86+
nav_targets: # i.e., where the robot needs to navigate to (x,y,yaw)
87+
dock:
88+
- 1.5
89+
- 0.0
90+
- 0.0
91+
test_place_target:
92+
- 2.5
93+
- 0.0
94+
- 0.0"""
4695
with open(waypoint_file, "w+") as f:
47-
f.write("")
48-
f.close()
49-
50-
def initialize_yaml(self, yaml_dict):
51-
# TODO: What if the yaml_dict is NONE???
52-
53-
# Create a templated waypoints.yaml file here, by creating all the keys that are not present
54-
yaml_dict["dock"] = [1.5, 0.0, 0.0]
55-
# TODO: Add other keys here
96+
yaml = ruamel.yaml.YAML() # defaults to round-trip if no parameters given
97+
ruamel.yaml.dump(
98+
yaml.load(init_yaml_dict), f, Dumper=ruamel.yaml.RoundTripDumper
99+
)
56100

57101
def read_yaml(self, waypoint_file: str):
58102
# Create waypoint file if it does not exist
@@ -62,16 +106,26 @@ def read_yaml(self, waypoint_file: str):
62106
with open(waypoint_file, "r") as f:
63107
yaml = ruamel.yaml.YAML() # defaults to round-trip if no parameters given
64108
yaml_dict = yaml.load(f.read())
65-
print(type(yaml_dict))
66109
return yaml_dict
67110

68-
def write_yaml(self, yaml_dict):
69-
with open(self.waypoint_file, "w") as f:
111+
def write_yaml(self, waypoint_file: str, yaml_dict):
112+
with open(waypoint_file, "w") as f:
70113
yaml = ruamel.yaml.YAML()
71114
yaml.dump(yaml_dict, f)
72115

116+
117+
class WaypointRecorder:
118+
def __init__(self, spot: Spot, waypoint_file_path: str = WAYPOINT_YAML):
119+
self.spot = spot
120+
121+
# Local copy of waypoints.yaml which keeps getting updated as new waypoints are added
122+
self.waypoint_file = waypoint_file_path
123+
self.yaml_handler = YamlHandler()
124+
self.yaml_dict = self.yaml_handler.read_yaml(waypoint_file=self.waypoint_file)
125+
return
126+
73127
def save_yaml(self):
74-
self.write_yaml(self.yaml_dict)
128+
self.yaml_handler.write_yaml(self.waypoint_file, self.yaml_dict)
75129
print(f"Successfully saved all waypoints to file at {self.waypoint_file}:\n")
76130

77131
def unmark_clutter(self, clutter_target_name: str):
@@ -104,7 +158,9 @@ def record_nav_target(self, nav_target_name: str):
104158

105159
# Add nav_targets list if not present
106160
if "nav_targets" not in self.yaml_dict:
107-
self.yaml_dict["nav_targets"] = {}
161+
self.yaml_dict["nav_targets"] = {
162+
"dock": "[1.5, 0.0, 0.0]",
163+
}
108164

109165
# Erase existing waypoint data if present
110166
if nav_target_name in self.yaml_dict.get("nav_targets"):
@@ -128,7 +184,7 @@ def record_clutter_target(self, clutter_target_name: str):
128184

129185
# Add waypoint as clutter_amounts if it does not exist
130186
if clutter_target_name not in self.yaml_dict.get("clutter_amounts"):
131-
self.yaml_dict["clutter_amounts"].update({clutter_target_name: 0})
187+
self.yaml_dict["clutter_amounts"].update({clutter_target_name: 1})
132188
print(
133189
f"Added {clutter_target_name} in 'clutter_amounts' => ({clutter_target_name}:{self.yaml_dict.get('clutter_amounts').get(clutter_target_name)})"
134190
)
@@ -167,7 +223,6 @@ def main(spot: Spot):
167223
len([i for i in arg_bools if i]) == 1
168224
), "Must pass in either -c, -p, or -n as an arg, and not more than one."
169225

170-
# try:
171226
# Create WaypointRecorder object with default waypoint file
172227
waypoint_recorder = WaypointRecorder(spot=spot)
173228

@@ -179,8 +234,7 @@ def main(spot: Spot):
179234
waypoint_recorder.record_place_target(args.waypoint_name)
180235
else:
181236
raise NotImplementedError
182-
# finally:
183-
# Save updated yaml file
237+
184238
waypoint_recorder.save_yaml()
185239

186240

0 commit comments

Comments
 (0)