Skip to content

Commit

Permalink
chat: fix handling of multiple fn calls
Browse files Browse the repository at this point in the history
  • Loading branch information
rmackay9 committed Dec 8, 2023
1 parent 590b684 commit e556790
Showing 1 changed file with 99 additions and 59 deletions.
158 changes: 99 additions & 59 deletions MAVProxy/modules/mavproxy_chat/chat_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,109 +160,147 @@ def handle_function_call(self, run):
print("chat::handle_function_call: submit tools outputs empty")
return None

# display details if more than one function call
tool_num = 0
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
print("chat::handle_function_call: call:" +str(tool_num) + " fn:" + tool_call.function.name + " tcid:" + str(tool_call.id))
tool_num = tool_num + 1

tool_outputs = []
tool_num = 0
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
# init output to None
output = "invalid function call"
recognised_function = False

# get current date and time
if tool_call.function.name == "get_current_datetime":
run_reply = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=[{
"tool_call_id": tool_call.id,
"output": self.get_formatted_date()
}]
)
return
recognised_function = True
output = self.get_formatted_date()

# get vehicle type
if tool_call.function.name == "get_vehicle_type":
recognised_function = True
output = json.dumps(self.get_vehicle_type())

# get vehicle state including armed, mode
if tool_call.function.name == "get_vehicle_state":
run_reply = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=[{
"tool_call_id": tool_call.id,
"output": json.dumps(self.get_vehicle_state())
}]
)
return
recognised_function = True
output = json.dumps(self.get_vehicle_state())

# get vehicle location
if tool_call.function.name == "get_vehicle_location":
run_reply = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=[{
"tool_call_id": tool_call.id,
"output": json.dumps(self.get_vehicle_location())
}]
)
return
recognised_function = True
output = json.dumps(self.get_vehicle_location())

# send mavlink command_int
if tool_call.function.name == "send_mavlink_command_int":
recognised_function = True
print("send_mavlink_command_int received:")
print(tool_call)
try:
arguments = json.loads(tool_call.function.arguments)
print(arguments)
output = self.send_mavlink_command_int(arguments)
except:
print("chat::handle_function_call: failed to parse arguments")
return
print(arguments)
run_reply = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=[{
"tool_call_id": tool_call.id,
"output": self.send_mavlink_command_int(arguments)
}]
)
return

# send mavlink set_position_target_global_int
if tool_call.function.name == "send_mavlink_set_position_target_global_int":
recognised_function = True
print("send_mavlink_set_position_target_global_int received:")
print(tool_call)
arguments = json.loads(tool_call.function.arguments)
print(arguments)
run_reply = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=[{
"tool_call_id": tool_call.id,
"output": self.send_mavlink_set_position_target_global_int(arguments)
}]
)
return

# unknown function
print("chat::handle_function_call: assistant_function_call is unknown function " + tool_call.function.name)
return
output = self.send_mavlink_set_position_target_global_int(arguments)

if not recognised_function:
print("chat::handle_function_call: unrecognised function call: " + tool_call.function.name)
output = "unrecognised function call: " + tool_call.function.name

# debug of reply
print("chat::handle_function_call: replying to call:" +str(tool_num) + " fn:" + tool_call.function.name + " tcid:" + str(tool_call.id))
tool_num = tool_num + 1

# append output to list of outputs
tool_outputs.append({"tool_call_id": tool_call.id, "output": output})

# send function replies to assistant
run_reply = self.client.beta.threads.runs.submit_tool_outputs(
thread_id=run.thread_id,
run_id=run.id,
tool_outputs=tool_outputs
)

# get the current date and time in the format, Saturday, June 24, 2023 6:14:14 PM
def get_formatted_date(self):
return datetime.now().strftime("%A, %B %d, %Y %I:%M:%S %p")

