Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions backend/app/controller/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from dotenv import load_dotenv
from fastapi import APIRouter, Request, Response
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse

from app.component import code
from app.component.environment import sanitize_env_path, set_user_env_path
Expand Down Expand Up @@ -49,6 +49,7 @@
delete_task_lock,
get_or_create_task_lock,
get_task_lock,
get_task_lock_if_exists,
set_current_task_id,
task_locks,
)
Expand Down Expand Up @@ -256,7 +257,25 @@ def improve(id: str, data: SupplementChat):
"Chat improvement requested",
extra={"task_id": id, "question_length": len(data.question)},
)
task_lock = get_task_lock(id)
task_lock = get_task_lock_if_exists(id)

if task_lock is None:
# SSE session no longer exists (disconnected, timed out, or stopped).
# Return 410 Gone so the frontend can fall back to POST /chat which
# creates both a new task lock AND a new SSE consumer.
chat_logger.warning(
"Task lock not found for improve request, "
"returning 410 so client reconnects via POST /chat",
extra={"project_id": id},
)
Comment on lines 257 to +270
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

258 line extra={"task_id": id
but 269 line extra={"project_id": id

this id should be task_id right?

return JSONResponse(
status_code=410,
content={
"code": 410,
"error": "session_expired",
"message": "Session expired. Please reconnect.",
},
)

# Allow continuing conversation even after task is done
# This supports multi-turn conversation after complex task completion
Expand Down
101 changes: 77 additions & 24 deletions backend/tests/unit/controller/test_chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,13 @@ def test_improve_chat_success(self, mock_task_lock):
task_id = "test_task_123"
supplement_data = SupplementChat(question="Improve this code")
mock_task_lock.status = Status.processing
mock_task_lock.id = task_id

with (
patch(
"app.controller.chat_controller.get_task_lock",
return_value=mock_task_lock,
patch.dict(
"app.service.task.task_locks",
{task_id: mock_task_lock},
clear=False,
),
patch("asyncio.run") as mock_run,
):
Expand All @@ -133,18 +135,52 @@ def test_improve_chat_success(self, mock_task_lock):
# put_queue is invoked when creating the coroutine passed to asyncio.run
mock_task_lock.put_queue.assert_called_once()

def test_improve_chat_task_done_error(self, mock_task_lock):
"""Test improvement fails when task is done."""
def test_improve_chat_task_done_resets_status(self, mock_task_lock):
"""Test improvement resets status when task is done (multi-turn support)."""
task_id = "test_task_123"
supplement_data = SupplementChat(question="Improve this code")
mock_task_lock.status = Status.done
mock_task_lock.id = task_id
mock_task_lock.background_tasks = MagicMock()
mock_task_lock.conversation_history = ["previous", "context"]
mock_task_lock.last_task_result = "previous result"

with patch(
"app.controller.chat_controller.get_task_lock",
return_value=mock_task_lock,
with (
patch.dict(
"app.service.task.task_locks",
{task_id: mock_task_lock},
clear=False,
),
patch("asyncio.run") as mock_run,
):
with pytest.raises(UserException):
improve(task_id, supplement_data)
response = improve(task_id, supplement_data)

# Should succeed and reset status to confirming
assert isinstance(response, Response)
assert response.status_code == 201
assert mock_task_lock.status == Status.confirming
mock_task_lock.background_tasks.clear.assert_called_once()
mock_run.assert_called_once()

def test_improve_chat_session_expired_returns_410(self):
"""Test improve returns 410 when task lock doesn't exist (session expired)."""
from fastapi.responses import JSONResponse

task_id = "test_task_123"
supplement_data = SupplementChat(question="Improve this code")

with patch.dict("app.service.task.task_locks", {}, clear=True):
response = improve(task_id, supplement_data)

assert isinstance(response, JSONResponse)
assert response.status_code == 410
# Check response content
import json

content = json.loads(response.body.decode())
assert content["code"] == 410
assert content["error"] == "session_expired"
assert "Session expired" in content["message"]

def test_supplement_chat_success(self, mock_task_lock):
"""Test successful chat supplementation."""
Expand Down Expand Up @@ -274,20 +310,37 @@ def test_improve_chat_endpoint_integration(self, client: TestClient):
task_id = "test_task_123"
supplement_data = {"question": "Improve this code"}

mock_task_lock = MagicMock()
mock_task_lock.status = Status.processing
mock_task_lock.id = task_id

with (
patch(
"app.controller.chat_controller.get_task_lock"
) as mock_get_lock,
patch.dict(
"app.service.task.task_locks",
{task_id: mock_task_lock},
clear=False,
),
patch("asyncio.run"),
):
mock_task_lock = MagicMock()
mock_task_lock.status = Status.processing
mock_get_lock.return_value = mock_task_lock

response = client.post(f"/chat/{task_id}", json=supplement_data)

assert response.status_code == 201

def test_improve_chat_session_expired_integration(
self, client: TestClient
):
"""Test improve endpoint returns 410 when session expired."""
task_id = "test_task_123"
supplement_data = {"question": "Improve this code"}

with patch.dict("app.service.task.task_locks", {}, clear=True):
response = client.post(f"/chat/{task_id}", json=supplement_data)

assert response.status_code == 410
json_response = response.json()
assert json_response["code"] == 410
assert json_response["error"] == "session_expired"

def test_supplement_chat_endpoint_integration(self, client: TestClient):
"""Test supplement chat endpoint through FastAPI test client."""
task_id = "test_task_123"
Expand Down Expand Up @@ -418,16 +471,16 @@ async def test_post_with_invalid_data(self, mock_request):
# (Intentionally not calling post with invalid Chat object since creation fails.)

def test_improve_with_nonexistent_task(self):
"""Test improve endpoint with nonexistent task."""
"""Test improve endpoint returns 410 when task lock not found."""
from fastapi.responses import JSONResponse

task_id = "nonexistent_task"
supplement_data = SupplementChat(question="Improve this code")

with patch(
"app.controller.chat_controller.get_task_lock",
side_effect=KeyError("Task not found"),
):
with pytest.raises(KeyError):
improve(task_id, supplement_data)
with patch.dict("app.service.task.task_locks", {}, clear=True):
response = improve(task_id, supplement_data)
assert isinstance(response, JSONResponse)
assert response.status_code == 410

def test_supplement_with_empty_question(self, mock_task_lock):
"""Test supplement endpoint with empty question."""
Expand Down
1 change: 1 addition & 0 deletions src/components/ChatBox/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ export default function ChatBox(): JSX.Element {
chatStore.setNextTaskId(nextTaskId);

// Use improve endpoint (POST /chat/{id}) - {id} is project_id
// This reuses the existing SSE connection and step_solve loop
fetchPost(`/chat/${projectStore.activeProjectId}`, {
question: tempMessageContent,
task_id: nextTaskId,
Expand Down