diff --git a/.changeset/curly-buses-kick.md b/.changeset/curly-buses-kick.md new file mode 100644 index 00000000..9bde3373 --- /dev/null +++ b/.changeset/curly-buses-kick.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/chat-ui": patch +--- + +Expose ReactMarkdown `components` prop diff --git a/packages/chat-ui/src/chat/message-parts/parts/markdown.tsx b/packages/chat-ui/src/chat/message-parts/parts/markdown.tsx index 0ec0ac19..505c777b 100644 --- a/packages/chat-ui/src/chat/message-parts/parts/markdown.tsx +++ b/packages/chat-ui/src/chat/message-parts/parts/markdown.tsx @@ -1,20 +1,19 @@ -import { ComponentType } from 'react' import { cn } from '../../../lib/utils.js' import { - CitationComponentProps, - LanguageRendererProps, Markdown, preprocessSourceNodes, + type MarkdownProps, } from '../../../widgets/index.js' import { useChatMessage } from '../../chat-message.context.js' -import { SourcesPartType, TextPartType } from '../types.js' import { usePart } from '../context.js' +import { SourcesPartType, TextPartType } from '../types.js' import { getParts } from '../utils.js' interface ChatMarkdownProps extends React.PropsWithChildren { - citationComponent?: ComponentType + components?: MarkdownProps['components'] + citationComponent?: MarkdownProps['citationComponent'] className?: string - languageRenderers?: Record> + languageRenderers?: MarkdownProps['languageRenderers'] } /** @@ -41,6 +40,7 @@ export function MarkdownPartUI(props: ChatMarkdownProps) { = memo( ReactMarkdown, @@ -110,21 +110,38 @@ export interface LanguageRendererProps { className?: string } +type ReactStyleMarkdownComponents = { + // Extract pulls out the ComponentType side of unions like ComponentType | string + // react-markdown supports passing "h1" for example, which is difficult to + [K in keyof Components]?: Extract> +} + +// Simple function to render a component if provided, otherwise use fallback +function combineComponent( + component: FC | undefined, + fallback: FC +): FC { + return props => component?.(props) || fallback(props) +} + +export interface MarkdownProps { + content: string + sources?: SourceData + backend?: string + components?: ReactStyleMarkdownComponents + citationComponent?: ComponentType + className?: string + languageRenderers?: Record> +} export function Markdown({ content, sources, backend, citationComponent: CitationComponent, className: customClassName, + components, languageRenderers, -}: { - content: string - sources?: SourceData - backend?: string - citationComponent?: ComponentType - className?: string - languageRenderers?: Record> -}) { +}: MarkdownProps) { const processedContent = preprocessContent(content) return ( @@ -137,49 +154,53 @@ export function Markdown({ remarkPlugins={[remarkGfm, remarkMath]} rehypePlugins={[rehypeKatex as any]} components={{ - p({ children }) { + ...components, + p: combineComponent(components?.p, ({ children }) => { return
{children}
- }, - code({ inline, className, children, ...props }) { - if (children.length) { - if (children[0] === '▍') { + }), + code: combineComponent( + components?.code, + ({ inline, className, children, ...props }) => { + if (children.length) { + if (children[0] === '▍') { + return ( + + ) + } + + children[0] = (children[0] as string).replace('`▍`', '▍') + } + + const match = /language-(\w+)/.exec(className || '') + const language = (match && match[1]) || '' + const codeValue = String(children).replace(/\n$/, '') + + if (inline) { return ( - + + {children} + ) } - children[0] = (children[0] as string).replace('`▍`', '▍') - } - - const match = /language-(\w+)/.exec(className || '') - const language = (match && match[1]) || '' - const codeValue = String(children).replace(/\n$/, '') + // Check for custom language renderer + if (languageRenderers?.[language]) { + const CustomRenderer = languageRenderers[language] + return + } - if (inline) { return ( - - {children} - + ) } - - // Check for custom language renderer - if (languageRenderers?.[language]) { - const CustomRenderer = languageRenderers[language] - return - } - - return ( - - ) - }, - a({ href, children }) { + ), + a: combineComponent(components?.a, ({ href, children }) => { // If href starts with `{backend}/api/files`, then it's a local document and we use DocumentInfo for rendering if (href?.startsWith(`${backend}/api/files`)) { // Check if the file is document file type @@ -231,7 +252,7 @@ export function Markdown({ {children} ) - }, + }), }} > {processedContent}