diff --git a/backend/app/component/environment.py b/backend/app/component/environment.py index 285a1fb14..93379d392 100644 --- a/backend/app/component/environment.py +++ b/backend/app/component/environment.py @@ -57,9 +57,9 @@ def sanitize_env_path(env_path: str | None) -> str | None: # Convert to Path object for safe manipulation user_path = Path(env_path) - # Reject absolute paths outside our control + # Resolve path: absolute paths are checked directly, relative paths + # are joined to env_base_dir. Both are validated against env_base_dir. if user_path.is_absolute(): - # Check if it's already within env_base_dir resolved_path = user_path.resolve() else: # Join relative path to base directory diff --git a/backend/app/controller/tool_controller.py b/backend/app/controller/tool_controller.py index a8277dd8e..799a3b6bb 100644 --- a/backend/app/controller/tool_controller.py +++ b/backend/app/controller/tool_controller.py @@ -481,7 +481,7 @@ async def uninstall_tool(tool: str): "Cancelled ongoing Google Calendar authorization" ) # Clear the state completely to remove cached credentials - oauth_state_manager._states.pop("google_calendar", None) + oauth_state_manager.remove_state("google_calendar") logger.info("Cleared Google Calendar OAuth state cache") return { diff --git a/backend/app/service/task.py b/backend/app/service/task.py index 604fbc717..ac64eb55c 100644 --- a/backend/app/service/task.py +++ b/backend/app/service/task.py @@ -14,6 +14,7 @@ import asyncio import logging +import threading import weakref from contextlib import contextmanager from contextvars import ContextVar @@ -528,22 +529,25 @@ def get_recent_context(self, max_entries: int = None) -> str: task_locks = dict[str, TaskLock]() +_task_locks_mutex = threading.Lock() # Cleanup task for removing stale task locks _cleanup_task: asyncio.Task | None = None task_index: dict[str, weakref.ref[Task]] = {} def get_task_lock(id: str) -> TaskLock: - if id not in task_locks: - logger.error("Task lock not found", extra={"task_id": id}) - raise ProgramException("Task not found") - logger.debug("Task lock retrieved", extra={"task_id": id}) - return task_locks[id] + with _task_locks_mutex: + if id not in task_locks: + logger.error("Task lock not found", extra={"task_id": id}) + raise ProgramException("Task not found") + logger.debug("Task lock retrieved", extra={"task_id": id}) + return task_locks[id] def get_task_lock_if_exists(id: str) -> TaskLock | None: """Get task lock if it exists, otherwise return None""" - return task_locks.get(id) + with _task_locks_mutex: + return task_locks.get(id) def set_current_task_id(project_id: str, task_id: str) -> None: @@ -557,61 +561,68 @@ def set_current_task_id(project_id: str, task_id: str) -> None: def create_task_lock(id: str) -> TaskLock: - if id in task_locks: - logger.warning( - "Attempting to create task lock that already exists", - extra={"task_id": id}, - ) - raise ProgramException("Task already exists") - - logger.info("Creating new task lock", extra={"task_id": id}) - task_locks[id] = TaskLock(id=id, queue=asyncio.Queue(), human_input={}) + with _task_locks_mutex: + if id in task_locks: + logger.warning( + "Attempting to create task lock that already exists", + extra={"task_id": id}, + ) + raise ProgramException("Task already exists") - # Start cleanup task if not running - # global _cleanup_task - # if _cleanup_task is None or _cleanup_task.done(): - # _cleanup_task = asyncio.create_task(_periodic_cleanup()) + logger.info("Creating new task lock", extra={"task_id": id}) + task_locks[id] = TaskLock(id=id, queue=asyncio.Queue(), human_input={}) - logger.info( - "Task lock created successfully", - extra={"task_id": id, "total_task_locks": len(task_locks)}, - ) - return task_locks[id] + logger.info( + "Task lock created successfully", + extra={"task_id": id, "total_task_locks": len(task_locks)}, + ) + return task_locks[id] def get_or_create_task_lock(id: str) -> TaskLock: """Get existing task lock or create a new one if it doesn't exist""" - if id in task_locks: - logger.debug("Using existing task lock", extra={"task_id": id}) + with _task_locks_mutex: + if id in task_locks: + logger.debug("Using existing task lock", extra={"task_id": id}) + return task_locks[id] + logger.info( + "Task lock not found, creating new one", extra={"task_id": id} + ) + task_locks[id] = TaskLock(id=id, queue=asyncio.Queue(), human_input={}) + logger.info( + "Task lock created successfully", + extra={"task_id": id, "total_task_locks": len(task_locks)}, + ) return task_locks[id] - logger.info("Task lock not found, creating new one", extra={"task_id": id}) - return create_task_lock(id) async def delete_task_lock(id: str): - if id not in task_locks: - logger.warning( - "Attempting to delete non-existent task lock", - extra={"task_id": id}, + with _task_locks_mutex: + if id not in task_locks: + logger.warning( + "Attempting to delete non-existent task lock", + extra={"task_id": id}, + ) + raise ProgramException("Task not found") + + # Clean up background tasks before deletion + task_lock = task_locks[id] + logger.info( + "Cleaning up task lock", + extra={ + "task_id": id, + "background_tasks": len(task_lock.background_tasks), + }, ) - raise ProgramException("Task not found") - # Clean up background tasks before deletion - task_lock = task_locks[id] - logger.info( - "Cleaning up task lock", - extra={ - "task_id": id, - "background_tasks": len(task_lock.background_tasks), - }, - ) await task_lock.cleanup() - del task_locks[id] - logger.info( - "Task lock deleted successfully", - extra={"task_id": id, "remaining_task_locks": len(task_locks)}, - ) + with _task_locks_mutex: + del task_locks[id] + logger.info( + "Task lock deleted successfully", + extra={"task_id": id, "remaining_task_locks": len(task_locks)}, + ) def get_camel_task(id: str, tasks: list[Task]) -> None | Task: diff --git a/backend/app/utils/cookie_manager.py b/backend/app/utils/cookie_manager.py index 8e72cf870..9d1816830 100644 --- a/backend/app/utils/cookie_manager.py +++ b/backend/app/utils/cookie_manager.py @@ -16,6 +16,7 @@ import os import shutil import sqlite3 +import tempfile from datetime import datetime from typing import Any @@ -60,7 +61,10 @@ def _get_cookies_connection(self) -> sqlite3.Connection | None: ) return None - temp_db_path = self.cookies_db_path + ".tmp" + fd, temp_db_path = tempfile.mkstemp( + suffix=".tmp", dir=os.path.dirname(self.cookies_db_path) + ) + os.close(fd) conn = None try: shutil.copy2(self.cookies_db_path, temp_db_path) diff --git a/backend/app/utils/oauth_state_manager.py b/backend/app/utils/oauth_state_manager.py index 1fe90af26..dc1217f5a 100644 --- a/backend/app/utils/oauth_state_manager.py +++ b/backend/app/utils/oauth_state_manager.py @@ -110,6 +110,11 @@ def update_status( state.completed_at = datetime.now() logger.info(f"Updated {provider} OAuth status to {status}") + def remove_state(self, provider: str) -> None: + """Remove the state for a provider under lock""" + with self._lock: + self._states.pop(provider, None) + # Global instance oauth_state_manager = OAuthStateManager() diff --git a/backend/tests/app/component/test_environment.py b/backend/tests/app/component/test_environment.py index bb7202e11..677eac264 100644 --- a/backend/tests/app/component/test_environment.py +++ b/backend/tests/app/component/test_environment.py @@ -39,11 +39,14 @@ def test_valid_relative_path(): assert result.endswith("project1.env") -def test_valid_absolute_path_within_base_dir(): - """Test that absolute path within base directory is accepted.""" +def test_absolute_path_within_base_dir_accepted(): + """Test that absolute paths within base dir are accepted.""" valid_path = os.path.join(env_base_dir, "valid.env") result = sanitize_env_path(valid_path) - assert result == os.path.abspath(valid_path) + assert result is not None, ( + "Absolute paths within base dir should be accepted" + ) + assert result.startswith(env_base_dir) def test_path_traversal_attack_rejected(): diff --git a/backend/tests/app/service/test_task.py b/backend/tests/app/service/test_task.py index 4ca49b8eb..1bcd81d3a 100644 --- a/backend/tests/app/service/test_task.py +++ b/backend/tests/app/service/test_task.py @@ -13,6 +13,7 @@ # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import asyncio +import threading import weakref from datetime import datetime, timedelta from unittest.mock import patch @@ -36,9 +37,11 @@ Agents, ImprovePayload, TaskLock, + _task_locks_mutex, create_task_lock, delete_task_lock, get_camel_task, + get_or_create_task_lock, get_task_lock, process_task, set_process_task, @@ -544,6 +547,62 @@ async def test_periodic_cleanup_handles_exceptions(self): mock_logger.error.assert_called() +@pytest.fixture +def clean_task_locks_with_mutex(): + """Clean up task_locks using mutex before and after each test.""" + with _task_locks_mutex: + task_locks.clear() + yield + with _task_locks_mutex: + task_locks.clear() + + +@pytest.mark.unit +def test_create_task_lock_is_thread_safe(clean_task_locks_with_mutex): + """Concurrent create_task_lock calls should not corrupt task_locks.""" + errors = [] + barrier = threading.Barrier(10) + + def worker(idx): + try: + barrier.wait(timeout=5) + create_task_lock(f"task_{idx}") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + assert not errors, f"Unexpected errors: {errors}" + with _task_locks_mutex: + assert len(task_locks) == 10 + + +@pytest.mark.unit +def test_get_or_create_is_idempotent(clean_task_locks_with_mutex): + """get_or_create_task_lock called twice returns the same lock.""" + lock1 = get_or_create_task_lock("same_id") + lock2 = get_or_create_task_lock("same_id") + assert lock1 is lock2 + + +@pytest.mark.unit +def test_create_task_lock_raises_on_duplicate(clean_task_locks_with_mutex): + """create_task_lock should raise for an existing id.""" + create_task_lock("dup_id") + with pytest.raises(Exception): + create_task_lock("dup_id") + + +@pytest.mark.unit +def test_mutex_attribute_exists(): + """_task_locks_mutex should be a threading.Lock instance.""" + assert isinstance(_task_locks_mutex, type(threading.Lock())) + + @pytest.mark.integration class TestTaskServiceIntegration: """Integration tests for task service components.""" diff --git a/backend/tests/app/utils/test_cookie_manager.py b/backend/tests/app/utils/test_cookie_manager.py new file mode 100644 index 000000000..b8350f024 --- /dev/null +++ b/backend/tests/app/utils/test_cookie_manager.py @@ -0,0 +1,69 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +import os + +import pytest + +from app.utils.cookie_manager import CookieManager + + +@pytest.mark.unit +class TestCookieManagerTempFile: + """Tests verifying cookie_manager uses unique temp files via mkstemp.""" + + def test_get_cookies_connection_creates_unique_temp(self, tmp_path): + """_get_cookies_connection should create a uniquely-named temp copy.""" + # Create a minimal SQLite database to act as the cookies DB + import sqlite3 + + cookies_db = tmp_path / "Cookies" + conn = sqlite3.connect(str(cookies_db)) + conn.execute( + "CREATE TABLE cookies (" + "host_key TEXT, name TEXT, value TEXT, path TEXT, " + "expires_utc INTEGER, is_secure INTEGER, is_httponly INTEGER, " + "last_access_utc INTEGER)" + ) + conn.commit() + conn.close() + + manager = CookieManager(str(tmp_path)) + result_conn = manager._get_cookies_connection() + assert result_conn is not None + + # The temp file should NOT be the predictable ".tmp" suffix + # but a unique mkstemp-generated file + predictable_tmp = str(cookies_db) + ".tmp" + # The actual temp file is in the same directory + temp_files = [ + f + for f in os.listdir(str(tmp_path)) + if f.endswith(".tmp") and f != "Cookies.tmp" + ] + assert len(temp_files) >= 1, ( + "mkstemp temp file not found — still using predictable .tmp suffix?" + ) + + result_conn.close() + # Cleanup temp files + for f in temp_files: + full = os.path.join(str(tmp_path), f) + if os.path.exists(full): + os.remove(full) + + def test_missing_cookies_db_returns_none(self, tmp_path): + """_get_cookies_connection should return None for missing DB.""" + manager = CookieManager(str(tmp_path / "nonexistent")) + assert manager._get_cookies_connection() is None diff --git a/backend/tests/app/utils/test_oauth_state_manager.py b/backend/tests/app/utils/test_oauth_state_manager.py new file mode 100644 index 000000000..ff5f73ef1 --- /dev/null +++ b/backend/tests/app/utils/test_oauth_state_manager.py @@ -0,0 +1,73 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +import pytest + +from app.utils.oauth_state_manager import OAuthStateManager + + +@pytest.mark.unit +class TestOAuthStateManager: + """Tests for OAuthStateManager including the new remove_state method.""" + + def setup_method(self): + self.manager = OAuthStateManager() + + def test_create_and_get_state(self): + """create_state should return a retrievable state object.""" + state = self.manager.create_state("google") + assert state.provider == "google" + assert state.status == "pending" + retrieved = self.manager.get_state("google") + assert retrieved is state + + def test_remove_state_cleans_up(self): + """remove_state should remove the provider's state under lock.""" + self.manager.create_state("github") + assert self.manager.get_state("github") is not None + self.manager.remove_state("github") + assert self.manager.get_state("github") is None + + def test_remove_state_nonexistent_does_not_raise(self): + """remove_state on a missing provider should not raise.""" + self.manager.remove_state("nonexistent") + + def test_update_status_sets_completed_at(self): + """Updating to a terminal status should set completed_at.""" + self.manager.create_state("slack") + self.manager.update_status("slack", "success") + state = self.manager.get_state("slack") + assert state.status == "success" + assert state.completed_at is not None + + def test_create_state_cancels_previous_pending(self): + """Creating a new state for the same provider cancels the old one.""" + old = self.manager.create_state("google") + assert old.status == "pending" + _new = self.manager.create_state("google") + assert old.status == "cancelled" + + def test_to_dict(self): + """to_dict should include all expected keys.""" + state = self.manager.create_state("test") + d = state.to_dict() + assert set(d.keys()) == { + "provider", + "status", + "error", + "started_at", + "completed_at", + } + assert d["provider"] == "test" + assert d["status"] == "pending"