diff --git a/.changeset/cerebras-16k-max-tokens.md b/.changeset/cerebras-16k-max-tokens.md deleted file mode 100644 index 4dff805bf44..00000000000 --- a/.changeset/cerebras-16k-max-tokens.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -"kilo-code": patch ---- - -Update Cerebras maxTokens from 8192 to 16384 for all models diff --git a/.changeset/cli-image-paste-support.md b/.changeset/cli-image-paste-support.md new file mode 100644 index 00000000000..6352b455d90 --- /dev/null +++ b/.changeset/cli-image-paste-support.md @@ -0,0 +1,9 @@ +--- +"@kilocode/cli": patch +--- + +Add image paste support to CLI + +- Allow Ctrl+V in the CLI to paste clipboard images, attach them as [Image #N], and send them with messages (macOS only, with status feedback and cleanup) +- Add image mention parsing (@path and [Image #N]) so pasted or referenced images are included when sending messages +- Split media code into a dedicated module with platform-specific clipboard handlers and image utilities diff --git a/.changeset/fix-empty-checkpoints.md b/.changeset/fix-empty-checkpoints.md new file mode 100644 index 00000000000..031cea3f177 --- /dev/null +++ b/.changeset/fix-empty-checkpoints.md @@ -0,0 +1,5 @@ +--- +"kilo-code": patch +--- + +Prevent empty checkpoints from being created on every tool use diff --git a/.changeset/gorgeous-carrots-check.md b/.changeset/gorgeous-carrots-check.md deleted file mode 100644 index 78c1960c755..00000000000 --- a/.changeset/gorgeous-carrots-check.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -"kilo-code": minor ---- - -Add support for skills diff --git a/.changeset/short-hats-appear.md b/.changeset/short-hats-appear.md new file mode 100644 index 00000000000..94ff0e79e64 --- /dev/null +++ b/.changeset/short-hats-appear.md @@ -0,0 +1,5 @@ +--- +"kilo-code": patch +--- + +Jetbrains IDEs - Improve intialization process diff --git a/CHANGELOG.md b/CHANGELOG.md index d92e6ca68a9..4a3a62d0c6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # kilo-code +## 4.141.0 + +### Minor Changes + +- [#4702](https://github.com/Kilo-Org/kilocode/pull/4702) [`b84a66f`](https://github.com/Kilo-Org/kilocode/commit/b84a66f5923cf2600a6d5c8e2b5fd49759406696) Thanks [@chrarnoldus](https://github.com/chrarnoldus)! - Add support for skills + +### Patch Changes + +- [#4710](https://github.com/Kilo-Org/kilocode/pull/4710) [`c128319`](https://github.com/Kilo-Org/kilocode/commit/c1283192df1b0e59fef8b9ab2d3442bf4a07abde) Thanks [@sebastiand-cerebras](https://github.com/sebastiand-cerebras)! - Update Cerebras maxTokens from 8192 to 16384 for all models + +- [#4718](https://github.com/Kilo-Org/kilocode/pull/4718) [`9a465b0`](https://github.com/Kilo-Org/kilocode/commit/9a465b06fe401f70dd166fb5b320a8070f07c727) Thanks [@marius-kilocode](https://github.com/marius-kilocode)! - Fix terminal scroll-flicker in CLI by disabling streaming output and enabling Ink incremental rendering + +- [#4719](https://github.com/Kilo-Org/kilocode/pull/4719) [`57b0873`](https://github.com/Kilo-Org/kilocode/commit/57b08737788cd504954563d46eb1e6323d619301) Thanks [@marius-kilocode](https://github.com/marius-kilocode)! - Confirm before exiting the CLI on Ctrl+C/Cmd+C. + ## 4.140.3 ### Patch Changes diff --git a/apps/kilocode-docs/docs/advanced-usage/cloud-agent.md b/apps/kilocode-docs/docs/advanced-usage/cloud-agent.md index 58cb7d6970b..bd0b3d1ebe9 100644 --- a/apps/kilocode-docs/docs/advanced-usage/cloud-agent.md +++ b/apps/kilocode-docs/docs/advanced-usage/cloud-agent.md @@ -91,6 +91,16 @@ You can customize each Cloud Agent session by defining: --- +## Skills + +Cloud Agents support project-level [skills](../cli#skills) stored in your repository. When your repo is cloned, any skills in `.kilocode/skills/` are automatically available. + +:::note +Global skills (`~/.kilocode/skills/`) are not available in Cloud Agents since there is no persistent user home directory. +::: + +--- + ## Perfect For Cloud Agents are great for: diff --git a/apps/kilocode-docs/docs/cli.md b/apps/kilocode-docs/docs/cli.md index cb6f17bcd7c..f7b0eefc8ba 100644 --- a/apps/kilocode-docs/docs/cli.md +++ b/apps/kilocode-docs/docs/cli.md @@ -38,6 +38,7 @@ Upgrade the Kilo CLI package: - **Switch between hundreds of LLMs without constraints.** Other CLI tools only work with one model or curate opinionated lists. With Kilo, you can switch models without booting up another tool. - **Choose the right mode for the task in your workflow.** Select between Architect, Ask, Debug, Orchestrator, or custom agent modes. - **Automate tasks.** Get AI assistance writing shell scripts for tasks like renaming all of the files in a folder or transforming sizes for a set of images. +- **Extend capabilities with skills.** Add domain expertise and repeatable workflows through [Agent Skills](#skills). ## CLI reference @@ -68,6 +69,65 @@ Upgrade the Kilo CLI package: | `/help` | List available commands and how to use them | | | `/exit` | Exit the CLI | | +## Skills + +The CLI supports [Agent Skills](https://agentskills.io/), a lightweight format for extending AI capabilities with specialized knowledge and workflows. + +Skills are discovered from: + +- **Global skills**: `~/.kilocode/skills/` (available in all projects) +- **Project skills**: `.kilocode/skills/` (project-specific) + +Skills can be: + +- **Generic** - Available in all modes +- **Mode-specific** - Only loaded when using a particular mode (e.g., `code`, `architect`) + +For example: + +``` +your-project/ +└── .kilocode/ + ├── skills/ # Generic skills for this project + │ └── project-conventions/ + │ └── SKILL.md + └── skills-code/ # Code mode skills for this project + └── linting-rules/ + └── SKILL.md +``` + +### Adding a Skill + +1. Create the skill directory: + + ```bash + mkdir -p ~/.kilocode/skills/api-design + ``` + +2. Create a `SKILL.md` file with YAML frontmatter: + + ```markdown + --- + name: api-design + description: REST API design best practices and conventions + --- + + # API Design Guidelines + + When designing REST APIs, follow these conventions... + ``` + + The `name` field must match the directory name exactly. + +3. Start a new CLI session to load the skill + +#### Finding skills + +There are community efforts to build and share agent skills. Some resources include: + +- [Skills Marketplace](https://skillsmp.com/) - Community marketplace of skills +- [Skill Specification](https://agentskills.io/home) - Agent Skills specification + ## Checkpoint Management Kilo Code automatically creates checkpoints as you work, allowing you to revert to previous states in your project's history. diff --git a/apps/kilocode-docs/docs/features/skills.md b/apps/kilocode-docs/docs/features/skills.md new file mode 100644 index 00000000000..ed32388413f --- /dev/null +++ b/apps/kilocode-docs/docs/features/skills.md @@ -0,0 +1,290 @@ +# Skills + +Kilo Code implements [Agent Skills](https://agentskills.io/), a lightweight, open format for extending AI agent capabilities with specialized knowledge and workflows. + +## What Are Agent Skills? + +Agent Skills package domain expertise, new capabilities, and repeatable workflows that agents can use. At its core, a skill is a folder containing a `SKILL.md` file with metadata and instructions that tell an agent how to perform a specific task. + +This approach keeps agents fast while giving them access to more context on demand. When a task matches a skill's description, the agent reads the full instructions into context and follows them—optionally loading referenced files or executing bundled code as needed. + +### Key Benefits + +- **Self-documenting**: A skill author or user can read a `SKILL.md` file and understand what it does, making skills easy to audit and improve +- **Interoperable**: Skills work across any agent that implements the [Agent Skills specification](https://agentskills.io/specification) +- **Extensible**: Skills can range in complexity from simple text instructions to bundled scripts, templates, and reference materials +- **Shareable**: Skills are portable and can be easily shared between projects and developers + +## How Skills Work in Kilo Code + +Skills can be: + +- **Generic** - Available in all modes +- **Mode-specific** - Only loaded when using a particular mode (e.g., `code`, `architect`) + +The workflow is: + +1. **Discovery**: Skills are scanned from designated directories when Kilo Code initializes +2. **Activation**: When a mode is active, relevant skills are included in the system prompt +3. **Execution**: The AI agent follows the skill's instructions for applicable tasks + +## Skill Locations + +Skills are loaded from multiple locations, allowing both personal skills and project-specific instructions. + +### Global Skills (User-Level) + +Located in `~/.kilocode/skills/`: + +``` +~/.kilocode/ +├── skills/ # Generic skills (all modes) +│ ├── my-skill/ +│ │ └── SKILL.md +│ └── another-skill/ +│ └── SKILL.md +├── skills-code/ # Code mode only +│ └── refactoring/ +│ └── SKILL.md +└── skills-architect/ # Architect mode only + └── system-design/ + └── SKILL.md +``` + +### Project Skills (Workspace-Level) + +Located in `.kilocode/skills/` within your project: + +``` +your-project/ +└── .kilocode/ + ├── skills/ # Generic skills for this project + │ └── project-conventions/ + │ └── SKILL.md + └── skills-code/ # Code mode skills for this project + └── linting-rules/ + └── SKILL.md +``` + +## SKILL.md Format + +The `SKILL.md` file uses YAML frontmatter followed by Markdown content containing the instructions: + +```markdown +--- +name: my-skill-name +description: A brief description of what this skill does and when to use it +--- + +# Instructions + +Your detailed instructions for the AI agent go here. + +These instructions will be included in the system prompt when: + +1. The skill is discovered in a valid location +2. The current mode matches (or the skill is generic) + +## Example Usage + +You can include examples, guidelines, code snippets, etc. +``` + +### Frontmatter Fields + +Per the [Agent Skills specification](https://agentskills.io/specification): + +| Field | Required | Description | +| --------------- | -------- | ----------------------------------------------------------------------------------------------------- | +| `name` | Yes | Max 64 characters. Lowercase letters, numbers, and hyphens only. Must not start or end with a hyphen. | +| `description` | Yes | Max 1024 characters. Describes what the skill does and when to use it. | +| `license` | No | License name or reference to a bundled license file | +| `compatibility` | No | Environment requirements (intended product, system packages, network access, etc.) | +| `metadata` | No | Arbitrary key-value mapping for additional metadata | + +### Example with Optional Fields + +```markdown +--- +name: pdf-processing +description: Extract text and tables from PDF files, fill forms, merge documents. +license: Apache-2.0 +metadata: + author: example-org + version: 1.0.0 +--- + +## How to extract text + +1. Use pdfplumber for text extraction... + +## How to fill forms + +... +``` + +### Name Matching Rule + +In Kilo Code, the `name` field **must match** the parent directory name: + +``` +✅ Correct: +skills/ +└── frontend-design/ + └── SKILL.md # name: frontend-design + +❌ Incorrect: +skills/ +└── frontend-design/ + └── SKILL.md # name: my-frontend-skill (doesn't match!) +``` + +## Optional Bundled Resources + +While `SKILL.md` is the only required file, you can optionally include additional directories to support your skill: + +``` +my-skill/ +├── SKILL.md # Required: instructions + metadata +├── scripts/ # Optional: executable code +├── references/ # Optional: documentation +└── assets/ # Optional: templates, resources +``` + +These additional files can be referenced from your skill's instructions, allowing the agent to read documentation, execute scripts, or use templates as needed. + +## Priority and Overrides + +When multiple skills share the same name, Kilo Code uses these priority rules: + +1. **Project skills override global skills** - A project skill with the same name takes precedence +2. **Mode-specific skills override generic skills** - A skill in `skills-code/` overrides the same skill in `skills/` when in Code mode + +This allows you to: + +- Define global skills for personal use +- Override them per-project when needed +- Customize behavior for specific modes + +## When Skills Are Loaded + +Skills are discovered when Kilo Code initializes: + +- When VSCode starts +- When you reload the VSCode window (`Cmd+Shift+P` → "Developer: Reload Window") + +Skills directories are monitored for changes to `SKILL.md` files. However, the most reliable way to pick up new skills is to reload VS or the Kilo Code extension. + +**Adding or modifying skills requires reloading VSCode for changes to take effect.** + +## Example: Creating a Skill + +1. Create the skill directory: + + ```bash + mkdir -p ~/.kilocode/skills/api-design + ``` + +2. Create `SKILL.md`: + + ```markdown + --- + name: api-design + description: REST API design best practices and conventions + --- + + # API Design Guidelines + + When designing REST APIs, follow these conventions: + + ## URL Structure + + - Use plural nouns for resources: `/users`, `/orders` + - Use kebab-case for multi-word resources: `/order-items` + - Nest related resources: `/users/{id}/orders` + + ## HTTP Methods + + - GET: Retrieve resources + - POST: Create new resources + - PUT: Replace entire resource + - PATCH: Partial update + - DELETE: Remove resource + + ## Response Codes + + - 200: Success + - 201: Created + - 400: Bad Request + - 404: Not Found + - 500: Server Error + ``` + +3. Reload VSCode to load the skill + +4. The skill will now be available in all modes + +## Mode-Specific Skills + +To create a skill that only appears in a specific mode: + +```bash +# For Code mode only +mkdir -p ~/.kilocode/skills-code/typescript-patterns + +# For Architect mode only +mkdir -p ~/.kilocode/skills-architect/microservices +``` + +The directory naming pattern is `skills-{mode-slug}` where `{mode-slug}` matches the mode's identifier (e.g., `code`, `architect`, `ask`, `debug`). + +## Using Symlinks + +You can symlink skills directories to share skills across machines or from a central repository: + +```bash +# Symlink entire skills directory +ln -s /path/to/shared/skills ~/.kilocode/skills + +# Or symlink individual skills +ln -s /path/to/shared/api-design ~/.kilocode/skills/api-design +``` + +When using symlinks, the skill's `name` field must match the **symlink name**, not the target directory name. + +## Finding Skills + +There are community efforts to build and share agent skills. Some resources include: + +- [Skills Marketplace](https://skillsmp.com/) - Community marketplace of skills +- [Skill Specification](https://agentskills.io/home) - Agent Skills specification + +### Creating Your Own + +Skills are simple Markdown files with frontmatter. Start with your existing prompt templates or instructions and convert them to the skill format. + +## Troubleshooting + +### Skill Not Loading? + +1. **Check the Output panel**: Open `View` → `Output` → Select "Kilo Code" from dropdown. Look for skill-related errors. + +2. **Verify frontmatter**: Ensure `name` exactly matches the directory name and `description` is present. + +3. **Reload VSCode**: Skills are loaded at startup. Use `Cmd+Shift+P` → "Developer: Reload Window". + +4. **Check file location**: Ensure `SKILL.md` is directly inside the skill directory, not nested further. + +### Common Errors + +| Error | Cause | Solution | +| ------------------------------- | -------------------------------------------- | ------------------------------------------------ | +| "missing required 'name' field" | No `name` in frontmatter | Add `name: your-skill-name` | +| "name doesn't match directory" | Mismatch between frontmatter and folder name | Make `name` match exactly | +| Skill not appearing | Wrong directory structure | Verify path follows `skills/skill-name/SKILL.md` | + +## Related + +- [Custom Modes](custom-modes) - Create custom modes that can use specific skills +- [Custom Instructions](../advanced-usage/custom-instructions) - Global instructions vs. skill-based instructions +- [Custom Rules](../advanced-usage/custom-rules) - Project-level rules complementing skills diff --git a/apps/kilocode-docs/sidebars.ts b/apps/kilocode-docs/sidebars.ts index 1b40ad47756..809579dcdaf 100644 --- a/apps/kilocode-docs/sidebars.ts +++ b/apps/kilocode-docs/sidebars.ts @@ -171,7 +171,7 @@ const sidebars: SidebarsConfig = { { type: "category", label: "Customization", - items: ["features/settings-management", "features/custom-modes"], + items: ["features/settings-management", "features/custom-modes", "features/skills"], }, { type: "category", diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 0e343b8f160..ac5a59ec2a8 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -1,5 +1,5 @@ import { basename } from "node:path" -import { render, Instance } from "ink" +import { render, Instance, type RenderOptions } from "ink" import React from "react" import { createStore } from "jotai" import { createExtensionService, ExtensionService } from "./services/extension.js" @@ -33,6 +33,7 @@ import { getSelectedModelId } from "./utils/providers.js" import { KiloCodePathProvider, ExtensionMessengerAdapter } from "./services/session-adapters.js" import { getKiloToken } from "./config/persistence.js" import { SessionManager } from "../../src/shared/kilocode/cli-sessions/core/SessionManager.js" +import { triggerExitConfirmationAtom } from "./state/atoms/keyboard.js" /** * Main application class that orchestrates the CLI lifecycle @@ -330,6 +331,13 @@ export class CLI { // Disable stdin for Ink when in CI mode or when stdin is piped (not a TTY) // This prevents the "Raw mode is not supported" error const shouldDisableStdin = this.options.jsonInteractive || this.options.ci || !process.stdin.isTTY + const renderOptions: RenderOptions = { + // Enable Ink's incremental renderer to avoid redrawing the entire screen on every update. + // This reduces flickering for frequently updating UIs. + incrementalRendering: true, + exitOnCtrlC: false, + ...(shouldDisableStdin ? { stdout: process.stdout, stderr: process.stderr } : {}), + } this.ui = render( React.createElement(App, { @@ -349,12 +357,7 @@ export class CLI { }, onExit: () => this.dispose(), }), - shouldDisableStdin - ? { - stdout: process.stdout, - stderr: process.stderr, - } - : undefined, + renderOptions, ) // Wait for UI to exit @@ -671,6 +674,31 @@ export class CLI { return this.store } + /** + * Returns true if the CLI should show an exit confirmation prompt for SIGINT. + */ + shouldConfirmExitOnSigint(): boolean { + return ( + !!this.store && + !this.options.ci && + !this.options.json && + !this.options.jsonInteractive && + process.stdin.isTTY + ) + } + + /** + * Trigger the exit confirmation prompt. Returns true if handled. + */ + requestExitConfirmation(): boolean { + if (!this.shouldConfirmExitOnSigint()) { + return false + } + + this.store?.set(triggerExitConfirmationAtom) + return true + } + /** * Check if the application is initialized */ diff --git a/cli/src/index.ts b/cli/src/index.ts index a2f5d239620..03514c94357 100644 --- a/cli/src/index.ts +++ b/cli/src/index.ts @@ -275,6 +275,10 @@ program // Handle process termination signals process.on("SIGINT", async () => { + if (cli?.requestExitConfirmation()) { + return + } + if (cli) { await cli.dispose("SIGINT") } else { diff --git a/cli/src/media/__tests__/atMentionParser.test.ts b/cli/src/media/__tests__/atMentionParser.test.ts new file mode 100644 index 00000000000..144e2298777 --- /dev/null +++ b/cli/src/media/__tests__/atMentionParser.test.ts @@ -0,0 +1,271 @@ +import { + parseAtMentions, + extractImagePaths, + removeImageMentions, + reconstructText, + type ParsedSegment, +} from "../atMentionParser" + +describe("atMentionParser", () => { + describe("parseAtMentions", () => { + it("should parse simple @ mentions", () => { + const result = parseAtMentions("Check @./image.png please") + + expect(result.paths).toEqual(["./image.png"]) + expect(result.imagePaths).toEqual(["./image.png"]) + expect(result.otherPaths).toEqual([]) + expect(result.segments).toHaveLength(3) + }) + + it("should parse multiple @ mentions", () => { + const result = parseAtMentions("Look at @./first.png and @./second.jpg") + + expect(result.paths).toEqual(["./first.png", "./second.jpg"]) + expect(result.imagePaths).toEqual(["./first.png", "./second.jpg"]) + }) + + it("should distinguish image and non-image paths", () => { + const result = parseAtMentions("Check @./code.ts and @./screenshot.png") + + expect(result.paths).toEqual(["./code.ts", "./screenshot.png"]) + expect(result.imagePaths).toEqual(["./screenshot.png"]) + expect(result.otherPaths).toEqual(["./code.ts"]) + }) + + it("should handle quoted paths with spaces", () => { + const result = parseAtMentions('Look at @"path with spaces/image.png"') + + expect(result.paths).toEqual(["path with spaces/image.png"]) + expect(result.imagePaths).toEqual(["path with spaces/image.png"]) + }) + + it("should handle single-quoted paths", () => { + const result = parseAtMentions("Look at @'path with spaces/image.png'") + + expect(result.paths).toEqual(["path with spaces/image.png"]) + }) + + it("should handle escaped spaces in paths", () => { + const result = parseAtMentions("Look at @path\\ with\\ spaces/image.png") + + expect(result.paths).toEqual(["path with spaces/image.png"]) + }) + + it("should stop at path terminators", () => { + const result = parseAtMentions("Check @./image.png, then @./other.jpg") + + expect(result.paths).toEqual(["./image.png", "./other.jpg"]) + }) + + it("should handle @ at end of string", () => { + const result = parseAtMentions("End with @") + + expect(result.paths).toEqual([]) + expect(result.segments).toHaveLength(1) + }) + + it("should handle text without @ mentions", () => { + const result = parseAtMentions("Just regular text without mentions") + + expect(result.paths).toEqual([]) + expect(result.segments).toHaveLength(1) + expect(result.segments[0]).toMatchObject({ + type: "text", + content: "Just regular text without mentions", + }) + }) + + it("should handle absolute paths", () => { + const result = parseAtMentions("Check @/absolute/path/image.png") + + expect(result.paths).toEqual(["/absolute/path/image.png"]) + }) + + it("should handle relative paths with parent directory", () => { + const result = parseAtMentions("Check @../parent/image.png") + + expect(result.paths).toEqual(["../parent/image.png"]) + }) + + it("should preserve segment positions", () => { + const input = "Start @./image.png end" + const result = parseAtMentions(input) + + expect(result.segments[0]).toMatchObject({ + type: "text", + content: "Start ", + startIndex: 0, + endIndex: 6, + }) + expect(result.segments[1]).toMatchObject({ + type: "atPath", + content: "./image.png", + startIndex: 6, + endIndex: 18, + }) + expect(result.segments[2]).toMatchObject({ + type: "text", + content: " end", + startIndex: 18, + endIndex: 22, + }) + }) + + it("should handle @ in email addresses (not a file path)", () => { + // @ followed by typical email pattern should be parsed but not as an image + const result = parseAtMentions("Email: test@example.com") + + // It will try to parse but example.com is not an image + expect(result.imagePaths).toEqual([]) + }) + + it("should handle multiple @ mentions consecutively", () => { + const result = parseAtMentions("@./a.png@./b.png") + + // Without whitespace separator, @ is part of the path + // This is expected behavior - paths need whitespace separation + expect(result.paths).toHaveLength(1) + expect(result.paths[0]).toBe("./a.png@./b.png") + }) + + it("should ignore trailing punctuation when parsing image paths", () => { + const result = parseAtMentions("Check @./image.png? please and @./second.jpg.") + + expect(result.imagePaths).toEqual(["./image.png", "./second.jpg"]) + expect(result.otherPaths).toEqual([]) + }) + }) + + describe("extractImagePaths", () => { + it("should extract only image paths", () => { + const paths = extractImagePaths("Check @./code.ts and @./image.png and @./doc.md") + + expect(paths).toEqual(["./image.png"]) + }) + + it("should return empty array for text without images", () => { + const paths = extractImagePaths("No images here, just @./file.ts") + + expect(paths).toEqual([]) + }) + + it("should handle all supported image formats", () => { + const paths = extractImagePaths("@./a.png @./b.jpg @./c.jpeg @./d.webp") + + expect(paths).toEqual(["./a.png", "./b.jpg", "./c.jpeg", "./d.webp"]) + }) + }) + + describe("removeImageMentions", () => { + it("should remove image mentions from text", () => { + const result = removeImageMentions("Check @./image.png please") + + expect(result).toBe("Check please") + }) + + it("should preserve non-image mentions", () => { + const result = removeImageMentions("Check @./code.ts and @./image.png") + + expect(result).toBe("Check @./code.ts and ") + }) + + it("should use custom placeholder", () => { + const result = removeImageMentions("Check @./image.png please", "[image]") + + expect(result).toBe("Check [image] please") + }) + + it("should handle multiple image mentions", () => { + const result = removeImageMentions("@./a.png and @./b.jpg here") + + expect(result).toBe(" and here") + }) + + it("should not collapse newlines or indentation", () => { + const input = "Line1\n @./img.png\nLine3" + const result = removeImageMentions(input) + + expect(result).toBe("Line1\n \nLine3") + }) + }) + + describe("reconstructText", () => { + it("should reconstruct text from segments", () => { + const segments: ParsedSegment[] = [ + { type: "text", content: "Hello ", startIndex: 0, endIndex: 6 }, + { type: "atPath", content: "./image.png", startIndex: 6, endIndex: 18 }, + { type: "text", content: " world", startIndex: 18, endIndex: 24 }, + ] + + const result = reconstructText(segments) + + expect(result).toBe("Hello @./image.png world") + }) + + it("should apply transform function", () => { + const segments: ParsedSegment[] = [ + { type: "text", content: "Check ", startIndex: 0, endIndex: 6 }, + { type: "atPath", content: "./image.png", startIndex: 6, endIndex: 18 }, + ] + + const result = reconstructText(segments, (seg) => { + if (seg.type === "atPath") { + return `[IMG: ${seg.content}]` + } + return seg.content + }) + + expect(result).toBe("Check [IMG: ./image.png]") + }) + }) + + describe("edge cases", () => { + it("should handle empty string", () => { + const result = parseAtMentions("") + + expect(result.paths).toEqual([]) + expect(result.segments).toHaveLength(0) + }) + + it("should handle only @", () => { + const result = parseAtMentions("@") + + expect(result.paths).toEqual([]) + }) + + it("should handle @ followed by space", () => { + const result = parseAtMentions("@ space") + + expect(result.paths).toEqual([]) + }) + + it("should handle unclosed quotes", () => { + const result = parseAtMentions('Check @"unclosed quote') + + // Should still extract what it can + expect(result.paths).toHaveLength(1) + }) + + it("should handle escaped backslash in path", () => { + const result = parseAtMentions("@path\\\\with\\\\backslash.png") + + expect(result.paths).toEqual(["path\\with\\backslash.png"]) + }) + + it("should handle various path terminators", () => { + const tests = [ + { input: "@./img.png)", expected: "./img.png" }, + { input: "@./img.png]", expected: "./img.png" }, + { input: "@./img.png}", expected: "./img.png" }, + { input: "@./img.png>", expected: "./img.png" }, + { input: "@./img.png|", expected: "./img.png" }, + { input: "@./img.png&", expected: "./img.png" }, + ] + + for (const { input, expected } of tests) { + const result = parseAtMentions(input) + expect(result.paths).toEqual([expected]) + } + }) + }) +}) diff --git a/cli/src/media/__tests__/clipboard.test.ts b/cli/src/media/__tests__/clipboard.test.ts new file mode 100644 index 00000000000..41a8530fd58 --- /dev/null +++ b/cli/src/media/__tests__/clipboard.test.ts @@ -0,0 +1,170 @@ +import { + isClipboardSupported, + // Domain logic functions (exported for testing) + parseClipboardInfo, + detectImageFormat, + buildDataUrl, + getUnsupportedClipboardPlatformMessage, + getClipboardDir, + generateClipboardFilename, +} from "../clipboard" + +describe("clipboard utility", () => { + describe("parseClipboardInfo (macOS clipboard info parsing)", () => { + it("should detect PNG format", () => { + expect(parseClipboardInfo("«class PNGf», 1234")).toEqual({ hasImage: true, format: "png" }) + }) + + it("should detect JPEG format", () => { + expect(parseClipboardInfo("«class JPEG», 5678")).toEqual({ hasImage: true, format: "jpeg" }) + }) + + it("should detect TIFF format", () => { + expect(parseClipboardInfo("TIFF picture, 9012")).toEqual({ hasImage: true, format: "tiff" }) + }) + + it("should detect GIF format", () => { + expect(parseClipboardInfo("«class GIFf», 3456")).toEqual({ hasImage: true, format: "gif" }) + }) + + it("should return no image for text-only clipboard", () => { + expect(parseClipboardInfo("«class utf8», 100")).toEqual({ hasImage: false, format: null }) + }) + + it("should return no image for empty string", () => { + expect(parseClipboardInfo("")).toEqual({ hasImage: false, format: null }) + }) + + it("should handle multiple types and pick first image", () => { + expect(parseClipboardInfo("«class PNGf», 1234, «class utf8», 100")).toEqual({ + hasImage: true, + format: "png", + }) + }) + }) + + describe("detectImageFormat (format detection from bytes)", () => { + it("should detect PNG from magic bytes", () => { + const pngBytes = Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]) + expect(detectImageFormat(pngBytes)).toBe("png") + }) + + it("should detect JPEG from magic bytes", () => { + const jpegBytes = Buffer.from([0xff, 0xd8, 0xff, 0xe0]) + expect(detectImageFormat(jpegBytes)).toBe("jpeg") + }) + + it("should detect GIF from magic bytes", () => { + const gifBytes = Buffer.from([0x47, 0x49, 0x46, 0x38, 0x39, 0x61]) // GIF89a + expect(detectImageFormat(gifBytes)).toBe("gif") + }) + + it("should detect WebP from magic bytes", () => { + const webpBytes = Buffer.from([0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x45, 0x42, 0x50]) + expect(detectImageFormat(webpBytes)).toBe("webp") + }) + + it("should return null for unknown format", () => { + const unknownBytes = Buffer.from([0x00, 0x01, 0x02, 0x03]) + expect(detectImageFormat(unknownBytes)).toBe(null) + }) + + it("should return null for empty buffer", () => { + expect(detectImageFormat(Buffer.from([]))).toBe(null) + }) + }) + + describe("buildDataUrl", () => { + it("should build PNG data URL", () => { + const data = Buffer.from([0x89, 0x50, 0x4e, 0x47]) + const result = buildDataUrl(data, "png") + expect(result).toBe(`data:image/png;base64,${data.toString("base64")}`) + }) + + it("should build JPEG data URL", () => { + const data = Buffer.from([0xff, 0xd8, 0xff]) + const result = buildDataUrl(data, "jpeg") + expect(result).toBe(`data:image/jpeg;base64,${data.toString("base64")}`) + }) + + it("should handle arbitrary binary data", () => { + const data = Buffer.from("Hello, World!") + const result = buildDataUrl(data, "png") + expect(result).toMatch(/^data:image\/png;base64,/) + expect(result).toContain(data.toString("base64")) + }) + }) + + describe("getUnsupportedClipboardPlatformMessage", () => { + it("should mention macOS", () => { + const msg = getUnsupportedClipboardPlatformMessage() + expect(msg).toContain("macOS") + }) + + it("should mention @path/to/image.png alternative", () => { + const msg = getUnsupportedClipboardPlatformMessage() + expect(msg).toContain("@") + expect(msg.toLowerCase()).toContain("image") + }) + }) + + describe("isClipboardSupported (platform detection)", () => { + const originalPlatform = process.platform + + afterEach(() => { + Object.defineProperty(process, "platform", { value: originalPlatform }) + }) + + it("should return true for darwin", async () => { + Object.defineProperty(process, "platform", { value: "darwin" }) + expect(await isClipboardSupported()).toBe(true) + }) + + it("should return false for win32", async () => { + Object.defineProperty(process, "platform", { value: "win32" }) + expect(await isClipboardSupported()).toBe(false) + }) + }) + + describe("getClipboardDir", () => { + it("should return clipboard directory in system temp", () => { + const result = getClipboardDir() + expect(result).toContain("kilocode-clipboard") + // Should be in temp directory, not a project directory + expect(result).not.toContain(".kilocode-clipboard") + }) + }) + + describe("generateClipboardFilename", () => { + it("should generate unique filenames", () => { + const filename1 = generateClipboardFilename("png") + const filename2 = generateClipboardFilename("png") + expect(filename1).not.toBe(filename2) + }) + + it("should include correct extension", () => { + const pngFilename = generateClipboardFilename("png") + const jpegFilename = generateClipboardFilename("jpeg") + expect(pngFilename).toMatch(/\.png$/) + expect(jpegFilename).toMatch(/\.jpeg$/) + }) + + it("should start with clipboard- prefix", () => { + const filename = generateClipboardFilename("png") + expect(filename).toMatch(/^clipboard-/) + }) + + it("should include timestamp", () => { + const before = Date.now() + const filename = generateClipboardFilename("png") + const after = Date.now() + + // Extract timestamp from filename (clipboard-TIMESTAMP-RANDOM.ext) + const match = filename.match(/^clipboard-(\d+)-/) + expect(match).toBeTruthy() + const timestamp = parseInt(match![1], 10) + expect(timestamp).toBeGreaterThanOrEqual(before) + expect(timestamp).toBeLessThanOrEqual(after) + }) + }) +}) diff --git a/cli/src/media/__tests__/images.test.ts b/cli/src/media/__tests__/images.test.ts new file mode 100644 index 00000000000..b453b375eb5 --- /dev/null +++ b/cli/src/media/__tests__/images.test.ts @@ -0,0 +1,251 @@ +import * as fs from "fs/promises" +import * as path from "path" +import * as os from "os" +import { + isImagePath, + getMimeType, + readImageAsDataUrl, + processImagePaths, + SUPPORTED_IMAGE_EXTENSIONS, + MAX_IMAGE_SIZE_BYTES, +} from "../images" + +describe("images utility", () => { + let tempDir: string + + beforeEach(async () => { + tempDir = await fs.mkdtemp(path.join(os.tmpdir(), "images-test-")) + }) + + afterEach(async () => { + await fs.rm(tempDir, { recursive: true, force: true }) + }) + + describe("isImagePath", () => { + it("should return true for supported image extensions", () => { + expect(isImagePath("image.png")).toBe(true) + expect(isImagePath("image.PNG")).toBe(true) + expect(isImagePath("image.jpg")).toBe(true) + expect(isImagePath("image.JPG")).toBe(true) + expect(isImagePath("image.jpeg")).toBe(true) + expect(isImagePath("image.JPEG")).toBe(true) + expect(isImagePath("image.webp")).toBe(true) + expect(isImagePath("image.WEBP")).toBe(true) + expect(isImagePath("image.gif")).toBe(true) + expect(isImagePath("image.GIF")).toBe(true) + expect(isImagePath("image.tiff")).toBe(true) + expect(isImagePath("image.TIFF")).toBe(true) + }) + + it("should return false for non-image extensions", () => { + expect(isImagePath("file.txt")).toBe(false) + expect(isImagePath("file.ts")).toBe(false) + expect(isImagePath("file.js")).toBe(false) + expect(isImagePath("file.pdf")).toBe(false) + expect(isImagePath("file.bmp")).toBe(false) // BMP not supported + expect(isImagePath("file")).toBe(false) + }) + + it("should handle paths with directories", () => { + expect(isImagePath("/path/to/image.png")).toBe(true) + expect(isImagePath("./relative/path/image.jpg")).toBe(true) + expect(isImagePath("../parent/image.webp")).toBe(true) + }) + + it("should handle paths with dots in filename", () => { + expect(isImagePath("my.file.name.png")).toBe(true) + expect(isImagePath("version.1.2.3.jpg")).toBe(true) + }) + }) + + describe("getMimeType", () => { + it("should return correct MIME type for PNG", () => { + expect(getMimeType("image.png")).toBe("image/png") + expect(getMimeType("image.PNG")).toBe("image/png") + }) + + it("should return correct MIME type for JPEG", () => { + expect(getMimeType("image.jpg")).toBe("image/jpeg") + expect(getMimeType("image.jpeg")).toBe("image/jpeg") + expect(getMimeType("image.JPG")).toBe("image/jpeg") + }) + + it("should return correct MIME type for WebP", () => { + expect(getMimeType("image.webp")).toBe("image/webp") + }) + + it("should return correct MIME type for GIF and TIFF", () => { + expect(getMimeType("image.gif")).toBe("image/gif") + expect(getMimeType("image.tiff")).toBe("image/tiff") + }) + + it("should throw for unsupported types", () => { + expect(() => getMimeType("image.bmp")).toThrow("Unsupported image type") + expect(() => getMimeType("image.svg")).toThrow("Unsupported image type") + }) + }) + + describe("readImageAsDataUrl", () => { + it("should read a PNG file and return data URL", async () => { + // Create a minimal valid PNG (1x1 red pixel) + const pngData = Buffer.from([ + 0x89, + 0x50, + 0x4e, + 0x47, + 0x0d, + 0x0a, + 0x1a, + 0x0a, // PNG signature + 0x00, + 0x00, + 0x00, + 0x0d, // IHDR length + 0x49, + 0x48, + 0x44, + 0x52, // IHDR type + 0x00, + 0x00, + 0x00, + 0x01, // width = 1 + 0x00, + 0x00, + 0x00, + 0x01, // height = 1 + 0x08, + 0x02, // bit depth 8, color type 2 (RGB) + 0x00, + 0x00, + 0x00, // compression, filter, interlace + 0x90, + 0x77, + 0x53, + 0xde, // CRC + 0x00, + 0x00, + 0x00, + 0x0c, // IDAT length + 0x49, + 0x44, + 0x41, + 0x54, // IDAT type + 0x08, + 0xd7, + 0x63, + 0xf8, + 0xcf, + 0xc0, + 0x00, + 0x00, + 0x01, + 0x01, + 0x01, + 0x00, // compressed data + 0x18, + 0xdd, + 0x8d, + 0xb5, // CRC + 0x00, + 0x00, + 0x00, + 0x00, // IEND length + 0x49, + 0x45, + 0x4e, + 0x44, // IEND type + 0xae, + 0x42, + 0x60, + 0x82, // CRC + ]) + + const imagePath = path.join(tempDir, "test.png") + await fs.writeFile(imagePath, pngData) + + const dataUrl = await readImageAsDataUrl(imagePath) + + expect(dataUrl).toMatch(/^data:image\/png;base64,/) + expect(dataUrl.length).toBeGreaterThan("data:image/png;base64,".length) + }) + + it("should resolve relative paths from basePath", async () => { + const pngData = Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]) // Minimal PNG header + const imagePath = path.join(tempDir, "relative.png") + await fs.writeFile(imagePath, pngData) + + const dataUrl = await readImageAsDataUrl("relative.png", tempDir) + + expect(dataUrl).toMatch(/^data:image\/png;base64,/) + }) + + it("should throw for non-existent files", async () => { + await expect(readImageAsDataUrl("/non/existent/path.png")).rejects.toThrow("Image file not found") + }) + + it("should throw for non-image files", async () => { + const textPath = path.join(tempDir, "test.txt") + await fs.writeFile(textPath, "Hello, world!") + + await expect(readImageAsDataUrl(textPath)).rejects.toThrow("Not a supported image type") + }) + + it("should throw for files larger than the maximum size", async () => { + const largeBuffer = Buffer.alloc(MAX_IMAGE_SIZE_BYTES + 1, 0xff) + const largePath = path.join(tempDir, "too-big.png") + await fs.writeFile(largePath, largeBuffer) + + await expect(readImageAsDataUrl(largePath)).rejects.toThrow("Image file is too large") + }) + }) + + describe("processImagePaths", () => { + it("should process multiple image paths", async () => { + // Create test images + const pngData = Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]) + const image1 = path.join(tempDir, "image1.png") + const image2 = path.join(tempDir, "image2.png") + await fs.writeFile(image1, pngData) + await fs.writeFile(image2, pngData) + + const result = await processImagePaths([image1, image2]) + + expect(result.images).toHaveLength(2) + expect(result.errors).toHaveLength(0) + expect(result.images[0]).toMatch(/^data:image\/png;base64,/) + expect(result.images[1]).toMatch(/^data:image\/png;base64,/) + }) + + it("should collect errors for failed paths", async () => { + const result = await processImagePaths(["/non/existent.png", "/another/missing.jpg"]) + + expect(result.images).toHaveLength(0) + expect(result.errors).toHaveLength(2) + expect(result.errors[0]).toMatchObject({ + path: "/non/existent.png", + }) + }) + + it("should partially succeed when some paths fail", async () => { + const pngData = Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]) + const validPath = path.join(tempDir, "valid.png") + await fs.writeFile(validPath, pngData) + + const result = await processImagePaths([validPath, "/non/existent.png"]) + + expect(result.images).toHaveLength(1) + expect(result.errors).toHaveLength(1) + }) + }) + + describe("SUPPORTED_IMAGE_EXTENSIONS", () => { + it("should contain expected extensions", () => { + expect(SUPPORTED_IMAGE_EXTENSIONS).toContain(".png") + expect(SUPPORTED_IMAGE_EXTENSIONS).toContain(".jpg") + expect(SUPPORTED_IMAGE_EXTENSIONS).toContain(".jpeg") + expect(SUPPORTED_IMAGE_EXTENSIONS).toContain(".webp") + expect(SUPPORTED_IMAGE_EXTENSIONS).toContain(".gif") + expect(SUPPORTED_IMAGE_EXTENSIONS).toContain(".tiff") + }) + }) +}) diff --git a/cli/src/media/__tests__/processMessageImages.test.ts b/cli/src/media/__tests__/processMessageImages.test.ts new file mode 100644 index 00000000000..e5c47cadd41 --- /dev/null +++ b/cli/src/media/__tests__/processMessageImages.test.ts @@ -0,0 +1,144 @@ +import { removeImageReferences, extractImageReferences, processMessageImages } from "../processMessageImages" +import * as images from "../images" + +// Mock the images module +vi.mock("../images", () => ({ + readImageAsDataUrl: vi.fn(), +})) + +// Mock the logs module +vi.mock("../../services/logs", () => ({ + logs: { + debug: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }, +})) + +describe("processMessageImages helpers", () => { + describe("removeImageReferences", () => { + it("should remove image reference tokens without collapsing whitespace", () => { + const input = "Line1\n [Image #1]\nLine3" + const result = removeImageReferences(input) + + expect(result).toBe("Line1\n \nLine3") + }) + + it("should remove multiple image references", () => { + const input = "Hello [Image #1] world [Image #2] test" + const result = removeImageReferences(input) + + expect(result).toBe("Hello world test") + }) + + it("should handle text with no image references", () => { + const input = "Hello world" + const result = removeImageReferences(input) + + expect(result).toBe("Hello world") + }) + }) + + describe("extractImageReferences", () => { + it("should extract single image reference number", () => { + const input = "Hello [Image #1] world" + const result = extractImageReferences(input) + + expect(result).toEqual([1]) + }) + + it("should extract multiple image reference numbers", () => { + const input = "Hello [Image #1] world [Image #3] test [Image #2]" + const result = extractImageReferences(input) + + expect(result).toEqual([1, 3, 2]) + }) + + it("should return empty array when no references", () => { + const input = "Hello world" + const result = extractImageReferences(input) + + expect(result).toEqual([]) + }) + + it("should handle large reference numbers", () => { + const input = "[Image #999]" + const result = extractImageReferences(input) + + expect(result).toEqual([999]) + }) + }) + + describe("processMessageImages", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("should return original text when no images", async () => { + const result = await processMessageImages("Hello world") + + expect(result).toEqual({ + text: "Hello world", + images: [], + hasImages: false, + errors: [], + }) + }) + + it("should load images from [Image #N] references", async () => { + const mockDataUrl = "data:image/png;base64,abc123" + vi.mocked(images.readImageAsDataUrl).mockResolvedValue(mockDataUrl) + + const imageReferences = { 1: "/tmp/test.png" } + const result = await processMessageImages("Hello [Image #1] world", imageReferences) + + expect(images.readImageAsDataUrl).toHaveBeenCalledWith("/tmp/test.png") + expect(result.images).toEqual([mockDataUrl]) + expect(result.text).toBe("Hello world") + expect(result.hasImages).toBe(true) + expect(result.errors).toEqual([]) + }) + + it("should report error when image reference not found", async () => { + const imageReferences = { 2: "/tmp/other.png" } + const result = await processMessageImages("Hello [Image #1] world", imageReferences) + + expect(result.errors).toContain("Image #1 not found") + expect(result.images).toEqual([]) + }) + + it("should report error when image file fails to load", async () => { + vi.mocked(images.readImageAsDataUrl).mockRejectedValue(new Error("File not found")) + + const imageReferences = { 1: "/tmp/missing.png" } + const result = await processMessageImages("Hello [Image #1] world", imageReferences) + + expect(result.errors).toContain("Failed to load Image #1: File not found") + expect(result.images).toEqual([]) + }) + + it("should handle multiple image references", async () => { + const mockDataUrl1 = "data:image/png;base64,img1" + const mockDataUrl2 = "data:image/png;base64,img2" + vi.mocked(images.readImageAsDataUrl).mockResolvedValueOnce(mockDataUrl1).mockResolvedValueOnce(mockDataUrl2) + + const imageReferences = { + 1: "/tmp/test1.png", + 2: "/tmp/test2.png", + } + const result = await processMessageImages("[Image #1] and [Image #2]", imageReferences) + + expect(result.images).toEqual([mockDataUrl1, mockDataUrl2]) + expect(result.text).toBe(" and ") + expect(result.hasImages).toBe(true) + }) + + it("should process without imageReferences parameter", async () => { + const result = await processMessageImages("Hello world") + + expect(result.text).toBe("Hello world") + expect(result.images).toEqual([]) + expect(result.hasImages).toBe(false) + }) + }) +}) diff --git a/cli/src/media/atMentionParser.ts b/cli/src/media/atMentionParser.ts new file mode 100644 index 00000000000..f15b0cd6b0e --- /dev/null +++ b/cli/src/media/atMentionParser.ts @@ -0,0 +1,220 @@ +import { isImagePath } from "./images.js" +export interface ParsedSegment { + type: "text" | "atPath" + content: string + startIndex: number + endIndex: number +} + +export interface ParsedPrompt { + segments: ParsedSegment[] + paths: string[] + imagePaths: string[] + otherPaths: string[] +} + +const PATH_TERMINATORS = new Set([" ", "\t", "\n", "\r", ",", ";", ")", "]", "}", ">", "|", "&", "'", '"']) +const TRAILING_PUNCTUATION = new Set([".", ",", ":", ";", "!", "?"]) + +function isEscapedAt(input: string, index: number): boolean { + return index > 0 && input[index - 1] === "\\" +} + +function parseQuotedPath(input: string, startIndex: number): { path: string; endIndex: number } | null { + let i = startIndex + 2 // skip @ and opening quote + let path = "" + const quote = input[startIndex + 1] + + while (i < input.length) { + const char = input[i] + if (char === "\\" && i + 1 < input.length) { + const nextChar = input[i + 1] + if (nextChar === quote || nextChar === "\\") { + path += nextChar + i += 2 + continue + } + } + if (char === quote) { + return { path, endIndex: i + 1 } + } + path += char + i++ + } + + return path ? { path, endIndex: i } : null +} + +function stripTrailingPunctuation(path: string): { path: string; trimmed: boolean } { + let trimmed = path + let removed = false + while (trimmed.length > 0 && TRAILING_PUNCTUATION.has(trimmed[trimmed.length - 1]!)) { + trimmed = trimmed.slice(0, -1) + removed = true + } + return { path: trimmed, trimmed: removed } +} + +function parseUnquotedPath(input: string, startIndex: number): { path: string; endIndex: number } | null { + let i = startIndex + 1 + let path = "" + + while (i < input.length) { + const char = input[i]! + + if (char === "\\" && i + 1 < input.length) { + const nextChar = input[i + 1]! + if (nextChar === " " || nextChar === "\\" || PATH_TERMINATORS.has(nextChar)) { + path += nextChar + i += 2 + continue + } + } + + if (PATH_TERMINATORS.has(char)) { + break + } + + path += char + i++ + } + + if (!path) { + return null + } + + const { path: trimmedPath, trimmed } = stripTrailingPunctuation(path) + if (!trimmedPath) { + return null + } + + const endIndex = i - (trimmed ? path.length - trimmedPath.length : 0) + return { path: trimmedPath, endIndex } +} + +function extractPath(input: string, startIndex: number): { path: string; endIndex: number } | null { + if (startIndex + 1 >= input.length) { + return null + } + + const nextChar = input[startIndex + 1] + if (nextChar === '"' || nextChar === "'") { + return parseQuotedPath(input, startIndex) + } + + return parseUnquotedPath(input, startIndex) +} + +function pushTextSegment(segments: ParsedSegment[], input: string, textStart: number, currentIndex: number): void { + if (currentIndex > textStart) { + segments.push({ + type: "text", + content: input.slice(textStart, currentIndex), + startIndex: textStart, + endIndex: currentIndex, + }) + } +} + +function pushPathSegment( + segments: ParsedSegment[], + paths: string[], + imagePaths: string[], + otherPaths: string[], + currentIndex: number, + extracted: { path: string; endIndex: number }, +): void { + segments.push({ + type: "atPath", + content: extracted.path, + startIndex: currentIndex, + endIndex: extracted.endIndex, + }) + + paths.push(extracted.path) + + if (isImagePath(extracted.path)) { + imagePaths.push(extracted.path) + } else { + otherPaths.push(extracted.path) + } +} + +export function parseAtMentions(input: string): ParsedPrompt { + const segments: ParsedSegment[] = [] + const paths: string[] = [] + const imagePaths: string[] = [] + const otherPaths: string[] = [] + + let currentIndex = 0 + let textStart = 0 + + while (currentIndex < input.length) { + const char = input[currentIndex] + + if (char === "@" && !isEscapedAt(input, currentIndex)) { + const extracted = extractPath(input, currentIndex) + if (!extracted) { + currentIndex++ + continue + } + + pushTextSegment(segments, input, textStart, currentIndex) + pushPathSegment(segments, paths, imagePaths, otherPaths, currentIndex, extracted) + currentIndex = extracted.endIndex + textStart = currentIndex + continue + } + + currentIndex++ + } + + if (textStart < input.length) { + segments.push({ + type: "text", + content: input.slice(textStart), + startIndex: textStart, + endIndex: input.length, + }) + } + + return { segments, paths, imagePaths, otherPaths } +} + +export function extractImagePaths(input: string): string[] { + return parseAtMentions(input).imagePaths +} + +export function removeImageMentions(input: string, placeholder: string = ""): string { + const parsed = parseAtMentions(input) + + let result = "" + for (const segment of parsed.segments) { + if (segment.type === "text") { + result += segment.content + } else if (segment.type === "atPath") { + if (isImagePath(segment.content)) { + result += placeholder + } else { + result += `@${segment.content}` + } + } + } + + return result +} + +export function reconstructText(segments: ParsedSegment[], transform?: (segment: ParsedSegment) => string): string { + if (transform) { + return segments.map(transform).join("") + } + + return segments + .map((seg) => { + if (seg.type === "text") { + return seg.content + } + return `@${seg.content}` + }) + .join("") +} diff --git a/cli/src/media/clipboard-macos.ts b/cli/src/media/clipboard-macos.ts new file mode 100644 index 00000000000..3abdf143eec --- /dev/null +++ b/cli/src/media/clipboard-macos.ts @@ -0,0 +1,144 @@ +import * as fs from "fs" +import * as path from "path" +import { logs } from "../services/logs.js" +import { + buildDataUrl, + ensureClipboardDir, + execFileAsync, + generateClipboardFilename, + parseClipboardInfo, + type ClipboardImageResult, + type SaveClipboardResult, +} from "./clipboard-shared.js" + +export async function hasClipboardImageMacOS(): Promise { + const { stdout } = await execFileAsync("osascript", ["-e", "clipboard info"]) + return parseClipboardInfo(stdout).hasImage +} + +export async function readClipboardImageMacOS(): Promise { + const { stdout: info } = await execFileAsync("osascript", ["-e", "clipboard info"]) + const parsed = parseClipboardInfo(info) + + if (!parsed.hasImage || !parsed.format) { + return { + success: false, + error: "No image found in clipboard.", + } + } + + const formatToClass: Record = { + png: "PNGf", + jpeg: "JPEG", + tiff: "TIFF", + gif: "GIFf", + } + + const appleClass = formatToClass[parsed.format] + if (!appleClass) { + return { + success: false, + error: `Unsupported image format: ${parsed.format}`, + } + } + + const script = `set imageData to the clipboard as «class ${appleClass}» +return imageData` + + const { stdout } = await execFileAsync("osascript", ["-e", script], { + encoding: "buffer", + maxBuffer: 50 * 1024 * 1024, + }) + + const imageBuffer = Buffer.isBuffer(stdout) ? stdout : Buffer.from(stdout) + + if (imageBuffer.length === 0) { + return { + success: false, + error: "Failed to read image data from clipboard.", + } + } + + const mimeFormat = parsed.format === "tiff" ? "tiff" : parsed.format + + return { + success: true, + dataUrl: buildDataUrl(imageBuffer, mimeFormat), + } +} + +export async function saveClipboardImageMacOS(): Promise { + const { stdout: info } = await execFileAsync("osascript", ["-e", "clipboard info"]) + const parsed = parseClipboardInfo(info) + + if (!parsed.hasImage || !parsed.format) { + return { + success: false, + error: "No image found in clipboard.", + } + } + + const formatToClass: Record = { + png: "PNGf", + jpeg: "JPEG", + tiff: "TIFF", + gif: "GIFf", + } + + const appleClass = formatToClass[parsed.format] + if (!appleClass) { + return { + success: false, + error: `Unsupported image format: ${parsed.format}`, + } + } + + const clipboardDir = await ensureClipboardDir() + + const filename = generateClipboardFilename(parsed.format) + const filePath = path.join(clipboardDir, filename) + + // Escape backslashes and quotes for AppleScript string interpolation + const escapedPath = filePath.replace(/\\/g, "\\\\").replace(/"/g, '\\"') + + const script = ` +set imageData to the clipboard as «class ${appleClass}» +set filePath to POSIX file "${escapedPath}" +set fileRef to open for access filePath with write permission +write imageData to fileRef +close access fileRef +return "${escapedPath}" +` + + try { + await execFileAsync("osascript", ["-e", script], { + maxBuffer: 50 * 1024 * 1024, + }) + + const stats = await fs.promises.stat(filePath) + if (stats.size === 0) { + await fs.promises.unlink(filePath) + return { + success: false, + error: "Failed to write image data to file.", + } + } + + return { + success: true, + filePath, + } + } catch (error) { + try { + await fs.promises.unlink(filePath) + } catch (cleanupError) { + const err = cleanupError as NodeJS.ErrnoException + logs.debug("Failed to remove partial clipboard file after error", "clipboard", { + filePath, + error: err?.message ?? String(cleanupError), + code: err?.code, + }) + } + throw error + } +} diff --git a/cli/src/media/clipboard-shared.ts b/cli/src/media/clipboard-shared.ts new file mode 100644 index 00000000000..5592c9964cc --- /dev/null +++ b/cli/src/media/clipboard-shared.ts @@ -0,0 +1,109 @@ +import { execFile } from "child_process" +import * as fs from "fs" +import * as os from "os" +import * as path from "path" +import { promisify } from "util" + +export const execFileAsync = promisify(execFile) + +export const CLIPBOARD_DIR = "kilocode-clipboard" +export const MAX_CLIPBOARD_IMAGE_AGE_MS = 60 * 60 * 1000 + +export interface ClipboardImageResult { + success: boolean + dataUrl?: string + error?: string +} + +export interface ClipboardInfoResult { + hasImage: boolean + format: "png" | "jpeg" | "tiff" | "gif" | null +} + +export interface SaveClipboardResult { + success: boolean + filePath?: string + error?: string +} + +export function parseClipboardInfo(output: string): ClipboardInfoResult { + if (!output) { + return { hasImage: false, format: null } + } + + if (output.includes("PNGf") || output.includes("class PNGf")) { + return { hasImage: true, format: "png" } + } + if (output.includes("JPEG") || output.includes("class JPEG")) { + return { hasImage: true, format: "jpeg" } + } + if (output.includes("TIFF") || output.includes("TIFF picture")) { + return { hasImage: true, format: "tiff" } + } + if (output.includes("GIFf") || output.includes("class GIFf")) { + return { hasImage: true, format: "gif" } + } + + return { hasImage: false, format: null } +} + +export function detectImageFormat(buffer: Buffer): "png" | "jpeg" | "gif" | "webp" | null { + if (buffer.length < 4) { + return null + } + + if (buffer[0] === 0x89 && buffer[1] === 0x50 && buffer[2] === 0x4e && buffer[3] === 0x47) { + return "png" + } + + if (buffer[0] === 0xff && buffer[1] === 0xd8 && buffer[2] === 0xff) { + return "jpeg" + } + + if (buffer[0] === 0x47 && buffer[1] === 0x49 && buffer[2] === 0x46 && buffer[3] === 0x38) { + return "gif" + } + + if ( + buffer.length >= 12 && + buffer[0] === 0x52 && + buffer[1] === 0x49 && + buffer[2] === 0x46 && + buffer[3] === 0x46 && + buffer[8] === 0x57 && + buffer[9] === 0x45 && + buffer[10] === 0x42 && + buffer[11] === 0x50 + ) { + return "webp" + } + + return null +} + +export function buildDataUrl(data: Buffer, format: string): string { + return `data:image/${format};base64,${data.toString("base64")}` +} + +export function getUnsupportedClipboardPlatformMessage(): string { + return `Clipboard image paste is only supported on macOS. + +Alternative: + - Use @path/to/image.png to attach images` +} + +export function getClipboardDir(): string { + return path.join(os.tmpdir(), CLIPBOARD_DIR) +} + +export async function ensureClipboardDir(): Promise { + const clipboardDir = getClipboardDir() + await fs.promises.mkdir(clipboardDir, { recursive: true }) + return clipboardDir +} + +export function generateClipboardFilename(format: string): string { + const timestamp = Date.now() + const random = Math.random().toString(36).substring(2, 8) + return `clipboard-${timestamp}-${random}.${format}` +} diff --git a/cli/src/media/clipboard.ts b/cli/src/media/clipboard.ts new file mode 100644 index 00000000000..3f25406b6f8 --- /dev/null +++ b/cli/src/media/clipboard.ts @@ -0,0 +1,101 @@ +import * as fs from "fs" +import * as path from "path" +import { logs } from "../services/logs.js" +import { + buildDataUrl, + detectImageFormat, + generateClipboardFilename, + getClipboardDir, + parseClipboardInfo, + MAX_CLIPBOARD_IMAGE_AGE_MS, + getUnsupportedClipboardPlatformMessage, + type ClipboardImageResult, + type ClipboardInfoResult, + type SaveClipboardResult, +} from "./clipboard-shared.js" +import { hasClipboardImageMacOS, saveClipboardImageMacOS } from "./clipboard-macos.js" + +export { + buildDataUrl, + detectImageFormat, + generateClipboardFilename, + getClipboardDir, + getUnsupportedClipboardPlatformMessage, + parseClipboardInfo, + type ClipboardImageResult, + type ClipboardInfoResult, + type SaveClipboardResult, +} + +export async function isClipboardSupported(): Promise { + return process.platform === "darwin" +} + +export async function clipboardHasImage(): Promise { + try { + if (process.platform === "darwin") { + return await hasClipboardImageMacOS() + } + return false + } catch (error) { + const err = error as NodeJS.ErrnoException + logs.debug("clipboardHasImage failed, treating as no image", "clipboard", { + error: err?.message ?? String(error), + code: err?.code, + }) + return false + } +} + +export async function saveClipboardImage(): Promise { + if (process.platform !== "darwin") { + return { + success: false, + error: getUnsupportedClipboardPlatformMessage(), + } + } + + try { + return await saveClipboardImageMacOS() + } catch (error) { + return { + success: false, + error: error instanceof Error ? error.message : String(error), + } + } +} + +export async function cleanupOldClipboardImages(): Promise { + const clipboardDir = getClipboardDir() + + try { + const files = await fs.promises.readdir(clipboardDir) + const now = Date.now() + + for (const file of files) { + if (!file.startsWith("clipboard-")) continue + + const filePath = path.join(clipboardDir, file) + try { + const stats = await fs.promises.stat(filePath) + if (now - stats.mtimeMs > MAX_CLIPBOARD_IMAGE_AGE_MS) { + await fs.promises.unlink(filePath) + } + } catch (error) { + const err = error as NodeJS.ErrnoException + logs.debug("Failed to delete stale clipboard image", "clipboard", { + filePath, + error: err?.message ?? String(error), + code: err?.code, + }) + } + } + } catch (error) { + const err = error as NodeJS.ErrnoException + logs.debug("Skipping clipboard cleanup; directory not accessible", "clipboard", { + dir: clipboardDir, + error: err?.message ?? String(error), + code: err?.code, + }) + } +} diff --git a/cli/src/media/images.ts b/cli/src/media/images.ts new file mode 100644 index 00000000000..1cf4686571b --- /dev/null +++ b/cli/src/media/images.ts @@ -0,0 +1,99 @@ +import fs from "fs/promises" +import path from "path" +import { logs } from "../services/logs.js" + +export const MAX_IMAGE_SIZE_BYTES = 8 * 1024 * 1024 // 8MB + +export const SUPPORTED_IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".gif", ".tiff"] as const +export type SupportedImageExtension = (typeof SUPPORTED_IMAGE_EXTENSIONS)[number] + +export function isImagePath(filePath: string): boolean { + const ext = path.extname(filePath).toLowerCase() + return SUPPORTED_IMAGE_EXTENSIONS.includes(ext as SupportedImageExtension) +} + +export function getMimeType(filePath: string): string { + const ext = path.extname(filePath).toLowerCase() + switch (ext) { + case ".png": + return "image/png" + case ".jpeg": + case ".jpg": + return "image/jpeg" + case ".webp": + return "image/webp" + case ".gif": + return "image/gif" + case ".tiff": + return "image/tiff" + default: + throw new Error(`Unsupported image type: ${ext}`) + } +} + +export async function readImageAsDataUrl(imagePath: string, basePath?: string): Promise { + // Resolve the path + const resolvedPath = path.isAbsolute(imagePath) ? imagePath : path.resolve(basePath || process.cwd(), imagePath) + + // Verify it's a supported image type + if (!isImagePath(resolvedPath)) { + throw new Error(`Not a supported image type: ${imagePath}`) + } + + // Check if file exists + try { + await fs.access(resolvedPath) + } catch { + throw new Error(`Image file not found: ${resolvedPath}`) + } + + // Enforce size limit before reading + const stats = await fs.stat(resolvedPath) + if (stats.size > MAX_IMAGE_SIZE_BYTES) { + const maxMb = (MAX_IMAGE_SIZE_BYTES / (1024 * 1024)).toFixed(1) + const actualMb = (stats.size / (1024 * 1024)).toFixed(1) + throw new Error(`Image file is too large (${actualMb} MB). Max allowed is ${maxMb} MB.`) + } + + // Read file and convert to base64 + const buffer = await fs.readFile(resolvedPath) + const base64 = buffer.toString("base64") + const mimeType = getMimeType(resolvedPath) + const dataUrl = `data:${mimeType};base64,${base64}` + + logs.debug(`Read image as data URL: ${path.basename(imagePath)}`, "images", { + path: resolvedPath, + size: buffer.length, + mimeType, + }) + + return dataUrl +} + +export interface ProcessedImageMentions { + text: string + images: string[] + errors: Array<{ path: string; error: string }> +} + +export async function processImagePaths(imagePaths: string[], basePath?: string): Promise { + const images: string[] = [] + const errors: Array<{ path: string; error: string }> = [] + + for (const imagePath of imagePaths) { + try { + const dataUrl = await readImageAsDataUrl(imagePath, basePath) + images.push(dataUrl) + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + errors.push({ path: imagePath, error: errorMessage }) + logs.warn(`Failed to load image: ${imagePath}`, "images", { error: errorMessage }) + } + } + + return { + text: "", // Will be set by the caller + images, + errors, + } +} diff --git a/cli/src/media/processMessageImages.ts b/cli/src/media/processMessageImages.ts new file mode 100644 index 00000000000..21b22c59953 --- /dev/null +++ b/cli/src/media/processMessageImages.ts @@ -0,0 +1,140 @@ +import { logs } from "../services/logs.js" +import { parseAtMentions, removeImageMentions } from "./atMentionParser.js" +import { readImageAsDataUrl } from "./images.js" + +export interface ProcessedMessage { + text: string + images: string[] + hasImages: boolean + errors: string[] +} + +const IMAGE_REFERENCE_REGEX = /\[Image #(\d+)\]/g + +export function extractImageReferences(text: string): number[] { + const refs: number[] = [] + let match + IMAGE_REFERENCE_REGEX.lastIndex = 0 + while ((match = IMAGE_REFERENCE_REGEX.exec(text)) !== null) { + const ref = match[1] + if (ref !== undefined) { + refs.push(parseInt(ref, 10)) + } + } + return refs +} + +export function removeImageReferences(text: string): string { + return text.replace(IMAGE_REFERENCE_REGEX, "") +} + +async function loadImage( + imagePath: string, + onSuccess: (dataUrl: string) => void, + onError: (error: string) => void, + successLog: string, + errorLog: { message: string; meta?: Record }, +): Promise { + try { + const dataUrl = await readImageAsDataUrl(imagePath) + onSuccess(dataUrl) + logs.debug(successLog, "processMessageImages") + } catch (error) { + const errorMsg = error instanceof Error ? error.message : String(error) + onError(errorMsg) + logs.warn(errorLog.message, "processMessageImages", { ...errorLog.meta, error: errorMsg }) + } +} + +async function loadReferenceImages( + refs: number[], + imageReferences: Record, + images: string[], + errors: string[], +): Promise { + logs.debug(`Found ${refs.length} image reference(s)`, "processMessageImages", { refs }) + + for (const refNum of refs) { + const filePath = imageReferences[refNum] + if (!filePath) { + errors.push(`Image #${refNum} not found`) + logs.warn(`Image reference #${refNum} not found in references map`, "processMessageImages") + continue + } + + await loadImage( + filePath, + (dataUrl) => images.push(dataUrl), + (errorMsg) => errors.push(`Failed to load Image #${refNum}: ${errorMsg}`), + `Loaded image #${refNum}: ${filePath}`, + { message: `Failed to load image #${refNum}: ${filePath}` }, + ) + } +} + +async function loadPathImages(imagePaths: string[], images: string[], errors: string[]): Promise { + logs.debug(`Found ${imagePaths.length} @path image mention(s)`, "processMessageImages", { + paths: imagePaths, + }) + + for (const imagePath of imagePaths) { + await loadImage( + imagePath, + (dataUrl) => images.push(dataUrl), + (errorMsg) => errors.push(`Failed to load image "${imagePath}": ${errorMsg}`), + `Loaded image: ${imagePath}`, + { message: `Failed to load image: ${imagePath}` }, + ) + } +} + +async function handleReferenceImages( + text: string, + imageReferences: Record, + images: string[], + errors: string[], +): Promise { + const refs = extractImageReferences(text) + if (refs.length === 0) { + return text + } + + await loadReferenceImages(refs, imageReferences, images, errors) + return removeImageReferences(text) +} + +async function handlePathMentions( + text: string, + images: string[], + errors: string[], +): Promise<{ cleanedText: string; hasImages: boolean }> { + const parsed = parseAtMentions(text) + if (parsed.imagePaths.length === 0) { + return { cleanedText: text, hasImages: images.length > 0 } + } + + await loadPathImages(parsed.imagePaths, images, errors) + return { cleanedText: removeImageMentions(text), hasImages: images.length > 0 } +} + +export async function processMessageImages( + text: string, + imageReferences?: Record, +): Promise { + const images: string[] = [] + const errors: string[] = [] + + let cleanedText = text + if (imageReferences) { + cleanedText = await handleReferenceImages(cleanedText, imageReferences, images, errors) + } + + const { cleanedText: finalText, hasImages } = await handlePathMentions(cleanedText, images, errors) + + return { + text: finalText, + images, + hasImages, + errors, + } +} diff --git a/cli/src/state/atoms/__tests__/keyboard.test.ts b/cli/src/state/atoms/__tests__/keyboard.test.ts index 8d92aa7fee5..fbd591672e1 100644 --- a/cli/src/state/atoms/__tests__/keyboard.test.ts +++ b/cli/src/state/atoms/__tests__/keyboard.test.ts @@ -9,7 +9,13 @@ import { fileMentionSuggestionsAtom, } from "../ui.js" import { textBufferStringAtom, textBufferStateAtom } from "../textBuffer.js" -import { keyboardHandlerAtom, submissionCallbackAtom, submitInputAtom } from "../keyboard.js" +import { + exitPromptVisibleAtom, + exitRequestCounterAtom, + keyboardHandlerAtom, + submissionCallbackAtom, + submitInputAtom, +} from "../keyboard.js" import { pendingApprovalAtom } from "../approval.js" import { historyDataAtom, historyModeAtom, historyIndexAtom as _historyIndexAtom } from "../history.js" import { chatMessagesAtom } from "../extension.js" @@ -1087,5 +1093,26 @@ describe("keypress atoms", () => { // When not streaming, ESC should clear the buffer (normal behavior) expect(store.get(textBufferStringAtom)).toBe("") }) + + it("should require confirmation before exiting on Ctrl+C", async () => { + const ctrlCKey: Key = { + name: "c", + sequence: "\u0003", + ctrl: true, + meta: false, + shift: false, + paste: false, + } + + await store.set(keyboardHandlerAtom, ctrlCKey) + + expect(store.get(exitPromptVisibleAtom)).toBe(true) + expect(store.get(exitRequestCounterAtom)).toBe(0) + + await store.set(keyboardHandlerAtom, ctrlCKey) + + expect(store.get(exitPromptVisibleAtom)).toBe(false) + expect(store.get(exitRequestCounterAtom)).toBe(1) + }) }) }) diff --git a/cli/src/state/atoms/__tests__/shell.test.ts b/cli/src/state/atoms/__tests__/shell.test.ts index b807ab78763..76f184131a7 100644 --- a/cli/src/state/atoms/__tests__/shell.test.ts +++ b/cli/src/state/atoms/__tests__/shell.test.ts @@ -12,10 +12,9 @@ import { } from "../shell.js" import { textBufferStringAtom, setTextAtom } from "../textBuffer.js" -// Mock child_process to avoid actual command execution +// Mock child_process to avoid actual command execution; provide exec and execFile for clipboard code vi.mock("child_process", () => ({ exec: vi.fn((command) => { - // Simulate successful command execution const stdout = `Mock output for: ${command}` const stderr = "" const process = { @@ -41,6 +40,9 @@ vi.mock("child_process", () => ({ } return process }), + execFile: vi.fn((..._args) => { + throw new Error("execFile mocked in shell tests") + }), })) describe("shell mode - comprehensive tests", () => { diff --git a/cli/src/state/atoms/keyboard.ts b/cli/src/state/atoms/keyboard.ts index 3e8b7236467..9f83e80a989 100644 --- a/cli/src/state/atoms/keyboard.ts +++ b/cli/src/state/atoms/keyboard.ts @@ -58,6 +58,8 @@ import { navigateShellHistoryDownAtom, executeShellCommandAtom, } from "./shell.js" +import { saveClipboardImage, clipboardHasImage, cleanupOldClipboardImages } from "../../media/clipboard.js" +import { logs } from "../../services/logs.js" // Export shell atoms for backward compatibility export { @@ -68,6 +70,65 @@ export { executeShellCommandAtom, } +// ============================================================================ +// Clipboard Image Atoms +// ============================================================================ + +/** + * Map of image reference numbers to file paths for current message + * e.g., { 1: "/tmp/kilocode-clipboard/clipboard-xxx.png", 2: "/tmp/..." } + */ +export const imageReferencesAtom = atom>(new Map()) + +/** + * Current image reference counter (increments with each paste) + */ +export const imageReferenceCounterAtom = atom(0) + +/** + * Add a clipboard image and get its reference number + * Returns the reference number assigned to this image + */ +export const addImageReferenceAtom = atom(null, (get, set, filePath: string): number => { + const counter = get(imageReferenceCounterAtom) + 1 + set(imageReferenceCounterAtom, counter) + + const refs = new Map(get(imageReferencesAtom)) + refs.set(counter, filePath) + set(imageReferencesAtom, refs) + + return counter +}) + +/** + * Clear image references (after message is sent) + */ +export const clearImageReferencesAtom = atom(null, (_get, set) => { + set(imageReferencesAtom, new Map()) + set(imageReferenceCounterAtom, 0) +}) + +/** + * Get all image references as an object for easier consumption + */ +export const getImageReferencesAtom = atom((get) => { + return Object.fromEntries(get(imageReferencesAtom)) +}) + +/** + * Status message for clipboard operations + */ +export const clipboardStatusAtom = atom(null) +let clipboardStatusTimer: NodeJS.Timeout | null = null + +function setClipboardStatusWithTimeout(set: Setter, message: string, timeoutMs: number): void { + if (clipboardStatusTimer) { + clearTimeout(clipboardStatusTimer) + } + set(clipboardStatusAtom, message) + clipboardStatusTimer = setTimeout(() => set(clipboardStatusAtom, null), timeoutMs) +} + // ============================================================================ // Core State Atoms // ============================================================================ @@ -92,6 +153,41 @@ export const kittyProtocolEnabledAtom = atom(false) */ export const debugKeystrokeLoggingAtom = atom(false) +// ============================================================================ +// Exit Confirmation State +// ============================================================================ + +const EXIT_CONFIRMATION_WINDOW_MS = 2000 + +type ExitPromptTimeout = ReturnType + +export const exitPromptVisibleAtom = atom(false) +const exitPromptTimeoutAtom = atom(null) +export const exitRequestCounterAtom = atom(0) + +export const triggerExitConfirmationAtom = atom(null, (get, set) => { + const exitPromptVisible = get(exitPromptVisibleAtom) + const existingTimeout = get(exitPromptTimeoutAtom) + + if (existingTimeout) { + clearTimeout(existingTimeout) + set(exitPromptTimeoutAtom, null) + } + + if (exitPromptVisible) { + set(exitPromptVisibleAtom, false) + set(exitRequestCounterAtom, (count) => count + 1) + return + } + + set(exitPromptVisibleAtom, true) + const timeout = setTimeout(() => { + set(exitPromptVisibleAtom, false) + set(exitPromptTimeoutAtom, null) + }, EXIT_CONFIRMATION_WINDOW_MS) + set(exitPromptTimeoutAtom, timeout) +}) + // ============================================================================ // Buffer Atoms // ============================================================================ @@ -792,10 +888,40 @@ function handleTextInputKeys(get: Getter, set: Setter, key: Key) { } function handleGlobalHotkeys(get: Getter, set: Setter, key: Key): boolean { + // Debug logging for key detection + if (key.ctrl || key.sequence === "\x16") { + logs.debug( + `Key detected: name=${key.name}, ctrl=${key.ctrl}, meta=${key.meta}, sequence=${JSON.stringify(key.sequence)}`, + "clipboard", + ) + } + + // Check for Ctrl+V by sequence first (ASCII 0x16 = SYN character) + // This is how Ctrl+V appears in most terminals + if (key.sequence === "\x16") { + logs.debug("Detected Ctrl+V via sequence \\x16", "clipboard") + handleClipboardImagePaste(get, set).catch((err) => + logs.error("Unhandled clipboard paste error", "clipboard", { error: err }), + ) + return true + } + switch (key.name) { case "c": if (key.ctrl) { - process.exit(0) + set(triggerExitConfirmationAtom) + return true + } + break + case "v": + // Ctrl+V - check for clipboard image + if (key.ctrl) { + logs.debug("Detected Ctrl+V via key.name", "clipboard") + // Handle clipboard image paste asynchronously + handleClipboardImagePaste(get, set).catch((err) => + logs.error("Unhandled clipboard paste error", "clipboard", { error: err }), + ) + return true } break case "x": @@ -849,6 +975,65 @@ function handleGlobalHotkeys(get: Getter, set: Setter, key: Key): boolean { return false } +/** + * Handle clipboard image paste (Ctrl+V) + * Saves clipboard image to a temp file and inserts @path reference into text buffer + */ +async function handleClipboardImagePaste(get: Getter, set: Setter): Promise { + logs.debug("handleClipboardImagePaste called", "clipboard") + try { + // Check if clipboard has an image + logs.debug("Checking clipboard for image...", "clipboard") + const hasImage = await clipboardHasImage() + logs.debug(`clipboardHasImage returned: ${hasImage}`, "clipboard") + if (!hasImage) { + setClipboardStatusWithTimeout(set, "No image in clipboard", 2000) + logs.debug("No image in clipboard", "clipboard") + return + } + + // Save the image to a file in temp directory + const result = await saveClipboardImage() + if (result.success && result.filePath) { + // Add image to references and get its number + const refNumber = set(addImageReferenceAtom, result.filePath) + + // Build the [Image #N] reference to insert + // Add space before and after if needed + const currentText = get(textBufferStringAtom) + let insertText = `[Image #${refNumber}]` + + // Check if we need spaces around the insertion + const charBefore = currentText.length > 0 ? currentText[currentText.length - 1] : "" + if (charBefore && charBefore !== " " && charBefore !== "\n") { + insertText = " " + insertText + } + insertText = insertText + " " + + // Insert at current cursor position + set(insertTextAtom, insertText) + + setClipboardStatusWithTimeout(set, `Image #${refNumber} attached`, 2000) + logs.debug(`Inserted clipboard image #${refNumber}: ${result.filePath}`, "clipboard") + + // Clean up old clipboard images in the background + cleanupOldClipboardImages().catch((cleanupError) => { + logs.debug("Clipboard cleanup failed", "clipboard", { + error: cleanupError instanceof Error ? cleanupError.message : String(cleanupError), + }) + }) + } else { + setClipboardStatusWithTimeout(set, result.error || "Failed to save clipboard image", 3000) + } + } catch (error) { + setClipboardStatusWithTimeout( + set, + `Clipboard error: ${error instanceof Error ? error.message : String(error)}`, + 3000, + ) + } +} + /** * Main keyboard handler that routes based on mode * This is the central keyboard handling atom that all key events go through diff --git a/cli/src/state/atoms/ui.ts b/cli/src/state/atoms/ui.ts index cbfe50456f8..1012d33522e 100644 --- a/cli/src/state/atoms/ui.ts +++ b/cli/src/state/atoms/ui.ts @@ -696,7 +696,7 @@ export const resetMessageCutoffAtom = atom(null, (get, set) => { */ export const splitMessagesAtom = atom((get) => { const allMessages = get(mergedMessagesAtom) - return splitMessages(allMessages) + return splitMessages(allMessages, { hidePartialMessages: true }) }) /** diff --git a/cli/src/state/hooks/useMessageHandler.ts b/cli/src/state/hooks/useMessageHandler.ts index b885c777b63..b14518009d5 100644 --- a/cli/src/state/hooks/useMessageHandler.ts +++ b/cli/src/state/hooks/useMessageHandler.ts @@ -3,14 +3,16 @@ * Provides a clean interface for sending user messages to the extension */ -import { useSetAtom } from "jotai" +import { useSetAtom, useAtomValue } from "jotai" import { useCallback, useState } from "react" import { addMessageAtom } from "../atoms/ui.js" +import { imageReferencesAtom, clearImageReferencesAtom } from "../atoms/keyboard.js" import { useWebviewMessage } from "./useWebviewMessage.js" import { useTaskState } from "./useTaskState.js" import type { CliMessage } from "../../types/cli.js" import { logs } from "../../services/logs.js" import { getTelemetryService } from "../../services/telemetry/index.js" +import { processMessageImages } from "../../media/processMessageImages.js" /** * Options for useMessageHandler hook @@ -34,7 +36,7 @@ export interface UseMessageHandlerReturn { * Hook that provides message sending functionality * * This hook handles sending regular user messages (non-commands) to the extension, - * including adding the message to the UI and handling errors. + * including processing @path image mentions and handling errors. * * @example * ```tsx @@ -58,51 +60,72 @@ export function useMessageHandler(options: UseMessageHandlerOptions = {}): UseMe const { ciMode = false } = options const [isSending, setIsSending] = useState(false) const addMessage = useSetAtom(addMessageAtom) + const imageReferences = useAtomValue(imageReferencesAtom) + const clearImageReferences = useSetAtom(clearImageReferencesAtom) const { sendMessage, sendAskResponse } = useWebviewMessage() const { hasActiveTask } = useTaskState() const sendUserMessage = useCallback( async (text: string): Promise => { const trimmedText = text.trim() - if (!trimmedText) { return } - // Don't add user message to CLI state - the extension will handle it - // This prevents duplicate messages in the UI - - // Set sending state setIsSending(true) try { - // Track user message + // Convert image references Map to object for processMessageImages + const imageRefsObject = Object.fromEntries(imageReferences) + + // Process any @path image mentions and [Image #N] references in the message + const processed = await processMessageImages(trimmedText, imageRefsObject) + + // Show any image loading errors to the user + if (processed.errors.length > 0) { + for (const error of processed.errors) { + const errorMessage: CliMessage = { + id: `img-err-${Date.now()}-${Math.random()}`, + type: "error", + content: error, + ts: Date.now(), + } + addMessage(errorMessage) + } + } + + // Track telemetry getTelemetryService().trackUserMessageSent( - trimmedText.length, - false, // hasImages - CLI doesn't support images yet + processed.text.length, + processed.hasImages, hasActiveTask, - undefined, // taskId - will be added when we have task tracking + undefined, ) - // Check if there's an active task to determine message type - // This matches the webview behavior in ChatView.tsx (lines 650-683) + // Build message payload + const payload = { + text: processed.text, + ...(processed.hasImages && { images: processed.images }), + } + + // Clear image references after processing + if (imageReferences.size > 0) { + clearImageReferences() + } + + // Send to extension - either as response to active task or as new task if (hasActiveTask) { - // Send as response to existing task (like webview does) - logs.debug("Sending message as response to active task", "useMessageHandler") - await sendAskResponse({ - response: "messageResponse", - text: trimmedText, + logs.debug("Sending message as response to active task", "useMessageHandler", { + hasImages: processed.hasImages, }) + await sendAskResponse({ response: "messageResponse", ...payload }) } else { - // Start new task (no active conversation) - logs.debug("Starting new task", "useMessageHandler") - await sendMessage({ - type: "newTask", - text: trimmedText, + logs.debug("Starting new task", "useMessageHandler", { + hasImages: processed.hasImages, }) + await sendMessage({ type: "newTask", ...payload }) } } catch (error) { - // Add error message if sending failed const errorMessage: CliMessage = { id: Date.now().toString(), type: "error", @@ -111,11 +134,10 @@ export function useMessageHandler(options: UseMessageHandlerOptions = {}): UseMe } addMessage(errorMessage) } finally { - // Reset sending state setIsSending(false) } }, - [addMessage, ciMode, sendMessage, sendAskResponse, hasActiveTask], + [addMessage, ciMode, sendMessage, sendAskResponse, hasActiveTask, imageReferences, clearImageReferences], ) return { diff --git a/cli/src/ui/UI.tsx b/cli/src/ui/UI.tsx index 9f05dfc358c..3ee7abb8fbe 100644 --- a/cli/src/ui/UI.tsx +++ b/cli/src/ui/UI.tsx @@ -36,6 +36,7 @@ import { generateNotificationMessage } from "../utils/notifications.js" import { notificationsAtom } from "../state/atoms/notifications.js" import { workspacePathAtom } from "../state/atoms/shell.js" import { useTerminal } from "../state/hooks/useTerminal.js" +import { exitRequestCounterAtom } from "../state/atoms/keyboard.js" // Initialize commands on module load initializeCommands() @@ -65,6 +66,7 @@ export const UI: React.FC = ({ options, onExit }) => { const setWorkspacePath = useSetAtom(workspacePathAtom) const taskResumedViaSession = useAtomValue(taskResumedViaContinueOrSessionAtom) const { hasActiveTask } = useTaskState() + const exitRequestCounter = useAtomValue(exitRequestCounterAtom) // Use specialized hooks for command and message handling const { executeCommand, isExecuting: isExecutingCommand } = useCommandHandler() @@ -94,6 +96,17 @@ export const UI: React.FC = ({ options, onExit }) => { onExit: onExit, }) + const handledExitRequestRef = useRef(exitRequestCounter) + + useEffect(() => { + if (exitRequestCounter === handledExitRequestRef.current) { + return + } + + handledExitRequestRef.current = exitRequestCounter + void executeCommand("/exit", onExit) + }, [exitRequestCounter, executeCommand, onExit]) + // Track if prompt has been executed and welcome message shown const promptExecutedRef = useRef(false) const welcomeShownRef = useRef(false) diff --git a/cli/src/ui/components/StatusIndicator.tsx b/cli/src/ui/components/StatusIndicator.tsx index 57d22cf9d28..acfc55aa31b 100644 --- a/cli/src/ui/components/StatusIndicator.tsx +++ b/cli/src/ui/components/StatusIndicator.tsx @@ -12,6 +12,7 @@ import { ThinkingAnimation } from "./ThinkingAnimation.js" import { useAtomValue } from "jotai" import { isStreamingAtom } from "../../state/atoms/ui.js" import { hasResumeTaskAtom } from "../../state/atoms/extension.js" +import { exitPromptVisibleAtom } from "../../state/atoms/keyboard.js" export interface StatusIndicatorProps { /** Whether the indicator is disabled */ @@ -34,6 +35,8 @@ export const StatusIndicator: React.FC = ({ disabled = fal const { hotkeys, shouldShow } = useHotkeys() const isStreaming = useAtomValue(isStreamingAtom) const hasResumeTask = useAtomValue(hasResumeTaskAtom) + const exitPromptVisible = useAtomValue(exitPromptVisibleAtom) + const exitModifierKey = process.platform === "darwin" ? "Cmd" : "Ctrl" // Don't render if no hotkeys to show or disabled if (!shouldShow || disabled) { @@ -44,8 +47,14 @@ export const StatusIndicator: React.FC = ({ disabled = fal {/* Status text on the left */} - {isStreaming && } - {hasResumeTask && Task ready to resume} + {exitPromptVisible ? ( + Press {exitModifierKey}+C again to exit. + ) : ( + <> + {isStreaming && } + {hasResumeTask && Task ready to resume} + + )} {/* Hotkeys on the right */} diff --git a/cli/src/ui/components/__tests__/StatusIndicator.test.tsx b/cli/src/ui/components/__tests__/StatusIndicator.test.tsx index 3d0a31d856c..98830a5b6c2 100644 --- a/cli/src/ui/components/__tests__/StatusIndicator.test.tsx +++ b/cli/src/ui/components/__tests__/StatusIndicator.test.tsx @@ -10,6 +10,7 @@ import { createStore } from "jotai" import { StatusIndicator } from "../StatusIndicator.js" import { showFollowupSuggestionsAtom } from "../../../state/atoms/ui.js" import { chatMessagesAtom } from "../../../state/atoms/extension.js" +import { exitPromptVisibleAtom } from "../../../state/atoms/keyboard.js" import type { ExtensionChatMessage } from "../../../types/messages.js" // Mock the hooks @@ -92,6 +93,19 @@ describe("StatusIndicator", () => { expect(output).toContain("for commands") }) + it("should show exit confirmation prompt when Ctrl+C is pressed once", () => { + store.set(exitPromptVisibleAtom, true) + + const { lastFrame } = render( + + + , + ) + + const output = lastFrame() + expect(output).toMatch(/Press (?:Ctrl|Cmd)\+C again to exit\./) + }) + it("should not show Thinking status when not streaming", () => { // Complete message = not streaming const completeMessage: ExtensionChatMessage = { diff --git a/cli/src/ui/messages/MessageDisplay.tsx b/cli/src/ui/messages/MessageDisplay.tsx index 0bdd30b4faf..38925ff372c 100644 --- a/cli/src/ui/messages/MessageDisplay.tsx +++ b/cli/src/ui/messages/MessageDisplay.tsx @@ -2,21 +2,15 @@ * MessageDisplay component - displays chat messages from both CLI and extension state * Uses Ink Static component to optimize rendering of completed messages * - * Performance Optimization: - * ------------------------ - * Messages are split into two sections: - * 1. Static section: Completed messages that won't change (rendered once with Ink Static) - * 2. Dynamic section: Incomplete/updating messages (re-rendered as needed) - * - * This prevents unnecessary re-renders of completed messages, improving performance - * especially in long conversations. + * Pure Static Mode: + * ----------------- + * Partial/streaming messages are filtered out at the atom level (see splitMessagesAtom), + * so this component only ever renders completed messages using Ink Static. * * Message Completion Logic: * ------------------------- - * A message is considered complete when: - * - CLI messages: partial !== true - * - Extension messages: depends on type (see messageCompletion.ts) - * - Sequential rule: A message can only be static if all previous messages are complete + * In pure static mode, any message with `partial === true` is hidden and everything else is + * treated as complete for display purposes. * * Key Generation Strategy: * ----------------------- @@ -40,16 +34,9 @@ import React from "react" import { Box, Static } from "ink" import { useAtomValue } from "jotai" -import { type UnifiedMessage, staticMessagesAtom, dynamicMessagesAtom } from "../../state/atoms/ui.js" +import { type UnifiedMessage, staticMessagesAtom } from "../../state/atoms/ui.js" import { MessageRow } from "./MessageRow.js" -interface MessageDisplayProps { - /** Optional filter to show only specific message types */ - filterType?: "ask" | "say" - /** Maximum number of messages to display (default: all) */ - maxMessages?: number -} - /** * Generate a unique key for a unified message * Uses a composite key strategy to ensure uniqueness even when messages @@ -79,34 +66,22 @@ function getMessageKey(msg: UnifiedMessage, index: number): string { return `${subtypeKey}-${index}` } -export const MessageDisplay: React.FC = () => { +export const MessageDisplay: React.FC = () => { const staticMessages = useAtomValue(staticMessagesAtom) - const dynamicMessages = useAtomValue(dynamicMessagesAtom) - if (staticMessages.length === 0 && dynamicMessages.length === 0) { + if (staticMessages.length === 0) { return null } return ( - {/* Static section for completed messages - won't re-render */} - {/* Key includes resetCounter to force re-mount when messages are replaced */} - {staticMessages.length > 0 && ( - - {(message, index) => ( - - - - )} - - )} - - {/* Dynamic section for incomplete/updating messages - will re-render */} - {dynamicMessages.map((unifiedMsg, index) => ( - - - - ))} + + {(message, index) => ( + + + + )} + ) } diff --git a/cli/src/ui/messages/utils/__tests__/messageCompletion.test.ts b/cli/src/ui/messages/utils/__tests__/messageCompletion.test.ts index 326546fd30f..3a7d2fe4090 100644 --- a/cli/src/ui/messages/utils/__tests__/messageCompletion.test.ts +++ b/cli/src/ui/messages/utils/__tests__/messageCompletion.test.ts @@ -438,4 +438,30 @@ describe("messageCompletion", () => { expect(result.dynamicMessages).toHaveLength(0) }) }) + + describe("splitMessages with hidePartialMessages option", () => { + it("should filter out all partial messages when hidePartialMessages is true", () => { + const messages: UnifiedMessage[] = [ + { + source: "cli", + message: { id: "1", type: "assistant", content: "A", ts: 1, partial: false }, + }, + { + source: "cli", + message: { id: "2", type: "assistant", content: "B", ts: 2, partial: true }, + }, + { + source: "cli", + message: { id: "3", type: "assistant", content: "C", ts: 3, partial: false }, + }, + ] + + const result = splitMessages(messages, { hidePartialMessages: true }) + + expect(result.staticMessages).toHaveLength(2) + expect(result.dynamicMessages).toHaveLength(0) + expect((result.staticMessages[0]?.message as CliMessage).id).toBe("1") + expect((result.staticMessages[1]?.message as CliMessage).id).toBe("3") + }) + }) }) diff --git a/cli/src/ui/messages/utils/messageCompletion.ts b/cli/src/ui/messages/utils/messageCompletion.ts index c60b476551d..343c40ebf77 100644 --- a/cli/src/ui/messages/utils/messageCompletion.ts +++ b/cli/src/ui/messages/utils/messageCompletion.ts @@ -110,15 +110,38 @@ function deduplicateCheckpointMessages(messages: UnifiedMessage[]): UnifiedMessa * - Visual jumping when messages complete out of order * * @param messages - Array of unified messages in chronological order + * @param options - Optional behavior flags * @returns Object with staticMessages (complete) and dynamicMessages (incomplete) */ -export function splitMessages(messages: UnifiedMessage[]): { +export interface SplitMessagesOptions { + /** + * When true, hides all partial messages and treats everything else as static. + * This enables a "pure static" mode where nothing streams to the terminal. + */ + hidePartialMessages?: boolean +} + +export function splitMessages( + messages: UnifiedMessage[], + options?: SplitMessagesOptions, +): { staticMessages: UnifiedMessage[] dynamicMessages: UnifiedMessage[] } { // First, deduplicate checkpoint messages const deduplicatedMessages = deduplicateCheckpointMessages(messages) + // hide any partial messages and treat everything else as static. + if (options?.hidePartialMessages) { + const filteredMessages = deduplicatedMessages.filter( + (msg) => (msg.message as { partial?: boolean }).partial !== true, + ) + return { + staticMessages: filteredMessages, + dynamicMessages: [], + } + } + let lastCompleteIndex = -1 const incompleteReasons: Array<{ index: number; reason: string; message: unknown }> = [] diff --git a/jetbrains/host/src/extension.ts b/jetbrains/host/src/extension.ts index 242d8e85657..cd19a60db81 100644 --- a/jetbrains/host/src/extension.ts +++ b/jetbrains/host/src/extension.ts @@ -56,8 +56,9 @@ if (pipeName) { // Reconnection related variables let isReconnecting = false let reconnectAttempts = 0 - const MAX_RECONNECT_ATTEMPTS = 5 - const RECONNECT_DELAY = 1000 // 1 second + const MAX_RECONNECT_ATTEMPTS = 10 // Increased from 5 to 10 for slower machines + const RECONNECT_DELAY = 2000 // Increased from 1s to 2s for slower machines + const RECONNECT_BACKOFF_MULTIPLIER = 1.5 // Exponential backoff // Override process.on process.on = function (event: string, listener: (...args: any[]) => void): any { @@ -264,9 +265,10 @@ if (pipeName) { console.log(`Attempting to reconnect (attempt ${reconnectAttempts}/${MAX_RECONNECT_ATTEMPTS})...`) - // Retry after waiting for a period of time - console.log(`Waiting ${RECONNECT_DELAY}ms before reconnecting...`) - await new Promise((resolve) => setTimeout(resolve, RECONNECT_DELAY)) + // Calculate delay with exponential backoff + const delay = RECONNECT_DELAY * Math.pow(RECONNECT_BACKOFF_MULTIPLIER, reconnectAttempts - 1) + console.log(`Waiting ${delay.toFixed(0)}ms before reconnecting (with exponential backoff)...`) + await new Promise((resolve) => setTimeout(resolve, delay)) console.log("Reconnection delay finished, attempting to connect...") // Reset reconnection state to allow new reconnection attempts diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/actions/GitCommitMessageAction.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/actions/GitCommitMessageAction.kt index c3332e6a413..c0aa9d66ed9 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/actions/GitCommitMessageAction.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/actions/GitCommitMessageAction.kt @@ -27,7 +27,6 @@ import kotlinx.coroutines.runBlocking class GitCommitMessageAction : AnAction(I18n.t("kilocode:commitMessage.ui.generateButton")) { private val logger: Logger = Logger.getInstance(GitCommitMessageAction::class.java) - private val commitMessageService = CommitMessageService.getInstance() private val fileDiscoveryService = FileDiscoveryService() override fun getActionUpdateThread(): ActionUpdateThread { @@ -106,7 +105,7 @@ class GitCommitMessageAction : AnAction(I18n.t("kilocode:commitMessage.ui.genera indicator.text = I18n.t("kilocode:commitMessage.progress.generating") val result = runBlocking { - commitMessageService.generateCommitMessage(project, workspacePath, files.ifEmpty { null }) + CommitMessageService.getInstance(project).generateCommitMessage(project, workspacePath, files.ifEmpty { null }) } ApplicationManager.getApplication().invokeLater { @@ -158,7 +157,7 @@ class GitCommitMessageAction : AnAction(I18n.t("kilocode:commitMessage.ui.genera indicator.text = I18n.t("kilocode:commitMessage.progress.generating") val result = runBlocking { - commitMessageService.generateCommitMessage(project, workspacePath, files.ifEmpty { null }) + CommitMessageService.getInstance(project).generateCommitMessage(project, workspacePath, files.ifEmpty { null }) } ApplicationManager.getApplication().invokeLater { diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionHostManager.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionHostManager.kt index bc0fc99dc52..9f087086dc2 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionHostManager.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionHostManager.kt @@ -1,8 +1,3 @@ -// Copyright 2009-2025 Weibo, Inc. -// SPDX-FileCopyrightText: 2025 Weibo, Inc. -// -// SPDX-License-Identifier: Apache-2.0 - package ai.kilocode.jetbrains.core import ai.kilocode.jetbrains.editor.EditorAndDocManager @@ -28,6 +23,11 @@ import kotlinx.coroutines.cancel import java.net.Socket import java.nio.channels.SocketChannel import java.nio.file.Paths +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock /** * Extension host manager, responsible for communication with extension processes. @@ -36,6 +36,7 @@ import java.nio.file.Paths class ExtensionHostManager : Disposable { companion object { val LOG = Logger.getInstance(ExtensionHostManager::class.java) + private const val INITIALIZATION_TIMEOUT_MS = 60000L // 60 seconds } private val project: Project @@ -62,6 +63,12 @@ class ExtensionHostManager : Disposable { private var projectPath: String? = null + // Initialization state management with state machine + val stateMachine = InitializationStateMachine() + private val messageQueue = ConcurrentLinkedQueue<() -> Unit>() + private val queueLock = ReentrantLock() + private var completionCheckTimer: java.util.Timer? = null + // Support Socket constructor constructor(clientSocket: Socket, projectPath: String, project: Project) { clientSocket.tcpNoDelay = true @@ -81,11 +88,14 @@ class ExtensionHostManager : Disposable { * Start communication with the extension process. */ fun start() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "start()") + try { // Initialize extension manager extensionManager = ExtensionManager() val extensionPath = PluginResourceUtil.getResourcePath(PluginConstants.PLUGIN_ID, PluginConstants.PLUGIN_CODE_DIR) rooCodeIdentifier = extensionPath?.let { extensionManager!!.registerExtension(it).identifier.value } + // Create protocol protocol = PersistentProtocol( PersistentProtocol.PersistentProtocolOptions( @@ -97,13 +107,55 @@ class ExtensionHostManager : Disposable { this::handleMessage, ) + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "Protocol created") LOG.info("ExtensionHostManager started successfully") } catch (e: Exception) { LOG.error("Failed to start ExtensionHostManager", e) + stateMachine.transitionTo(InitializationState.FAILED, "start() exception: ${e.message}") dispose() } } + /** + * Wait for extension host to be ready. + * @return CompletableFuture that completes when extension host is initialized. + */ + fun waitForReady(): CompletableFuture { + return stateMachine.waitForState(InitializationState.EXTENSION_ACTIVATED) + .thenApply { true } + .orTimeout(INITIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS) + .exceptionally { ex -> + LOG.error("Extension host initialization timeout or failure", ex) + false + } + } + + /** + * Queue a message to be sent after initialization. + * If already initialized, executes immediately. + * Uses lock to prevent race condition between checking state and adding to queue. + * @param message The message function to execute. + */ + fun queueMessage(message: () -> Unit) { + queueLock.withLock { + val currentState = stateMachine.getCurrentState() + + // Can execute immediately if extension is activated + if (currentState.ordinal >= InitializationState.EXTENSION_ACTIVATED.ordinal && + currentState != InitializationState.FAILED) { + try { + message() + } catch (e: Exception) { + LOG.error("Error executing message", e) + } + } else { + // Queue for later + messageQueue.offer(message) + LOG.debug("Message queued, total queued: ${messageQueue.size}, current state: $currentState") + } + } + } + /** * Get RPC responsive state. * @return Responsive state, or null if RPC manager is not initialized. @@ -164,6 +216,10 @@ class ExtensionHostManager : Disposable { * Handle Ready message, send initialization data. */ private fun handleReadyMessage() { + if (!stateMachine.transitionTo(InitializationState.READY_RECEIVED, "handleReadyMessage()")) { + return + } + LOG.info("Received Ready message from extension host") try { @@ -174,9 +230,12 @@ class ExtensionHostManager : Disposable { val jsonData = gson.toJson(initData).toByteArray() protocol?.send(jsonData) + + stateMachine.transitionTo(InitializationState.INIT_DATA_SENT, "Init data sent") LOG.info("Sent initialization data to extension host") } catch (e: Exception) { LOG.error("Failed to handle Ready message", e) + stateMachine.transitionTo(InitializationState.FAILED, "handleReadyMessage() exception: ${e.message}") } } @@ -184,37 +243,114 @@ class ExtensionHostManager : Disposable { * Handle Initialized message, create RPC manager and activate plugin. */ private fun handleInitializedMessage() { + if (!stateMachine.transitionTo(InitializationState.INITIALIZED_RECEIVED, "handleInitializedMessage()")) { + return + } + LOG.info("Received Initialized message from extension host") try { - // Get protocol val protocol = this.protocol ?: throw IllegalStateException("Protocol is not initialized") val extensionManager = this.extensionManager ?: throw IllegalStateException("ExtensionManager is not initialized") + stateMachine.transitionTo(InitializationState.RPC_CREATING, "Creating RPC manager") + // Create RPC manager rpcManager = RPCManager(protocol, extensionManager, null, project) + stateMachine.transitionTo(InitializationState.RPC_CREATED, "RPC manager created") + // Start initialization process rpcManager?.startInitialize() // Start file monitoring project.getService(WorkspaceFileChangeManager::class.java) -// WorkspaceFileChangeManager.getInstance() - project.getService(EditorAndDocManager::class.java).initCurrentIdeaEditor() + + stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATING, "Activating extension") + // Activate RooCode plugin val rooCodeId = rooCodeIdentifier ?: throw IllegalStateException("RooCode identifier is not initialized") extensionManager.activateExtension(rooCodeId, rpcManager!!.getRPCProtocol()) .whenComplete { _, error -> if (error != null) { LOG.error("Failed to activate RooCode plugin", error) + stateMachine.transitionTo(InitializationState.FAILED, "Extension activation failed: ${error.message}") } else { LOG.info("RooCode plugin activated successfully") + stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATED, "Extension activated") + + // Process queued messages atomically + processQueuedMessages() + + // Now safe to initialize editors + project.getService(EditorAndDocManager::class.java).initCurrentIdeaEditor() + + // Schedule a check to transition to COMPLETE if webview isn't registered + // This handles cases where the extension doesn't use webviews + scheduleCompletionCheck() } } LOG.info("Initialized extension host") } catch (e: Exception) { LOG.error("Failed to handle Initialized message", e) + stateMachine.transitionTo(InitializationState.FAILED, "handleInitializedMessage() exception: ${e.message}") + } + } + + /** + * Process all queued messages atomically. + * This method is called after extension activation to ensure no messages are lost. + */ + private fun processQueuedMessages() { + queueLock.withLock { + val queueSize = messageQueue.size + LOG.info("Processing $queueSize queued messages") + var processedCount = 0 + + while (messageQueue.isNotEmpty()) { + messageQueue.poll()?.let { message -> + try { + message() + processedCount++ + } catch (e: Exception) { + LOG.error("Error processing queued message", e) + } + } + } + + LOG.info("Processed $processedCount/$queueSize queued messages") + } + } + + /** + * Schedule a check to transition to COMPLETE state if webview registration doesn't happen. + * This handles cases where the extension doesn't require webviews. + */ + private fun scheduleCompletionCheck() { + // Cancel any existing timer first + completionCheckTimer?.cancel() + + // Wait 10 seconds after extension activation (increased from 5s for slow machines) + // If still at EXTENSION_ACTIVATED state, transition to COMPLETE + completionCheckTimer = java.util.Timer().apply { + schedule(object : java.util.TimerTask() { + override fun run() { + val currentState = stateMachine.getCurrentState() + + // Only transition if still at EXTENSION_ACTIVATED + if (currentState == InitializationState.EXTENSION_ACTIVATED) { + LOG.info("No webview registration detected after extension activation, transitioning to COMPLETE") + stateMachine.transitionTo(InitializationState.COMPLETE, "Extension activated without webview") + } else if (currentState.ordinal < InitializationState.EXTENSION_ACTIVATED.ordinal) { + // State hasn't reached EXTENSION_ACTIVATED yet, this shouldn't happen + LOG.warn("Completion check fired but state is $currentState, expected EXTENSION_ACTIVATED or later") + } else { + // State has progressed past EXTENSION_ACTIVATED, which is expected + LOG.debug("Completion check skipped, current state: $currentState (already progressed)") + } + } + }, 10000) // 10 seconds delay (increased from 5s for slow machines) } } @@ -334,11 +470,63 @@ class ExtensionHostManager : Disposable { return URI.file(path) } + /** + * Get initialization report for diagnostics. + * @return String containing initialization state machine report. + */ + fun getInitializationReport(): String { + return stateMachine.generateReport() + } + + /** + * Restart initialization if stuck or failed. + * This method resets the state machine and clears the message queue, + * then restarts the initialization process. + */ + fun restartInitialization() { + LOG.warn("Restarting initialization") + + // Reset state machine + stateMachine.transitionTo(InitializationState.NOT_STARTED, "Manual restart") + + // Clear message queue + queueLock.withLock { + val queueSize = messageQueue.size + if (queueSize > 0) { + LOG.info("Clearing $queueSize queued messages") + messageQueue.clear() + } + } + + // Restart + start() + } + /** * Resource disposal. */ override fun dispose() { LOG.info("Disposing ExtensionHostManager") + + // Log final state before disposal + LOG.info("Final initialization state: ${stateMachine.getCurrentState()}") + if (LOG.isDebugEnabled) { + LOG.debug(getInitializationReport()) + } + + // Cancel completion check timer to prevent memory leak + completionCheckTimer?.let { timer -> + timer.cancel() + timer.purge() + } + completionCheckTimer = null + + // Clear message queue + val remainingMessages = messageQueue.size + if (remainingMessages > 0) { + LOG.warn("Disposing with $remainingMessages unprocessed messages in queue") + messageQueue.clear() + } // Cancel coroutines coroutineScope.cancel() diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionProcessManager.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionProcessManager.kt index e39f1562f1d..f9308c68cf4 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionProcessManager.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionProcessManager.kt @@ -20,6 +20,8 @@ import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.util.SystemInfo import java.io.File import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicLong /** * Extension process manager @@ -57,6 +59,13 @@ class ExtensionProcessManager : Disposable { // Whether running @Volatile private var isRunning = false + + // Crash recovery state + private val crashCount = AtomicInteger(0) + private val lastCrashTime = AtomicLong(0) + private val maxCrashesBeforeGiveUp = 3 + private val crashResetWindow = 300000L // 5 minutes + private var lastPortOrPath: Any? = null /** * Start extension process @@ -68,6 +77,10 @@ class ExtensionProcessManager : Disposable { LOG.info("Extension process is already running") return true } + + // Store for potential restart + lastPortOrPath = portOrPath + val isUds = portOrPath is String if (!ExtensionUtils.isValidPortOrPath(portOrPath)) { LOG.error("Invalid socket info: $portOrPath") @@ -219,12 +232,14 @@ class ExtensionProcessManager : Disposable { logThread.start() // Wait for process to end - try { - val exitCode = proc.waitFor() - LOG.info("Extension process exited with code: $exitCode") + val exitCode = try { + proc.waitFor() } catch (e: InterruptedException) { LOG.info("Process monitor interrupted") + -1 } + + LOG.info("Extension process exited with code: $exitCode") // Ensure log thread ends logThread.interrupt() @@ -233,6 +248,11 @@ class ExtensionProcessManager : Disposable { } catch (e: InterruptedException) { // Ignore } + + // Handle unexpected crashes + if (exitCode != 0 && !Thread.currentThread().isInterrupted) { + handleProcessCrash(exitCode) + } } catch (e: Exception) { LOG.error("Error monitoring extension process", e) } finally { @@ -244,6 +264,56 @@ class ExtensionProcessManager : Disposable { } } } + + /** + * Handle process crash and attempt recovery + */ + private fun handleProcessCrash(exitCode: Int) { + val now = System.currentTimeMillis() + + // Reset crash count if enough time has passed + if (now - lastCrashTime.get() > crashResetWindow) { + crashCount.set(0) + } + + val crashes = crashCount.incrementAndGet() + lastCrashTime.set(now) + + LOG.error("Extension process crashed with exit code $exitCode (crash #$crashes)") + + if (crashes <= maxCrashesBeforeGiveUp) { + LOG.info("Attempting automatic restart (attempt $crashes/$maxCrashesBeforeGiveUp)") + + try { + // Wait before restart with exponential backoff + val delay = 2000L * crashes + Thread.sleep(delay) + + // Attempt restart + val portOrPath = lastPortOrPath + if (portOrPath != null) { + val restarted = start(portOrPath) + if (restarted) { + LOG.info("Extension process restarted successfully after crash") + } else { + LOG.error("Failed to restart extension process after crash") + } + } else { + LOG.error("Cannot restart: no port/path information available") + } + } catch (e: InterruptedException) { + LOG.info("Restart attempt interrupted") + } catch (e: Exception) { + LOG.error("Error during crash recovery", e) + } + } else { + LOG.error("Max crash count reached ($crashes), giving up on automatic restart") + NotificationUtil.showError( + I18n.t("jetbrains:errors.extensionCrashed.title"), + I18n.t("jetbrains:errors.extensionCrashed.message", mapOf("crashes" to crashes)) + ) + } + } /** * Stop extension process diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionSocketServer.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionSocketServer.kt index 4d07b990ad0..3390c764a40 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionSocketServer.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionSocketServer.kt @@ -156,6 +156,9 @@ class ExtensionSocketServer() : ISocketServer { // Create extension host manager val manager = ExtensionHostManager(clientSocket, projectPath, project) clientManagers[clientSocket] = manager + + // Register with PluginContext for access from UI + project.getService(PluginContext::class.java).setExtensionHostManager(manager) handleClient(clientSocket, manager) } catch (e: IOException) { @@ -317,6 +320,9 @@ class ExtensionSocketServer() : ISocketServer { // Create extension host manager val manager = ExtensionHostManager(clientSocket, projectPath, project) clientManagers[clientSocket] = manager + + // Register with PluginContext for access from UI + project.getService(PluginContext::class.java).setExtensionHostManager(manager) // Start connection handling in background thread thread(start = true, name = "DebugHostHandler") { diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionUnixDomainSocketServer.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionUnixDomainSocketServer.kt index 39ed5efbbed..7da61129e93 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionUnixDomainSocketServer.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/ExtensionUnixDomainSocketServer.kt @@ -4,6 +4,7 @@ package ai.kilocode.jetbrains.core +import ai.kilocode.jetbrains.plugin.SystemObjectProvider import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.project.Project import java.io.IOException @@ -146,6 +147,24 @@ class ExtensionUnixDomainSocketServer : ISocketServer { logger.info("[UDS] New client connected") val manager = ExtensionHostManager(clientChannel, projectPath, project) clientManagers[clientChannel] = manager + + // Register ExtensionHostManager in SystemObjectProvider for access by other components + try { + val systemObjectProvider = SystemObjectProvider.getInstance(project) + systemObjectProvider.register("extensionHostManager", manager) + logger.info("[UDS] Registered ExtensionHostManager in SystemObjectProvider") + } catch (e: Exception) { + logger.error("[UDS] Failed to register ExtensionHostManager in SystemObjectProvider", e) + } + + // Also register with PluginContext for UI access + try { + project.getService(PluginContext::class.java).setExtensionHostManager(manager) + logger.info("[UDS] Registered ExtensionHostManager in PluginContext") + } catch (e: Exception) { + logger.error("[UDS] Failed to register ExtensionHostManager in PluginContext", e) + } + handleClient(clientChannel, manager) // Start client handler thread } catch (e: Exception) { if (isRunning) { @@ -202,6 +221,24 @@ class ExtensionUnixDomainSocketServer : ISocketServer { // Connection close and resource release manager.dispose() clientManagers.remove(clientChannel) + + // Remove ExtensionHostManager from SystemObjectProvider + try { + val systemObjectProvider = SystemObjectProvider.getInstance(project) + systemObjectProvider.remove("extensionHostManager") + logger.info("[UDS] Removed ExtensionHostManager from SystemObjectProvider") + } catch (e: Exception) { + logger.warn("[UDS] Failed to remove ExtensionHostManager from SystemObjectProvider", e) + } + + // Also clear from PluginContext + try { + project.getService(PluginContext::class.java).clear() + logger.info("[UDS] Cleared PluginContext") + } catch (e: Exception) { + logger.warn("[UDS] Failed to clear PluginContext", e) + } + try { clientChannel.close() } catch (e: IOException) { diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/InitializationHealthCheck.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/InitializationHealthCheck.kt new file mode 100644 index 00000000000..3bf71228ec3 --- /dev/null +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/InitializationHealthCheck.kt @@ -0,0 +1,145 @@ +package ai.kilocode.jetbrains.core + +import com.intellij.openapi.diagnostic.Logger + +/** + * Health check mechanism for initialization process. + * Monitors initialization progress and provides recovery suggestions. + */ +class InitializationHealthCheck(private val stateMachine: InitializationStateMachine) { + private val logger = Logger.getInstance(InitializationHealthCheck::class.java) + + enum class HealthStatus { + HEALTHY, + STUCK, + FAILED + } + + /** + * Check the health of the initialization process. + * @return Current health status + */ + fun checkHealth(): HealthStatus { + val currentState = stateMachine.getCurrentState() + val stateAge = stateMachine.getStateDuration(currentState) ?: 0L + + return when { + currentState == InitializationState.FAILED -> HealthStatus.FAILED + stateAge > getMaxDuration(currentState) -> HealthStatus.STUCK + else -> HealthStatus.HEALTHY + } + } + + /** + * Get suggestions for the current health status. + * @return List of suggestions for the user + */ + fun getSuggestions(): List { + val status = checkHealth() + val currentState = stateMachine.getCurrentState() + + return when (status) { + HealthStatus.STUCK -> getSuggestionsForStuckState(currentState) + HealthStatus.FAILED -> listOf("Initialization failed. Please check logs and restart the IDE.") + else -> emptyList() + } + } + + /** + * Get maximum allowed duration for a state before it's considered stuck. + * This is typically 3-5x the expected duration to account for slow machines. + * @param state The state to check + * @return Maximum duration in milliseconds + */ + private fun getMaxDuration(state: InitializationState): Long { + return when (state) { + InitializationState.SOCKET_CONNECTING -> 20000L // 20 seconds + InitializationState.SOCKET_CONNECTED -> 5000L + InitializationState.READY_RECEIVED -> 5000L + InitializationState.INIT_DATA_SENT -> 10000L + InitializationState.INITIALIZED_RECEIVED -> 10000L + InitializationState.RPC_CREATING -> 5000L + InitializationState.RPC_CREATED -> 5000L + InitializationState.EXTENSION_ACTIVATING -> 20000L // 20 seconds + InitializationState.EXTENSION_ACTIVATED -> 15000L // 15 seconds + InitializationState.WEBVIEW_REGISTERING -> 5000L + InitializationState.WEBVIEW_REGISTERED -> 3000L + InitializationState.WEBVIEW_RESOLVING -> 5000L + InitializationState.WEBVIEW_RESOLVED -> 10000L + InitializationState.HTML_LOADING -> 30000L // 30 seconds + InitializationState.HTML_LOADED -> 5000L + InitializationState.THEME_INJECTING -> 25000L // 25 seconds (10 retries with backoff) + InitializationState.THEME_INJECTED -> 3000L + else -> 5000L + } + } + + /** + * Get suggestions for a stuck state. + * @param state The state that is stuck + * @return List of suggestions + */ + private fun getSuggestionsForStuckState(state: InitializationState): List { + return when (state) { + InitializationState.SOCKET_CONNECTING -> listOf( + "Socket connection is taking longer than expected.", + "Check if Node.js is installed and accessible.", + "Check firewall settings.", + "Try restarting the IDE." + ) + InitializationState.EXTENSION_ACTIVATING -> listOf( + "Extension activation is taking longer than expected.", + "This might be due to slow disk I/O or CPU.", + "Try closing other applications to free up resources.", + "Check if antivirus software is scanning the plugin files." + ) + InitializationState.HTML_LOADING -> listOf( + "HTML loading is taking longer than expected.", + "This might be due to slow disk I/O.", + "Try closing other applications to free up resources.", + "Check if antivirus software is interfering." + ) + InitializationState.THEME_INJECTING -> listOf( + "Theme injection is taking longer than expected.", + "This might be due to slow JCEF initialization.", + "The webview should still work without theme.", + "Try restarting the IDE if the issue persists." + ) + InitializationState.WEBVIEW_REGISTERING, + InitializationState.WEBVIEW_RESOLVING -> listOf( + "WebView registration is taking longer than expected.", + "This might be due to slow JCEF initialization.", + "Try closing other applications to free up resources." + ) + else -> listOf( + "Initialization is taking longer than expected. Please wait...", + "If the issue persists, try restarting the IDE." + ) + } + } + + /** + * Get a diagnostic report including health status and suggestions. + * @return Diagnostic report as a string + */ + fun getDiagnosticReport(): String { + val status = checkHealth() + val suggestions = getSuggestions() + val currentState = stateMachine.getCurrentState() + val stateAge = stateMachine.getStateDuration(currentState) ?: 0L + + return buildString { + appendLine("=== Initialization Health Check ===") + appendLine("Status: $status") + appendLine("Current State: $currentState") + appendLine("State Age: ${stateAge}ms") + appendLine() + if (suggestions.isNotEmpty()) { + appendLine("Suggestions:") + suggestions.forEach { suggestion -> + appendLine(" - $suggestion") + } + } + } + } +} diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/InitializationStateMachine.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/InitializationStateMachine.kt new file mode 100644 index 00000000000..bcbfcf8ac97 --- /dev/null +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/InitializationStateMachine.kt @@ -0,0 +1,238 @@ +package ai.kilocode.jetbrains.core + +import com.intellij.openapi.diagnostic.Logger +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock + +enum class InitializationState { + NOT_STARTED, + SOCKET_CONNECTING, + SOCKET_CONNECTED, + READY_RECEIVED, + INIT_DATA_SENT, + INITIALIZED_RECEIVED, + RPC_CREATING, + RPC_CREATED, + EXTENSION_ACTIVATING, + EXTENSION_ACTIVATED, + WEBVIEW_REGISTERING, + WEBVIEW_REGISTERED, + WEBVIEW_RESOLVING, + WEBVIEW_RESOLVED, + HTML_LOADING, + HTML_LOADED, + THEME_INJECTING, + THEME_INJECTED, + COMPLETE, + FAILED; + + fun canTransitionTo(newState: InitializationState): Boolean { + return when (this) { + NOT_STARTED -> newState == SOCKET_CONNECTING + SOCKET_CONNECTING -> newState in setOf(SOCKET_CONNECTED, FAILED) + SOCKET_CONNECTED -> newState in setOf(READY_RECEIVED, FAILED) + READY_RECEIVED -> newState in setOf(INIT_DATA_SENT, FAILED) + INIT_DATA_SENT -> newState in setOf(INITIALIZED_RECEIVED, FAILED) + INITIALIZED_RECEIVED -> newState in setOf(RPC_CREATING, FAILED) + RPC_CREATING -> newState in setOf(RPC_CREATED, FAILED) + RPC_CREATED -> newState in setOf(EXTENSION_ACTIVATING, FAILED) + EXTENSION_ACTIVATING -> newState in setOf(EXTENSION_ACTIVATED, WEBVIEW_REGISTERING, FAILED) + EXTENSION_ACTIVATED -> newState in setOf(WEBVIEW_REGISTERING, WEBVIEW_REGISTERED, HTML_LOADING, HTML_LOADED, COMPLETE, FAILED) // Allow forward progress even if webview registration happened during activation + WEBVIEW_REGISTERING -> newState in setOf(WEBVIEW_REGISTERED, EXTENSION_ACTIVATED, FAILED) // Allow EXTENSION_ACTIVATED for race condition + WEBVIEW_REGISTERED -> newState in setOf(WEBVIEW_RESOLVING, FAILED) + WEBVIEW_RESOLVING -> newState in setOf(WEBVIEW_RESOLVED, FAILED) + WEBVIEW_RESOLVED -> newState in setOf(HTML_LOADING, FAILED) + HTML_LOADING -> newState in setOf(HTML_LOADED, FAILED) + HTML_LOADED -> newState in setOf(THEME_INJECTING, COMPLETE, FAILED) + THEME_INJECTING -> newState in setOf(THEME_INJECTED, FAILED) + THEME_INJECTED -> newState in setOf(COMPLETE, FAILED) + COMPLETE -> false // No transitions from COMPLETE + FAILED -> false // No transitions from FAILED + } + } +} + +class InitializationStateMachine { + private val logger = Logger.getInstance(InitializationStateMachine::class.java) + private val state = AtomicReference(InitializationState.NOT_STARTED) + private val stateLock = ReentrantLock() + private val stateTimestamps = ConcurrentHashMap() + private val stateCompletions = ConcurrentHashMap>() + private val stateListeners = ConcurrentHashMap Unit>>() + + init { + // Create completion futures for all states + InitializationState.values().forEach { state -> + stateCompletions[state] = CompletableFuture() + } + // Mark NOT_STARTED as complete immediately + stateTimestamps[InitializationState.NOT_STARTED] = System.currentTimeMillis() + stateCompletions[InitializationState.NOT_STARTED]?.complete(Unit) + } + + fun getCurrentState(): InitializationState = state.get() + + fun transitionTo(newState: InitializationState, context: String = ""): Boolean { + return stateLock.withLock { + val currentState = state.get() + + // Idempotent: if already at target state, return success without logging error + if (currentState == newState) { + logger.debug("Already at state $newState, ignoring duplicate transition (context: $context)") + return true + } + + // Terminal states: once reached, no further transitions allowed + if (currentState == InitializationState.COMPLETE) { + logger.debug("Already in COMPLETE state, ignoring transition to $newState (context: $context)") + return false + } + + if (currentState == InitializationState.FAILED) { + logger.debug("Already in FAILED state, ignoring transition to $newState (context: $context)") + return false + } + + if (!currentState.canTransitionTo(newState)) { + logger.warn("Invalid state transition: $currentState -> $newState (context: $context)") + return false + } + + val now = System.currentTimeMillis() + val previousTimestamp = stateTimestamps[currentState] ?: now + val duration = now - previousTimestamp + + // Check if transition took longer than expected and log warning + val expectedDuration = getExpectedDuration(currentState) + if (duration > expectedDuration) { + logger.warn("Slow state transition: $currentState -> $newState took ${duration}ms (expected: ${expectedDuration}ms, context: $context)") + } else { + logger.info("State transition: $currentState -> $newState (took ${duration}ms, context: $context)") + } + + state.set(newState) + stateTimestamps[newState] = now + + // Complete the future for this state + stateCompletions[newState]?.complete(Unit) + + // Notify listeners + stateListeners[newState]?.forEach { listener -> + try { + listener(newState) + } catch (e: Exception) { + logger.error("Error in state listener for $newState", e) + } + } + + // If failed, complete all remaining futures exceptionally + if (newState == InitializationState.FAILED) { + val error = IllegalStateException("Initialization failed at state $currentState (context: $context)") + stateCompletions.values.forEach { future -> + if (!future.isDone) { + future.completeExceptionally(error) + } + } + } + + true + } + } + + /** + * Get expected duration for a state transition in milliseconds. + * These are baseline expectations for normal machines. + * @param state The state to get expected duration for + * @return Expected duration in milliseconds + */ + private fun getExpectedDuration(state: InitializationState): Long { + return when (state) { + InitializationState.SOCKET_CONNECTING -> 5000L + InitializationState.SOCKET_CONNECTED -> 1000L + InitializationState.READY_RECEIVED -> 1000L + InitializationState.INIT_DATA_SENT -> 2000L + InitializationState.INITIALIZED_RECEIVED -> 2000L + InitializationState.RPC_CREATING -> 1000L + InitializationState.RPC_CREATED -> 1000L + InitializationState.EXTENSION_ACTIVATING -> 5000L + InitializationState.EXTENSION_ACTIVATED -> 2000L + InitializationState.WEBVIEW_REGISTERING -> 1000L + InitializationState.WEBVIEW_REGISTERED -> 500L + InitializationState.WEBVIEW_RESOLVING -> 1000L + InitializationState.WEBVIEW_RESOLVED -> 2000L + InitializationState.HTML_LOADING -> 10000L + InitializationState.HTML_LOADED -> 1000L + InitializationState.THEME_INJECTING -> 2000L + InitializationState.THEME_INJECTED -> 500L + else -> 1000L + } + } + + fun waitForState(targetState: InitializationState): CompletableFuture { + val currentState = state.get() + + // If already at or past target state, return completed future + if (currentState.ordinal >= targetState.ordinal && currentState != InitializationState.FAILED) { + return CompletableFuture.completedFuture(Unit) + } + + // If failed, return failed future + if (currentState == InitializationState.FAILED) { + return CompletableFuture.failedFuture( + IllegalStateException("Initialization failed before reaching $targetState") + ) + } + + // Otherwise return the completion future for that state + return stateCompletions[targetState] ?: CompletableFuture.failedFuture( + IllegalStateException("No completion future for state $targetState") + ) + } + + fun onStateReached(targetState: InitializationState, listener: (InitializationState) -> Unit) { + stateListeners.computeIfAbsent(targetState) { mutableListOf() }.add(listener) + + // If already at this state, call listener immediately + if (state.get() == targetState) { + try { + listener(targetState) + } catch (e: Exception) { + logger.error("Error in immediate state listener for $targetState", e) + } + } + } + + fun getStateDuration(state: InitializationState): Long? { + val timestamp = stateTimestamps[state] ?: return null + val nextState = InitializationState.values().getOrNull(state.ordinal + 1) + val nextTimestamp = nextState?.let { stateTimestamps[it] } ?: System.currentTimeMillis() + return nextTimestamp - timestamp + } + + fun generateReport(): String { + val report = StringBuilder() + report.appendLine("=== Initialization State Machine Report ===") + report.appendLine("Current State: ${state.get()}") + report.appendLine() + + val startTime = stateTimestamps[InitializationState.NOT_STARTED] ?: System.currentTimeMillis() + + InitializationState.values().forEach { state -> + val timestamp = stateTimestamps[state] + if (timestamp != null) { + val elapsed = timestamp - startTime + val duration = getStateDuration(state) + report.append("$state: ${elapsed}ms from start") + if (duration != null) { + report.append(" (duration: ${duration}ms)") + } + report.appendLine() + } + } + + return report.toString() + } +} diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/PluginContext.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/PluginContext.kt index 01b195d6422..9329d885f4a 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/PluginContext.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/core/PluginContext.kt @@ -21,6 +21,10 @@ class PluginContext { // RPC protocol instance @Volatile private var rpcProtocol: IRPCProtocol? = null + + // Extension host manager instance + @Volatile + private var extensionHostManager: ExtensionHostManager? = null /** * Set RPC protocol instance @@ -38,6 +42,23 @@ class PluginContext { fun getRPCProtocol(): IRPCProtocol? { return rpcProtocol } + + /** + * Set extension host manager instance + * @param manager Extension host manager instance + */ + fun setExtensionHostManager(manager: ExtensionHostManager) { + logger.info("Setting extension host manager instance") + extensionHostManager = manager + } + + /** + * Get extension host manager instance + * @return Extension host manager instance, or null if not set + */ + fun getExtensionHostManager(): ExtensionHostManager? { + return extensionHostManager + } /** * Clear all resources @@ -45,6 +66,7 @@ class PluginContext { fun clear() { logger.info("Clearing resources in PluginContext") rpcProtocol = null + extensionHostManager = null } companion object { diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/editor/EditorAndDocManager.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/editor/EditorAndDocManager.kt index e3c7c6270d7..b10c6c1af1f 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/editor/EditorAndDocManager.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/editor/EditorAndDocManager.kt @@ -4,6 +4,7 @@ package ai.kilocode.jetbrains.editor +import ai.kilocode.jetbrains.plugin.SystemObjectProvider import ai.kilocode.jetbrains.util.URI import com.intellij.diff.DiffContentFactory import com.intellij.diff.chains.DiffRequestChain @@ -131,18 +132,48 @@ class EditorAndDocManager(val project: Project) : Disposable { fun initCurrentIdeaEditor() { CoroutineScope(Dispatchers.Default).launch { - FileEditorManager.getInstance(project).allEditors.forEach { editor -> - // Record and synchronize - if (editor is FileEditor) { - val uri = URI.file(editor.file.path) - val handle = sync2ExtHost(uri, false) - handle.ideaEditor = editor - val group = tabManager.createTabGroup(EditorGroupColumn.BESIDE.value, true) - val options = TabOptions(isActive = true) - val tab = group.addTab(EditorTabInput(uri, uri.path, ""), options) - handle.tab = tab - handle.group = group + // Wait for extension host to be ready before initializing editors + try { + // Get ExtensionHostManager from SystemObjectProvider + val systemObjectProvider = SystemObjectProvider.getInstance(project) + val extensionHostManager = systemObjectProvider.get("extensionHostManager") + + if (extensionHostManager == null) { + logger.error("ExtensionHostManager not available in SystemObjectProvider, skipping editor initialization") + return@launch } + + val isReady = try { + extensionHostManager.waitForReady().get() + } catch (e: Exception) { + logger.error("Error waiting for extension host to be ready", e) + false + } + + if (!isReady) { + logger.error("Extension host failed to initialize, skipping editor initialization") + return@launch + } + + logger.info("Extension host ready, initializing current IDE editors") + + FileEditorManager.getInstance(project).allEditors.forEach { editor -> + // Record and synchronize + if (editor is FileEditor) { + val uri = URI.file(editor.file.path) + val handle = sync2ExtHost(uri, false) + handle.ideaEditor = editor + val group = tabManager.createTabGroup(EditorGroupColumn.BESIDE.value, true) + val options = TabOptions(isActive = true) + val tab = group.addTab(EditorTabInput(uri, uri.path, ""), options) + handle.tab = tab + handle.group = group + } + } + + logger.info("Completed initialization of ${FileEditorManager.getInstance(project).allEditors.size} editors") + } catch (e: Exception) { + logger.error("Error during editor initialization", e) } } } diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/editor/EditorStateService.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/editor/EditorStateService.kt index 995e27ca988..8a6f401810d 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/editor/EditorStateService.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/editor/EditorStateService.kt @@ -4,6 +4,7 @@ package ai.kilocode.jetbrains.editor +import ai.kilocode.jetbrains.core.ExtensionHostManager import ai.kilocode.jetbrains.core.PluginContext import ai.kilocode.jetbrains.core.ServiceProxyRegistry import ai.kilocode.jetbrains.ipc.proxy.interfaces.ExtHostDocumentsAndEditorsProxy @@ -17,35 +18,45 @@ class EditorStateService(val project: Project) { var extHostDocumentsAndEditorsProxy: ExtHostDocumentsAndEditorsProxy? = null var extHostEditorsProxy: ExtHostEditorsProxy? = null var extHostDocumentsProxy: ExtHostDocumentsProxy? = null + + private fun getExtensionHostManager(): ExtensionHostManager? { + return PluginContext.getInstance(project).getExtensionHostManager() + } fun acceptDocumentsAndEditorsDelta(detail: DocumentsAndEditorsDelta) { - val protocol = PluginContext.getInstance(project).getRPCProtocol() - if (extHostDocumentsAndEditorsProxy == null) { - extHostDocumentsAndEditorsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostDocumentsAndEditors) + getExtensionHostManager()?.queueMessage { + val protocol = PluginContext.getInstance(project).getRPCProtocol() + if (extHostDocumentsAndEditorsProxy == null) { + extHostDocumentsAndEditorsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostDocumentsAndEditors) + } + extHostDocumentsAndEditorsProxy?.acceptDocumentsAndEditorsDelta(detail) } - extHostDocumentsAndEditorsProxy?.acceptDocumentsAndEditorsDelta(detail) } fun acceptEditorPropertiesChanged(detail: Map) { - val protocol = PluginContext.getInstance(project).getRPCProtocol() - if (extHostEditorsProxy == null) { - extHostEditorsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostEditors) - } - extHostEditorsProxy?.let { - for ((id, data) in detail) { - it.acceptEditorPropertiesChanged(id, data) + getExtensionHostManager()?.queueMessage { + val protocol = PluginContext.getInstance(project).getRPCProtocol() + if (extHostEditorsProxy == null) { + extHostEditorsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostEditors) + } + extHostEditorsProxy?.let { + for ((id, data) in detail) { + it.acceptEditorPropertiesChanged(id, data) + } } } } fun acceptModelChanged(detail: Map) { - val protocol = PluginContext.getInstance(project).getRPCProtocol() - if (extHostDocumentsProxy == null) { - extHostDocumentsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostDocuments) - } - extHostDocumentsProxy?.let { - for ((uri, data) in detail) { - it.acceptModelChanged(uri, data, data.isDirty) + getExtensionHostManager()?.queueMessage { + val protocol = PluginContext.getInstance(project).getRPCProtocol() + if (extHostDocumentsProxy == null) { + extHostDocumentsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostDocuments) + } + extHostDocumentsProxy?.let { + for ((uri, data) in detail) { + it.acceptModelChanged(uri, data, data.isDirty) + } } } } @@ -53,28 +64,38 @@ class EditorStateService(val project: Project) { class TabStateService(val project: Project) { var extHostEditorTabsProxy: ExtHostEditorTabsProxy? = null + + private fun getExtensionHostManager(): ExtensionHostManager? { + return PluginContext.getInstance(project).getExtensionHostManager() + } fun acceptEditorTabModel(detail: List) { - val protocol = PluginContext.getInstance(project).getRPCProtocol() - if (extHostEditorTabsProxy == null) { - extHostEditorTabsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostEditorTabs) + getExtensionHostManager()?.queueMessage { + val protocol = PluginContext.getInstance(project).getRPCProtocol() + if (extHostEditorTabsProxy == null) { + extHostEditorTabsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostEditorTabs) + } + extHostEditorTabsProxy?.acceptEditorTabModel(detail) } - extHostEditorTabsProxy?.acceptEditorTabModel(detail) } fun acceptTabOperation(detail: TabOperation) { - val protocol = PluginContext.getInstance(project).getRPCProtocol() - if (extHostEditorTabsProxy == null) { - extHostEditorTabsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostEditorTabs) + getExtensionHostManager()?.queueMessage { + val protocol = PluginContext.getInstance(project).getRPCProtocol() + if (extHostEditorTabsProxy == null) { + extHostEditorTabsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostEditorTabs) + } + extHostEditorTabsProxy?.acceptTabOperation(detail) } - extHostEditorTabsProxy?.acceptTabOperation(detail) } fun acceptTabGroupUpdate(detail: EditorTabGroupDto) { - val protocol = PluginContext.getInstance(project).getRPCProtocol() - if (extHostEditorTabsProxy == null) { - extHostEditorTabsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostEditorTabs) + getExtensionHostManager()?.queueMessage { + val protocol = PluginContext.getInstance(project).getRPCProtocol() + if (extHostEditorTabsProxy == null) { + extHostEditorTabsProxy = protocol?.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostEditorTabs) + } + extHostEditorTabsProxy?.acceptTabGroupUpdate(detail) } - extHostEditorTabsProxy?.acceptTabGroupUpdate(detail) } } diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/events/EventBus.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/events/EventBus.kt index e55cd5e17f8..2a7f1747814 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/events/EventBus.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/events/EventBus.kt @@ -147,6 +147,12 @@ open class AbsEventBus : Disposable { } override fun dispose() { + logger.info("Disposing EventBus, clearing ${listeners.size} listener types") + + // Clear all listeners to prevent memory leaks + listeners.clear() + + logger.info("EventBus disposed") } } diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/git/CommitMessageHandler.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/git/CommitMessageHandler.kt index ded4bc8ec7f..81953e32ef8 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/git/CommitMessageHandler.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/git/CommitMessageHandler.kt @@ -43,7 +43,7 @@ class CommitMessageHandler( ) : CheckinHandler() { private val logger: Logger = Logger.getInstance(CommitMessageHandler::class.java) - private val commitMessageService = CommitMessageService.getInstance() + private val commitMessageService by lazy { CommitMessageService.getInstance(panel.project) } private val fileDiscoveryService = FileDiscoveryService() private lateinit var generateButton: JButton diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/git/CommitMessageService.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/git/CommitMessageService.kt index 97bda0d2731..14035125cf9 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/git/CommitMessageService.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/git/CommitMessageService.kt @@ -129,9 +129,10 @@ class CommitMessageService { companion object { /** * Gets or creates the CommitMessageService instance for the project. + * @param project The project context for which to get the service */ - fun getInstance(): CommitMessageService { - return ApplicationManager.getApplication().getService(CommitMessageService::class.java) + fun getInstance(project: Project): CommitMessageService { + return project.getService(CommitMessageService::class.java) } } } diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/PersistentProtocol.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/PersistentProtocol.kt index e9b912c4c4e..6c5419252b4 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/PersistentProtocol.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/PersistentProtocol.kt @@ -116,19 +116,36 @@ class PersistentProtocol(opts: PersistentProtocolOptions, msgListener: ((data: B } override fun dispose() { - _outgoingAckTimeout?.cancel() + // Cancel and purge all timers to prevent memory leaks + _outgoingAckTimeout?.let { timer -> + timer.cancel() + timer.purge() + } _outgoingAckTimeout = null - _incomingAckTimeout?.cancel() + _incomingAckTimeout?.let { timer -> + timer.cancel() + timer.purge() + } _incomingAckTimeout = null - _keepAliveInterval?.cancel() + _keepAliveInterval?.let { timer -> + timer.cancel() + timer.purge() + } _keepAliveInterval = null + // Dispose socket-related resources _socketDisposables.forEach { it.dispose() } _socketDisposables.clear() + // Clear message queues to free memory + val unackMsgCount = _outgoingUnackMsg.size + _outgoingUnackMsg.clear() + _isDisposed = true + + LOG.info("PersistentProtocol disposed, cleared $unackMsgCount unacknowledged messages") } override suspend fun drain() { diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/ProtocolConstants.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/ProtocolConstants.kt index 7c026cf0c6b..85865cd4f19 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/ProtocolConstants.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/ProtocolConstants.kt @@ -16,14 +16,16 @@ object ProtocolConstants { /** * Maximum delay time for sending acknowledgment messages (milliseconds) + * Increased from 2s to 5s to accommodate slower machines */ - const val ACKNOWLEDGE_TIME = 2000 // 2 seconds + const val ACKNOWLEDGE_TIME = 5000 // 5 seconds /** * If a sent message has not been acknowledged beyond this time, and no server data has been received during this period, * the connection is considered timed out + * Increased from 20s to 60s to accommodate slower machines and initialization delays */ - const val TIMEOUT_TIME = 20000 // 20 seconds + const val TIMEOUT_TIME = 60000 // 60 seconds /** * If no reconnection occurs within this time range, the connection is considered permanently closed diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/proxy/PendingRPCReply.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/proxy/PendingRPCReply.kt index ef442f7d893..22176a6c0b8 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/proxy/PendingRPCReply.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/proxy/PendingRPCReply.kt @@ -14,6 +14,8 @@ class PendingRPCReply( private val promise: LazyPromise, private val disposable: Disposable, ) { + val creationTime: Long = System.currentTimeMillis() + /** * Resolve reply successfully * @param value Result value diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/proxy/RPCProtocol.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/proxy/RPCProtocol.kt index 6f3529e07fb..8e5e9de15d1 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/proxy/RPCProtocol.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ipc/proxy/RPCProtocol.kt @@ -85,8 +85,24 @@ class RPCProtocol( /** * Unresponsive time threshold (milliseconds) + * Increased from 3s to 10s to accommodate slower machines and initialization delays */ - private const val UNRESPONSIVE_TIME = 3 * 1000 // 3s, same as TS implementation + private const val UNRESPONSIVE_TIME = 10 * 1000 // 10s + + /** + * Maximum pending replies before warning + */ + private const val PENDING_REPLY_WARNING_THRESHOLD = 500 + + /** + * Maximum pending replies before cleanup + */ + private const val MAX_PENDING_REPLIES = 1000 + + /** + * Stale reply timeout (5 minutes) + */ + private const val STALE_REPLY_TIMEOUT = 300000L /** * RPC protocol symbol (used to identify objects implementing this interface) @@ -459,6 +475,9 @@ class RPCProtocol( pendingRPCReplies[callId] = PendingRPCReply(result, disposable) onWillSendRequest(req) + + // Monitor pending reply count + checkPendingReplies() val usesCancellationToken = cancellationToken != null val msg = MessageIO.serializeRequest(req, rpcId, methodName, serializedRequestArguments, usesCancellationToken) @@ -476,6 +495,48 @@ class RPCProtocol( // Directly return Promise, do not block current thread return result } + + /** + * Check pending reply count and cleanup stale replies + */ + private fun checkPendingReplies() { + val pendingCount = pendingRPCReplies.size + + if (pendingCount > MAX_PENDING_REPLIES) { + LOG.error("Too many pending RPC replies ($pendingCount), possible leak or deadlock - cleaning up stale replies") + cleanupStalePendingReplies() + } else if (pendingCount > PENDING_REPLY_WARNING_THRESHOLD) { + LOG.warn("High number of pending RPC replies: $pendingCount") + } + } + + /** + * Cleanup stale pending replies that have been waiting too long + */ + private fun cleanupStalePendingReplies() { + val now = System.currentTimeMillis() + var cleanedCount = 0 + + pendingRPCReplies.entries.removeIf { (msgId, reply) -> + val age = now - reply.creationTime + if (age > STALE_REPLY_TIMEOUT) { + LOG.warn("Removing stale pending reply: msgId=$msgId, age=${age}ms") + try { + reply.resolveErr(java.util.concurrent.TimeoutException("Reply timeout after ${age}ms")) + } catch (e: Exception) { + LOG.error("Error resolving stale reply", e) + } + cleanedCount++ + true + } else { + false + } + } + + if (cleanedCount > 0) { + LOG.info("Cleaned up $cleanedCount stale pending replies") + } + } /** * Receive a message diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/plugin/WecoderPlugin.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/plugin/WecoderPlugin.kt index b6eb91870f4..3f3a1dad40f 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/plugin/WecoderPlugin.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/plugin/WecoderPlugin.kt @@ -65,8 +65,9 @@ class WecoderPlugin : StartupActivity.DumbAware { LOG.info("Project closed: ${project.name}") // Clean up resources for closed project try { - val pluginService = getInstance(project) - pluginService.dispose() + // Use getServiceIfCreated to avoid initializing service during disposal + val pluginService = project.getServiceIfCreated(WecoderPluginService::class.java) + pluginService?.dispose() } catch (e: Exception) { LOG.error("Failed to dispose plugin for closed project: ${project.name}", e) } @@ -191,6 +192,13 @@ class WecoderPluginService(private var currentProject: Project) : Disposable { // Whether initialized @Volatile private var isInitialized = false + + // Disposal state + @Volatile + private var isDisposing = false + + @Volatile + private var isDisposed = false // Plugin initialization complete flag private var initializationComplete = CompletableFuture() @@ -265,6 +273,12 @@ class WecoderPluginService(private var currentProject: Project) : Disposable { * Initialize plugin service */ fun initialize(project: Project) { + // Check if disposing or disposed + if (isDisposing || isDisposed) { + LOG.warn("Cannot initialize: service is disposing or disposed") + return + } + // Check if already initialized for the same project if (isInitialized && this.currentProject == project) { LOG.info("WecoderPluginService already initialized for project: ${project.name}") @@ -481,10 +495,17 @@ class WecoderPluginService(private var currentProject: Project) : Disposable { * Close service */ override fun dispose() { + if (isDisposed) { + LOG.warn("Service already disposed") + return + } + if (!isInitialized) { + isDisposed = true return } + isDisposing = true LOG.info("Disposing WecoderPluginService") currentProject?.getService(WebViewManager::class.java)?.dispose() @@ -495,6 +516,8 @@ class WecoderPluginService(private var currentProject: Project) : Disposable { // Clean up resources cleanup() + isDisposed = true + isDisposing = false LOG.info("WecoderPluginService disposed") } } diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ui/RooToolWindowFactory.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ui/RooToolWindowFactory.kt index 7bad8bb48f7..650304bf567 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ui/RooToolWindowFactory.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/ui/RooToolWindowFactory.kt @@ -1,13 +1,11 @@ -// SPDX-FileCopyrightText: 2025 Weibo, Inc. -// -// SPDX-License-Identifier: Apache-2.0 - package ai.kilocode.jetbrains.ui import ai.kilocode.jetbrains.actions.OpenDevToolsAction +import ai.kilocode.jetbrains.core.PluginContext import ai.kilocode.jetbrains.plugin.DebugMode import ai.kilocode.jetbrains.plugin.WecoderPlugin import ai.kilocode.jetbrains.plugin.WecoderPluginService +import ai.kilocode.jetbrains.util.NodeVersionUtil import ai.kilocode.jetbrains.util.PluginConstants import ai.kilocode.jetbrains.webview.DragDropHandler import ai.kilocode.jetbrains.webview.WebViewCreationCallback @@ -80,8 +78,49 @@ class RooToolWindowFactory : ToolWindowFactory { // Placeholder label private val placeholderLabel = JLabel(createSystemInfoText()) - // System info text for copying - private val systemInfoText = createSystemInfoPlainText() + // System info text for copying - will be updated + private var systemInfoText = createSystemInfoPlainText() + + // Timer for updating status display + private var statusUpdateTimer: java.util.Timer? = null + + /** + * Get initialization state text from state machine + */ + private fun getInitStateText(): String { + val pluginContext = try { + project.getService(PluginContext::class.java) + } catch (e: Exception) { + null + } + + val extensionHostManager = pluginContext?.getExtensionHostManager() + val initState = extensionHostManager?.stateMachine?.getCurrentState() + return when (initState?.name) { + null -> "Initializing..." + "NOT_STARTED" -> "Starting..." + "SOCKET_CONNECTING" -> "Connecting to extension host..." + "SOCKET_CONNECTED" -> "Connected to extension host" + "READY_RECEIVED" -> "Extension host ready" + "INIT_DATA_SENT" -> "Sending initialization data..." + "INITIALIZED_RECEIVED" -> "Extension host initialized" + "RPC_CREATING" -> "Creating RPC protocol..." + "RPC_CREATED" -> "RPC protocol created" + "EXTENSION_ACTIVATING" -> "Activating extension..." + "EXTENSION_ACTIVATED" -> "Extension activated" + "WEBVIEW_REGISTERING" -> "Registering webview..." + "WEBVIEW_REGISTERED" -> "Webview registered" + "WEBVIEW_RESOLVING" -> "Resolving webview..." + "WEBVIEW_RESOLVED" -> "Webview resolved" + "HTML_LOADING" -> "Loading UI..." + "HTML_LOADED" -> "UI loaded" + "THEME_INJECTING" -> "Applying theme..." + "THEME_INJECTED" -> "Theme applied" + "COMPLETE" -> "Ready" + "FAILED" -> "Failed" + else -> initState.name.replace("_", " ").lowercase().replaceFirstChar { it.uppercase() } + } + } /** * Create system information text in HTML format @@ -97,12 +136,36 @@ class RooToolWindowFactory : ToolWindowFactory { // Check for Linux ARM system val isLinuxArm = osName.lowercase().contains("linux") && (osArch.lowercase().contains("aarch64") || osArch.lowercase().contains("arm")) + + // Get initialization status + val initStateText = getInitStateText() + + // Get Node.js version + val nodeVersion = try { + val nodePath = ai.kilocode.jetbrains.util.PluginResourceUtil.getResourcePath( + PluginConstants.PLUGIN_ID, + PluginConstants.NODE_MODULES_PATH + )?.let { resourcePath -> + val nodeFile = java.io.File(resourcePath, if (System.getProperty("os.name").lowercase().contains("windows")) "node.exe" else ".bin/node") + if (nodeFile.exists()) nodeFile.absolutePath else null + } ?: com.intellij.execution.configurations.PathEnvironmentVariableUtil.findExecutableInPathOnAnyOS("node")?.absolutePath + + if (nodePath != null) { + NodeVersionUtil.getNodeVersion(nodePath)?.toString() ?: "unknown" + } else { + "not found" + } + } catch (e: Exception) { + "error: ${e.message}" + } return buildString { - append("") - append("

Kilo Code is initializing...") + append("") + append("

Kilo Code Initialization

") + append("

Status: $initStateText

") append("

System Information

") append("") + append("") append("") append("") append("") @@ -153,10 +216,37 @@ class RooToolWindowFactory : ToolWindowFactory { // Check for Linux ARM system val isLinuxArm = osName.lowercase().contains("linux") && (osArch.lowercase().contains("aarch64") || osArch.lowercase().contains("arm")) + + // Get initialization status + val initStateText = getInitStateText() + + // Get Node.js version + val nodeVersion = try { + val nodePath = ai.kilocode.jetbrains.util.PluginResourceUtil.getResourcePath( + PluginConstants.PLUGIN_ID, + PluginConstants.NODE_MODULES_PATH + )?.let { resourcePath -> + val nodeFile = java.io.File(resourcePath, if (System.getProperty("os.name").lowercase().contains("windows")) "node.exe" else ".bin/node") + if (nodeFile.exists()) nodeFile.absolutePath else null + } ?: com.intellij.execution.configurations.PathEnvironmentVariableUtil.findExecutableInPathOnAnyOS("node")?.absolutePath + + if (nodePath != null) { + NodeVersionUtil.getNodeVersion(nodePath)?.toString() ?: "unknown" + } else { + "not found" + } + } catch (e: Exception) { + "error: ${e.message}" + } return buildString { + append("Kilo Code Initialization\n") + append("========================\n") + append("Status: $initStateText\n") + append("\n") append("System Information\n") append("==================\n") + append("Node.js Version: $nodeVersion\n") append("CPU Architecture: $osArch\n") append("Operating System: $osName $osVersion\n") append("IDE Version: ${appInfo.fullApplicationName} (build ${appInfo.build})\n") @@ -226,6 +316,9 @@ class RooToolWindowFactory : ToolWindowFactory { } init { + // Start timer to update status display + startStatusUpdateTimer() + // Try to get existing WebView webViewManager.getLatestWebView()?.let { webView -> // Add WebView component immediately when created @@ -236,16 +329,54 @@ class RooToolWindowFactory : ToolWindowFactory { webView.setPageLoadCallback { ApplicationManager.getApplication().invokeLater { hideSystemInfo() + stopStatusUpdateTimer() } } // If page is already loaded, hide system info immediately if (webView.isPageLoaded()) { ApplicationManager.getApplication().invokeLater { hideSystemInfo() + stopStatusUpdateTimer() } } } ?: webViewManager.addCreationCallback(this, toolWindow.disposable) } + + /** + * Start timer to update status display + */ + private fun startStatusUpdateTimer() { + statusUpdateTimer = java.util.Timer().apply { + scheduleAtFixedRate(object : java.util.TimerTask() { + override fun run() { + ApplicationManager.getApplication().invokeLater { + updateStatusDisplay() + } + } + }, 500, 500) // Update every 500ms + } + } + + /** + * Stop status update timer + */ + private fun stopStatusUpdateTimer() { + statusUpdateTimer?.cancel() + statusUpdateTimer?.purge() + statusUpdateTimer = null + } + + /** + * Update status display + */ + private fun updateStatusDisplay() { + try { + placeholderLabel.text = createSystemInfoText() + systemInfoText = createSystemInfoPlainText() + } catch (e: Exception) { + logger.error("Error updating status display", e) + } + } /** * WebView creation callback implementation @@ -278,8 +409,11 @@ class RooToolWindowFactory : ToolWindowFactory { return } } + + // Remove placeholder and buttons before adding webview + contentPanel.removeAll() - // Add WebView component without removing existing components + // Add WebView component contentPanel.add(webView.browser.component, BorderLayout.CENTER) setupDragAndDropSupport(webView) @@ -287,8 +421,11 @@ class RooToolWindowFactory : ToolWindowFactory { // Relayout contentPanel.revalidate() contentPanel.repaint() + + // Stop status update timer since webview is now visible + stopStatusUpdateTimer() - logger.info("WebView component added to tool window") + logger.info("WebView component added to tool window, placeholder removed") } /** @@ -296,6 +433,9 @@ class RooToolWindowFactory : ToolWindowFactory { */ private fun hideSystemInfo() { logger.info("Hiding system info placeholder") + + // Stop status update timer + stopStatusUpdateTimer() // Remove all components from content panel except WebView component val components = contentPanel.components diff --git a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/webview/WebViewManager.kt b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/webview/WebViewManager.kt index 9ef9e5d64eb..2373aac9464 100644 --- a/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/webview/WebViewManager.kt +++ b/jetbrains/plugin/src/main/kotlin/ai/kilocode/jetbrains/webview/WebViewManager.kt @@ -1,9 +1,7 @@ -// SPDX-FileCopyrightText: 2025 Weibo, Inc. -// -// SPDX-License-Identifier: Apache-2.0 - package ai.kilocode.jetbrains.webview +import ai.kilocode.jetbrains.core.InitializationState +import ai.kilocode.jetbrains.core.InitializationStateMachine import ai.kilocode.jetbrains.core.PluginContext import ai.kilocode.jetbrains.core.ServiceProxyRegistry import ai.kilocode.jetbrains.events.WebviewHtmlUpdateData @@ -91,6 +89,23 @@ class WebViewManager(var project: Project) : Disposable, ThemeChangeListener { // Prevent repeated dispose private var isDisposed = false private var themeInitialized = false + + // State machine reference for tracking initialization progress (lazy initialization) + private val stateMachine: InitializationStateMachine? by lazy { + try { + val pluginContext = project.getService(PluginContext::class.java) + val sm = pluginContext.getExtensionHostManager()?.stateMachine + if (sm == null) { + logger.warn("State machine not available from PluginContext") + } else { + logger.info("State machine reference obtained successfully") + } + sm + } catch (e: Exception) { + logger.error("Failed to get state machine reference", e) + null + } + } /** * Initialize theme manager @@ -240,64 +255,105 @@ class WebViewManager(var project: Project) : Disposable, ThemeChangeListener { */ fun registerProvider(data: WebviewViewProviderData) { logger.info("Register WebView provider and create WebView instance: ${data.viewType} for project: ${project.name}") - val extension = data.extension - - // Clean up any existing WebView for this project before creating a new one - disposeLatestWebView() - - // Get location info from extension and set resource root directory + try { - @Suppress("UNCHECKED_CAST") - val location = extension.get("location") as? Map - val fsPath = location?.get("fsPath") as? String - - if (fsPath != null) { - // Set resource root directory - val path = Paths.get(fsPath) - logger.info("Get resource directory path from extension: $path") - - // Ensure the resource directory exists - if (!path.exists()) { - path.createDirectories() + val currentState = stateMachine?.getCurrentState() + + // Check if we should transition to WEBVIEW_REGISTERING + // Only transition if we're at or past EXTENSION_ACTIVATING and haven't registered yet + if (currentState != null) { + when { + currentState.ordinal < InitializationState.EXTENSION_ACTIVATING.ordinal -> { + logger.warn("Webview registration called before extension activation (state: $currentState)") + // Don't transition yet, but continue with registration + } + currentState.ordinal >= InitializationState.WEBVIEW_REGISTERING.ordinal -> { + logger.debug("Webview already registering or registered (state: $currentState)") + // Don't transition, already past this state + } + else -> { + // Safe to transition to WEBVIEW_REGISTERING + stateMachine?.transitionTo(InitializationState.WEBVIEW_REGISTERING, "registerProvider() called") + } } - - // Update resource root directory - resourceRootDir = path - - // Initialize theme manager - initializeThemeManager(fsPath) } - } catch (e: Exception) { - logger.error("Failed to get resource directory from extension", e) - } + + val extension = data.extension - val protocol = project.getService(PluginContext::class.java).getRPCProtocol() - if (protocol == null) { - logger.error("Cannot get RPC protocol instance, cannot register WebView provider: ${data.viewType}") - return - } - // When registration event is notified, create a new WebView instance - val viewId = UUID.randomUUID().toString() + // Clean up any existing WebView for this project before creating a new one + disposeLatestWebView() + + // Get location info from extension and set resource root directory + try { + @Suppress("UNCHECKED_CAST") + val location = extension.get("location") as? Map + val fsPath = location?.get("fsPath") as? String + + if (fsPath != null) { + // Set resource root directory + val path = Paths.get(fsPath) + logger.info("Get resource directory path from extension: $path") + + // Ensure the resource directory exists + if (!path.exists()) { + path.createDirectories() + } - val title = data.options["title"] as? String ?: data.viewType + // Update resource root directory + resourceRootDir = path - @Suppress("UNCHECKED_CAST") - val state = data.options["state"] as? Map ?: emptyMap() + // Initialize theme manager + initializeThemeManager(fsPath) + } + } catch (e: Exception) { + logger.error("Failed to get resource directory from extension", e) + } - val webview = WebViewInstance(data.viewType, viewId, title, state, project, data.extension) - // DEBUG HERE! - // webview.showDebugWindow() + val protocol = project.getService(PluginContext::class.java).getRPCProtocol() + if (protocol == null) { + logger.error("Cannot get RPC protocol instance, cannot register WebView provider: ${data.viewType}") + stateMachine?.transitionTo(InitializationState.FAILED, "RPC protocol not available") + return + } + // When registration event is notified, create a new WebView instance + val viewId = UUID.randomUUID().toString() - val proxy = protocol.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostWebviewViews) - proxy.resolveWebviewView(viewId, data.viewType, title, state, null) + val title = data.options["title"] as? String ?: data.viewType - // Set as the latest created WebView - latestWebView = webview + @Suppress("UNCHECKED_CAST") + val state = data.options["state"] as? Map ?: emptyMap() + + val webview = WebViewInstance(data.viewType, viewId, title, state, project, data.extension, stateMachine) + // DEBUG HERE! + // webview.showDebugWindow() + + stateMachine?.transitionTo(InitializationState.WEBVIEW_REGISTERED, "WebView instance created") + + stateMachine?.transitionTo(InitializationState.WEBVIEW_RESOLVING, "Resolving webview") + val proxy = protocol.getProxy(ServiceProxyRegistry.ExtHostContext.ExtHostWebviewViews) + proxy.resolveWebviewView(viewId, data.viewType, title, state, null) + stateMachine?.transitionTo(InitializationState.WEBVIEW_RESOLVED, "Webview resolved") + + // Set as the latest created WebView + latestWebView = webview + + // If theme config is already available, send it to the newly created WebView + if (currentThemeConfig != null) { + logger.info("Theme config available, sending to newly created WebView") + webview.sendThemeConfigToWebView(currentThemeConfig!!, bodyThemeClass) + } else { + logger.debug("No theme config available yet for newly created WebView") + } - logger.info("Create WebView instance: viewType=${data.viewType}, viewId=$viewId for project: ${project.name}") + logger.info("Create WebView instance: viewType=${data.viewType}, viewId=$viewId for project: ${project.name}") - // Notify callback - notifyWebViewCreated(webview) + // Notify callback + notifyWebViewCreated(webview) + } catch (e: Exception) { + logger.error("Failed to register WebView provider", e) + stateMachine?.transitionTo(InitializationState.FAILED, "registerProvider() exception: ${e.message}") + throw e + } } /** @@ -312,122 +368,142 @@ class WebViewManager(var project: Project) : Disposable, ThemeChangeListener { * @param data HTML update data */ fun updateWebViewHtml(data: WebviewHtmlUpdateData) { - data.htmlContent = data.htmlContent.replace("/jetbrains/resources/kilocode/", "./") - data.htmlContent = data.htmlContent.replace("", "") - val encodedState = getLatestWebView()?.state.toString().replace("\"", "\\\"") - val mRst = """""".toRegex().find(data.htmlContent) - val str = mRst?.value ?: "" - data.htmlContent = data.htmlContent.replace( - str, - """ - $str - // First define the function to send messages - window.sendMessageToPlugin = function(message) { - // Convert JS object to JSON string - // console.log("sendMessageToPlugin: ", message); - const msgStr = JSON.stringify(message); - ${getLatestWebView()?.jsQuery?.inject("msgStr")} - }; - - // Inject VSCode API mock - globalThis.acquireVsCodeApi = (function() { - let acquired = false; - - let state = JSON.parse('$encodedState'); - - if (typeof window !== "undefined" && !window.receiveMessageFromPlugin) { - console.log("VSCodeAPIWrapper: Setting up receiveMessageFromPlugin for IDEA plugin compatibility"); - window.receiveMessageFromPlugin = (message) => { - // console.log("receiveMessageFromPlugin received message:", JSON.stringify(message)); - // Create a new MessageEvent and dispatch it to maintain compatibility with existing code - const event = new MessageEvent("message", { - data: message, - }); - window.dispatchEvent(event); - }; - } + try { + stateMachine?.transitionTo(InitializationState.HTML_LOADING, "Loading HTML content") + + data.htmlContent = data.htmlContent.replace("/jetbrains/resources/kilocode/", "./") + data.htmlContent = data.htmlContent.replace("", "") + val encodedState = getLatestWebView()?.state.toString().replace("\"", "\\\"") + val mRst = """""".toRegex().find(data.htmlContent) + val str = mRst?.value ?: "" + data.htmlContent = data.htmlContent.replace( + str, + """ + $str + // First define the function to send messages + window.sendMessageToPlugin = function(message) { + // Convert JS object to JSON string + // console.log("sendMessageToPlugin: ", message); + const msgStr = JSON.stringify(message); + ${getLatestWebView()?.jsQuery?.inject("msgStr")} + }; - return () => { - if (acquired) { - throw new Error('An instance of the VS Code API has already been acquired'); + // Inject VSCode API mock + globalThis.acquireVsCodeApi = (function() { + let acquired = false; + + let state = JSON.parse('$encodedState'); + + if (typeof window !== "undefined" && !window.receiveMessageFromPlugin) { + console.log("VSCodeAPIWrapper: Setting up receiveMessageFromPlugin for IDEA plugin compatibility"); + window.receiveMessageFromPlugin = (message) => { + // console.log("receiveMessageFromPlugin received message:", JSON.stringify(message)); + // Create a new MessageEvent and dispatch it to maintain compatibility with existing code + const event = new MessageEvent("message", { + data: message, + }); + window.dispatchEvent(event); + }; } - acquired = true; - return Object.freeze({ - postMessage: function(message, transfer) { - // console.log("postMessage: ", message); - window.sendMessageToPlugin(message); - }, - setState: function(newState) { - state = newState; - window.sendMessageToPlugin(newState); - return newState; - }, - getState: function() { - return state; + + return () => { + if (acquired) { + throw new Error('An instance of the VS Code API has already been acquired'); } - }); - }; - })(); + acquired = true; + return Object.freeze({ + postMessage: function(message, transfer) { + // console.log("postMessage: ", message); + window.sendMessageToPlugin(message); + }, + setState: function(newState) { + state = newState; + window.sendMessageToPlugin(newState); + return newState; + }, + getState: function() { + return state; + } + }); + }; + })(); - // Clean up references to window parent for security - delete window.parent; - delete window.top; - delete window.frameElement; + // Clean up references to window parent for security + delete window.parent; + delete window.top; + delete window.frameElement; - console.log("VSCode API mock injected"); - """, - ) + console.log("VSCode API mock injected"); + """, + ) - logger.info("=== Received HTML update event ===") - logger.info("Handle: ${data.handle}") - logger.info("HTML length: ${data.htmlContent.length}") + logger.info("=== Received HTML update event ===") + logger.info("Handle: ${data.handle}") + logger.info("HTML length: ${data.htmlContent.length}") - val webView = getLatestWebView() + val webView = getLatestWebView() - if (webView != null) { - try { - // If HTTP server is running - if (resourceRootDir != null) { - logger.info("Resource root directory is set: ${resourceRootDir?.pathString}") + if (webView != null) { + try { + // If HTTP server is running + if (resourceRootDir != null) { + logger.info("Resource root directory is set: ${resourceRootDir?.pathString}") - // Generate unique file name for WebView - val filename = "index-${project.hashCode()}.html" + // Generate unique file name for WebView + val filename = "index-${project.hashCode()}.html" - // Save HTML content to file - val savedPath = saveHtmlToResourceDir(data.htmlContent, filename) - logger.info("HTML saved to: ${savedPath?.pathString}") + // Save HTML content to file + val savedPath = saveHtmlToResourceDir(data.htmlContent, filename) + logger.info("HTML saved to: ${savedPath?.pathString}") - // Use HTTP URL to load WebView content - val url = "http://localhost:12345/$filename" - logger.info("Loading WebView via HTTP URL: $url") + // Use HTTP URL to load WebView content + val url = "http://localhost:12345/$filename" + logger.info("Loading WebView via HTTP URL: $url") - webView.loadUrl(url) - } else { - // Fallback to direct HTML loading - logger.warn("Resource root directory is NULL - loading HTML content directly") - webView.loadHtml(data.htmlContent) - } - - logger.info("WebView HTML content updated: handle=${data.handle}") + webView.loadUrl(url) + } else { + // Fallback to direct HTML loading + logger.warn("Resource root directory is NULL - loading HTML content directly") + webView.loadHtml(data.htmlContent) + } - // If there is already a theme config, send it after content is loaded - if (currentThemeConfig != null) { - // Delay sending theme config to ensure HTML is loaded - ApplicationManager.getApplication().invokeLater { - try { - webView.sendThemeConfigToWebView(currentThemeConfig!!, this.bodyThemeClass) - } catch (e: Exception) { - logger.error("Failed to send theme config to WebView", e) + logger.info("WebView HTML content updated: handle=${data.handle}") + + // If there is already a theme config, send it after content is loaded + if (currentThemeConfig != null) { + // Set callback to inject theme after page loads + webView.setPageLoadCallback { + try { + logger.info("Page load callback triggered, injecting theme") + webView.sendThemeConfigToWebView(currentThemeConfig!!, this.bodyThemeClass) + } catch (e: Exception) { + logger.error("Failed to send theme config to WebView in page load callback", e) + } + } + + // Also try to inject immediately in case page is already loaded + if (webView.isPageLoaded()) { + try { + webView.sendThemeConfigToWebView(currentThemeConfig!!, this.bodyThemeClass) + } catch (e: Exception) { + logger.error("Failed to send theme config to WebView immediately", e) + } } } + } catch (e: Exception) { + logger.error("Failed to update WebView HTML content", e) + stateMachine?.transitionTo(InitializationState.FAILED, "HTML loading failed: ${e.message}") + // Fallback to direct HTML loading + webView.loadHtml(data.htmlContent) } - } catch (e: Exception) { - logger.error("Failed to update WebView HTML content", e) - // Fallback to direct HTML loading - webView.loadHtml(data.htmlContent) + } else { + logger.warn("WebView instance not found: handle=${data.handle}") + stateMachine?.transitionTo(InitializationState.FAILED, "WebView instance not found") } - } else { - logger.warn("WebView instance not found: handle=${data.handle}") + } catch (e: Exception) { + logger.error("Failed in updateWebViewHtml", e) + stateMachine?.transitionTo(InitializationState.FAILED, "updateWebViewHtml() exception: ${e.message}") + throw e } } @@ -516,6 +592,7 @@ class WebViewInstance( val state: Map, val project: Project, val extension: Map, + private val stateMachine: InitializationStateMachine? = null, ) : Disposable { private val logger = Logger.getInstance(WebViewInstance::class.java) @@ -537,6 +614,9 @@ class WebViewInstance( // Coroutine scope private val coroutineScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + // Synchronization for page load state + private val pageLoadLock = Any() + @Volatile private var isPageLoaded = false private var isInitialPageLoad = true @@ -544,6 +624,16 @@ class WebViewInstance( // Callback for page load completion private var pageLoadCallback: (() -> Unit)? = null + + // Theme injection retry mechanism + private var themeInjectionAttempts = 0 + private val maxThemeInjectionAttempts = 10 // Increased from 3 for slow machines + private val themeInjectionRetryDelay = 2000L // Increased from 1s to 2s for slow machines + private val themeInjectionBackoffMultiplier = 1.5 // Exponential backoff multiplier + + // Track if initial theme injection has completed + @Volatile + private var initialThemeInjectionComplete = false init { setupJSBridge() @@ -555,17 +645,23 @@ class WebViewInstance( * Send theme config to the specified WebView instance */ fun sendThemeConfigToWebView(themeConfig: JsonObject, bodyThemeClass: String) { - currentThemeConfig = themeConfig - this.bodyThemeClass = bodyThemeClass if (isDisposed) { logger.warn("WebView has been disposed, cannot send theme config") return } - if (!isPageLoaded) { - logger.debug("WebView page not yet loaded, theme will be injected after page load") - return + + // Always store the theme config, even if page isn't loaded yet + currentThemeConfig = themeConfig + this.bodyThemeClass = bodyThemeClass + logger.debug("Theme config stored for WebView($viewId), will inject when page loads") + + synchronized(pageLoadLock) { + if (!isPageLoaded) { + logger.debug("WebView page not yet loaded, theme will be injected after page load") + return + } + injectTheme() } - injectTheme() } /** @@ -573,7 +669,9 @@ class WebViewInstance( * @return true if page is loaded, false otherwise */ fun isPageLoaded(): Boolean { - return isPageLoaded + synchronized(pageLoadLock) { + return isPageLoaded + } } /** @@ -586,9 +684,57 @@ class WebViewInstance( private fun injectTheme() { if (currentThemeConfig == null) { + logger.warn("Cannot inject theme: currentThemeConfig is null for WebView($viewId)") + return + } + logger.info("Starting theme injection for WebView($viewId)") + + // Check if we're in a terminal state + val currentState = stateMachine?.getCurrentState() + if (currentState == InitializationState.COMPLETE || + currentState == InitializationState.FAILED) { + logger.debug("Skipping theme state transitions, already in terminal state: $currentState") + + // Still inject the theme (for theme changes), but don't update state machine + injectThemeWithoutStateTransitions() return } + + // Check if page is loaded with synchronization + synchronized(pageLoadLock) { + if (!isPageLoaded) { + if (themeInjectionAttempts < maxThemeInjectionAttempts) { + themeInjectionAttempts++ + // Calculate exponential backoff delay + val delay = (themeInjectionRetryDelay * Math.pow(themeInjectionBackoffMultiplier, (themeInjectionAttempts - 1).toDouble())).toLong() + logger.debug("Page not loaded, scheduling theme injection retry (attempt $themeInjectionAttempts/$maxThemeInjectionAttempts, delay: ${delay}ms)") + + // Schedule retry with exponential backoff + Timer().schedule(object : TimerTask() { + override fun run() { + injectTheme() + } + }, delay) + } else { + // Graceful degradation: continue without theme instead of failing + logger.warn("Max theme injection attempts ($maxThemeInjectionAttempts) reached, continuing without theme") + stateMachine?.transitionTo(InitializationState.COMPLETE, "Initialization complete (theme injection skipped)") + initialThemeInjectionComplete = true + } + return + } + + // Reset attempts on successful injection + themeInjectionAttempts = 0 + } + try { + // Only transition states during initial theme injection + val shouldTransitionStates = !initialThemeInjectionComplete + + if (shouldTransitionStates) { + stateMachine?.transitionTo(InitializationState.THEME_INJECTING, "Injecting theme") + } var cssContent: String? = null // Get cssContent from themeConfig and save, then remove from object @@ -603,10 +749,15 @@ class WebViewInstance( if (cssContent != null) { val injectThemeScript = """ (function() { + // Check if already injected at the top level + if (window.__cssVariablesInjected) { + console.log("CSS variables already injected, skipping"); + return; + } + // Set flag immediately to prevent race conditions + window.__cssVariablesInjected = true; + function injectCSSVariables() { - if (window.__cssVariablesInjected) { - return; - } if(document.documentElement) { // Convert cssContent to style attribute of html tag try { @@ -754,7 +905,6 @@ class WebViewInstance( } `; console.log("Default style injected to id=_defaultStyles"); - window.__cssVariablesInjected = true; } } else { // If html tag does not exist yet, wait for DOM to load and try again @@ -801,8 +951,249 @@ class WebViewInstance( postMessageToWebView(message) logger.info("Theme config has been sent to WebView") } + + if (shouldTransitionStates) { + stateMachine?.transitionTo(InitializationState.THEME_INJECTED, "Theme injected") + stateMachine?.transitionTo(InitializationState.COMPLETE, "Initialization complete") + initialThemeInjectionComplete = true + } else { + logger.debug("Theme injected (runtime theme change, no state transitions)") + } } catch (e: Exception) { logger.error("Failed to send theme config to WebView", e) + if (!initialThemeInjectionComplete) { + stateMachine?.transitionTo(InitializationState.FAILED, "Theme injection failed: ${e.message}") + } + } + } + + /** + * Inject theme without state machine transitions (for runtime theme changes) + */ + private fun injectThemeWithoutStateTransitions() { + if (currentThemeConfig == null) { + return + } + + try { + var cssContent: String? = null + + // Get cssContent from themeConfig and save, then remove from object + if (currentThemeConfig!!.has("cssContent")) { + cssContent = currentThemeConfig!!.get("cssContent").asString + // Create a copy of themeConfig to modify without affecting the original object + val themeConfigCopy = currentThemeConfig!!.deepCopy() + // Remove cssContent property from the copy + themeConfigCopy.remove("cssContent") + + // Inject CSS variables into WebView + if (cssContent != null) { + val injectThemeScript = """ + (function() { + // Check if already injected at the top level + if (window.__cssVariablesInjected) { + console.log("CSS variables already injected, skipping"); + return; + } + // Set flag immediately to prevent race conditions + window.__cssVariablesInjected = true; + + function injectCSSVariables() { + if(document.documentElement) { + // Convert cssContent to style attribute of html tag + try { + // Extract CSS variables (format: --name:value;) + const cssLines = `$cssContent`.split('\n'); + const cssVariables = []; + + // Process each line, extract CSS variable declarations + for (const line of cssLines) { + const trimmedLine = line.trim(); + // Skip comments and empty lines + if (trimmedLine.startsWith('/*') || trimmedLine.startsWith('*') || trimmedLine.startsWith('*/') || trimmedLine === '') { + continue; + } + // Extract CSS variable part + if (trimmedLine.startsWith('--')) { + cssVariables.push(trimmedLine); + } + } + + // Merge extracted CSS variables into style attribute string + const styleAttrValue = cssVariables.join(' '); + + // Set as style attribute of html tag + document.documentElement.setAttribute('style', styleAttrValue); + console.log("CSS variables set as style attribute of HTML tag"); + + // Add theme class to body element for styled-components compatibility + // Remove existing theme classes + document.body.classList.remove('vscode-dark', 'vscode-light'); + + // Add appropriate theme class based on current theme + document.body.classList.add('$bodyThemeClass'); + console.log("Added theme class to body: $bodyThemeClass"); + } catch (error) { + console.error("Error processing CSS variables and theme classes:", error); + } + + // Keep original default style injection logic + if(document.head) { + // Inject default theme style into head, use id="_defaultStyles" + let defaultStylesElement = document.getElementById('_defaultStyles'); + if (!defaultStylesElement) { + defaultStylesElement = document.createElement('style'); + defaultStylesElement.id = '_defaultStyles'; + document.head.appendChild(defaultStylesElement); + } + + // Add default_themes.css content + defaultStylesElement.textContent = ` + html { + background: var(--vscode-sideBar-background); + scrollbar-color: var(--vscode-scrollbarSlider-background) var(--vscode-sideBar-background); + } + + body { + overscroll-behavior-x: none; + background-color: transparent; + color: var(--vscode-editor-foreground); + font-family: var(--vscode-font-family); + font-weight: var(--vscode-font-weight); + font-size: var(--vscode-font-size); + margin: 0; + padding: 0 20px; + } + + img, video { + max-width: 100%; + max-height: 100%; + } + + a, a code { + color: var(--vscode-textLink-foreground); + } + + p > a { + text-decoration: var(--text-link-decoration); + } + + a:hover { + color: var(--vscode-textLink-activeForeground); + } + + a:focus, + input:focus, + select:focus, + textarea:focus { + outline: 1px solid -webkit-focus-ring-color; + outline-offset: -1px; + } + + code { + font-family: var(--monaco-monospace-font); + color: var(--vscode-textPreformat-foreground); + background-color: var(--vscode-textPreformat-background); + padding: 1px 3px; + border-radius: 4px; + } + + pre code { + padding: 0; + } + + blockquote { + background: var(--vscode-textBlockQuote-background); + border-color: var(--vscode-textBlockQuote-border); + } + + kbd { + background-color: var(--vscode-keybindingLabel-background); + color: var(--vscode-keybindingLabel-foreground); + border-style: solid; + border-width: 1px; + border-radius: 3px; + border-color: var(--vscode-keybindingLabel-border); + border-bottom-color: var(--vscode-keybindingLabel-bottomBorder); + box-shadow: inset 0 -1px 0 var(--vscode-widget-shadow); + vertical-align: middle; + padding: 1px 3px; + } + + ::-webkit-scrollbar { + width: 10px; + height: 10px; + } + + ::-webkit-scrollbar-corner { + background-color: var(--vscode-editor-background); + } + + ::-webkit-scrollbar-thumb { + background-color: var(--vscode-scrollbarSlider-background); + } + ::-webkit-scrollbar-thumb:hover { + background-color: var(--vscode-scrollbarSlider-hoverBackground); + } + ::-webkit-scrollbar-thumb:active { + background-color: var(--vscode-scrollbarSlider-activeBackground); + } + ::highlight(find-highlight) { + background-color: var(--vscode-editor-findMatchHighlightBackground); + } + ::highlight(current-find-highlight) { + background-color: var(--vscode-editor-findMatchBackground); + } + `; + console.log("Default style injected to id=_defaultStyles"); + } + } else { + // If html tag does not exist yet, wait for DOM to load and try again + setTimeout(injectCSSVariables, 10); + } + } + // If document is already loaded + if (document.readyState === 'complete' || document.readyState === 'interactive') { + console.log("Document loaded, inject CSS variables immediately"); + injectCSSVariables(); + } else { + // Otherwise wait for DOMContentLoaded event + console.log("Document not loaded, waiting for DOMContentLoaded event"); + document.addEventListener('DOMContentLoaded', injectCSSVariables); + } + })() + """.trimIndent() + + logger.debug("Injecting theme style into WebView($viewId) without state transitions, size: ${cssContent.length} bytes") + executeJavaScript(injectThemeScript) + } + + // Pass the theme config without cssContent via message + val themeConfigJson = gson.toJson(themeConfigCopy) + val message = """ + { + "type": "theme", + "text": "${themeConfigJson.replace("\"", "\\\"")}" + } + """.trimIndent() + + postMessageToWebView(message) + logger.debug("Theme config without cssContent has been sent to WebView (runtime theme change)") + } else { + // If there is no cssContent, send the original config directly + val themeConfigJson = gson.toJson(currentThemeConfig) + val message = """ + { + "type": "theme", + "text": "${themeConfigJson.replace("\"", "\\\"")}" + } + """.trimIndent() + + postMessageToWebView(message) + logger.debug("Theme config has been sent to WebView (runtime theme change)") + } + } catch (e: Exception) { + logger.error("Failed to inject theme without state transitions", e) } } @@ -836,13 +1227,47 @@ class WebViewInstance( */ fun postMessageToWebView(message: String) { if (!isDisposed) { - // Send message to WebView via JavaScript function + // Send message to WebView via JavaScript function with retry mechanism val script = """ - if (window.receiveMessageFromPlugin) { - window.receiveMessageFromPlugin($message); - } else { - console.warn("receiveMessageFromPlugin not available"); - } + (function() { + function sendMessage() { + if (window.receiveMessageFromPlugin) { + window.receiveMessageFromPlugin($message); + return true; + } + return false; + } + + // Try to send immediately + if (sendMessage()) { + return; + } + + // If not available, retry with exponential backoff + let attempts = 0; + const maxAttempts = 10; + const baseDelay = 50; // Start with 50ms + + function retryWithBackoff() { + if (attempts >= maxAttempts) { + console.warn("receiveMessageFromPlugin not available after " + maxAttempts + " attempts"); + return; + } + + attempts++; + const delay = baseDelay * Math.pow(1.5, attempts - 1); + + setTimeout(function() { + if (sendMessage()) { + console.log("Message sent successfully after " + attempts + " attempts"); + } else { + retryWithBackoff(); + } + }, delay); + } + + retryWithBackoff(); + })(); """.trimIndent() executeJavaScript(script) } @@ -895,8 +1320,10 @@ class WebViewInstance( transitionType: CefRequest.TransitionType?, ) { logger.info("WebView started loading: ${frame?.url}, transition type: $transitionType") - isPageLoaded = false - isInitialPageLoad = true + synchronized(pageLoadLock) { + isPageLoaded = false + isInitialPageLoad = true + } } override fun onLoadEnd( @@ -905,12 +1332,18 @@ class WebViewInstance( httpStatusCode: Int, ) { logger.info("WebView finished loading: ${frame?.url}, status code: $httpStatusCode") - isPageLoaded = true - - if (isInitialPageLoad) { - injectTheme() - pageLoadCallback?.invoke() - isInitialPageLoad = false + + synchronized(pageLoadLock) { + // Only process initial page load once + if (isInitialPageLoad) { + isInitialPageLoad = false + isPageLoaded = true + stateMachine?.transitionTo(InitializationState.HTML_LOADED, "HTML loaded") + injectTheme() + pageLoadCallback?.invoke() + } else { + logger.debug("Ignoring subsequent onLoadEnd event (not initial page load)") + } } } @@ -921,7 +1354,8 @@ class WebViewInstance( errorText: String?, failedUrl: String?, ) { - logger.info("WebView load error: $failedUrl, error code: $errorCode, error message: $errorText") + logger.error("WebView load error: $failedUrl, error code: $errorCode, error message: $errorText") + stateMachine?.transitionTo(InitializationState.FAILED, "HTML load error: $errorCode - $errorText") } }, browser.cefBrowser, @@ -1002,7 +1436,29 @@ class WebViewInstance( fun executeJavaScript(script: String) { if (!isDisposed) { logger.info("WebView executing JavaScript, script length: ${script.length}") - browser.cefBrowser.executeJavaScript(script, browser.cefBrowser.url, 0) + try { + // Check if JCEF browser is initialized before executing JavaScript + val url = browser.cefBrowser.url + if (url == null || url.isEmpty()) { + logger.warn("JCEF browser not fully initialized (URL is null/empty), deferring JavaScript execution") + // Retry after a short delay + Timer().schedule(object : TimerTask() { + override fun run() { + executeJavaScript(script) + } + }, 100) + return + } + browser.cefBrowser.executeJavaScript(script, url, 0) + } catch (e: Exception) { + logger.error("Failed to execute JavaScript, will retry", e) + // Retry after a short delay + Timer().schedule(object : TimerTask() { + override fun run() { + executeJavaScript(script) + } + }, 100) + } } } diff --git a/jetbrains/plugin/src/test/kotlin/ai/kilocode/jetbrains/core/InitializationHealthCheckTest.kt b/jetbrains/plugin/src/test/kotlin/ai/kilocode/jetbrains/core/InitializationHealthCheckTest.kt new file mode 100644 index 00000000000..f4732436a85 --- /dev/null +++ b/jetbrains/plugin/src/test/kotlin/ai/kilocode/jetbrains/core/InitializationHealthCheckTest.kt @@ -0,0 +1,125 @@ +package ai.kilocode.jetbrains.core + +import org.junit.Assert.* +import org.junit.Before +import org.junit.Test + +class InitializationHealthCheckTest { + private lateinit var stateMachine: InitializationStateMachine + private lateinit var healthCheck: InitializationHealthCheck + + @Before + fun setUp() { + stateMachine = InitializationStateMachine() + healthCheck = InitializationHealthCheck(stateMachine) + } + + @Test + fun testHealthyStatusForNormalInitialization() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + val status = healthCheck.checkHealth() + assertEquals(InitializationHealthCheck.HealthStatus.HEALTHY, status) + } + + @Test + fun testFailedStatusWhenInitializationFails() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.FAILED, "test failure") + + val status = healthCheck.checkHealth() + assertEquals(InitializationHealthCheck.HealthStatus.FAILED, status) + } + + @Test + fun testStuckStatusForLongRunningState() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + + // Wait longer than the max duration for SOCKET_CONNECTING (20 seconds) + // For testing, we'll just verify the logic exists + // In a real scenario, this would require mocking time or waiting + + // The health check should detect stuck states + // This is a basic test to ensure the method works + val status = healthCheck.checkHealth() + // Should be HEALTHY since we just transitioned + assertEquals(InitializationHealthCheck.HealthStatus.HEALTHY, status) + } + + @Test + fun testSuggestionsForFailedState() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.FAILED, "test failure") + + val suggestions = healthCheck.getSuggestions() + assertFalse(suggestions.isEmpty()) + assertTrue(suggestions.any { it.contains("failed") || it.contains("Failed") }) + } + + @Test + fun testSuggestionsForSocketConnectingState() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + + // Simulate stuck state by checking suggestions + // In a real stuck scenario, suggestions would be provided + val suggestions = healthCheck.getSuggestions() + // Should be empty for healthy state + assertTrue(suggestions.isEmpty()) + } + + @Test + fun testDiagnosticReportIncludesStatusAndState() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + val report = healthCheck.getDiagnosticReport() + + assertTrue(report.contains("Health Check")) + assertTrue(report.contains("Status:")) + assertTrue(report.contains("Current State:")) + assertTrue(report.contains("SOCKET_CONNECTED")) + } + + @Test + fun testDiagnosticReportIncludesSuggestionsWhenStuck() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.FAILED, "test failure") + + val report = healthCheck.getDiagnosticReport() + + assertTrue(report.contains("Suggestions:")) + } + + @Test + fun testHealthyStateHasNoSuggestions() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + val suggestions = healthCheck.getSuggestions() + assertTrue(suggestions.isEmpty()) + } + + @Test + fun testDifferentStatesHaveAppropriateSuggestions() { + // Test HTML_LOADING state suggestions + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + stateMachine.transitionTo(InitializationState.READY_RECEIVED, "test") + stateMachine.transitionTo(InitializationState.INIT_DATA_SENT, "test") + stateMachine.transitionTo(InitializationState.INITIALIZED_RECEIVED, "test") + stateMachine.transitionTo(InitializationState.RPC_CREATING, "test") + stateMachine.transitionTo(InitializationState.RPC_CREATED, "test") + stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATING, "test") + stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATED, "test") + stateMachine.transitionTo(InitializationState.WEBVIEW_REGISTERING, "test") + stateMachine.transitionTo(InitializationState.WEBVIEW_REGISTERED, "test") + stateMachine.transitionTo(InitializationState.WEBVIEW_RESOLVING, "test") + stateMachine.transitionTo(InitializationState.WEBVIEW_RESOLVED, "test") + stateMachine.transitionTo(InitializationState.HTML_LOADING, "test") + + // For a healthy state, no suggestions + val suggestions = healthCheck.getSuggestions() + assertTrue(suggestions.isEmpty()) + } +} diff --git a/jetbrains/plugin/src/test/kotlin/ai/kilocode/jetbrains/core/InitializationStateMachineTest.kt b/jetbrains/plugin/src/test/kotlin/ai/kilocode/jetbrains/core/InitializationStateMachineTest.kt new file mode 100644 index 00000000000..dece08d961f --- /dev/null +++ b/jetbrains/plugin/src/test/kotlin/ai/kilocode/jetbrains/core/InitializationStateMachineTest.kt @@ -0,0 +1,233 @@ +package ai.kilocode.jetbrains.core + +import org.junit.Assert.* +import org.junit.Before +import org.junit.Test +import java.util.concurrent.TimeUnit + +class InitializationStateMachineTest { + private lateinit var stateMachine: InitializationStateMachine + + @Before + fun setUp() { + stateMachine = InitializationStateMachine() + } + + @Test + fun testInitialStateIsNotStarted() { + assertEquals(InitializationState.NOT_STARTED, stateMachine.getCurrentState()) + } + + @Test + fun testValidStateTransitions() { + assertTrue(stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test")) + assertEquals(InitializationState.SOCKET_CONNECTING, stateMachine.getCurrentState()) + + assertTrue(stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test")) + assertEquals(InitializationState.SOCKET_CONNECTED, stateMachine.getCurrentState()) + } + + @Test + fun testTransitionToFailedFromAnyState() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + assertTrue(stateMachine.transitionTo(InitializationState.FAILED, "test failure")) + assertEquals(InitializationState.FAILED, stateMachine.getCurrentState()) + } + + @Test + fun testWaitForStateCompletesWhenStateIsReached() { + val future = stateMachine.waitForState(InitializationState.SOCKET_CONNECTED) + + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + // Should not throw + future.get(1, TimeUnit.SECONDS) + } + + @Test + fun testWaitForStateReturnsImmediatelyIfAlreadyAtTargetState() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + val future = stateMachine.waitForState(InitializationState.SOCKET_CONNECTED) + + assertTrue(future.isDone) + // Should not throw + future.get(100, TimeUnit.MILLISECONDS) + } + + @Test + fun testStateDurationTracking() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + Thread.sleep(100) // Wait a bit + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + val duration = stateMachine.getStateDuration(InitializationState.SOCKET_CONNECTING) + assertNotNull(duration) + assertTrue("Duration should be at least 100ms, was $duration", duration!! >= 100) + } + + @Test + fun testGenerateReportIncludesStateInformation() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + val report = stateMachine.generateReport() + + assertTrue(report.contains("Current State: SOCKET_CONNECTED")) + assertTrue(report.contains("SOCKET_CONNECTING")) + assertTrue(report.contains("SOCKET_CONNECTED")) + } + + @Test + fun testStateListenersAreNotified() { + var listenerCalled = false + var receivedState: InitializationState? = null + + stateMachine.onStateReached(InitializationState.SOCKET_CONNECTED) { state -> + listenerCalled = true + receivedState = state + } + + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + assertTrue(listenerCalled) + assertEquals(InitializationState.SOCKET_CONNECTED, receivedState) + } + + @Test + fun testListenerCalledImmediatelyIfAlreadyAtTargetState() { + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + var listenerCalled = false + stateMachine.onStateReached(InitializationState.SOCKET_CONNECTED) { + listenerCalled = true + } + + assertTrue(listenerCalled) + } + + @Test + fun testSlowTransitionWarningThreshold() { + // This test verifies that the expected duration method exists and returns reasonable values + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + Thread.sleep(100) + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + + // The transition should complete without errors + assertEquals(InitializationState.SOCKET_CONNECTED, stateMachine.getCurrentState()) + } + + @Test + fun testInvalidStateTransitionIsRejected() { + // Transition to HTML_LOADED state + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + stateMachine.transitionTo(InitializationState.READY_RECEIVED, "test") + stateMachine.transitionTo(InitializationState.INIT_DATA_SENT, "test") + stateMachine.transitionTo(InitializationState.INITIALIZED_RECEIVED, "test") + stateMachine.transitionTo(InitializationState.RPC_CREATING, "test") + stateMachine.transitionTo(InitializationState.RPC_CREATED, "test") + stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATING, "test") + stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATED, "test") + stateMachine.transitionTo(InitializationState.WEBVIEW_REGISTERING, "test") + stateMachine.transitionTo(InitializationState.WEBVIEW_REGISTERED, "test") + stateMachine.transitionTo(InitializationState.WEBVIEW_RESOLVING, "test") + stateMachine.transitionTo(InitializationState.WEBVIEW_RESOLVED, "test") + stateMachine.transitionTo(InitializationState.HTML_LOADING, "test") + stateMachine.transitionTo(InitializationState.HTML_LOADED, "test") + + assertEquals(InitializationState.HTML_LOADED, stateMachine.getCurrentState()) + + // Attempting to transition to HTML_LOADED again should succeed (idempotent) + val result = stateMachine.transitionTo(InitializationState.HTML_LOADED, "duplicate transition") + assertTrue("Idempotent transition should succeed", result) + + // State should remain HTML_LOADED + assertEquals(InitializationState.HTML_LOADED, stateMachine.getCurrentState()) + + // But transitioning to an invalid state (e.g., SOCKET_CONNECTING) should fail + val invalidResult = stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "invalid backward transition") + assertFalse("Invalid backward transition should fail", invalidResult) + assertEquals(InitializationState.HTML_LOADED, stateMachine.getCurrentState()) + } + + @Test + fun testIdempotentTransitions() { + // Transition to a state + assertTrue(stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test")) + assertEquals(InitializationState.SOCKET_CONNECTING, stateMachine.getCurrentState()) + + // Attempt same transition again - should succeed (idempotent) + assertTrue(stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "duplicate")) + assertEquals(InitializationState.SOCKET_CONNECTING, stateMachine.getCurrentState()) + } + + @Test + fun testTerminalStateProtection() { + // Transition through to COMPLETE + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + stateMachine.transitionTo(InitializationState.READY_RECEIVED, "test") + stateMachine.transitionTo(InitializationState.INIT_DATA_SENT, "test") + stateMachine.transitionTo(InitializationState.INITIALIZED_RECEIVED, "test") + stateMachine.transitionTo(InitializationState.RPC_CREATING, "test") + stateMachine.transitionTo(InitializationState.RPC_CREATED, "test") + stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATING, "test") + stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATED, "test") + stateMachine.transitionTo(InitializationState.COMPLETE, "test") + + assertEquals(InitializationState.COMPLETE, stateMachine.getCurrentState()) + + // Attempt to transition from COMPLETE - should fail + assertFalse(stateMachine.transitionTo(InitializationState.HTML_LOADED, "after complete")) + assertEquals(InitializationState.COMPLETE, stateMachine.getCurrentState()) + + // Attempt to transition to COMPLETE again - should succeed (idempotent) + assertTrue(stateMachine.transitionTo(InitializationState.COMPLETE, "duplicate complete")) + assertEquals(InitializationState.COMPLETE, stateMachine.getCurrentState()) + } + + @Test + fun testRaceConditionScenario() { + // Simulate the race condition: WEBVIEW_REGISTERING -> EXTENSION_ACTIVATED + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "test") + stateMachine.transitionTo(InitializationState.READY_RECEIVED, "test") + stateMachine.transitionTo(InitializationState.INIT_DATA_SENT, "test") + stateMachine.transitionTo(InitializationState.INITIALIZED_RECEIVED, "test") + stateMachine.transitionTo(InitializationState.RPC_CREATING, "test") + stateMachine.transitionTo(InitializationState.RPC_CREATED, "test") + stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATING, "test") + + // Webview registration starts before activation completes + assertTrue(stateMachine.transitionTo(InitializationState.WEBVIEW_REGISTERING, "webview starts")) + assertEquals(InitializationState.WEBVIEW_REGISTERING, stateMachine.getCurrentState()) + + // Extension activation completes - this should now be allowed + assertTrue(stateMachine.transitionTo(InitializationState.EXTENSION_ACTIVATED, "activation completes")) + assertEquals(InitializationState.EXTENSION_ACTIVATED, stateMachine.getCurrentState()) + } + + @Test + fun testFailedStateProtection() { + // Transition to FAILED state + stateMachine.transitionTo(InitializationState.SOCKET_CONNECTING, "test") + stateMachine.transitionTo(InitializationState.FAILED, "test failure") + + assertEquals(InitializationState.FAILED, stateMachine.getCurrentState()) + + // Attempt to transition from FAILED - should fail + assertFalse(stateMachine.transitionTo(InitializationState.SOCKET_CONNECTED, "after failed")) + assertEquals(InitializationState.FAILED, stateMachine.getCurrentState()) + + // Attempt to transition to FAILED again - should succeed (idempotent) + assertTrue(stateMachine.transitionTo(InitializationState.FAILED, "duplicate failed")) + assertEquals(InitializationState.FAILED, stateMachine.getCurrentState()) + } +} diff --git a/packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts b/packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts index 04e15488088..7ff201978fe 100644 --- a/packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts +++ b/packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts @@ -33,6 +33,7 @@ describe("ExtensionChannel", () => { wrapperCode: null, wrapperVersion: null, machineId: null, + vscodeIsTelemetryEnabled: null, // kilocode_change end hostname: "test-host", } diff --git a/packages/cloud/src/bridge/__tests__/TaskChannel.test.ts b/packages/cloud/src/bridge/__tests__/TaskChannel.test.ts index 6119b4fe028..18478759f67 100644 --- a/packages/cloud/src/bridge/__tests__/TaskChannel.test.ts +++ b/packages/cloud/src/bridge/__tests__/TaskChannel.test.ts @@ -36,6 +36,7 @@ describe("TaskChannel", () => { wrapperCode: null, wrapperVersion: null, machineId: null, + vscodeIsTelemetryEnabled: null, // kilocode_change end hostname: "test-host", } diff --git a/packages/types/src/telemetry.ts b/packages/types/src/telemetry.ts index c021a3d5ff2..023ce48afa1 100644 --- a/packages/types/src/telemetry.ts +++ b/packages/types/src/telemetry.ts @@ -133,6 +133,7 @@ export const staticAppPropertiesSchema = z.object({ wrapperCode: z.string().nullable(), wrapperVersion: z.string().nullable(), machineId: z.string().nullable(), + vscodeIsTelemetryEnabled: z.boolean().nullable(), // kilocode_change end hostname: z.string().optional(), }) diff --git a/src/api/providers/__tests__/kilocode-openrouter.spec.ts b/src/api/providers/__tests__/kilocode-openrouter.spec.ts index 3e94c90e023..016eac8fb40 100644 --- a/src/api/providers/__tests__/kilocode-openrouter.spec.ts +++ b/src/api/providers/__tests__/kilocode-openrouter.spec.ts @@ -2,14 +2,27 @@ // npx vitest run src/api/providers/__tests__/kilocode-openrouter.spec.ts // Mock vscode first to avoid import errors -vitest.mock("vscode", () => ({})) +vitest.mock("vscode", () => ({ + env: { + uriScheme: "vscode", + language: "en", + uiKind: 1, + appName: "Visual Studio Code", + }, + version: "1.85.0", +})) import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" import { KilocodeOpenrouterHandler } from "../kilocode-openrouter" import { ApiHandlerOptions } from "../../../shared/api" -import { X_KILOCODE_TASKID, X_KILOCODE_ORGANIZATIONID, X_KILOCODE_PROJECTID } from "../../../shared/kilocode/headers" +import { + X_KILOCODE_TASKID, + X_KILOCODE_ORGANIZATIONID, + X_KILOCODE_PROJECTID, + X_KILOCODE_EDITORNAME, +} from "../../../shared/kilocode/headers" import { streamSse } from "../../../services/continuedev/core/fetch/stream" // Mock the stream module @@ -69,6 +82,7 @@ describe("KilocodeOpenrouterHandler", () => { expect(result).toEqual({ headers: { [X_KILOCODE_TASKID]: "test-task-id", + [X_KILOCODE_EDITORNAME]: "Visual Studio Code 1.85.0", }, }) }) @@ -84,6 +98,7 @@ describe("KilocodeOpenrouterHandler", () => { headers: { [X_KILOCODE_TASKID]: "test-task-id", [X_KILOCODE_ORGANIZATIONID]: "test-org-id", + [X_KILOCODE_EDITORNAME]: "Visual Studio Code 1.85.0", }, }) }) @@ -104,6 +119,7 @@ describe("KilocodeOpenrouterHandler", () => { [X_KILOCODE_TASKID]: "test-task-id", [X_KILOCODE_ORGANIZATIONID]: "test-org-id", [X_KILOCODE_PROJECTID]: "https://github.com/user/repo.git", + [X_KILOCODE_EDITORNAME]: "Visual Studio Code 1.85.0", }, }) }) @@ -124,6 +140,7 @@ describe("KilocodeOpenrouterHandler", () => { [X_KILOCODE_TASKID]: "test-task-id", [X_KILOCODE_PROJECTID]: "https://github.com/user/repo.git", [X_KILOCODE_ORGANIZATIONID]: "test-org-id", + [X_KILOCODE_EDITORNAME]: "Visual Studio Code 1.85.0", }, }) }) @@ -139,6 +156,7 @@ describe("KilocodeOpenrouterHandler", () => { headers: { [X_KILOCODE_TASKID]: "test-task-id", [X_KILOCODE_ORGANIZATIONID]: "test-org-id", + [X_KILOCODE_EDITORNAME]: "Visual Studio Code 1.85.0", }, }) expect(result?.headers).not.toHaveProperty(X_KILOCODE_PROJECTID) @@ -155,16 +173,21 @@ describe("KilocodeOpenrouterHandler", () => { expect(result).toEqual({ headers: { [X_KILOCODE_TASKID]: "test-task-id", + [X_KILOCODE_EDITORNAME]: "Visual Studio Code 1.85.0", }, }) expect(result?.headers).not.toHaveProperty(X_KILOCODE_PROJECTID) }) - it("returns undefined when no headers are needed", () => { + it("returns only editorName header when no other headers are needed", () => { const handler = new KilocodeOpenrouterHandler(mockOptions) const result = handler.customRequestOptions() - expect(result).toBeUndefined() + expect(result).toEqual({ + headers: { + [X_KILOCODE_EDITORNAME]: "Visual Studio Code 1.85.0", + }, + }) }) }) @@ -209,6 +232,7 @@ describe("KilocodeOpenrouterHandler", () => { [X_KILOCODE_TASKID]: "test-task-id", [X_KILOCODE_PROJECTID]: "https://github.com/user/repo.git", [X_KILOCODE_ORGANIZATIONID]: "test-org-id", + [X_KILOCODE_EDITORNAME]: "Visual Studio Code 1.85.0", }), }), // kilocode_change end diff --git a/src/api/providers/kilocode-openrouter.ts b/src/api/providers/kilocode-openrouter.ts index 2fea1dc9709..6bcb55ffa27 100644 --- a/src/api/providers/kilocode-openrouter.ts +++ b/src/api/providers/kilocode-openrouter.ts @@ -14,10 +14,12 @@ import { X_KILOCODE_TASKID, X_KILOCODE_PROJECTID, X_KILOCODE_TESTER, + X_KILOCODE_EDITORNAME, } from "../../shared/kilocode/headers" import { KILOCODE_TOKEN_REQUIRED_ERROR } from "../../shared/kilocode/errorUtils" import { DEFAULT_HEADERS } from "./constants" import { streamSse } from "../../services/continuedev/core/fetch/stream" +import { getEditorNameHeader } from "../../core/kilocode/wrapper" /** * A custom OpenRouter handler that overrides the getModel function @@ -52,7 +54,9 @@ export class KilocodeOpenrouterHandler extends OpenRouterHandler { } override customRequestOptions(metadata?: ApiHandlerCreateMessageMetadata) { - const headers: Record = {} + const headers: Record = { + [X_KILOCODE_EDITORNAME]: getEditorNameHeader(), + } if (metadata?.taskId) { headers[X_KILOCODE_TASKID] = metadata.taskId diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index 0ea50ba0bc8..3419bba4b27 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -1227,9 +1227,11 @@ async function checkpointSaveAndMark(task: Task) { return } try { - // kilocode_change: order changed to prevent second execution while still awaiting the save + // kilocode_change start: order changed to prevent second execution while still awaiting the save task.currentStreamingDidCheckpoint = true - await task.checkpointSave(true) + // kilocode_change: don't force empty checkpoints - only create checkpoint if there are actual file changes + await task.checkpointSave(false) + // kilocode_change end } catch (error) { console.error(`[Task#presentAssistantMessage] Error saving checkpoint: ${error.message}`, error) } diff --git a/src/core/kilocode/wrapper.ts b/src/core/kilocode/wrapper.ts index 50efbd189eb..42e57745af6 100644 --- a/src/core/kilocode/wrapper.ts +++ b/src/core/kilocode/wrapper.ts @@ -31,3 +31,14 @@ export const getKiloCodeWrapperProperties = (): KiloCodeWrapperProperties => { kiloCodeWrapperJetbrains, } } + +export const getEditorNameHeader = () => { + const props = getKiloCodeWrapperProperties() + return ( + props.kiloCodeWrapped + ? [props.kiloCodeWrapperTitle, props.kiloCodeWrapperVersion] + : [vscode.env.appName, vscode.version] + ) + .filter(Boolean) + .join(" ") +} diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 02528d72466..9dbc296eefc 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -3307,6 +3307,7 @@ ${prompt} wrapperVersion: kiloCodeWrapperVersion, wrapperTitle: kiloCodeWrapperTitle, machineId: vscode.env.machineId, + vscodeIsTelemetryEnabled: vscode.env.isTelemetryEnabled, // kilocode_change end } } diff --git a/src/package.json b/src/package.json index 26a7bb26f93..687ea0d80ec 100644 --- a/src/package.json +++ b/src/package.json @@ -3,7 +3,7 @@ "displayName": "%extension.displayName%", "description": "%extension.description%", "publisher": "kilocode", - "version": "4.140.3", + "version": "4.141.0", "icon": "assets/icons/logo-outline-black.png", "galleryBanner": { "color": "#FFFFFF", diff --git a/src/shared/kilocode/headers.ts b/src/shared/kilocode/headers.ts index 9f7be47c0ef..080606622a8 100644 --- a/src/shared/kilocode/headers.ts +++ b/src/shared/kilocode/headers.ts @@ -2,4 +2,5 @@ export const X_KILOCODE_VERSION = "X-KiloCode-Version" export const X_KILOCODE_ORGANIZATIONID = "X-KiloCode-OrganizationId" export const X_KILOCODE_TASKID = "X-KiloCode-TaskId" export const X_KILOCODE_PROJECTID = "X-KiloCode-ProjectId" +export const X_KILOCODE_EDITORNAME = "X-KiloCode-EditorName" export const X_KILOCODE_TESTER = "X-KILOCODE-TESTER"
Node.js Version:$nodeVersion
CPU Architecture:$osArch
Operating System:$osName $osVersion
IDE Version:${appInfo.fullApplicationName} (build ${appInfo.build})