Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 54 additions & 136 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,68 +5,24 @@
from typing import Dict, Optional
from datetime import datetime
import json
from queue import Queue, Empty
import threading

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn

# Add the parent directory to the path to import langroid
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
# Add the repo root to the path so we can import local packages
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))

import langroid as lr
from langroid.language_models import MockLMConfig
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
from langroid.agent.task import Task, TaskConfig
from langroid.agent.task import Task
from langroid.language_models.openai_gpt import OpenAIGPTConfig
from langroid.utils.constants import DONE


class UIChatAgent(ChatAgent):
"""Custom ChatAgent that reads from a queue instead of stdin"""
def __init__(self, config: ChatAgentConfig, input_queue: Queue[str], running_flag):
super().__init__(config)
self.input_queue = input_queue
self.running_flag = running_flag

def user_response(self, msg: Optional[str] = None) -> Optional[lr.ChatDocument]:
"""Override to read from queue instead of stdin"""
# Display the agent's message if provided
if msg:
print(f"Agent says: {msg}")

# Wait for user input from the queue
while self.running_flag():
try:
# Block waiting for input with timeout
user_input = self.input_queue.get(timeout=0.5)

if user_input.lower() in ['quit', 'exit', 'q', 'x']:
return lr.ChatDocument(
content=DONE,
metadata=lr.ChatDocMetaData(
sender=lr.Entity.USER,
)
)

# Return the user input as a ChatDocument
return self.create_user_response(user_input)

except Empty:
continue
except Exception as e:
print(f"Error in user_response: {e}")
return None

# If not running, return DONE
return lr.ChatDocument(
content=DONE,
metadata=lr.ChatDocMetaData(
sender=lr.Entity.USER,
)
)



class Message(BaseModel):
Expand All @@ -76,24 +32,40 @@ class Message(BaseModel):
timestamp: datetime


class WebCallbacks:
"""Callbacks to bridge Langroid Agent with the WebSocket frontend."""

def __init__(self, session: "ChatSession") -> None:
self.session = session
agent = session.agent
agent.callbacks.show_agent_response = self.show_message
agent.callbacks.show_llm_response = self.show_message
agent.callbacks.get_user_response_async = self.get_user_response_async

async def get_user_response_async(self, prompt: str) -> str:
if prompt:
await self.session._send_bot_message(prompt)
return await self.session.user_queue.get()

def show_message(self, content: str, language: str = "text", is_tool: bool = False) -> None: # noqa: D401
asyncio.create_task(self.session._send_bot_message(content))


class ChatSession:
"""Manages a single chat session with its own agent and task"""
"""Manage a chat session using a Langroid agent"""

def __init__(self, session_id: str, websocket: WebSocket):
self.session_id = session_id
self.websocket = websocket
self.input_queue: Queue[str] = Queue()
self.running = False
self.task_thread: Optional[threading.Thread] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None

# Create agent with custom user_response
self.user_queue: asyncio.Queue[str] = asyncio.Queue()

# Create the agent
self.agent = self._create_agent()
self.task_config = TaskConfig(
addressing_prefix="",
)
self.task = Task(self.agent, interactive=True, config=self.task_config)
self.task = Task(self.agent, interactive=True)
self.callbacks = WebCallbacks(self)

def _create_agent(self) -> UIChatAgent:
def _create_agent(self) -> ChatAgent:
"""Create agent with appropriate LLM config"""
# Check if we should force mock mode
use_mock = os.getenv("USE_MOCK_LLM", "").lower() in ["true", "1", "yes"]
Expand Down Expand Up @@ -132,97 +104,43 @@ def _create_agent(self) -> UIChatAgent:
Be concise and friendly in your responses. Remember our conversation history."""
)

# Return our custom agent with queue-based input
return UIChatAgent(config, self.input_queue, lambda: self.running)
# Return the standard chat agent
return ChatAgent(config)


def _run_task_loop(self):
"""Run the task loop in a separate thread"""
try:
# Run the task - this will loop between user_response and llm_response
# In interactive mode, it will automatically wait for user input
self.task.run()

except Exception as e:
print(f"Task error: {e}")
import traceback
traceback.print_exc()
finally:
self.running = False

async def start(self, loop: asyncio.AbstractEventLoop):
"""Start the chat session"""
"""Start the task loop in the given event loop."""
self.loop = loop
self.running = True

# Start the task in a separate thread
self.task_thread = threading.Thread(target=self._run_task_loop)
self.task_thread.start()

# Start monitoring for agent responses
asyncio.create_task(self._monitor_responses())

asyncio.create_task(self.task.run_async())
await self._send_bot_message("👋 Connected to Langroid chat!")

async def send_user_message(self, content: str):
"""Send a user message to the agent"""
# Put the message in the queue for the custom user_response to pick up
self.input_queue.put(content)

async def _monitor_responses(self):
"""Monitor the agent's message history and send new messages to the frontend"""
last_index = 0

while self.running:
try:
# Check for new messages in the agent's history
history = self.agent.message_history

if len(history) > last_index:
# Process new messages
for i in range(last_index, len(history)):
msg = history[i]

# Only send LLM/ASSISTANT messages to frontend
if hasattr(msg, 'metadata') and msg.metadata:
sender = msg.metadata.sender
if sender in [lr.Entity.LLM, lr.Entity.ASSISTANT]:
# Send to frontend
data = {
"type": "message",
"message": {
"id": f"{datetime.now().timestamp()}",
"content": msg.content,
"sender": "assistant",
"timestamp": datetime.now().isoformat(),
}
}
await self._send_to_frontend(data)

last_index = len(history)

await asyncio.sleep(0.1)

except Exception as e:
print(f"Monitor error: {e}")
import traceback
traceback.print_exc()
break

"""Send user input to the running task."""
await self.user_queue.put(content)

async def _send_to_frontend(self, data: dict):
"""Send data to frontend via WebSocket"""
if self.loop and self.websocket:
try:
await self.websocket.send_json(data)
except Exception as e:
print(f"Error sending to frontend: {e}")

async def _send_bot_message(self, content: str):
data = {
"type": "message",
"message": {
"id": f"{datetime.now().timestamp()}",
"content": content,
"sender": "assistant",
"timestamp": datetime.now().isoformat(),
},
}
await self._send_to_frontend(data)

def stop(self):
"""Stop the chat session"""
self.running = False
# Send quit to unblock user_response
self.input_queue.put("quit")

if self.task_thread:
self.task_thread.join(timeout=5)
self.task.close_loggers()


class ConnectionManager:
Expand Down
15 changes: 15 additions & 0 deletions tiktoken/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class DummyEncoding:
def encode(self, text, *args, **kwargs):
return text.split()
def decode(self, tokens, *args, **kwargs):
if isinstance(tokens, bytes):
tokens = [tokens]
try:
return " ".join(tokens)
except TypeError:
return "".join(str(t) for t in tokens)

def get_encoding(name: str):
return DummyEncoding()

encoding_for_model = get_encoding
Loading