# get vehicle vehicle type (e.g. "Copter", "Plane", "Rover", "Boat", etc)
def get_vehicle_type(self):
# get vehicle type from latest HEARTBEAT message
hearbeat_msg = self.mpstate.master().messages.get('HEARTBEAT', None)
vehicle_type_str = "unknown"
if hearbeat_msg is not None:
if hearbeat_msg.type in [mavutil.mavlink.MAV_TYPE_FIXED_WING,
mavutil.mavlink.MAV_TYPE_VTOL_DUOROTOR,
mavutil.mavlink.MAV_TYPE_VTOL_QUADROTOR,
mavutil.mavlink.MAV_TYPE_VTOL_TILTROTOR]:
vehicle_type_str = 'Plane'
if hearbeat_msg.type == mavutil.mavlink.MAV_TYPE_GROUND_ROVER:
vehicle_type_str = 'Rover'
if hearbeat_msg.type == mavutil.mavlink.MAV_TYPE_SURFACE_BOAT:
vehicle_type_str = 'Boat'
if hearbeat_msg.type == mavutil.mavlink.MAV_TYPE_SUBMARINE:
vehicle_type_str = 'Sub'
if hearbeat_msg.type in [mavutil.mavlink.MAV_TYPE_QUADROTOR,
mavutil.mavlink.MAV_TYPE_COAXIAL,
mavutil.mavlink.MAV_TYPE_HEXAROTOR,
mavutil.mavlink.MAV_TYPE_OCTOROTOR,
mavutil.mavlink.MAV_TYPE_TRICOPTER,
mavutil.mavlink.MAV_TYPE_DODECAROTOR]:
vehicle_type_str = "Copter"
if hearbeat_msg.type == mavutil.mavlink.MAV_TYPE_HELICOPTER:
vehicle_type_str = "Heli"
if hearbeat_msg.type == mavutil.mavlink.MAV_TYPE_ANTENNA_TRACKER:
vehicle_type_str = "Tracker"
if hearbeat_msg.type == mavutil.mavlink.MAV_TYPE_AIRSHIP:
vehicle_type_str = "Blimp"
return {
"vehicle_type": vehicle_type_str
}

# get vehicle state including armed, mode
def get_vehicle_state(self):
# get mode from latest HEARTBEAT message
hearbeat_msg = self.mpstate.master().messages.get('HEARTBEAT', None)
if hearbeat_msg is None:
return "vehicle mode is unknown"
mode_number = 0
print ("get_vehicle_state: vehicle mode is unknown")
else:
mode_number = hearbeat_msg.custom_mode
return {
"armed": self.mpstate.master().motors_armed(),
"mode": hearbeat_msg.custom_mode
"mode": mode_number
}

# return a string of the vehicle's location
def get_vehicle_location(self):
lat_deg = 0
lon_deg = 0
alt_amsl_m = 0
alt_rel_m = 0
gpi = self.mpstate.master().messages.get('GLOBAL_POSITION_INT', None)
if gpi is None or (gpi.lat == 0 and gpi.lon == 0):
return "vehicle position unknown"
if gpi:
lat_deg = gpi.lat * 1e-7,
lon_deg = gpi.lon * 1e-7,
alt_amsl_m = gpi.alt * 1e-3,
alt_rel_m = gpi.relative_alt * 1e-3
location = {
"latitude": gpi.lat * 1e-7,
"longitude": gpi.lon * 1e-7,
"altitude_amsl": gpi.alt * 1e-3,
"altitude_above_home": gpi.relative_alt * 1e-3
"latitude": lat_deg,
"longitude": lon_deg,
"altitude_amsl": alt_amsl_m,
"altitude_above_home": alt_rel_m
}
return location

Expand All @@ -289,6 +327,8 @@ def send_mavlink_command_int(self, arguments):

# send a mavlink send_mavlink_set_position_target_global_int message to the vehicle
def send_mavlink_set_position_target_global_int(self, arguments):
if arguments is None:
return "send_mavlink_set_position_target_global_int: arguments is None"
print("send_mavlink_set_position_target_global_int: arguments:" + str(arguments))
time_boot_ms = arguments.get("time_boot_ms", 0)
target_system = arguments.get("target_system", 1)
Expand Down

0 comments on commit e556790

Please sign in to comment.