From e1b06549e55e72130f087f991e2e096a8154955b Mon Sep 17 00:00:00 2001 From: adityaomar3 Date: Sat, 15 Jun 2024 19:44:13 +0530 Subject: [PATCH] Chat: feature to cancel active runs --- MAVProxy/modules/mavproxy_chat/chat_openai.py | 43 ++++++++++++++----- MAVProxy/modules/mavproxy_chat/chat_window.py | 14 ++++++ 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/MAVProxy/modules/mavproxy_chat/chat_openai.py b/MAVProxy/modules/mavproxy_chat/chat_openai.py index 0fd5d3bd3c..f4b502874c 100644 --- a/MAVProxy/modules/mavproxy_chat/chat_openai.py +++ b/MAVProxy/modules/mavproxy_chat/chat_openai.py @@ -48,6 +48,7 @@ def __init__(self, mpstate, status_cb=None, wait_for_command_ack_fn=None): self.client = None self.assistant = None self.assistant_thread = None + self.latest_run = None # check connection to OpenAI assistant and connect if necessary # returns True if connection is good, False if not @@ -99,6 +100,25 @@ def set_api_key(self, api_key_str): self.assistant = None self.assistant_thread = None + # cancel the active run + def cancel_run(self): + # check the active thread and run + if (self.assistant_thread and self.latest_run is not None): + run_status = self.latest_run.status + if (run_status != "completed" and run_status != "cancelled" and + run_status != "cancelling"): + + # cancel the active run + self.client.beta.threads.runs.cancel( + thread_id=self.assistant_thread.id, + run_id=self.run.id + ) + else: + if (self.latest_run.status == "completed"): + print("Chat is completed, cannot be cancelled") + elif (self.latest_run.status == "cancelled"): + print("Chat is cancelled") + # send text to assistant def send_to_assistant(self, text): # get lock @@ -132,7 +152,7 @@ def send_to_assistant(self, text): time.sleep(0.1) # retrieve the run - latest_run = self.client.beta.threads.runs.retrieve( + self.latest_run = self.client.beta.threads.runs.retrieve( thread_id=self.assistant_thread.id, run_id=self.run.id ) @@ -141,22 +161,22 @@ def send_to_assistant(self, text): failure_message = None # check run status - if latest_run.status in ["queued", "in_progress", "cancelling"]: + if self.latest_run.status in ["queued", "in_progress", "cancelling"]: run_done = False - elif latest_run.status in ["cancelled", "completed", "expired"]: + elif self.latest_run.status in ["cancelled", "completed", "expired"]: run_done = True - elif latest_run.status in ["failed"]: - failure_message = latest_run.last_error.message + elif self.latest_run.status in ["failed"]: + failure_message = self.latest_run.last_error.message run_done = True - elif latest_run.status in ["requires_action"]: - self.handle_function_call(latest_run) + elif self.latest_run.status in ["requires_action"]: + self.handle_function_call(self.latest_run) run_done = False else: - print("chat: unrecognised run status" + latest_run.status) + print("chat: unrecognised run status" + self.latest_run.status) run_done = True # send status to status callback - status_message = latest_run.status + status_message = self.latest_run.status if failure_message is not None: status_message = status_message + ": " + failure_message self.send_status(status_message) @@ -165,7 +185,10 @@ def send_to_assistant(self, text): reply_messages = self.client.beta.threads.messages.list(self.assistant_thread.id, order="asc", after=input_message.id) - if reply_messages is None: + + if (self.latest_run.status == "cancelled"): + return "cancelled successfully" + elif reply_messages is None: return "chat: failed to retrieve messages" # concatenate all messages into a single reply skipping the first which is our question diff --git a/MAVProxy/modules/mavproxy_chat/chat_window.py b/MAVProxy/modules/mavproxy_chat/chat_window.py index 45e0d1b9d7..6f5fc44381 100644 --- a/MAVProxy/modules/mavproxy_chat/chat_window.py +++ b/MAVProxy/modules/mavproxy_chat/chat_window.py @@ -70,6 +70,12 @@ def __init__(self, mpstate, wait_for_command_ack_fn): self.frame.Bind(wx.EVT_BUTTON, self.send_button_click, self.send_button) self.horiz_sizer.Add(self.send_button, proportion=0, flag=wx.ALIGN_TOP | wx.ALL, border=5) + # add a cancel button + self.cancel_button = wx.Button(self.frame, id=-1, label="cancel", size=(75, 25)) + self.frame.Bind(wx.EVT_BUTTON, self.cancel_button_click , self.cancel_button) + self.horiz_sizer.Add(self.cancel_button, proportion=0, flag=wx.ALIGN_TOP | wx.ALL, border=5) + wx.CallAfter(self.cancel_button.Disable) + # set size hints and add sizer to frame self.vert_sizer.Add(self.text_reply, proportion=1, flag=wx.EXPAND, border=5) self.vert_sizer.Add(self.text_status, proportion=0, flag=wx.EXPAND, border=5) @@ -139,6 +145,10 @@ def record_button_click_execute(self, event): self.set_status_text("sending text to assistasnt") self.send_text_to_assistant() + # cancel button clicked + def cancel_button_click(self, event): + self.chat_openai.cancel_run() + # send button clicked def send_button_click(self, event): self.text_input_change(event) @@ -161,6 +171,8 @@ def send_text_to_assistant(self): focus = self.text_input # disable buttons and text input to stop multiple inputs (can't be done from a thread or must use CallAfter) + # enable the cancel button to cancel the current run + wx.CallAfter(self.cancel_button.Enable) wx.CallAfter(self.record_button.Disable) wx.CallAfter(self.text_input.Disable) wx.CallAfter(self.send_button.Disable) @@ -181,6 +193,8 @@ def send_text_to_assistant(self): wx.CallAfter(self.text_reply.AppendText, reply + "\n\n") # reenable buttons and text input (can't be done from a thread or must use CallAfter) + # disable the cancel button + wx.CallAfter(self.cancel_button.Disable) wx.CallAfter(self.record_button.Enable) wx.CallAfter(self.text_input.Enable) wx.CallAfter(self.send_button.Enable)