diff --git a/README.md b/README.md index 9f737591..3a48d7b6 100644 --- a/README.md +++ b/README.md @@ -60,18 +60,18 @@ npm run dev ``` > ### Windows Users -> +> > If you're developing on Windows, you should use the Windows-specific command: -> +> > ```bash > npm run dev:win > ``` -> +> > **Technical Reason:** Windows has two different asyncio event loop implementations: -> +> > - **SelectorEventLoop** (default): Uses select-based I/O and doesn't support subprocesses properly > - **ProactorEventLoop**: Uses I/O completion ports and fully supports subprocesses -> +> > Playwright requires subprocess support to launch browsers. When hot reloading is enabled, the default SelectorEventLoop is used, causing a `NotImplementedError` when Playwright tries to create a subprocess. > Reference Issue: https://github.com/steel-dev/surf.new/issues/32 diff --git a/api/index.py b/api/index.py index 69cf1aa3..2fe8c4a6 100644 --- a/api/index.py +++ b/api/index.py @@ -14,6 +14,8 @@ import asyncio import subprocess import re +import time +import json # 1) Import the Steel client try: @@ -25,7 +27,7 @@ load_dotenv(".env.local") app = FastAPI() -app.add_middleware(ProfilingMiddleware) # Uncomment this when profiling is not needed +app.add_middleware(ProfilingMiddleware) # Uncomment this when profiling is not needed STEEL_API_KEY = os.getenv("STEEL_API_KEY") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") STEEL_API_URL = os.getenv("STEEL_API_URL") @@ -148,8 +150,7 @@ async def on_disconnect(): ) # Use background=on_disconnect to catch client-aborted requests - response = StreamingResponse( - streaming_response, background=on_disconnect) + response = StreamingResponse(streaming_response, background=on_disconnect) response.headers["x-vercel-ai-data-stream"] = "v1" # response.headers["model_used"] = request.model_name return response @@ -162,8 +163,7 @@ async def on_disconnect(): "code": getattr(e, "code", 500), } } - raise HTTPException(status_code=getattr( - e, "code", 500), detail=error_response) + raise HTTPException(status_code=getattr(e, "code", 500), detail=error_response) @app.get("/api/agents", tags=["Agents"]) @@ -187,17 +187,19 @@ class OllamaModel(BaseModel): tag: str base_name: str + class OllamaModelsResponse(BaseModel): models: List[OllamaModel] + @app.get("/api/ollama/models", response_model=OllamaModelsResponse, tags=["Ollama"]) async def get_ollama_models(): """ Fetches available models from a local Ollama instance using the 'ollama list' command. - + Returns: A list of model objects with full tags and base names that can be used with Ollama. - + Example response: { "models": [ @@ -214,37 +216,29 @@ async def get_ollama_models(): """ try: result = subprocess.run( - ["ollama", "list"], - capture_output=True, - text=True, - check=True + ["ollama", "list"], capture_output=True, text=True, check=True ) - + models = [] - lines = result.stdout.strip().split('\n') - + lines = result.stdout.strip().split("\n") + if lines and "NAME" in lines[0] and "ID" in lines[0]: lines = lines[1:] - + for line in lines: if line.strip(): - parts = re.split(r'\s{2,}', line.strip()) + parts = re.split(r"\s{2,}", line.strip()) if parts and parts[0]: full_tag = parts[0] - base_name = full_tag.split(':')[0] if ':' in full_tag else full_tag - models.append({ - "tag": full_tag, - "base_name": base_name - }) - + base_name = full_tag.split(":")[0] if ":" in full_tag else full_tag + models.append({"tag": full_tag, "base_name": base_name}) + return {"models": models} except subprocess.CalledProcessError as e: raise HTTPException( - status_code=500, - detail=f"Failed to fetch Ollama models: {e.stderr}" + status_code=500, detail=f"Failed to fetch Ollama models: {e.stderr}" ) except Exception as e: raise HTTPException( - status_code=500, - detail=f"Error fetching Ollama models: {str(e)}" + status_code=500, detail=f"Error fetching Ollama models: {str(e)}" ) diff --git a/api/models.py b/api/models.py index edba5693..bb415e11 100644 --- a/api/models.py +++ b/api/models.py @@ -6,6 +6,7 @@ class ModelProvider(str, Enum): OPENAI = "openai" ANTHROPIC = "anthropic" ANTHROPIC_COMPUTER_USE = "anthropic_computer_use" + OPENAI_COMPUTER_USE = "openai_computer_use" GEMINI = "gemini" DEEPSEEK = "deepseek" OLLAMA = "ollama" @@ -79,6 +80,7 @@ def default_model(provider: ModelProvider) -> str: ModelProvider.OPENAI: "gpt-4o-mini", ModelProvider.ANTHROPIC: "claude-3-7-sonnet-latest", ModelProvider.ANTHROPIC_COMPUTER_USE: "claude-3-5-sonnet-20241022", + ModelProvider.OPENAI_COMPUTER_USE: "computer-use-preview", ModelProvider.GEMINI: "gemini-2.0-flash", ModelProvider.DEEPSEEK: "deepseek-chat", ModelProvider.OLLAMA: "llama3.3", diff --git a/api/plugins/__init__.py b/api/plugins/__init__.py index 717cd040..44d79104 100644 --- a/api/plugins/__init__.py +++ b/api/plugins/__init__.py @@ -13,8 +13,10 @@ from .base import base_agent from .claude_computer_use import claude_computer_use from .browser_use import browser_use_agent +from .openai_computer_use import openai_computer_use_agent from ..utils.types import AgentSettings -from .claude_computer_use.prompts import SYSTEM_PROMPT +from .claude_computer_use.prompts import SYSTEM_PROMPT as CLAUDE_SYSTEM_PROMPT +from .openai_computer_use.prompts import SYSTEM_PROMPT as OPENAI_SYSTEM_PROMPT # from .example_plugin import example_agent @@ -24,6 +26,7 @@ class WebAgentType(Enum): EXAMPLE = "example" CLAUDE_COMPUTER_USE = "claude_computer_use" BROWSER_USE = "browser_use_agent" + OPENAI_COMPUTER_USE = "openai_computer_use_agent" class SettingType(Enum): @@ -197,7 +200,7 @@ class SettingConfig(TypedDict): "agent_settings": { "system_prompt": { "type": SettingType.TEXTAREA.value, - "default": SYSTEM_PROMPT, + "default": CLAUDE_SYSTEM_PROMPT, "maxLength": 4000, "description": "System prompt for the agent", }, @@ -217,7 +220,62 @@ class SettingConfig(TypedDict): }, }, }, - + WebAgentType.OPENAI_COMPUTER_USE.value: { + "name": "OpenAI Computer Use", + "description": "Agent that uses OpenAI's Computer-Using Agent (CUA) via the /v1/responses API", + "supported_models": [ + { + "provider": ModelProvider.OPENAI_COMPUTER_USE.value, + "models": ["computer-use-preview"], + } + ], + "model_settings": { + "max_tokens": { + "type": SettingType.INTEGER.value, + "default": 3000, + "min": 1, + "max": 4096, + "description": "Maximum tokens for the responses endpoint", + }, + "temperature": { + "type": SettingType.FLOAT.value, + "default": 0.2, + "min": 0, + "max": 1, + "step": 0.05, + "description": "Optional temperature param for final assistant messages", + }, + }, + "agent_settings": { + "system_prompt": { + "type": SettingType.TEXTAREA.value, + "default": OPENAI_SYSTEM_PROMPT, + "maxLength": 4000, + "description": "Custom system prompt for the agent", + }, + "num_images_to_keep": { + "type": SettingType.INTEGER.value, + "default": 10, + "min": 1, + "max": 50, + "description": "Number of images to keep in memory", + }, + "wait_time_between_steps": { + "type": SettingType.INTEGER.value, + "default": 1, + "min": 0, + "max": 10, + "description": "Wait time between steps in seconds", + }, + "max_steps": { + "type": SettingType.INTEGER.value, + "default": 30, + "min": 10, + "max": 50, + "description": "Maximum number of steps the agent can take", + } + }, + }, } @@ -232,6 +290,8 @@ def get_web_agent( return claude_computer_use elif name == WebAgentType.BROWSER_USE: return browser_use_agent + elif name == WebAgentType.OPENAI_COMPUTER_USE: + return openai_computer_use_agent else: raise ValueError(f"Invalid agent type: {name}") diff --git a/api/plugins/openai_computer_use/__init__.py b/api/plugins/openai_computer_use/__init__.py new file mode 100644 index 00000000..95d4526c --- /dev/null +++ b/api/plugins/openai_computer_use/__init__.py @@ -0,0 +1,3 @@ +from .agent import openai_computer_use_agent + +__all__ = ["openai_computer_use_agent"] \ No newline at end of file diff --git a/api/plugins/openai_computer_use/agent.py b/api/plugins/openai_computer_use/agent.py new file mode 100644 index 00000000..18e77e26 --- /dev/null +++ b/api/plugins/openai_computer_use/agent.py @@ -0,0 +1,430 @@ +""" +OpenAI Computer Use Agent - Main Module + +This module contains the main agent function that orchestrates the interaction +between the OpenAI computer-use-preview model and the Steel browser automation. +""" + +import asyncio +import json +import logging +import os +import signal +from typing import AsyncIterator, Any, Dict, List, Mapping, Optional, Set + +from fastapi import HTTPException +from playwright.async_api import async_playwright +from langchain_core.messages import BaseMessage +from api.models import ModelConfig +from api.utils.types import AgentSettings +from api.utils.prompt import chat_dict_to_base_messages + +# Import from our own modules +from .config import ( + STEEL_API_KEY, + STEEL_API_URL, + STEEL_CONNECT_URL, + OPENAI_RESPONSES_URL, + VALID_OPENAI_CUA_MODELS, + DEFAULT_MAX_STEPS, + DEFAULT_WAIT_TIME_BETWEEN_STEPS, + DEFAULT_NUM_IMAGES_TO_KEEP, +) +from .prompts import SYSTEM_PROMPT +from .tools import _create_tools +from .steel_computer import SteelComputer +from .conversation_manager import ConversationManager +from .message_handler import MessageHandler + +logger = logging.getLogger("openai_computer_use") + + +async def openai_computer_use_agent( + model_config: ModelConfig, + agent_settings: AgentSettings, + history: List[Mapping[str, Any]], + session_id: str, + cancel_event: Optional[asyncio.Event] = None, +) -> AsyncIterator[Any]: + """ + OpenAI's computer-use-preview model integration, refactored for clarity. + + Steps: + 1. Validate model, create session & connect to browser + 2. Initialize ConversationManager + MessageHandler + 3. Main loop that: + - Prepares conversation => calls /v1/responses => processes items + - Yields messages or tool calls => executes tool calls => yields results + - Repeats until we get a final "assistant" item or exceed max steps + """ + # Keep track of background tasks we create + pending_tasks: Set[asyncio.Task] = set() + + # Helper to track and clean up tasks + def track_task(task: asyncio.Task) -> None: + pending_tasks.add(task) + task.add_done_callback(lambda t: pending_tasks.discard(t)) + + logger.info(f"Starting openai_computer_use_agent with session_id: {session_id}") + openai_api_key = model_config.api_key or os.getenv("OPENAI_API_KEY") + if not openai_api_key: + raise HTTPException(status_code=400, detail="No OPENAI_API_KEY configured") + + # Validate model name + model_name = model_config.model_name or "computer-use-preview-2025-02-04" + if model_name not in VALID_OPENAI_CUA_MODELS: + raise HTTPException( + status_code=400, + detail=f"Invalid model name: {model_name}. Must be one of: {VALID_OPENAI_CUA_MODELS}", + ) + + # Extract settings with defaults + max_steps = getattr(agent_settings, "max_steps", DEFAULT_MAX_STEPS) + wait_time = getattr( + agent_settings, "wait_time_between_steps", DEFAULT_WAIT_TIME_BETWEEN_STEPS + ) + num_images = getattr( + agent_settings, "num_images_to_keep", DEFAULT_NUM_IMAGES_TO_KEEP + ) + + # Create a Steel session + from steel import Steel + + client = Steel(steel_api_key=STEEL_API_KEY, base_url=STEEL_API_URL) + try: + session = client.sessions.retrieve(session_id) + logger.info(f"Successfully connected to Steel session: {session.id}") + logger.info(f"Session viewer URL: {session.session_viewer_url}") + yield "[OPENAI-CUA] Session loaded. Connecting to remote browser..." + except Exception as exc: + logger.error(f"Failed to retrieve Steel session: {exc}") + raise HTTPException(400, f"Failed to retrieve Steel session: {exc}") + + # Set up a handler for SIGINT (keyboard interrupt) to allow cleanup + original_sigint_handler = None + if hasattr(signal, "SIGINT"): + original_sigint_handler = signal.getsignal(signal.SIGINT) + + def sigint_handler(sig, frame): + logger.info("SIGINT received, preparing for shutdown") + if cancel_event: + cancel_event.set() + # Don't call the default handler yet - let cleanup run first + + signal.signal(signal.SIGINT, sigint_handler) + + # Connect to browser + steel_computer = None + playwright_instance = None + try: + # Create a shared cancel event if one wasn't provided + local_cancel_event = False + if cancel_event is None: + cancel_event = asyncio.Event() + local_cancel_event = True + + # Launch playwright + playwright_instance = await async_playwright().start() + try: + browser = await playwright_instance.chromium.connect_over_cdp( + f"{STEEL_CONNECT_URL}?apiKey={STEEL_API_KEY}&sessionId={session.id}" + ) + yield "[OPENAI-CUA] Playwright connected!" + except Exception as e: + logger.error(f"Failed to connect Playwright over CDP: {e}") + yield f"Error: could not connect to browser session (CDP) - {e}" + return + + # Initialize SteelComputer - this handles all browser management + steel_computer = await SteelComputer.create(browser) + + # Initialize MessageHandler and ConversationManager + msg_handler = MessageHandler(steel_computer) + conversation = ConversationManager(num_images_to_keep=num_images) + + # Load history + system prompt + system_prompt = agent_settings.system_prompt or SYSTEM_PROMPT + base_msgs = chat_dict_to_base_messages(history) + conversation.initialize_from_history(base_msgs, system_prompt=system_prompt) + + # Get viewport size for computer-preview tool + viewport_size = await steel_computer.get_viewport_size() + + # Setup model request parameters + headers = { + "Authorization": f"Bearer {openai_api_key}", + "Content-Type": "application/json", + # OpenAI "Beta" header for /v1/responses + "OpenAI-Beta": "responses=v1", + "Openai-Beta": "responses=v1", + } + + step_count = 0 + request_task = None + + # Main loop + try: + while True: + if cancel_event and cancel_event.is_set(): + logger.info("Cancel event detected, exiting agent loop") + yield "[OPENAI-CUA] Cancel event detected, stopping..." + break + + if step_count >= max_steps: + logger.info( + f"Reached maximum steps ({max_steps}), exiting agent loop" + ) + yield f"[OPENAI-CUA] Reached max steps ({max_steps}), stopping..." + break + + step_count += 1 + logger.info(f"Starting step {step_count}/{max_steps}") + + # Prepare the conversation for /v1/responses + items_for_model = conversation.prepare_for_model() + tools_for_model = _create_tools() + + # Update the display dimensions in the computer-preview tool + for tool in tools_for_model: + if tool.get("type") == "computer-preview": + tool["display_width"] = viewport_size["width"] + tool["display_height"] = viewport_size["height"] + break + + request_body = { + "model": model_name, + "input": items_for_model, + "tools": tools_for_model, + "truncation": "auto", + "reasoning": {"generate_summary": "concise"}, + } + + # Make the request + try: + logger.info("Sending request to OpenAI /v1/responses endpoint...") + + # Create a task for the request with a timeout + async def make_request(): + import aiohttp + + max_retries = 3 + retry_count = 0 + retry_delay = 1 # Start with 1 second delay + + while retry_count <= max_retries: + try: + async with aiohttp.ClientSession() as session: + async with session.post( + OPENAI_RESPONSES_URL, + json=request_body, + headers=headers, + timeout=aiohttp.ClientTimeout(total=120), + ) as resp: + if not resp.ok: + error_detail = "" + try: + error_json = await resp.json() + error_detail = json.dumps( + error_json, indent=2 + ) + except: + error_detail = await resp.text() + + logger.error( + f"OpenAI API error response ({resp.status}):" + ) + logger.error( + f"Response headers: {dict(resp.headers)}" + ) + logger.error( + f"Response body: {error_detail}" + ) + + # Retry only on 5xx (server) errors + if ( + 500 <= resp.status < 600 + and retry_count < max_retries + ): + retry_count += 1 + logger.info( + f"Retrying request (attempt {retry_count}/{max_retries}) after 5xx error" + ) + await asyncio.sleep(retry_delay) + # Exponential backoff + retry_delay *= 2 + continue + + # For other errors or if we've exhausted retries, raise the exception + resp.raise_for_status() + + return await resp.json() + except aiohttp.ClientConnectionError as e: + # Also retry on connection errors + if retry_count < max_retries: + retry_count += 1 + logger.info( + f"Connection error, retrying ({retry_count}/{max_retries}): {e}" + ) + await asyncio.sleep(retry_delay) + retry_delay *= 2 + continue + raise # Re-raise if we've exhausted retries + + # This should never be reached if properly handled above + raise RuntimeError("Unexpected exit from retry loop") + + # Create and track the request task + request_task = asyncio.create_task(make_request()) + track_task(request_task) + + # Wait for either the request to complete or cancellation + # Create a task that waits for cancellation + if cancel_event: + cancellation_task = asyncio.create_task(cancel_event.wait()) + track_task(cancellation_task) + + # Wait for either request to complete or cancellation + done, pending = await asyncio.wait( + [request_task, cancellation_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # If cancellation happened first + if cancellation_task in done: + # Cancel the request_task + if not request_task.done(): + request_task.cancel() + logger.info("Request cancelled due to cancel event") + yield "[OPENAI-CUA] Request cancelled..." + break + + # Otherwise, cancel the cancellation_task (no longer needed) + if not cancellation_task.done(): + cancellation_task.cancel() + else: + # Just wait for the request if no cancellation event + await request_task + + # Check if cancelled while we were waiting + if cancel_event and cancel_event.is_set(): + logger.info("Cancel event detected after request") + if not request_task.done(): + request_task.cancel() + yield "[OPENAI-CUA] Request cancelled..." + break + + # Get the result (will raise if cancelled) + data = request_task.result() + + except asyncio.CancelledError: + logger.info("Request was cancelled") + yield "[OPENAI-CUA] Request cancelled..." + break + except Exception as ex: + logger.error(f"Error making request to OpenAI: {ex}") + yield f"[OPENAI-CUA] Error from OpenAI: {str(ex)}" + break + + if "output" not in data: + logger.error(f"No 'output' in response: {data}") + yield f"No 'output' in /v1/responses result: {data}" + break + + new_items = data["output"] + logger.info(f"Received {len(new_items)} new items from OpenAI") + + got_final_assistant = False + for item in new_items: + # Check for cancellation inside loop + if cancel_event and cancel_event.is_set(): + logger.info("Cancel event detected while processing items") + break + + # Add this item to conversation first + conversation.add_item(item) + + # Process the item + immediate_msg, action_needed = await msg_handler.process_item(item) + + # 1. If there's an immediate message (e.g. partial AI chunk), yield it + if immediate_msg: + yield immediate_msg + # Check if it's a "reasoning" item, yield a stop marker for visual break + if item.get("type") == "reasoning": + yield {"stop": True} + # If it's an assistant item, mark as final + if item.get("role") == "assistant": + got_final_assistant = True + + # 2. If an action is required (tool call), do it + if action_needed: + # Wait the configured time between steps if needed + if wait_time > 0: + logger.debug(f"Waiting {wait_time}s between steps") + await asyncio.sleep(wait_time) + + # Execute the action and get results + result_item, result_tool_msg = await msg_handler.execute_action( + action_needed + ) + # Add the result to conversation + conversation.add_item(result_item) + + # Yield the tool result message + yield result_tool_msg + + # Check again for cancellation + if cancel_event and cancel_event.is_set(): + break + + # If we got a final assistant message, end the conversation + if got_final_assistant: + logger.info("Received final assistant message, ending conversation") + break + finally: + # Clean up any pending tasks we created + logger.info(f"Cleaning up {len(pending_tasks)} pending tasks") + for task in pending_tasks: + if not task.done(): + task.cancel() + + # Wait briefly for tasks to clean up + if pending_tasks: + try: + await asyncio.wait(pending_tasks, timeout=0.5) + except asyncio.CancelledError: + pass + + # Clean up browser resources + if steel_computer: + logger.info("Cleaning up SteelComputer resources") + await steel_computer.cleanup() + + # End of main loop + logger.info("Exited main loop, finishing agent execution") + yield "[OPENAI-CUA] Agent ended." + except Exception as e: + logger.error(f"Unexpected error in agent: {e}", exc_info=True) + # Attempt cleanup even on error + if steel_computer: + try: + await steel_computer.cleanup() + except Exception as cleanup_err: + logger.error(f"Error during emergency cleanup: {cleanup_err}") + yield f"[OPENAI-CUA] Error: {str(e)}" + finally: + # Close the playwright instance + if playwright_instance: + try: + await playwright_instance.stop() + logger.info("Closed Playwright instance") + except Exception as e: + logger.error(f"Error closing Playwright instance: {e}") + + # Restore original SIGINT handler + if original_sigint_handler and hasattr(signal, "SIGINT"): + signal.signal(signal.SIGINT, original_sigint_handler) + + # Clean up our local cancel event if we created one + if local_cancel_event and cancel_event and not cancel_event.is_set(): + cancel_event.set() diff --git a/api/plugins/openai_computer_use/config.py b/api/plugins/openai_computer_use/config.py new file mode 100644 index 00000000..aef9a776 --- /dev/null +++ b/api/plugins/openai_computer_use/config.py @@ -0,0 +1,20 @@ +import os +from dotenv import load_dotenv + +# Load environment variables +load_dotenv(".env.local") + +# Environment variables and constants +STEEL_API_KEY = os.getenv("STEEL_API_KEY") +STEEL_CONNECT_URL = os.getenv("STEEL_CONNECT_URL") +STEEL_API_URL = os.getenv("STEEL_API_URL") +OPENAI_RESPONSES_URL = "https://api.openai.com/v1/responses" +VALID_OPENAI_CUA_MODELS = { + "computer-use-preview", + "computer-use-preview-2025-02-04", +} + +# Default settings +DEFAULT_MAX_STEPS = 30 +DEFAULT_WAIT_TIME_BETWEEN_STEPS = 1 +DEFAULT_NUM_IMAGES_TO_KEEP = 10 diff --git a/api/plugins/openai_computer_use/conversation_manager.py b/api/plugins/openai_computer_use/conversation_manager.py new file mode 100644 index 00000000..f4fe5d12 --- /dev/null +++ b/api/plugins/openai_computer_use/conversation_manager.py @@ -0,0 +1,166 @@ +import datetime +import json +import logging +from typing import Any, Dict, List, Optional + +from langchain_core.messages import BaseMessage + +from .tools import _make_cua_content_for_role + +logger = logging.getLogger("openai_computer_use.conversation_manager") + +# Minimal 1x1 transparent pixel as base64 - very small footprint +MINIMAL_BASE64_IMAGE = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" + + +class ConversationManager: + """ + Manages conversation state/history for the OpenAI computer-use model. + - Holds a list of conversation 'items' + - Can initialize from user & system prompts + - Provides a method to convert items into the format required by the /v1/responses endpoint + - (Optional) can do image-trimming to reduce context size + """ + + def __init__(self, num_images_to_keep: int = 10): + self.items: List[Dict[str, Any]] = [] + self.num_images_to_keep = num_images_to_keep + self.logger = logger + + def initialize_from_history( + self, + base_msgs: List[BaseMessage], + system_prompt: Optional[str] = None, + ) -> None: + """ + Convert existing chat history into a set of conversation items + recognized by the OpenAI 'responses' endpoint. Optionally prepend + a system prompt. + """ + # 1) Add system message if provided + if system_prompt: + sys_text = f"{system_prompt}\nCurrent date/time: {datetime.datetime.now():%Y-%m-%d %H:%M:%S}" + self.items.append( + { + "role": "system", + "content": _make_cua_content_for_role("system", sys_text), + } + ) + self.logger.info("Added system prompt to conversation") + + # 2) Convert each message in history to the correct format + for msg in base_msgs: + # In standard usage, msg can be an AIMessage, ToolMessage, or HumanMessage + if hasattr(msg, "tool_call_id") and msg.tool_call_id: + # It's a tool result (ToolMessage) + # In the original code, tool messages become "computer_call_output" items + self.logger.info( + f"Processing tool message with call_id: {msg.tool_call_id}" + ) + tool_content = msg.content + if isinstance(tool_content, str): + # Keep strings as is + content_for_output = tool_content + else: + # Serialize dictionaries or other objects + content_for_output = json.dumps(tool_content) + + self.items.append( + { + "type": "computer_call_output", + "call_id": msg.tool_call_id, + "output": { + "type": "input_image", + "image_url": content_for_output, + }, + } + ) + self.logger.info( + f"Added computer_call_output for tool response with call_id: {msg.tool_call_id}" + ) + elif msg.type == "ai": + # It's an AI message + self.logger.info("Processing AI message") + text = msg.content + text = text if isinstance(text, str) else json.dumps(text) + self.items.append( + { + "role": "assistant", + "content": _make_cua_content_for_role("assistant", text), + } + ) + self.logger.info("Added assistant role item") + elif msg.type == "human": + # It's a user message + self.logger.info("Processing human message") + text = msg.content + text = text if isinstance(text, str) else json.dumps(text) + self.items.append( + { + "role": "user", + "content": _make_cua_content_for_role("user", text), + } + ) + self.logger.info("Added user role item") + else: + # Fallback for system or other + self.logger.info(f"Processing {msg.type} message") + text = msg.content + text = text if isinstance(text, str) else json.dumps(text) + self.items.append( + { + "role": msg.type, + "content": _make_cua_content_for_role(msg.type, text), + } + ) + self.logger.info(f"Added {msg.type} role item") + + self.logger.info(f"Processed {len(self.items)} total conversation items") + + def add_item(self, item: Dict[str, Any]) -> None: + """Add a single item from the new model response (or user input) to the conversation.""" + self.items.append(item) + + def trim_images(self) -> None: + """ + Optionally trim older images from conversation to save tokens. + This implementation keeps the most recent images and replaces older ones with a minimal base64 image. + """ + # Count computer_call_output items with images + image_items = [ + (i, item) + for i, item in enumerate(self.items) + if item.get("type") == "computer_call_output" + and item.get("output", {}).get("type") == "input_image" + ] + + if len(image_items) <= self.num_images_to_keep: + return # No trimming needed + + # Keep only the most recent images + images_to_trim = image_items[: -self.num_images_to_keep] + + for idx, item in images_to_trim: + # Replace image with minimal base64 image placeholder + call_id = item.get("call_id", "unknown") + self.items[idx] = { + "type": "computer_call_output", + "call_id": call_id, + "output": { + **item.get("output", {}), + "image_url": MINIMAL_BASE64_IMAGE, + }, + } + + self.logger.info( + f"Trimmed {len(images_to_trim)} older images from conversation" + ) + + def prepare_for_model(self) -> List[Dict[str, Any]]: + """ + Return a copy of self.items with all transformations needed + for the /v1/responses request body. + """ + + self.trim_images() + return self.items.copy() diff --git a/api/plugins/openai_computer_use/cursor_overlay.py b/api/plugins/openai_computer_use/cursor_overlay.py new file mode 100644 index 00000000..1e17f1d7 --- /dev/null +++ b/api/plugins/openai_computer_use/cursor_overlay.py @@ -0,0 +1,46 @@ +""" +JavaScript code and helper function for injecting a visible cursor overlay into the browser. +""" + +from playwright.async_api import Page + +async def inject_cursor_overlay(page: Page) -> None: + """Inject a custom cursor overlay into the page.""" + await page.add_init_script(""" + // Only run in the top frame + if (window.self === window.top) { + function initCursor() { + const CURSOR_ID = '__cursor__'; + if (document.getElementById(CURSOR_ID)) return; + + const cursor = document.createElement('div'); + cursor.id = CURSOR_ID; + Object.assign(cursor.style, { + position: 'fixed', + top: '0px', + left: '0px', + width: '20px', + height: '20px', + backgroundImage: 'url("data:image/svg+xml;utf8,")', + backgroundSize: 'cover', + pointerEvents: 'none', + zIndex: '99999', + transform: 'translate(-2px, -2px)', + }); + + document.body.appendChild(cursor); + document.addEventListener("mousemove", (e) => { + cursor.style.top = e.clientY + "px"; + cursor.style.left = e.clientX + "px"; + }); + } + + requestAnimationFrame(function checkBody() { + if (document.body) { + initCursor(); + } else { + requestAnimationFrame(checkBody); + } + }); + } + """) \ No newline at end of file diff --git a/api/plugins/openai_computer_use/key_mapping.py b/api/plugins/openai_computer_use/key_mapping.py new file mode 100644 index 00000000..bc317fc2 --- /dev/null +++ b/api/plugins/openai_computer_use/key_mapping.py @@ -0,0 +1,100 @@ +""" +Map OpenAI CUA keys to Playwright-compatible keys. +""" + +CUA_KEY_TO_PLAYWRIGHT_KEY = { + # Common / Basic Keys + "return": "Enter", + "enter": "Enter", + "tab": "Tab", + "backspace": "Backspace", + "up": "ArrowUp", + "down": "ArrowDown", + "left": "ArrowLeft", + "right": "ArrowRight", + "space": "Space", + "ctrl": "Control", + "control": "Control", + "alt": "Alt", + "shift": "Shift", + "meta": "Meta", + "command": "Meta", + "windows": "Meta", + "esc": "Escape", + "escape": "Escape", + # Numpad Keys + "kp_0": "Numpad0", + "kp_1": "Numpad1", + "kp_2": "Numpad2", + "kp_3": "Numpad3", + "kp_4": "Numpad4", + "kp_5": "Numpad5", + "kp_6": "Numpad6", + "kp_7": "Numpad7", + "kp_8": "Numpad8", + "kp_9": "Numpad9", + # Numpad Operations + "kp_enter": "NumpadEnter", + "kp_multiply": "NumpadMultiply", + "kp_add": "NumpadAdd", + "kp_subtract": "NumpadSubtract", + "kp_decimal": "NumpadDecimal", + "kp_divide": "NumpadDivide", + # Navigation + "page_down": "PageDown", + "page_up": "PageUp", + "home": "Home", + "end": "End", + "insert": "Insert", + "delete": "Delete", + # Function Keys + "f1": "F1", + "f2": "F2", + "f3": "F3", + "f4": "F4", + "f5": "F5", + "f6": "F6", + "f7": "F7", + "f8": "F8", + "f9": "F9", + "f10": "F10", + "f11": "F11", + "f12": "F12", + # Left/Right Variants + "shift_l": "ShiftLeft", + "shift_r": "ShiftRight", + "control_l": "ControlLeft", + "control_r": "ControlRight", + "alt_l": "AltLeft", + "alt_r": "AltRight", + # Media Keys + "audiovolumemute": "AudioVolumeMute", + "audiovolumedown": "AudioVolumeDown", + "audiovolumeup": "AudioVolumeUp", + # Additional Special Keys + "print": "PrintScreen", + "scroll_lock": "ScrollLock", + "pause": "Pause", + "menu": "ContextMenu", + # Additional mappings for common variations + "/": "Divide", + "\\": "Backslash", + "capslock": "CapsLock", + "option": "Alt", # Mac "option" maps to Alt + "super": "Meta", # Mac "⌘" or Win "⊞" + "win": "Meta", +} + + +def translate_key(key: str) -> str: + """ + Map CUA-style key strings to Playwright-compatible keys. + Reference: https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key/Key_Values + + Args: + key: The key string to translate + + Returns: + The Playwright-compatible key string + """ + return CUA_KEY_TO_PLAYWRIGHT_KEY.get(key.lower(), key) \ No newline at end of file diff --git a/api/plugins/openai_computer_use/message_handler.py b/api/plugins/openai_computer_use/message_handler.py new file mode 100644 index 00000000..47542f33 --- /dev/null +++ b/api/plugins/openai_computer_use/message_handler.py @@ -0,0 +1,255 @@ +import json +import logging +from typing import Any, Dict, Optional, Tuple + +from langchain.schema import AIMessage +from langchain_core.messages import BaseMessage, ToolMessage + +from .steel_computer import SteelComputer + +logger = logging.getLogger("openai_computer_use.message_handler") + + +class MessageHandler: + """ + Processes the model's output items. Responsible for: + - Converting them into immediate yieldable messages (AIMessage, etc.) + - Identifying + packaging browser actions (tool calls) + - Executing those actions using a SteelComputer + - Returning the final result (screenshot, error, etc.) + """ + + def __init__(self, computer: SteelComputer): + self.computer = computer + self.logger = logger + + async def process_item( + self, item: Dict[str, Any] + ) -> Tuple[Optional[BaseMessage], Optional[dict]]: + """ + Converts a single item from the model's response into: + - a BaseMessage to yield to the client, or None + - an action dict to pass to `execute_action()`, or None + + For example: + - "message" -> immediate AIMessage + - "computer_call" -> yield a tool call, return an action + - "function_call" -> yield a tool call, return an action + - "assistant" -> final AIMessage + - "reasoning" -> yield partial chain-of-thought (hidden or shown) + """ + item_type = item.get("type") + self.logger.debug(f"Processing item of type: {item_type}") + + # 1) message chunk + if item_type == "message": + # It's a chunk of text from the user or assistant + # Usually "input_text" or "output_text" parts + text_segments = item.get("content", []) + combined_text = "" + # Gather text from "output_text" or "input_text" + for seg in text_segments: + if seg["type"] in ("output_text", "input_text"): + combined_text += seg["text"] + + if combined_text.strip(): + # Return an AIMessage chunk + self.logger.info(f"Yielding message text: {combined_text[:100]}...") + return AIMessage(content=combined_text), None + return None, None + + # 2) computer_call -> a direct request to do "click", "scroll", etc. + if item_type == "computer_call": + call_id = item["call_id"] + action = item["action"] + ack_checks = item.get("pending_safety_checks", []) + + self.logger.info( + f"[TOOL_CALL] Processing computer action call: {action['type']} (id: {call_id})" + ) + + # We'll yield a minimal AIMessage with a tool call + # Then return the action to be executed + tool_call_msg = AIMessage( + content="", + tool_calls=[{"name": action["type"], "args": action, "id": call_id}], + ) + return tool_call_msg, { + "call_id": call_id, + "action": action, + "action_type": "computer_call", + "ack_checks": ack_checks, + } + + # 3) function_call -> a request to call "goto", "back", "forward" + if item_type == "function_call": + call_id = item["call_id"] + fn_name = item["name"] + self.logger.info( + f"Processing function_call: {fn_name} with call_id: {call_id}" + ) + + try: + fn_args = json.loads(item["arguments"]) + self.logger.info( + f"Successfully parsed arguments for {fn_name}: {json.dumps(fn_args)}" + ) + except Exception as arg_err: + self.logger.error(f"Failed to parse arguments for {fn_name}: {arg_err}") + fn_args = {} + + # yield a minimal AIMessage with the function call + tool_call_msg = AIMessage( + content="", + tool_calls=[{"name": fn_name, "args": fn_args, "id": call_id}], + ) + return tool_call_msg, { + "call_id": call_id, + "action": fn_args, + "action_type": fn_name, + } + + # 4) reasoning -> partial chain-of-thought (if any) + if item_type == "reasoning": + # We can yield it as a "thoughts" message + self.logger.info("Processing reasoning item") + + reasoning_text = None + + # Check for tokens first + if "tokens" in item: + reasoning_text = item["tokens"] + self.logger.info(f"Found reasoning tokens: {reasoning_text}") + # Then check for summary + elif "summary" in item: + summary_text = [ + s.get("text", "") + for s in item["summary"] + if s.get("type") == "summary_text" + ] + if summary_text: + reasoning_text = "\n".join(summary_text) + self.logger.info(f"Found reasoning summary: {reasoning_text}") + + if reasoning_text and reasoning_text.strip(): + # We'll yield it as an AI message with "Thoughts" + self.logger.info("Yielding reasoning as AIMessage with thoughts format") + return AIMessage(content=f"*Thoughts*:\n{reasoning_text}"), None + return None, None + + # 5) assistant -> final assistant message + if item_type == "assistant": + self.logger.info("Received final assistant message") + content_array = item.get("content", []) + final_text = "" + for seg in content_array: + if seg.get("type") == "output_text": + final_text += seg["text"] + if final_text.strip(): + self.logger.info(f"Yielding final assistant msg: {final_text[:100]}...") + return AIMessage(content=final_text), None + return None, None + + # By default, do nothing + self.logger.warning(f"Unknown item type {item_type} - ignoring.") + return None, None + + async def execute_action( + self, action_dict: dict + ) -> Tuple[Dict[str, Any], ToolMessage]: + """ + Execute the previously identified action and build the final tool result message. + + Returns a tuple: + (item_to_add_to_history, tool_result_message_to_yield) + """ + call_id = action_dict["call_id"] + action_type = action_dict["action_type"] + action = action_dict["action"] + ack_checks = action_dict.get("ack_checks", []) + + self.logger.info(f"Executing action: {action_type} (call_id: {call_id})") + + if action_type in ("goto", "back", "forward"): + if action_type == "goto": + final_action = {"type": "goto", "url": action.get("url", "about:blank")} + elif action_type == "back": + final_action = {"type": "back"} + elif action_type == "forward": + final_action = {"type": "forward"} + else: + final_action = {"type": "screenshot"} + else: + final_action = action + + result = await self.computer.execute_action(final_action) + + if result.get("type") == "error": + item_to_add = { + "type": "computer_call_output", + "call_id": call_id, + "output": { + "type": "error", + "error": result["error"], + "tool_name": result.get("tool_name", action_type), + "tool_args": result.get("tool_args", final_action), + }, + } + else: + if action_type in ("goto", "back", "forward") or result.get( + "tool_name" + ) in ("goto", "back", "forward"): + item_to_add = { + "type": "function_call_output", + "call_id": call_id, + "output": "success", + } + else: + item_to_add = { + "type": "computer_call_output", + "call_id": call_id, + "acknowledged_safety_checks": ack_checks, + "output": { + "type": "input_image", + "image_url": f"data:image/png;base64,{result.get('source', {}).get('data', '')}", + "current_url": result.get("current_url", "about:blank"), + "toolName": result.get("tool_name", action_type), + "args": result.get("tool_args", final_action), + }, + } + + # Build the tool result message to yield + content_for_tool_msg = [] + if result.get("type") == "error": + # Return an error structure + content_for_tool_msg.append( + { + "type": "error", + "error": result["error"], + "tool_name": result.get("tool_name", action_type), + "tool_args": result.get("tool_args", final_action), + } + ) + else: + # Return an image structure + content_for_tool_msg.append(result) + + tool_result_msg = ToolMessage( + content=content_for_tool_msg, + tool_call_id=call_id, + type="tool", + name=action_type, + args=final_action, + metadata={"message_type": "tool_result"}, + ) + + return item_to_add, tool_result_msg + + async def cleanup(self): + try: + # Clean up any pending Playwright tasks + for context in self.computer.browser.contexts: + await context.close() + await self.computer.browser.close() + except Exception as e: + self.logger.warning(f"Error during browser cleanup: {e}") diff --git a/api/plugins/openai_computer_use/prompts.py b/api/plugins/openai_computer_use/prompts.py new file mode 100644 index 00000000..bb93f206 --- /dev/null +++ b/api/plugins/openai_computer_use/prompts.py @@ -0,0 +1,17 @@ +SYSTEM_PROMPT = """You are an OpenAI Computer-Using Agent with full power to control a web browser. +You can see the screen and perform actions like clicking, typing, scrolling, and more. +Your goal is to help the user accomplish their tasks by interacting with the web interface. + +When you need to perform an action: +1. Carefully analyze the current state of the screen +2. Decide on the most appropriate action to take +3. Execute the action precisely + +For browser navigation: +- ALWAYS use the 'back' function to go back in browser history +- ALWAYS use the 'forward' function to go forward in browser history +- NEVER try to navigate back/forward by clicking browser buttons or using keyboard shortcuts +- Use 'goto' or 'change_url' for direct URL navigation + +Always explain what you're doing and why, and ask for clarification if needed. +""" diff --git a/api/plugins/openai_computer_use/steel_computer.py b/api/plugins/openai_computer_use/steel_computer.py new file mode 100644 index 00000000..968a15ab --- /dev/null +++ b/api/plugins/openai_computer_use/steel_computer.py @@ -0,0 +1,366 @@ +import asyncio +import base64 +import json +import logging +from typing import Any, Dict, Tuple, Optional, List, Set + +from playwright.async_api import Page, Browser, BrowserContext + +from .tools import translate_key +from .cursor_overlay import inject_cursor_overlay + +logger = logging.getLogger("openai_computer_use.steel_computer") + + +class SteelComputer: + """ + Wraps the browser and page interactions for the OpenAI computer-use agent. + Responsible for: + - Managing the browser, context, and pages + - Handling page navigation and tab management + - Executing user actions (click, scroll, type, etc.) + - Capturing screenshots and page state + """ + + def __init__(self, page: Page): + """Initialize with an active page.""" + self.active_page = page + self.browser = page.context.browser + self.context = page.context + self.logger = logger + + # Set up event handlers + self.context.on("page", self._handle_new_page) + + # Initialize tracking for pages that have been set up + self._setup_pages: Set[Page] = set() + + # Track pending tasks + self._pending_tasks: Set[asyncio.Task] = set() + + # Track if we're cleaned up + self._cleanup_done = False + + @classmethod + async def create(cls, browser: Browser) -> "SteelComputer": + """Create a new SteelComputer instance with a fresh page.""" + # Get or create a context + if not browser.contexts: + context = await browser.new_context() + else: + context = browser.contexts[0] + + # Get or create a page + if not context.pages: + page = await context.new_page() + else: + page = context.pages[0] + + # Create instance + computer = cls(page) + + # Set up the initial page + await computer.setup_page(page) + + # Navigate to a starting page + await page.goto("https://www.google.com") + + return computer + + async def setup_page(self, page: Page) -> None: + """Set up a page with necessary scripts and settings.""" + if page in self._setup_pages: + return + + # Get and set viewport size + viewport_size = await page.evaluate("""() => ({ + width: window.innerWidth, + height: window.innerHeight + })""") + await page.set_viewport_size(viewport_size) + self.logger.info(f"Set viewport size to {viewport_size['width']}x{viewport_size['height']}") + + # Add cursor overlay to make mouse movements visible + await inject_cursor_overlay(page) + self.logger.info("Cursor overlay injected") + + # Apply same-tab navigation script + await self._apply_same_tab_script(page) + self.logger.info("Same-tab navigation script injected") + + # Add page to setup tracking + self._setup_pages.add(page) + + async def _handle_new_page(self, new_page: Page) -> None: + """Handler for when a new page is created.""" + self.logger.info(f"New page created: {new_page.url}") + + # Set as active page + self.active_page = new_page + + # Wait for the page to load + await new_page.wait_for_load_state("domcontentloaded") + new_page.on("close", lambda: self._handle_page_close(new_page)) + + # Set up the page + await self.setup_page(new_page) + + self.logger.info(f"New page ready: {new_page.url}") + + def _handle_page_close(self, closed_page: Page) -> None: + """Handler for when a page is closed.""" + self.logger.info(f"Page closed: {closed_page.url}") + + # Remove from tracking + if closed_page in self._setup_pages: + self._setup_pages.remove(closed_page) + + # If the closed page was the active page, set active page to another open page + if self.active_page == closed_page: + remaining_pages = self.context.pages + if remaining_pages: + self.active_page = remaining_pages[0] + self.logger.info(f"Active page switched to: {self.active_page.url}") + else: + self.logger.warning("No remaining pages open") + # If no pages are left, create a new one + task = asyncio.create_task(self._create_new_page()) + self._track_task(task) + + def _track_task(self, task: asyncio.Task) -> None: + """Track an async task to ensure it's completed before cleanup.""" + self._pending_tasks.add(task) + task.add_done_callback(lambda t: self._pending_tasks.discard(t)) + + async def _create_new_page(self) -> None: + """Create a new page when all pages are closed.""" + try: + new_page = await self.context.new_page() + self.active_page = new_page + await self.setup_page(new_page) + await new_page.goto("https://www.google.com") + self.logger.info("Created new page after all were closed") + except Exception as e: + self.logger.error(f"Error creating new page: {e}") + + async def _apply_same_tab_script(self, target_page: Page) -> None: + """Apply script to make links open in the same tab.""" + await target_page.add_init_script(""" + window.addEventListener('load', () => { + // Initial cleanup + document.querySelectorAll('a[target="_blank"]').forEach(a => a.target = '_self'); + + // Watch for dynamic changes + const observer = new MutationObserver((mutations) => { + mutations.forEach((mutation) => { + if (mutation.addedNodes) { + mutation.addedNodes.forEach((node) => { + if (node.nodeType === 1) { // ELEMENT_NODE + // Check the added element itself + if (node.tagName === 'A' && node.target === '_blank') { + node.target = '_self'; + } + // Check any anchor children + node.querySelectorAll('a[target="_blank"]').forEach(a => a.target = '_self'); + } + }); + } + }); + }); + + observer.observe(document.body, { + childList: true, + subtree: true + }); + }); + """) + + @property + def environment(self) -> str: + """Return the environment type (always 'browser' here).""" + return "browser" + + async def get_viewport_size(self) -> Dict[str, int]: + """Return the current viewport dimensions.""" + view_size = await self.active_page.evaluate( + """() => ({ width: window.innerWidth, height: window.innerHeight })""" + ) + return view_size + + async def execute_action(self, action: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute a single action dictionary from the model, e.g.: + { + "type": "click", + "x": 100, + "y": 200, + "button": "left" + } + + Returns a dict with screenshot + current URL (or error). + """ + # Make sure we're using the active page + page = self.active_page + + if page.is_closed(): + self.logger.warning("Page is closed; cannot execute action.") + return { + "type": "error", + "error": "Page is closed; cannot execute action", + } + + action_type = action.get("type") + self.logger.info(f"Executing action: {action_type}") + + try: + if action_type == "click": + x = action.get("x", 0) + y = action.get("y", 0) + button = action.get("button", "left") + await page.mouse.move(x, y) + await page.mouse.click(x, y, button=button) + + elif action_type == "scroll": + x, y = action.get("x", 0), action.get("y", 0) + sx, sy = action.get("scroll_x", 0), action.get("scroll_y", 0) + await page.mouse.move(x, y) + # Simple approach: evaluate scrollBy + await page.evaluate(f"window.scrollBy({sx}, {sy})") + + elif action_type == "type": + text = action.get("text", "") + await page.keyboard.type(text) + + elif action_type == "keypress": + keys = action.get("keys", []) + for k in keys: + mapped = translate_key(k) + await page.keyboard.press(mapped) + + elif action_type == "wait": + ms = action.get("ms", 1000) + await asyncio.sleep(ms / 1000.0) + + elif action_type == "move": + x, y = action.get("x", 0), action.get("y", 0) + await page.mouse.move(x, y) + + elif action_type == "drag": + path = action.get("path", []) + if not path: + raise ValueError("No path provided for drag action.") + first = path[0] + await page.mouse.move(first[0], first[1]) + await page.mouse.down() + for pt in path[1:]: + await page.mouse.move(pt[0], pt[1]) + await page.mouse.up() + + elif action_type == "back": + await page.go_back() + + elif action_type == "forward": + await page.go_forward() + + elif action_type == "goto": + url = action.get("url") + if not url: + raise ValueError("URL is required for goto action.") + await page.goto(url, wait_until="networkidle") + + elif action_type == "screenshot": + # We do nothing here because screenshot is done automatically after the action + pass + + else: + self.logger.warning(f"Unknown action type: {action_type}") + + # After action, take screenshot + screenshot_data = await page.screenshot(full_page=False) + data64 = base64.b64encode(screenshot_data).decode("utf-8") + current_url = page.url if not page.is_closed() else "about:blank" + + return { + "type": "image", + "source": { + "media_type": "image/png", + "data": data64 + }, + "current_url": current_url, + "tool_name": action_type, + "tool_args": action, + } + + except Exception as e: + self.logger.error(f"Error executing action '{action_type}': {e}") + return { + "type": "error", + "error": str(e), + "tool_name": action_type, + "tool_args": action, + } + + async def cleanup(self) -> None: + """Close all browser contexts and pages properly.""" + # Prevent multiple cleanups + if self._cleanup_done: + self.logger.info("Cleanup already performed, skipping") + return + + self._cleanup_done = True + self.logger.info("Starting browser cleanup") + + try: + # First, cancel any pending tasks we're tracking + for task in self._pending_tasks: + if not task.done(): + task.cancel() + + # Wait for a short time to let tasks properly cancel + if self._pending_tasks: + pending = list(self._pending_tasks) + self.logger.info(f"Waiting for {len(pending)} pending tasks to complete") + try: + # Wait for tasks with a timeout + await asyncio.wait(pending, timeout=1.0) + except asyncio.CancelledError: + self.logger.warning("Task cancellation was itself cancelled") + pass + + # Remove our page event handlers to prevent new callbacks + if hasattr(self.context, "_listeners"): + self.logger.info("Removing page event handlers") + self.context.remove_listener("page", self._handle_new_page) + + # First close all pages explicitly except active page + for page in list(self._setup_pages): + if page != self.active_page and not page.is_closed(): + try: + await page.close() + except Exception as e: + self.logger.warning(f"Error closing page: {e}") + + # Close active page last + if self.active_page and not self.active_page.is_closed(): + try: + await self.active_page.close() + except Exception as e: + self.logger.warning(f"Error closing active page: {e}") + + # Now close all contexts explicitly + try: + await self.context.close() + except Exception as e: + self.logger.warning(f"Error closing context: {e}") + + # Finally close the browser + try: + await self.browser.close() + self.logger.info("Browser closed successfully") + except Exception as e: + self.logger.warning(f"Error during browser close: {e}") + + except Exception as e: + self.logger.error(f"Error during browser cleanup: {e}") + import traceback + self.logger.error(traceback.format_exc()) \ No newline at end of file diff --git a/api/plugins/openai_computer_use/tools.py b/api/plugins/openai_computer_use/tools.py new file mode 100644 index 00000000..c2d8ae4a --- /dev/null +++ b/api/plugins/openai_computer_use/tools.py @@ -0,0 +1,170 @@ +import asyncio +import base64 +from typing import Dict, Any, List +from playwright.async_api import Page +import logging +from .key_mapping import translate_key + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("openai_computer_use.tools") + + +async def _execute_computer_action(page: Page, action: Dict[str, Any]) -> None: + """ + Given a single computer action dict, do that action via Playwright. + No longer returns screenshot as that's handled by the caller. + """ + action_type = action.get("type") + logger.info(f"Executing computer action: {action_type}") + + # If the page or browser closed unexpectedly, short-circuit + if page.is_closed(): + logger.warning("Page is already closed, skipping action") + return + + try: + if action_type == "click": + x = action.get("x", 0) + y = action.get("y", 0) + button = action.get("button", "left") + logger.debug(f"Clicking at ({x}, {y}), button={button}") + await page.mouse.move(x, y) + await page.mouse.click(x, y, button=button) + + elif action_type == "scroll": + x, y = action.get("x", 0), action.get("y", 0) + sx, sy = action.get("scroll_x", 0), action.get("scroll_y", 0) + logger.debug(f"Scrolling at ({x}, {y}) by ({sx}, {sy})") + await page.mouse.move(x, y) + await page.evaluate(f"window.scrollBy({sx}, {sy})") + + elif action_type == "type": + text = action.get("text", "") + logger.debug(f"Typing text: {text[:50]} ...") + await page.keyboard.type(text) + + elif action_type == "keypress": + keys = action.get("keys", []) + logger.debug(f"Pressing keys: {keys}") + for k in keys: + mapped_key = translate_key(k) + logger.debug(f"Mapped key '{k}' to '{mapped_key}'") + await page.keyboard.press(mapped_key) + + elif action_type == "wait": + ms = action.get("ms", 1000) + logger.debug(f"Waiting {ms} ms") + await asyncio.sleep(ms / 1000) + + elif action_type == "move": + x, y = action.get("x", 0), action.get("y", 0) + logger.debug(f"Moving to ({x}, {y})") + await page.mouse.move(x, y) + + elif action_type == "drag": + path = action.get("path", []) + logger.debug(f"Dragging path with {len(path)} points") + if path: + first = path[0] + await page.mouse.move(first[0], first[1]) + await page.mouse.down() + for pt in path[1:]: + await page.mouse.move(pt[0], pt[1]) + await page.mouse.up() + + elif action_type == "back": + logger.debug("Navigating back in browser history") + await page.go_back() + + elif action_type == "forward": + logger.debug("Navigating forward in browser history") + await page.go_forward() + + elif action_type == "goto": + url = action.get("url") + if not url: + raise ValueError("URL is required for goto action") + logger.debug(f"Navigating to URL: {url}") + await page.goto(url, wait_until="networkidle") + + elif action_type == "screenshot": + logger.debug("CUA requested screenshot action. No-op since screenshots are handled by caller.") + + else: + logger.warning(f"Unknown action type: {action_type}") + + except Exception as e: + logger.error(f"Error executing computer action '{action_type}': {e}") + raise + + +def _create_tools() -> List[Dict[str, Any]]: + """ + Return a list of 'tools' recognized by the CUA model, including the + 'computer-preview' tool for environment AND navigation functions + for URL navigation and browser history. + """ + return [ + # The required computer-preview tool: + { + "type": "computer-preview", + "display_width": 1280, + "display_height": 800, + "environment": "browser" + }, + # Our custom 'goto' function tool: + { + "type": "function", + "name": "goto", + "description": "Navigate to a specific URL", + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The fully-qualified URL to navigate to" + } + }, + "required": ["url"], + "additionalProperties": False + } + }, + # Back navigation tool + { + "type": "function", + "name": "back", + "description": "Go back one page in browser history", + "parameters": { + "type": "object", + "properties": {}, + "additionalProperties": False + } + }, + # Forward navigation tool + { + "type": "function", + "name": "forward", + "description": "Go forward one page in browser history", + "parameters": { + "type": "object", + "properties": {}, + "additionalProperties": False + } + }, + ] + + +def _make_cua_content_for_role(role: str, text: str) -> List[Dict[str, str]]: + """ + Convert user/system vs assistant text into the correct format: + user/system -> input_text + assistant -> output_text + """ + if role in ("user", "system"): + return [{"type": "input_text", "text": text}] + elif role == "assistant": + return [{"type": "output_text", "text": text}] + else: + # fallback if you have other roles + return [{"type": "input_text", "text": text}] diff --git a/api/utils/prompt.py b/api/utils/prompt.py index 00c38e62..889eab06 100644 --- a/api/utils/prompt.py +++ b/api/utils/prompt.py @@ -44,7 +44,7 @@ def convert_to_chat_messages(messages: List[ClientMessage]): "type": "function", "function": { "name": tool_invocation.toolName, - "arguments": json.dumps(tool_invocation.args), + "arguments": tool_invocation.args, }, } for tool_invocation in message.toolInvocations @@ -55,7 +55,7 @@ def convert_to_chat_messages(messages: List[ClientMessage]): tool_results = [ { "role": "tool", - "content": json.dumps(tool_invocation.result), + "content": tool_invocation.result, "tool_call_id": tool_invocation.toolCallId, } for tool_invocation in message.toolInvocations @@ -95,7 +95,7 @@ def extract_content(content_array): ( ToolMessage( tool_call_id=message["tool_call_id"], - content=json.loads(message["content"]), + content=message["content"], ) if message["role"] == "tool" else ( @@ -106,7 +106,7 @@ def extract_content(content_array): id=tool["id"], type=tool["type"], name=tool["function"]["name"], - args=json.loads(tool["function"]["arguments"]), + args=tool["function"]["arguments"], ) for tool in message["tool_calls"] ], diff --git a/api/utils/types.py b/api/utils/types.py index c1647ea8..62c40bf7 100644 --- a/api/utils/types.py +++ b/api/utils/types.py @@ -10,11 +10,18 @@ class ToolInvocation(BaseModel): class AgentSettings(BaseModel): - steps: Optional[int] = None + # General settings system_prompt: Optional[str] = None + + # Image and timing settings num_images_to_keep: Optional[int] = Field(default=10, ge=1, le=50) wait_time_between_steps: Optional[int] = Field(default=1, ge=0, le=10) - steps: Optional[int] = None + + # Step control + max_steps: Optional[int] = Field(default=30, ge=10, le=50) + + # Legacy field for backward compatibility + steps: Optional[int] = None # Deprecated in favor of max_steps class ModelSettings(BaseModel): diff --git a/app/chat/page.tsx b/app/chat/page.tsx index 62664c69..c821542e 100644 --- a/app/chat/page.tsx +++ b/app/chat/page.tsx @@ -18,6 +18,8 @@ import { ToolInvocations } from "@/components/ui/tool"; import { useToast } from "@/hooks/use-toast"; +import { isLocalhost } from "@/lib/utils"; + import { useChatContext } from "@/app/contexts/ChatContext"; import { useSettings } from "@/app/contexts/SettingsContext"; import { useSteelContext } from "@/app/contexts/SteelContext"; @@ -172,7 +174,7 @@ function MarkdownText({ content }: { content: string }) { )}
-
{title}
+
{title}
{strippedContent ? (
{parseContent(strippedContent)}
) : ( @@ -262,17 +264,16 @@ export default function ChatPage() { const pendingMessageRef = useRef(""); const checkApiKey = () => { - // // For Ollama, we don't need an API key as it connects to a local instance - // if (currentSettings?.selectedProvider === 'ollama') { - // return true; - // } - - // // For other providers, check if API key exists - // const provider = currentSettings?.selectedProvider; - // if (!provider) return false; - // const hasKey = !!currentSettings?.providerApiKeys?.[provider]; - // return hasKey; - return true; + // For Ollama, we don't need an API key as it connects to a local instance + if (currentSettings?.selectedProvider === "ollama" && isLocalhost()) { + return true; + } + + // For other providers, check if API key exists + const provider = currentSettings?.selectedProvider; + if (!provider) return false; + const hasKey = !!currentSettings?.providerApiKeys?.[provider]; + return hasKey; }; const handleApiKeySubmit = (key: string) => { @@ -337,6 +338,9 @@ export default function ChatPage() { onFinish: message => { console.info("✅ Chat finished:", message); }, + onResponse: response => { + console.info("🔄 Chat response:", response); + }, onError: error => { console.error("❌ Chat error:", error); toast({ @@ -652,7 +656,7 @@ export default function ChatPage() { ) : (
- {messages.map((message, index) => { + {(() => { const isSpecialMessage = (message.content && (message.content.includes("*Memory*:") || @@ -731,7 +735,7 @@ export default function ChatPage() {
) : null; - })} + })()}
)}
diff --git a/app/contexts/SettingsContext.tsx b/app/contexts/SettingsContext.tsx index 0a51bbf2..7a5be128 100644 --- a/app/contexts/SettingsContext.tsx +++ b/app/contexts/SettingsContext.tsx @@ -14,10 +14,18 @@ export interface ModelSettings { } export interface AgentSettings { + // General settings system_prompt?: string; + + // Image and timing settings num_images_to_keep?: number; wait_time_between_steps?: number; - steps?: number; + + // Step control + max_steps?: number; + steps?: number; // Legacy field + + // Allow additional string-keyed settings [key: string]: string | number | undefined; } diff --git a/app/contexts/SteelContext.tsx b/app/contexts/SteelContext.tsx index b6fa5074..53df41b8 100644 --- a/app/contexts/SteelContext.tsx +++ b/app/contexts/SteelContext.tsx @@ -10,8 +10,6 @@ interface SteelContextType { createSession: () => Promise; isCreatingSession: boolean; resetSession: () => Promise; - sessionTimeElapsed: number; - isExpired: boolean; maxSessionDuration: number; } @@ -23,39 +21,8 @@ export function SteelProvider({ children }: { children: React.ReactNode }) { console.info("🔄 Initializing SteelProvider"); const [currentSession, setCurrentSession] = useState(null); const [isCreatingSession, setIsCreatingSession] = useState(false); - const [sessionTimeElapsed, setSessionTimeElapsed] = useState(0); - const [isExpired, setIsExpired] = useState(false); const { currentSettings } = useSettings(); - // Timer effect - useEffect(() => { - console.info("⏱️ Timer effect triggered", { currentSession, isExpired }); - let intervalId: NodeJS.Timeout; - - if (currentSession && !isExpired) { - console.info("⏰ Starting session timer"); - intervalId = setInterval(() => { - setSessionTimeElapsed(prev => { - const newTime = prev + 1; - if (newTime >= MAX_SESSION_DURATION) { - console.warn("⚠️ Session expired after reaching MAX_SESSION_DURATION"); - setIsExpired(true); - clearInterval(intervalId); - return MAX_SESSION_DURATION; - } - return newTime; - }); - }, 1000); - } - - return () => { - if (intervalId) { - console.info("🛑 Clearing session timer"); - clearInterval(intervalId); - } - }; - }, [currentSession, isExpired]); - // Helper function to release a session const releaseSession = async (sessionId: string) => { console.info("🔓 Attempting to release session:", sessionId); @@ -114,8 +81,6 @@ export function SteelProvider({ children }: { children: React.ReactNode }) { const session = await response.json(); console.info("✅ Session created successfully:", session); setCurrentSession(session); - setSessionTimeElapsed(0); - setIsExpired(false); return session; } } catch (err) { @@ -135,8 +100,6 @@ export function SteelProvider({ children }: { children: React.ReactNode }) { setCurrentSession(null); setIsCreatingSession(false); - setSessionTimeElapsed(0); - setIsExpired(false); console.info("✅ Session reset complete"); }; @@ -147,8 +110,6 @@ export function SteelProvider({ children }: { children: React.ReactNode }) { createSession, isCreatingSession, resetSession, - sessionTimeElapsed, - isExpired, maxSessionDuration: MAX_SESSION_DURATION, }} > diff --git a/app/hooks/useOllamaModels.ts b/app/hooks/useOllamaModels.ts index 7505a879..6ef368ea 100644 --- a/app/hooks/useOllamaModels.ts +++ b/app/hooks/useOllamaModels.ts @@ -1,5 +1,9 @@ import { useQuery } from "@tanstack/react-query"; +import { isLocalhost } from "@/lib/utils"; + +import { useSettings } from "@/app/contexts/SettingsContext"; + interface OllamaModel { tag: string; base_name: string; @@ -18,10 +22,18 @@ async function fetchOllamaModels(): Promise { } export function useOllamaModels() { + const { currentSettings } = useSettings(); + const isOllamaSelected = currentSettings?.selectedProvider === "ollama"; + const isLocal = isLocalhost(); + return useQuery({ queryKey: ["ollama-models"], queryFn: fetchOllamaModels, - staleTime: 30 * 1000, + staleTime: 60 * 1000, // Increase stale time to 1 minute retry: 2, + // Only fetch when Ollama is selected and we're running locally + enabled: isOllamaSelected && isLocal, + // Skip refetching in the background when out of focus + refetchOnWindowFocus: false, }); } diff --git a/components/LayoutContent.tsx b/components/LayoutContent.tsx index 314702f5..834c8f93 100644 --- a/components/LayoutContent.tsx +++ b/components/LayoutContent.tsx @@ -12,15 +12,15 @@ import { SteelProvider } from "../app/contexts/SteelContext"; export function LayoutContent({ children }: { children: React.ReactNode }) { return ( - - + +
{children}
-
-
+ +
); } diff --git a/components/ui/Browser.tsx b/components/ui/Browser.tsx index c968433e..a118b8a8 100644 --- a/components/ui/Browser.tsx +++ b/components/ui/Browser.tsx @@ -4,35 +4,25 @@ import { useEffect, useRef, useState } from "react"; import { GlobeIcon } from "lucide-react"; import Image from "next/image"; -import { cn } from "@/lib/utils"; - import { useSteelContext } from "@/app/contexts/SteelContext"; +import Timer from "./Timer"; + export function Browser() { // WebSocket and canvas state const parentRef = useRef(null); const canvasRef = useRef(null); - const [canvasSize, setCanvasSize] = useState<{ + const [canvasSize] = useState<{ width: number; height: number; } | null>(null); - const [latestImage, setLatestImage] = useState(null); - const [isLoading, setIsLoading] = useState(false); - const [error, setError] = useState(null); - const [isConnected, setIsConnected] = useState(false); + const [latestImage] = useState(null); const [url, setUrl] = useState(null); const [favicon, setFavicon] = useState(null); - const { currentSession, sessionTimeElapsed, isExpired, maxSessionDuration } = useSteelContext(); + const { currentSession } = useSteelContext(); const debugUrl = currentSession?.debugUrl; - // Format time as MM:SS - const formatTime = (seconds: number) => { - const mins = Math.floor(seconds / 60); - const secs = seconds % 60; - return `${mins.toString().padStart(2, "0")}:${secs.toString().padStart(2, "0")}`; - }; - // Canvas rendering useEffect(() => { const renderFrame = () => { @@ -123,27 +113,7 @@ export function Browser() { {/* Status Bar */}
-
- -
- {currentSession - ? isExpired - ? "Session Expired" - : "Session Connected" - : "No Session"} - - - - {currentSession ? formatTime(sessionTimeElapsed) : "--:--"} - {" "} - /{formatTime(maxSessionDuration)} - -
+ Browser Powered by{" "} diff --git a/components/ui/SettingsDrawer.tsx b/components/ui/SettingsDrawer.tsx index 81280ccc..1f5f058e 100644 --- a/components/ui/SettingsDrawer.tsx +++ b/components/ui/SettingsDrawer.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Info, Settings } from "lucide-react"; import { useRouter } from "next/navigation"; @@ -93,8 +93,24 @@ function SettingInput({ .map(word => word.charAt(0).toUpperCase() + word.slice(1)) .join(" "); + // Extract text content from formatted structure if it exists + const extractTextContent = (val: any) => { + if (Array.isArray(val) && val.length > 0 && val[0]?.type === "input_text") { + return val[0].text; + } + return val; + }; + // Use config.default if value is undefined - const currentValue = value ?? config.default; + const currentValue = extractTextContent(value ?? config.default); + + // Prepare value for saving + const prepareValueForSave = (val: any) => { + if (settingKey === "system_prompt" && val) { + return [{ type: "input_text", text: val }]; + } + return val; + }; // Sanitize number inputs const sanitizeNumber = (value: number) => { @@ -181,7 +197,7 @@ function SettingInput({