diff --git a/app/api/chat/route.ts b/app/api/chat/route.ts index cc844658..7ae5af7c 100644 --- a/app/api/chat/route.ts +++ b/app/api/chat/route.ts @@ -1,3 +1,4 @@ +import { createGoogleGenerativeAI } from "@ai-sdk/google" import { APICallError, convertToModelMessages, @@ -217,9 +218,45 @@ async function handleChatRequest(req: Request): Promise { // Read minimal style preference from header const minimalStyle = req.headers.get("x-minimal-style") === "true" + // Read image generation config from headers + const imageGenerationEnabled = + req.headers.get("x-image-generation") === "true" + const imageResolution = req.headers.get("x-image-resolution") || "1K" + const imageAspectRatio = req.headers.get("x-image-aspect-ratio") || "1:1" + // Get AI model with optional client overrides - const { model, providerOptions, headers, modelId } = - getAIModel(clientOverrides) + // If image generation is enabled, use gemini-3-pro-image-preview + let model: any + let providerOptions: any = {} + let headers: any = {} + let modelId: string + + if (imageGenerationEnabled) { + // Use Google Gemini for image generation + const googleApiKey = process.env.GOOGLE_GENERATIVE_AI_API_KEY + if (!googleApiKey) { + return Response.json( + { error: "Google API key not configured for image generation" }, + { status: 500 }, + ) + } + const googleProvider = createGoogleGenerativeAI({ + apiKey: googleApiKey, + }) + model = googleProvider("gemini-3-pro-image-preview", { + imageConfig: { + aspectRatio: imageAspectRatio, + imageSize: imageResolution, + }, + }) + modelId = "gemini-3-pro-image-preview" + } else { + const result = getAIModel(clientOverrides) + model = result.model + providerOptions = result.providerOptions + headers = result.headers + modelId = result.modelId + } // Check if model supports prompt caching const shouldCache = supportsPromptCaching(modelId) @@ -288,6 +325,40 @@ ${userInputText} msg.content && Array.isArray(msg.content) && msg.content.length > 0, ) + // In image generation mode, filter out images from assistant messages + // Gemini API doesn't support images in assistant messages + if (imageGenerationEnabled) { + enhancedMessages = enhancedMessages.map((msg: any) => { + if (msg.role === "assistant" && Array.isArray(msg.content)) { + const filteredContent = msg.content.filter((part: any) => { + // Remove image parts from assistant messages (multiple checks for different formats) + if ( + part.type === "image" || + part.image || + part.url || + part.mimeType?.startsWith("image/") || + (part.experimental_providerMetadata && + part.experimental_providerMetadata.anthropic + ?.type === "image") + ) { + console.log( + "[route.ts] Filtering out image from assistant message", + part.type, + ) + return false + } + return true + }) + return { ...msg, content: filteredContent } + } + return msg + }) + // Remove messages with empty content after filtering + enhancedMessages = enhancedMessages.filter( + (msg: any) => msg.content && msg.content.length > 0, + ) + } + // Filter out tool-calls with invalid inputs (from failed repair or interrupted streaming) // Bedrock API rejects messages where toolUse.input is not a valid JSON object enhancedMessages = enhancedMessages @@ -393,28 +464,35 @@ ${userInputText} // - Breakpoint 2: Current XML context - changes per diagram, but constant within a conversation turn // This allows: if only user message changes, both system caches are reused // if XML changes, instruction cache is still reused - const systemMessages = [ - // Cache breakpoint 1: Instructions (rarely change) - { - role: "system" as const, - content: systemMessage, - ...(shouldCache && { - providerOptions: { - bedrock: { cachePoint: { type: "default" } }, - }, - }), - }, - // Cache breakpoint 2: Previous and Current diagram XML context - { - role: "system" as const, - content: `${previousXml ? `Previous diagram XML (before user's last message):\n"""xml\n${previousXml}\n"""\n\n` : ""}Current diagram XML (AUTHORITATIVE - the source of truth):\n"""xml\n${xml || ""}\n"""\n\nIMPORTANT: The "Current diagram XML" is the SINGLE SOURCE OF TRUTH for what's on the canvas right now. The user can manually add, delete, or modify shapes directly in draw.io. Always count and describe elements based on the CURRENT XML, not on what you previously generated. If both previous and current XML are shown, compare them to understand what the user changed. When using edit_diagram, COPY search patterns exactly from the CURRENT XML - attribute order matters!`, - ...(shouldCache && { - providerOptions: { - bedrock: { cachePoint: { type: "default" } }, - }, - }), - }, - ] + const systemMessages = imageGenerationEnabled + ? [ + { + role: "system" as const, + content: `你是一个 AI 图片生成器。根据用户的描述创建高质量的图片。请充满创意和细节。`, + }, + ] + : [ + // Cache breakpoint 1: Instructions (rarely change) + { + role: "system" as const, + content: systemMessage, + ...(shouldCache && { + providerOptions: { + bedrock: { cachePoint: { type: "default" } }, + }, + }), + }, + // Cache breakpoint 2: Previous and Current diagram XML context + { + role: "system" as const, + content: `${previousXml ? `Previous diagram XML (before user's last message):\n"""xml\n${previousXml}\n"""\n\n` : ""}Current diagram XML (AUTHORITATIVE - the source of truth):\n"""xml\n${xml || ""}\n"""\n\nIMPORTANT: The "Current diagram XML" is the SINGLE SOURCE OF TRUTH for what's on the canvas right now. The user can manually add, delete, or modify shapes directly in draw.io. Always count and describe elements based on the CURRENT XML, not on what you previously generated. If both previous and current XML are shown, compare them to understand what the user changed. When using edit_diagram, COPY search patterns exactly from the CURRENT XML - attribute order matters!`, + ...(shouldCache && { + providerOptions: { + bedrock: { cachePoint: { type: "default" } }, + }, + }), + }, + ] const allMessages = [...systemMessages, ...enhancedMessages] @@ -423,72 +501,10 @@ ${userInputText} ...(process.env.MAX_OUTPUT_TOKENS && { maxOutputTokens: parseInt(process.env.MAX_OUTPUT_TOKENS, 10), }), - stopWhen: stepCountIs(5), - // Repair truncated tool calls when maxOutputTokens is reached mid-JSON - experimental_repairToolCall: async ({ toolCall, error }) => { - // DEBUG: Log what we're trying to repair - console.log(`[repairToolCall] Tool: ${toolCall.toolName}`) - console.log( - `[repairToolCall] Error: ${error.name} - ${error.message}`, - ) - console.log(`[repairToolCall] Input type: ${typeof toolCall.input}`) - console.log(`[repairToolCall] Input value:`, toolCall.input) - - // Only attempt repair for invalid tool input (broken JSON from truncation) - if ( - error instanceof InvalidToolInputError || - error.name === "AI_InvalidToolInputError" - ) { - try { - // Pre-process to fix common LLM JSON errors that jsonrepair can't handle - let inputToRepair = toolCall.input - if (typeof inputToRepair === "string") { - // Fix `:=` instead of `: ` (LLM sometimes generates this) - inputToRepair = inputToRepair.replace(/:=/g, ": ") - // Fix `= "` instead of `: "` - inputToRepair = inputToRepair.replace(/=\s*"/g, ': "') - } - // Use jsonrepair to fix truncated JSON - const repairedInput = jsonrepair(inputToRepair) - console.log( - `[repairToolCall] Repaired truncated JSON for tool: ${toolCall.toolName}`, - ) - return { ...toolCall, input: repairedInput } - } catch (repairError) { - console.warn( - `[repairToolCall] Failed to repair JSON for tool: ${toolCall.toolName}`, - repairError, - ) - // Return a placeholder input to avoid API errors in multi-step - // The tool will fail gracefully on client side - if (toolCall.toolName === "edit_diagram") { - return { - ...toolCall, - input: { - operations: [], - _error: "JSON repair failed - no operations to apply", - }, - } - } - if (toolCall.toolName === "display_diagram") { - return { - ...toolCall, - input: { - xml: "", - _error: "JSON repair failed - empty diagram", - }, - } - } - return null - } - } - // Don't attempt to repair other errors (like NoSuchToolError) - return null - }, + stopWhen: imageGenerationEnabled ? undefined : stepCountIs(5), messages: allMessages, - ...(providerOptions && { providerOptions }), // This now includes all reasoning configs + ...(providerOptions && { providerOptions }), ...(headers && { headers }), - // Langfuse telemetry config (returns undefined if not configured) ...(getTelemetryConfig({ sessionId: validSessionId, userId }) && { experimental_telemetry: getTelemetryConfig({ sessionId: validSessionId, @@ -496,16 +512,85 @@ ${userInputText} }), }), onFinish: ({ text, usage }) => { - // Pass usage to Langfuse (Bedrock streaming doesn't auto-report tokens to telemetry) setTraceOutput(text, { promptTokens: usage?.inputTokens, completionTokens: usage?.outputTokens, }) }, - tools: { - // Client-side tool that will be executed on the client - display_diagram: { - description: `Display a diagram on draw.io. Pass ONLY the mxCell elements - wrapper tags and root cells are added automatically. + // Only add repair and tools for diagram mode + ...(!imageGenerationEnabled && { + experimental_repairToolCall: async ({ toolCall, error }) => { + // DEBUG: Log what we're trying to repair + console.log(`[repairToolCall] Tool: ${toolCall.toolName}`) + console.log( + `[repairToolCall] Error: ${error.name} - ${error.message}`, + ) + console.log( + `[repairToolCall] Input type: ${typeof toolCall.input}`, + ) + console.log(`[repairToolCall] Input value:`, toolCall.input) + + // Only attempt repair for invalid tool input (broken JSON from truncation) + if ( + error instanceof InvalidToolInputError || + error.name === "AI_InvalidToolInputError" + ) { + try { + // Pre-process to fix common LLM JSON errors that jsonrepair can't handle + let inputToRepair = toolCall.input + if (typeof inputToRepair === "string") { + // Fix `:=` instead of `: ` (LLM sometimes generates this) + inputToRepair = inputToRepair.replace(/:=/g, ": ") + // Fix `= "` instead of `: "` + inputToRepair = inputToRepair.replace( + /=\s*"/g, + ': "', + ) + } + // Use jsonrepair to fix truncated JSON + const repairedInput = jsonrepair(inputToRepair) + console.log( + `[repairToolCall] Repaired truncated JSON for tool: ${toolCall.toolName}`, + ) + return { ...toolCall, input: repairedInput } + } catch (repairError) { + console.warn( + `[repairToolCall] Failed to repair JSON for tool: ${toolCall.toolName}`, + repairError, + ) + // Return a placeholder input to avoid API errors in multi-step + // The tool will fail gracefully on client side + if (toolCall.toolName === "edit_diagram") { + return { + ...toolCall, + input: { + operations: [], + _error: "JSON repair failed - no operations to apply", + }, + } + } + if (toolCall.toolName === "display_diagram") { + return { + ...toolCall, + input: { + xml: "", + _error: "JSON repair failed - empty diagram", + }, + } + } + return null + } + } + // Don't attempt to repair other errors (like NoSuchToolError) + return null + }, + }), + // Tools - only for diagram mode + ...(!imageGenerationEnabled && { + tools: { + // Client-side tool that will be executed on the client + display_diagram: { + description: `Display a diagram on draw.io. Pass ONLY the mxCell elements - wrapper tags and root cells are added automatically. VALIDATION RULES (XML will be rejected if violated): 1. Generate ONLY mxCell elements - NO wrapper tags (, , ) @@ -536,14 +621,14 @@ Notes: - For AWS diagrams, use **AWS 2025 icons**. - For animated connectors, add "flowAnimation=1" to edge style. `, - inputSchema: z.object({ - xml: z - .string() - .describe("XML string to be displayed on draw.io"), - }), - }, - edit_diagram: { - description: `Edit the current diagram by ID-based operations (update/add/delete cells). + inputSchema: z.object({ + xml: z + .string() + .describe("XML string to be displayed on draw.io"), + }), + }, + edit_diagram: { + description: `Edit the current diagram by ID-based operations (update/add/delete cells). Operations: - update: Replace an existing cell by its id. Provide cell_id and complete new_xml. @@ -553,31 +638,31 @@ Operations: For update/add, new_xml must be a complete mxCell element including mxGeometry. ⚠️ JSON ESCAPING: Every " inside new_xml MUST be escaped as \\". Example: id=\\"5\\" value=\\"Label\\"`, - inputSchema: z.object({ - operations: z - .array( - z.object({ - type: z - .enum(["update", "add", "delete"]) - .describe("Operation type"), - cell_id: z - .string() - .describe( - "The id of the mxCell. Must match the id attribute in new_xml.", - ), - new_xml: z - .string() - .optional() - .describe( - "Complete mxCell XML element (required for update/add)", - ), - }), - ) - .describe("Array of operations to apply"), - }), - }, - append_diagram: { - description: `Continue generating diagram XML when previous display_diagram output was truncated due to length limits. + inputSchema: z.object({ + operations: z + .array( + z.object({ + type: z + .enum(["update", "add", "delete"]) + .describe("Operation type"), + cell_id: z + .string() + .describe( + "The id of the mxCell. Must match the id attribute in new_xml.", + ), + new_xml: z + .string() + .optional() + .describe( + "Complete mxCell XML element (required for update/add)", + ), + }), + ) + .describe("Array of operations to apply"), + }), + }, + append_diagram: { + description: `Continue generating diagram XML when previous display_diagram output was truncated due to length limits. WHEN TO USE: Only call this tool after display_diagram was truncated (you'll see an error message about truncation). @@ -588,15 +673,30 @@ CRITICAL INSTRUCTIONS: 4. If still truncated, call append_diagram again with the next fragment Example: If previous output ended with '...' and complete the remaining elements.`, - inputSchema: z.object({ - xml: z - .string() - .describe( - "Continuation XML fragment to append (NO wrapper tags)", - ), - }), + inputSchema: z.object({ + xml: z + .string() + .describe( + "Continuation XML fragment to append (NO wrapper tags)", + ), + }), + }, + display_image: { + description: `在 draw.io 画布上显示生成的图片。此工具接收 base64 编码的图片数据并将其显示在画布上。`, + inputSchema: z.object({ + imageData: z + .string() + .describe( + "Base64 编码的图片数据(不包含 data:image 前缀)", + ), + description: z + .string() + .optional() + .describe("图片的可选描述"), + }), + }, }, - }, + }), ...(process.env.TEMPERATURE !== undefined && { temperature: parseFloat(process.env.TEMPERATURE), }), diff --git a/components/chat-message-display.tsx b/components/chat-message-display.tsx index b2dafadb..804694d9 100644 --- a/components/chat-message-display.tsx +++ b/components/chat-message-display.tsx @@ -8,6 +8,8 @@ import { ChevronUp, Copy, Cpu, + Download, + Eye, FileCode, FileText, Pencil, @@ -26,6 +28,7 @@ import { ReasoningContent, ReasoningTrigger, } from "@/components/ai-elements/reasoning" +import { ImagePreviewModal } from "@/components/image-preview-modal" import { ScrollArea } from "@/components/ui/scroll-area" import { applyDiagramOperations, @@ -234,6 +237,10 @@ export function ChatMessageDisplay({ const [editingMessageId, setEditingMessageId] = useState( null, ) + const [previewImage, setPreviewImage] = useState<{ + url: string + alt: string + } | null>(null) const editTextareaRef = useRef(null) const [editText, setEditText] = useState("") // Track which PDF sections are expanded (key: messageId-sectionIndex) @@ -1084,6 +1091,170 @@ export function ChatMessageDisplay({ part, partIndex, ) => { + // Handle image parts + if ( + part.type === + "image" || + ( + part as any + ).image + ) { + const imageUrl = + ( + part as any + ) + .image || + ( + part as any + ) + .url + + const handleDownload = + async () => { + try { + // Convert to PNG + const img = + new Image() + img.crossOrigin = + "anonymous" + + await new Promise( + ( + resolve, + reject, + ) => { + img.onload = + resolve + img.onerror = + reject + img.src = + imageUrl + }, + ) + + const canvas = + document.createElement( + "canvas", + ) + canvas.width = + img.width + canvas.height = + img.height + const ctx = + canvas.getContext( + "2d", + ) + + if ( + ctx + ) { + ctx.drawImage( + img, + 0, + 0, + ) + canvas.toBlob( + ( + blob, + ) => { + if ( + blob + ) { + const url = + URL.createObjectURL( + blob, + ) + const link = + document.createElement( + "a", + ) + link.href = + url + link.download = `ai-generated-image-${Date.now()}.png` + document.body.appendChild( + link, + ) + link.click() + document.body.removeChild( + link, + ) + URL.revokeObjectURL( + url, + ) + toast.success( + "图片已下载为 PNG 格式", + ) + } + }, + "image/png", + ) + } + } catch (error) { + console.error( + "Download failed:", + error, + ) + toast.error( + "下载失败", + ) + } + } + + const handlePreview = + () => { + setPreviewImage( + { + url: imageUrl, + alt: "AI生成图片", + }, + ) + } + + return ( +
+ AI生成图片 +
+ + +
+
+ ) + } + if ( part.type === "text" @@ -1358,6 +1529,12 @@ export function ChatMessageDisplay({ )}
+ !open && setPreviewImage(null)} + imageUrl={previewImage?.url || ""} + imageAlt={previewImage?.alt} + /> ) } diff --git a/components/chat-panel.tsx b/components/chat-panel.tsx index ea383924..53a6463b 100644 --- a/components/chat-panel.tsx +++ b/components/chat-panel.tsx @@ -18,6 +18,7 @@ import { FaGithub } from "react-icons/fa" import { Toaster, toast } from "sonner" import { ButtonWithTooltip } from "@/components/button-with-tooltip" import { ChatInput } from "@/components/chat-input" +import { ImageGenerationConfig } from "@/components/image-generation-config" import { ResetWarningModal } from "@/components/reset-warning-modal" import { SettingsDialog } from "@/components/settings-dialog" import { useDiagram } from "@/contexts/diagram-context" @@ -34,6 +35,10 @@ const STORAGE_MESSAGES_KEY = "next-ai-draw-io-messages" const STORAGE_XML_SNAPSHOTS_KEY = "next-ai-draw-io-xml-snapshots" const STORAGE_SESSION_ID_KEY = "next-ai-draw-io-session-id" export const STORAGE_DIAGRAM_XML_KEY = "next-ai-draw-io-diagram-xml" +const STORAGE_IMAGE_GENERATION_ENABLED_KEY = + "next-ai-draw-io-image-generation-enabled" +const STORAGE_IMAGE_RESOLUTION_KEY = "next-ai-draw-io-image-resolution" +const STORAGE_IMAGE_ASPECT_RATIO_KEY = "next-ai-draw-io-image-aspect-ratio" // sessionStorage keys const SESSION_STORAGE_INPUT_KEY = "next-ai-draw-io-input" @@ -150,12 +155,39 @@ export default function ChatPanel({ const [showNewChatDialog, setShowNewChatDialog] = useState(false) const [minimalStyle, setMinimalStyle] = useState(false) + // Image generation configuration states + const [imageGenerationEnabled, setImageGenerationEnabled] = useState(false) + const [imageResolution, setImageResolution] = useState("1K") + const [imageAspectRatio, setImageAspectRatio] = useState("1:1") + // Restore input from sessionStorage on mount (when ChatPanel remounts due to key change) useEffect(() => { const savedInput = sessionStorage.getItem(SESSION_STORAGE_INPUT_KEY) if (savedInput) { setInput(savedInput) } + + // Restore image generation config from localStorage + const savedImageEnabled = localStorage.getItem( + STORAGE_IMAGE_GENERATION_ENABLED_KEY, + ) + if (savedImageEnabled !== null) { + setImageGenerationEnabled(savedImageEnabled === "true") + } + + const savedResolution = localStorage.getItem( + STORAGE_IMAGE_RESOLUTION_KEY, + ) + if (savedResolution) { + setImageResolution(savedResolution) + } + + const savedAspectRatio = localStorage.getItem( + STORAGE_IMAGE_ASPECT_RATIO_KEY, + ) + if (savedAspectRatio) { + setImageAspectRatio(savedAspectRatio) + } }, []) // Check config on mount @@ -241,6 +273,40 @@ export default function ChatPanel({ ) } + if (toolCall.toolName === "display_image") { + const { imageData, description } = toolCall.input as { + imageData: string + description?: string + } + + // Create an mxCell with the image + const imageId = `img-${Date.now()}` + const imageXml = ` + +` + + try { + const validatedXml = validateAndFixXml(imageXml) + onDisplayChart(wrapWithMxFile(validatedXml), true) + + addToolOutput({ + tool: "display_image", + toolCallId: toolCall.toolCallId, + state: "output-available", + output: "图片已成功显示在画布上。", + }) + } catch (error) { + console.error("[display_image] Error:", error) + addToolOutput({ + tool: "display_image", + toolCallId: toolCall.toolCallId, + state: "output-error", + errorText: `显示图片失败: ${error instanceof Error ? error.message : String(error)}`, + }) + } + return + } + if (toolCall.toolName === "display_diagram") { const { xml } = toolCall.input as { xml: string } @@ -619,6 +685,49 @@ Continue from EXACTLY where you stopped.`, // DEBUG: Log finish reason to diagnose truncation console.log("[onFinish] finishReason:", metadata?.finishReason) console.log("[onFinish] metadata:", metadata) + console.log("[onFinish] message parts:", message?.parts) + + // Check if image generation mode produced an image + if (imageGenerationEnabled && message?.parts) { + for (const part of message.parts) { + // Check for image data in the part + if (part.type === "image" || (part as any).image) { + console.log("[onFinish] Found image in response:", part) + // Extract base64 image data + const imageUrl = + (part as any).image || (part as any).url + if (imageUrl && typeof imageUrl === "string") { + // Remove data URL prefix if present + const base64Data = imageUrl.replace( + /^data:image\/[^;]+;base64,/, + "", + ) + + // Create an mxCell with the image + const imageId = `img-${Date.now()}` + const imageXml = ` + +` + + try { + const validatedXml = validateAndFixXml(imageXml) + onDisplayChart( + wrapWithMxFile(validatedXml), + true, + ) + console.log( + "[onFinish] Image displayed on canvas", + ) + } catch (error) { + console.error( + "[onFinish] Error displaying image:", + error, + ) + } + } + } + } + } if (metadata) { // Use Number.isFinite to guard against NaN (typeof NaN === 'number' is true) @@ -954,6 +1063,25 @@ Continue from EXACTLY where you stopped.`, sessionStorage.setItem(SESSION_STORAGE_INPUT_KEY, input) } + // Image generation config handlers + const handleImageGenerationEnabledChange = (enabled: boolean) => { + setImageGenerationEnabled(enabled) + localStorage.setItem( + STORAGE_IMAGE_GENERATION_ENABLED_KEY, + String(enabled), + ) + } + + const handleImageResolutionChange = (resolution: string) => { + setImageResolution(resolution) + localStorage.setItem(STORAGE_IMAGE_RESOLUTION_KEY, resolution) + } + + const handleImageAspectRatioChange = (aspectRatio: string) => { + setImageAspectRatio(aspectRatio) + localStorage.setItem(STORAGE_IMAGE_ASPECT_RATIO_KEY, aspectRatio) + } + // Helper functions for message actions (regenerate/edit) // Extract previous XML snapshot before a given message index const getPreviousXml = (beforeIndex: number): string => { @@ -1036,6 +1164,11 @@ Continue from EXACTLY where you stopped.`, ...(minimalStyle && { "x-minimal-style": "true", }), + ...(imageGenerationEnabled && { + "x-image-generation": "true", + "x-image-resolution": imageResolution, + "x-image-aspect-ratio": imageAspectRatio, + }), }, }, ) @@ -1335,6 +1468,16 @@ Continue from EXACTLY where you stopped.`, /> + {/* Image Generation Config */} + + {/* Input */}