Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions backend/app/agent/listen_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,11 @@ def step(
f"tokens used: {total_tokens}"
)

assert message is not None
if message is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we change to warning and let the step continue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — changed to raise RuntimeError instead, keeping fail-fast behavior consistent with the res=None guard below. Fixed in latest push.

raise RuntimeError(
f"Agent {self.agent_name}: message is None after step "
f"completion — this indicates a missing LLM response"
)

_schedule_async_task(
task_lock.put_queue(
Expand All @@ -310,7 +314,11 @@ def step(

if error_info is not None:
raise error_info
assert res is not None
if res is None:
raise RuntimeError(
f"Agent {self.agent_name}: step() returned None "
f"without setting error_info"
)
return res

async def astep(
Expand Down Expand Up @@ -397,7 +405,11 @@ async def astep(

# Send deactivation for all non-streaming cases (success or error)
# Streaming responses handle deactivation in _astream_chunks
assert message is not None
if message is None:
raise RuntimeError(
f"Agent {self.agent_name}: message is None after astep "
f"completion — this indicates a missing LLM response"
)

asyncio.create_task(
task_lock.put_queue(
Expand All @@ -415,7 +427,11 @@ async def astep(

if error_info is not None:
raise error_info
assert res is not None
if res is None:
raise RuntimeError(
f"Agent {self.agent_name}: astep() returned None "
f"without setting error_info"
)
return res

def _execute_tool(
Expand Down
5 changes: 4 additions & 1 deletion backend/app/controller/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,11 @@ def improve(id: str, data: SupplementChat):
if task_lock.status == Status.done:
# Reset status to allow processing new messages
task_lock.status = Status.confirming
# Clear any existing background tasks since workforce was stopped
# Cancel and clear any existing background tasks since workforce was stopped
if hasattr(task_lock, "background_tasks"):
for bg_task in list(task_lock.background_tasks):
if not bg_task.done():
bg_task.cancel()
task_lock.background_tasks.clear()
# Note: conversation_history and last_task_result are preserved

Expand Down
116 changes: 116 additions & 0 deletions backend/tests/app/agent/test_listen_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,122 @@ def test_listen_chat_agent_step_with_task_lock_error(self):
agent.step("Test message")


@pytest.mark.unit
def test_step_raises_runtime_error_on_none_message(mock_task_lock):
"""step() raises RuntimeError (not AssertionError) when message is None."""
with (
patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock),
patch("camel.models.ModelFactory.create") as mock_create_model,
):
mock_backend = MagicMock()
mock_backend.model_type = "gpt-4"
mock_backend.current_model = MagicMock()
mock_backend.current_model.model_type = "gpt-4"
mock_create_model.return_value = mock_backend

agent = ListenChatAgent(
api_task_id="test_task",
agent_name="TestAgent",
model="gpt-4",
)
agent.process_task_id = "test_process"

error = Exception("Some unexpected error")
with patch.object(ChatAgent, "step", side_effect=error):
with pytest.raises(Exception, match="Some unexpected error"):
agent.step("test input")

mock_task_lock.put_queue.assert_called()


@pytest.mark.unit
def test_step_raises_runtime_error_when_res_is_none(mock_task_lock):
"""step() raises RuntimeError (not AssertionError) when res is None.

When ChatAgent.step returns None, message also remains None,
so the message-is-None guard fires first.
"""
with (
patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock),
patch("camel.models.ModelFactory.create") as mock_create_model,
):
mock_backend = MagicMock()
mock_backend.model_type = "gpt-4"
mock_backend.current_model = MagicMock()
mock_backend.current_model.model_type = "gpt-4"
mock_create_model.return_value = mock_backend

agent = ListenChatAgent(
api_task_id="test_task",
agent_name="TestAgent",
model="gpt-4",
)
agent.process_task_id = "test_process"

with patch.object(ChatAgent, "step", return_value=None):
with pytest.raises(RuntimeError, match="message is None"):
agent.step("test input")


@pytest.mark.unit
@pytest.mark.asyncio
async def test_astep_raises_runtime_error_on_none_message(mock_task_lock):
"""astep() raises RuntimeError (not AssertionError) when message is None."""
with (
patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock),
patch("camel.models.ModelFactory.create") as mock_create_model,
patch("asyncio.create_task"),
):
mock_backend = MagicMock()
mock_backend.model_type = "gpt-4"
mock_backend.current_model = MagicMock()
mock_backend.current_model.model_type = "gpt-4"
mock_create_model.return_value = mock_backend

agent = ListenChatAgent(
api_task_id="test_task",
agent_name="TestAgent",
model="gpt-4",
)
agent.process_task_id = "test_process"

error = Exception("Async unexpected error")
with patch.object(ChatAgent, "astep", side_effect=error):
with pytest.raises(Exception, match="Async unexpected error"):
await agent.astep("test input")


@pytest.mark.unit
@pytest.mark.asyncio
async def test_astep_raises_runtime_error_when_res_is_none(mock_task_lock):
"""astep() raises RuntimeError (not AssertionError) when res is None.

When ChatAgent.astep returns None, message also remains None,
so the message-is-None guard fires first.
"""
with (
patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock),
patch("camel.models.ModelFactory.create") as mock_create_model,
patch("asyncio.create_task"),
):
mock_backend = MagicMock()
mock_backend.model_type = "gpt-4"
mock_backend.current_model = MagicMock()
mock_backend.current_model.model_type = "gpt-4"
mock_create_model.return_value = mock_backend

agent = ListenChatAgent(
api_task_id="test_task",
agent_name="TestAgent",
model="gpt-4",
)
agent.process_task_id = "test_process"

with patch.object(ChatAgent, "astep", return_value=None):
with pytest.raises(RuntimeError, match="message is None"):
await agent.astep("test input")


@pytest.mark.model_backend
class TestAgentWithLLM:
"""Tests that require LLM backend (marked for selective running)."""
Expand Down
60 changes: 60 additions & 0 deletions backend/tests/app/controller/test_chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,66 @@ async def test_full_chat_workflow_with_llm(
assert True # Placeholder


@pytest.mark.unit
class TestBackgroundTaskCancellation:
"""Test that background_tasks are cancelled, not just cleared."""

def test_improve_cancels_background_tasks_when_done(self, mock_task_lock):
"""When task is done, improve() should cancel running background tasks
before clearing the set, not just drop references."""
task_id = "test_task_123"
supplement_data = SupplementChat(question="Follow-up question")
mock_task_lock.status = Status.done

# Create mock background tasks - one running, one already done
running_task = MagicMock()
running_task.done.return_value = False
done_task = MagicMock()
done_task.done.return_value = True

mock_task_lock.background_tasks = {running_task, done_task}

with (
patch(
"app.controller.chat_controller.get_task_lock",
return_value=mock_task_lock,
),
patch("asyncio.run") as mock_run,
):
response = improve(task_id, supplement_data)

assert response.status_code == 201
# The running task should have been cancelled
running_task.cancel.assert_called_once()
# The already-done task should NOT have been cancelled
done_task.cancel.assert_not_called()
# The set should be empty after clear()
assert len(mock_task_lock.background_tasks) == 0

def test_improve_handles_missing_background_tasks_attr(
self, mock_task_lock
):
"""improve() should handle task_lock without background_tasks attr."""
task_id = "test_task_123"
supplement_data = SupplementChat(question="Follow-up question")
mock_task_lock.status = Status.done

# Remove background_tasks attribute
if hasattr(mock_task_lock, "background_tasks"):
del mock_task_lock.background_tasks

with (
patch(
"app.controller.chat_controller.get_task_lock",
return_value=mock_task_lock,
),
patch("asyncio.run"),
):
# Should not raise AttributeError
response = improve(task_id, supplement_data)
assert response.status_code == 201


@pytest.mark.unit
class TestChatControllerErrorCases:
"""Test error cases and edge conditions."""
Expand Down
5 changes: 3 additions & 2 deletions electron/main/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1177,8 +1177,9 @@ function registerIpcHandlers() {
const childWindow = new BrowserWindow({
webPreferences: {
preload,
nodeIntegration: true,
contextIsolation: false,
nodeIntegration: false,
contextIsolation: true,
sandbox: true,
},
});

Expand Down
9 changes: 4 additions & 5 deletions server/app/component/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from fastapi import Depends, Header
from fastapi.security import OAuth2PasswordBearer
from fastapi_babel import _
from jwt.exceptions import InvalidTokenError
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from sqlmodel import Session, select

from app.component import code
Expand Down Expand Up @@ -51,12 +51,11 @@ def user(self):
def decode_token(cls, token: str):
try:
payload = jwt.decode(token, Auth.SECRET_KEY, algorithms=["HS256"])
id = payload["id"]
if payload["exp"] < int(datetime.now().timestamp()):
raise TokenException(code.token_expired, _("Validate credentials expired"))
except ExpiredSignatureError:
raise TokenException(code.token_expired, _("Validate credentials expired"))
except InvalidTokenError:
raise TokenException(code.token_invalid, _("Could not validate credentials"))
return Auth(id, payload["exp"])
return Auth(payload["id"], payload["exp"])

@classmethod
def create_access_token(cls, user_id: int, expires_delta: timedelta | None = None):
Expand Down