diff --git a/examples/vite/src/App.tsx b/examples/vite/src/App.tsx index 4fd893c68..f617e1f1c 100644 --- a/examples/vite/src/App.tsx +++ b/examples/vite/src/App.tsx @@ -1,5 +1,6 @@ import { ChannelFilters, ChannelOptions, ChannelSort } from 'stream-chat'; import { + AIStateIndicator, Channel, ChannelAvatar, ChannelHeader, @@ -88,6 +89,7 @@ const App = () => { + diff --git a/examples/vite/src/index.scss b/examples/vite/src/index.scss index 036d9caa6..a5b013cd4 100644 --- a/examples/vite/src/index.scss +++ b/examples/vite/src/index.scss @@ -29,6 +29,7 @@ body, width: 0; flex-shrink: 0; box-shadow: 0 0 8px rgba(0, 0, 0, 0.15); + max-width: 1000px; &--open { width: 30%; diff --git a/package.json b/package.json index c6440dd1b..4c8c5e922 100644 --- a/package.json +++ b/package.json @@ -145,7 +145,7 @@ "emoji-mart": "^5.4.0", "react": "^18.0.0 || ^17.0.0 || ^16.8.0", "react-dom": "^18.0.0 || ^17.0.0 || ^16.8.0", - "stream-chat": "^8.45.0" + "stream-chat": "^8.46.0" }, "peerDependenciesMeta": { "@breezystack/lamejs": { @@ -186,7 +186,7 @@ "@semantic-release/changelog": "^6.0.2", "@semantic-release/git": "^10.0.1", "@stream-io/rollup-plugin-node-builtins": "^2.1.5", - "@stream-io/stream-chat-css": "^5.4.0", + "@stream-io/stream-chat-css": "^5.5.0", "@testing-library/jest-dom": "^6.1.4", "@testing-library/react": "^13.1.1", "@testing-library/react-hooks": "^8.0.0", @@ -255,7 +255,7 @@ "react-dom": "^18.1.0", "react-test-renderer": "^18.1.0", "semantic-release": "^19.0.5", - "stream-chat": "^8.45.0", + "stream-chat": "^8.46.0", "ts-jest": "^29.1.4", "typescript": "^5.4.5" }, diff --git a/src/components/AIStateIndicator/AIStateIndicator.tsx b/src/components/AIStateIndicator/AIStateIndicator.tsx new file mode 100644 index 000000000..2bbe572f1 --- /dev/null +++ b/src/components/AIStateIndicator/AIStateIndicator.tsx @@ -0,0 +1,37 @@ +import React from 'react'; + +import { Channel } from 'stream-chat'; + +import { AIStates, useAIState } from './hooks/useAIState'; + +import { useChannelStateContext, useTranslationContext } from '../../context'; +import type { DefaultStreamChatGenerics } from '../../types/types'; + +export type AIStateIndicatorProps< + StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics +> = { + channel?: Channel; +}; + +export const AIStateIndicator = < + StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics +>({ + channel: channelFromProps, +}: AIStateIndicatorProps) => { + const { t } = useTranslationContext(); + const { channel: channelFromContext } = useChannelStateContext( + 'AIStateIndicator', + ); + const channel = channelFromProps || channelFromContext; + const { aiState } = useAIState(channel); + const allowedStates = { + [AIStates.Thinking]: t('Thinking...'), + [AIStates.Generating]: t('Generating...'), + }; + + return aiState in allowedStates ? ( +
+

{allowedStates[aiState]}

+
+ ) : null; +}; diff --git a/src/components/AIStateIndicator/hooks/useAIState.ts b/src/components/AIStateIndicator/hooks/useAIState.ts new file mode 100644 index 000000000..54e955a6f --- /dev/null +++ b/src/components/AIStateIndicator/hooks/useAIState.ts @@ -0,0 +1,57 @@ +import { useEffect, useState } from 'react'; + +import { AIState, Channel, Event } from 'stream-chat'; + +import type { DefaultStreamChatGenerics } from '../../../types/types'; + +export const AIStates = { + Error: 'AI_STATE_ERROR', + ExternalSources: 'AI_STATE_EXTERNAL_SOURCES', + Generating: 'AI_STATE_GENERATING', + Idle: 'AI_STATE_IDLE', + Thinking: 'AI_STATE_THINKING', +}; + +/** + * A hook that returns the current state of the AI. + * @param {Channel} channel - The channel for which we want to know the AI state. + * @returns {{ aiState: AIState }} The current AI state for the given channel. + */ +export const useAIState = < + StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics +>( + channel?: Channel, +): { aiState: AIState } => { + const [aiState, setAiState] = useState(AIStates.Idle); + + useEffect(() => { + if (!channel) { + return; + } + + const indicatorChangedListener = channel.on( + 'ai_indicator.update', + (event: Event) => { + const { cid } = event; + const state = event.ai_state as AIState; + if (channel.cid === cid) { + setAiState(state); + } + }, + ); + + const indicatorClearedListener = channel.on('ai_indicator.clear', (event) => { + const { cid } = event; + if (channel.cid === cid) { + setAiState(AIStates.Idle); + } + }); + + return () => { + indicatorChangedListener.unsubscribe(); + indicatorClearedListener.unsubscribe(); + }; + }, [channel]); + + return { aiState }; +}; diff --git a/src/components/AIStateIndicator/index.ts b/src/components/AIStateIndicator/index.ts new file mode 100644 index 000000000..1bae577b7 --- /dev/null +++ b/src/components/AIStateIndicator/index.ts @@ -0,0 +1,2 @@ +export * from './AIStateIndicator'; +export * from './hooks/useAIState'; diff --git a/src/components/Channel/Channel.tsx b/src/components/Channel/Channel.tsx index 419663eed..e468e79b8 100644 --- a/src/components/Channel/Channel.tsx +++ b/src/components/Channel/Channel.tsx @@ -161,6 +161,8 @@ type ChannelPropsForwardedToComponentContext< | 'UnreadMessagesNotification' | 'UnreadMessagesSeparator' | 'VirtualMessage' + | 'StopAIGenerationButton' + | 'StreamedMessageText' >; const isUserResponseArray = < @@ -1273,6 +1275,8 @@ const ChannelInner = < ReactionsList: props.ReactionsList, SendButton: props.SendButton, StartRecordingAudioButton: props.StartRecordingAudioButton, + StopAIGenerationButton: props.StopAIGenerationButton, + StreamedMessageText: props.StreamedMessageText, ThreadHead: props.ThreadHead, ThreadHeader: props.ThreadHeader, ThreadStart: props.ThreadStart, @@ -1339,6 +1343,8 @@ const ChannelInner = < props.UnreadMessagesNotification, props.UnreadMessagesSeparator, props.VirtualMessage, + props.StopAIGenerationButton, + props.StreamedMessageText, props.emojiSearchIndex, props.reactionOptions, ], diff --git a/src/components/ChannelPreview/utils.tsx b/src/components/ChannelPreview/utils.tsx index 06205e679..85184f21a 100644 --- a/src/components/ChannelPreview/utils.tsx +++ b/src/components/ChannelPreview/utils.tsx @@ -77,7 +77,9 @@ export const getLatestMessagePreview = < } if (previewTextToRender) { - return renderPreviewText(previewTextToRender); + return latestMessage.ai_generated + ? previewTextToRender + : renderPreviewText(previewTextToRender); } if (latestMessage.command) { diff --git a/src/components/Message/MessageSimple.tsx b/src/components/Message/MessageSimple.tsx index 9d08785e1..fdd3cfff0 100644 --- a/src/components/Message/MessageSimple.tsx +++ b/src/components/Message/MessageSimple.tsx @@ -34,6 +34,7 @@ import { MessageEditedTimestamp } from './MessageEditedTimestamp'; import type { MessageUIComponentProps } from './types'; import type { DefaultStreamChatGenerics } from '../../types/types'; +import { StreamedMessageText as DefaultStreamedMessageText } from './StreamedMessageText'; type MessageSimpleWithContextProps< StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics @@ -81,6 +82,7 @@ const MessageSimpleWithContext = < MessageStatus = DefaultMessageStatus, MessageTimestamp = DefaultMessageTimestamp, ReactionsList = DefaultReactionList, + StreamedMessageText = DefaultStreamedMessageText, PinIndicator, } = useComponentContext('MessageSimple'); @@ -185,7 +187,11 @@ const MessageSimpleWithContext = < {message.attachments?.length && !message.quoted_message ? ( ) : null} - + {message.ai_generated ? ( + + ) : ( + + )} {message.mml && ( = Pick, 'message' | 'renderText'> & { + renderingLetterCount?: number; + streamingLetterIntervalMs?: number; +}; + +export const StreamedMessageText = < + StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics +>( + props: StreamedMessageTextProps, +) => { + const { + message: messageFromProps, + renderingLetterCount, + renderText, + streamingLetterIntervalMs, + } = props; + const { message: messageFromContext } = useMessageContext( + 'StreamedMessageText', + ); + const message = messageFromProps || messageFromContext; + const { text = '' } = message; + const { streamedMessageText } = useMessageTextStreaming({ + renderingLetterCount, + streamingLetterIntervalMs, + text, + }); + + return ( + + ); +}; diff --git a/src/components/Message/hooks/index.ts b/src/components/Message/hooks/index.ts index 833c6e1c8..14e403552 100644 --- a/src/components/Message/hooks/index.ts +++ b/src/components/Message/hooks/index.ts @@ -12,3 +12,4 @@ export * from './useRetryHandler'; export * from './useUserHandler'; export * from './useUserRole'; export * from './useReactionsFetcher'; +export * from './useMessageTextStreaming'; diff --git a/src/components/Message/hooks/useMessageTextStreaming.ts b/src/components/Message/hooks/useMessageTextStreaming.ts new file mode 100644 index 000000000..d791988c9 --- /dev/null +++ b/src/components/Message/hooks/useMessageTextStreaming.ts @@ -0,0 +1,52 @@ +import { useEffect, useRef, useState } from 'react'; + +import type { DefaultStreamChatGenerics } from '../../../types/types'; +import type { StreamedMessageTextProps } from '../StreamedMessageText'; + +export type UseMessageTextStreamingProps< + StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics +> = Pick< + StreamedMessageTextProps, + 'streamingLetterIntervalMs' | 'renderingLetterCount' +> & { text: string }; + +const DEFAULT_LETTER_INTERVAL = 30; +const DEFAULT_RENDERING_LETTER_COUNT = 2; + +/** + * A hook that returns text in a streamed, typewriter fashion. The speed of streaming is + * configurable. + * @param {number} [streamingLetterIntervalMs=30] - The timeout between each typing animation in milliseconds. + * @param {number} [renderingLetterCount=2] - The number of letters to be rendered each time we update. + * @param {string} text - The text that we want to render in a typewriter fashion. + * @returns {{ streamedMessageText: string }} - A substring of the text property, up until we've finished rendering the typewriter animation. + */ +export const useMessageTextStreaming = < + StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics +>({ + streamingLetterIntervalMs = DEFAULT_LETTER_INTERVAL, + renderingLetterCount = DEFAULT_RENDERING_LETTER_COUNT, + text, +}: UseMessageTextStreamingProps): { streamedMessageText: string } => { + const [streamedMessageText, setStreamedMessageText] = useState(text); + const textCursor = useRef(text.length); + + useEffect(() => { + const textLength = text.length; + const interval = setInterval(() => { + if (!text || textCursor.current >= textLength) { + clearInterval(interval); + } + const newCursorValue = textCursor.current + renderingLetterCount; + const newText = text.substring(0, newCursorValue); + textCursor.current += newText.length - textCursor.current; + setStreamedMessageText(newText); + }, streamingLetterIntervalMs); + + return () => { + clearInterval(interval); + }; + }, [streamingLetterIntervalMs, renderingLetterCount, text]); + + return { streamedMessageText }; +}; diff --git a/src/components/Message/index.ts b/src/components/Message/index.ts index efb921fc3..222b4a8bc 100644 --- a/src/components/Message/index.ts +++ b/src/components/Message/index.ts @@ -13,4 +13,5 @@ export * from './QuotedMessage'; export * from './renderText'; export * from './types'; export * from './utils'; +export * from './StreamedMessageText'; export type { TimestampProps } from './Timestamp'; diff --git a/src/components/Message/renderText/renderText.tsx b/src/components/Message/renderText/renderText.tsx index 64be840d6..8e2e0ace5 100644 --- a/src/components/Message/renderText/renderText.tsx +++ b/src/components/Message/renderText/renderText.tsx @@ -32,6 +32,13 @@ export const defaultAllowedTagNames: Array( - message: Pick, 'message_text_updated_at'>, -) => !!message.message_text_updated_at; + message: Pick, 'message_text_updated_at'> & + Partial, 'ai_generated'>>, +) => !!message.message_text_updated_at && !message.ai_generated; diff --git a/src/components/MessageInput/MessageInputFlat.tsx b/src/components/MessageInput/MessageInputFlat.tsx index 71aabbb72..334298936 100644 --- a/src/components/MessageInput/MessageInputFlat.tsx +++ b/src/components/MessageInput/MessageInputFlat.tsx @@ -9,6 +9,7 @@ import { import { AttachmentPreviewList as DefaultAttachmentPreviewList } from './AttachmentPreviewList'; import { CooldownTimer as DefaultCooldownTimer } from './CooldownTimer'; import { SendButton as DefaultSendButton } from './SendButton'; +import { StopAIGenerationButton as DefaultStopAIGenerationButton } from './StopAIGenerationButton'; import { AudioRecorder as DefaultAudioRecorder, RecordingPermissionDeniedNotification as DefaultRecordingPermissionDeniedNotification, @@ -32,6 +33,7 @@ import { useMessageInputContext } from '../../context/MessageInputContext'; import { useComponentContext } from '../../context/ComponentContext'; import type { DefaultStreamChatGenerics } from '../../types/types'; +import { AIStates, useAIState } from '../AIStateIndicator'; export const MessageInputFlat = < StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics @@ -66,6 +68,7 @@ export const MessageInputFlat = < RecordingPermissionDeniedNotification = DefaultRecordingPermissionDeniedNotification, SendButton = DefaultSendButton, StartRecordingAudioButton = DefaultStartRecordingAudioButton, + StopAIGenerationButton: StopAIGenerationButtonOverride, EmojiPicker, } = useComponentContext('MessageInputFlat'); const { @@ -76,6 +79,10 @@ export const MessageInputFlat = < const { setQuotedMessage } = useChannelActionContext('MessageInputFlat'); const { channel } = useChatContext('MessageInputFlat'); + const { aiState } = useAIState(channel); + + const stopGenerating = useCallback(() => channel?.stopAIResponse(), [channel]); + const [ showRecordingPermissionDeniedNotification, setShowRecordingPermissionDeniedNotification, @@ -133,6 +140,18 @@ export const MessageInputFlat = < const recordingEnabled = !!(recordingController.recorder && navigator.mediaDevices); // account for requirement on iOS as per this bug report: https://bugs.webkit.org/show_bug.cgi?id=252303 const isRecording = !!recordingController.recordingState; + /* This bit here is needed to make sure that we can get rid of the default behaviour + * if need be. Essentially this allows us to pass StopAIGenerationButton={null} and + * completely circumvent the default logic if it's not what we want. We need it as a + * prop because there is no other trivial way to override the SendMessage button otherwise. + */ + const StopAIGenerationButton = + StopAIGenerationButtonOverride === undefined + ? DefaultStopAIGenerationButton + : StopAIGenerationButtonOverride; + const shouldDisplayStopAIGeneration = + [AIStates.Thinking, AIStates.Generating].includes(aiState) && !!StopAIGenerationButton; + return ( <>
@@ -174,41 +193,45 @@ export const MessageInputFlat = < {EmojiPicker && }
- {!hideSendButton && ( - <> - {cooldownRemaining ? ( - - ) : ( - <> - + ) : ( + !hideSendButton && ( + <> + {cooldownRemaining ? ( + - {recordingEnabled && ( - + a.type === RecordingAttachmentType.VOICE_RECORDING, - )) + !numberOfUploads && + !text.length && + attachments.length - failedUploadsCount === 0 } - onClick={() => { - recordingController.recorder?.start(); - setShowRecordingPermissionDeniedNotification(true); - }} + sendMessage={handleSubmit} /> - )} - - )} - + {recordingEnabled && ( + a.type === RecordingAttachmentType.VOICE_RECORDING, + )) + } + onClick={() => { + recordingController.recorder?.start(); + setShowRecordingPermissionDeniedNotification(true); + }} + /> + )} + + )} + + ) )} diff --git a/src/components/MessageInput/StopAIGenerationButton.tsx b/src/components/MessageInput/StopAIGenerationButton.tsx new file mode 100644 index 000000000..ff566cba3 --- /dev/null +++ b/src/components/MessageInput/StopAIGenerationButton.tsx @@ -0,0 +1,17 @@ +import React from 'react'; +import { useTranslationContext } from '../../context'; + +export type StopAIGenerationButtonProps = React.ComponentProps<'button'>; + +export const StopAIGenerationButton = ({ onClick, ...restProps }: StopAIGenerationButtonProps) => { + const { t } = useTranslationContext(); + return ( +