diff --git a/frontend/__tests__/components/chat/chat-interface.test.tsx b/frontend/__tests__/components/chat/chat-interface.test.tsx index 4a5c80b26e22..e4f5c1b0f0c6 100644 --- a/frontend/__tests__/components/chat/chat-interface.test.tsx +++ b/frontend/__tests__/components/chat/chat-interface.test.tsx @@ -17,8 +17,8 @@ import type { Message } from "#/message"; import { SUGGESTIONS } from "#/utils/suggestions"; import { ChatInterface } from "#/components/features/chat/chat-interface"; import { useWsClient } from "#/context/ws-client-provider"; +import { useErrorMessageStore } from "#/stores/error-message-store"; import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store"; -import { useWSErrorMessage } from "#/hooks/use-ws-error-message"; import { useConfig } from "#/hooks/query/use-config"; import { useGetTrajectory } from "#/hooks/mutation/use-get-trajectory"; import { useUploadFiles } from "#/hooks/mutation/use-upload-files"; @@ -26,8 +26,8 @@ import { OpenHandsAction } from "#/types/core/actions"; // Mock the hooks vi.mock("#/context/ws-client-provider"); +vi.mock("#/stores/error-message-store"); vi.mock("#/stores/optimistic-user-message-store"); -vi.mock("#/hooks/use-ws-error-message"); vi.mock("#/hooks/query/use-config"); vi.mock("#/hooks/mutation/use-get-trajectory"); vi.mock("#/hooks/mutation/use-upload-files"); @@ -61,7 +61,6 @@ vi.mock("#/hooks/use-conversation-name-context-menu", () => ({ }), })); - // Helper function to render with Router context const renderChatInterfaceWithRouter = () => renderWithProviders( @@ -114,8 +113,9 @@ describe("ChatInterface - Chat Suggestions", () => { setOptimisticUserMessage: vi.fn(), getOptimisticUserMessage: vi.fn(() => null), }); - (useWSErrorMessage as unknown as ReturnType).mockReturnValue({ - getErrorMessage: vi.fn(() => null), + ( + useErrorMessageStore as unknown as ReturnType + ).mockReturnValue({ setErrorMessage: vi.fn(), removeErrorMessage: vi.fn(), }); @@ -251,8 +251,9 @@ describe("ChatInterface - Empty state", () => { setOptimisticUserMessage: vi.fn(), getOptimisticUserMessage: vi.fn(() => null), }); - (useWSErrorMessage as unknown as ReturnType).mockReturnValue({ - getErrorMessage: vi.fn(() => null), + ( + useErrorMessageStore as unknown as ReturnType + ).mockReturnValue({ setErrorMessage: vi.fn(), removeErrorMessage: vi.fn(), }); diff --git a/frontend/src/components/features/chat/chat-interface.tsx b/frontend/src/components/features/chat/chat-interface.tsx index a61201d64779..d4eb108a23e5 100644 --- a/frontend/src/components/features/chat/chat-interface.tsx +++ b/frontend/src/components/features/chat/chat-interface.tsx @@ -22,8 +22,8 @@ import { useAgentStore } from "#/stores/agent-store"; import { ScrollToBottomButton } from "#/components/shared/buttons/scroll-to-bottom-button"; import { LoadingSpinner } from "#/components/shared/loading-spinner"; import { displayErrorToast } from "#/utils/custom-toast-handlers"; +import { useErrorMessageStore } from "#/stores/error-message-store"; import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store"; -import { useWSErrorMessage } from "#/hooks/use-ws-error-message"; import { ErrorMessageBanner } from "./error-message-banner"; import { hasUserEvent, @@ -46,7 +46,7 @@ function getEntryPoint( export function ChatInterface() { const { setMessageToSend } = useConversationStore(); - const { getErrorMessage } = useWSErrorMessage(); + const { errorMessage } = useErrorMessageStore(); const { send, isLoadingMessages, parsedEvents } = useWsClient(); const { setOptimisticUserMessage, getOptimisticUserMessage } = useOptimisticUserMessageStore(); @@ -73,7 +73,6 @@ export function ChatInterface() { const { mutateAsync: uploadFiles } = useUploadFiles(); const optimisticUserMessage = getOptimisticUserMessage(); - const errorMessage = getErrorMessage(); const events = parsedEvents.filter(shouldRenderEvent); diff --git a/frontend/src/context/ws-client-provider.tsx b/frontend/src/context/ws-client-provider.tsx index 9fa523206456..13312f9e30a3 100644 --- a/frontend/src/context/ws-client-provider.tsx +++ b/frontend/src/context/ws-client-provider.tsx @@ -26,8 +26,8 @@ import { isStatusUpdate, isUserMessage, } from "#/types/core/guards"; +import { useErrorMessageStore } from "#/stores/error-message-store"; import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store"; -import { useWSErrorMessage } from "#/hooks/use-ws-error-message"; export type WebSocketStatus = "CONNECTING" | "CONNECTED" | "DISCONNECTED"; @@ -131,8 +131,8 @@ export function WsClientProvider({ conversationId, children, }: React.PropsWithChildren) { + const { setErrorMessage, removeErrorMessage } = useErrorMessageStore(); const { removeOptimisticUserMessage } = useOptimisticUserMessageStore(); - const { setErrorMessage, removeErrorMessage } = useWSErrorMessage(); const queryClient = useQueryClient(); const sioRef = React.useRef(null); const [webSocketStatus, setWebSocketStatus] = diff --git a/frontend/src/hooks/use-ws-error-message.ts b/frontend/src/hooks/use-ws-error-message.ts deleted file mode 100644 index 370804b7b0f6..000000000000 --- a/frontend/src/hooks/use-ws-error-message.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { useQueryClient } from "@tanstack/react-query"; - -export const useWSErrorMessage = () => { - const queryClient = useQueryClient(); - - const setErrorMessage = (message: string) => { - queryClient.setQueryData(["error_message"], message); - }; - - const getErrorMessage = () => - queryClient.getQueryData(["error_message"]); - - const removeErrorMessage = () => { - queryClient.removeQueries({ queryKey: ["error_message"] }); - }; - - return { - setErrorMessage, - getErrorMessage, - removeErrorMessage, - }; -}; diff --git a/frontend/src/stores/error-message-store.ts b/frontend/src/stores/error-message-store.ts new file mode 100644 index 000000000000..4416814ed869 --- /dev/null +++ b/frontend/src/stores/error-message-store.ts @@ -0,0 +1,30 @@ +import { create } from "zustand"; + +interface ErrorMessageState { + errorMessage: string | null; +} + +interface ErrorMessageActions { + setErrorMessage: (message: string) => void; + removeErrorMessage: () => void; +} + +type ErrorMessageStore = ErrorMessageState & ErrorMessageActions; + +const initialState: ErrorMessageState = { + errorMessage: null, +}; + +export const useErrorMessageStore = create((set) => ({ + ...initialState, + + setErrorMessage: (message: string) => + set(() => ({ + errorMessage: message, + })), + + removeErrorMessage: () => + set(() => ({ + errorMessage: null, + })), +}));