Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions backend/app/component/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ def sanitize_env_path(env_path: str | None) -> str | None:
# Convert to Path object for safe manipulation
user_path = Path(env_path)

# Reject absolute paths outside our control
# Reject absolute paths — they should always be relative to env_base_dir
if user_path.is_absolute():
# Check if it's already within env_base_dir
resolved_path = user_path.resolve()
logger.warning(
f"Security: Rejected absolute env_path. Path: {env_path}"
)
return None
else:
# Join relative path to base directory
resolved_path = (Path(env_base_dir) / user_path).resolve()
Expand Down
2 changes: 1 addition & 1 deletion backend/app/controller/tool_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ async def uninstall_tool(tool: str):
"Cancelled ongoing Google Calendar authorization"
)
# Clear the state completely to remove cached credentials
oauth_state_manager._states.pop("google_calendar", None)
oauth_state_manager.remove_state("google_calendar")
logger.info("Cleared Google Calendar OAuth state cache")

return {
Expand Down
48 changes: 26 additions & 22 deletions backend/app/service/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import logging
import threading
import weakref
from contextlib import contextmanager
from contextvars import ContextVar
Expand Down Expand Up @@ -528,6 +529,7 @@ def get_recent_context(self, max_entries: int = None) -> str:


task_locks = dict[str, TaskLock]()
_task_locks_mutex = threading.Lock()
# Cleanup task for removing stale task locks
_cleanup_task: asyncio.Task | None = None
task_index: dict[str, weakref.ref[Task]] = {}
Expand Down Expand Up @@ -557,35 +559,37 @@ def set_current_task_id(project_id: str, task_id: str) -> None:


def create_task_lock(id: str) -> TaskLock:
if id in task_locks:
logger.warning(
"Attempting to create task lock that already exists",
extra={"task_id": id},
)
raise ProgramException("Task already exists")

logger.info("Creating new task lock", extra={"task_id": id})
task_locks[id] = TaskLock(id=id, queue=asyncio.Queue(), human_input={})
with _task_locks_mutex:
if id in task_locks:
logger.warning(
"Attempting to create task lock that already exists",
extra={"task_id": id},
)
raise ProgramException("Task already exists")

# Start cleanup task if not running
# global _cleanup_task
# if _cleanup_task is None or _cleanup_task.done():
# _cleanup_task = asyncio.create_task(_periodic_cleanup())
logger.info("Creating new task lock", extra={"task_id": id})
task_locks[id] = TaskLock(id=id, queue=asyncio.Queue(), human_input={})

logger.info(
"Task lock created successfully",
extra={"task_id": id, "total_task_locks": len(task_locks)},
)
return task_locks[id]
logger.info(
"Task lock created successfully",
extra={"task_id": id, "total_task_locks": len(task_locks)},
)
return task_locks[id]


def get_or_create_task_lock(id: str) -> TaskLock:
"""Get existing task lock or create a new one if it doesn't exist"""
if id in task_locks:
logger.debug("Using existing task lock", extra={"task_id": id})
with _task_locks_mutex:
if id in task_locks:
logger.debug("Using existing task lock", extra={"task_id": id})
return task_locks[id]
logger.info("Task lock not found, creating new one", extra={"task_id": id})
task_locks[id] = TaskLock(id=id, queue=asyncio.Queue(), human_input={})
logger.info(
"Task lock created successfully",
extra={"task_id": id, "total_task_locks": len(task_locks)},
)
return task_locks[id]
logger.info("Task lock not found, creating new one", extra={"task_id": id})
return create_task_lock(id)


async def delete_task_lock(id: str):
Expand Down
6 changes: 5 additions & 1 deletion backend/app/utils/cookie_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import shutil
import sqlite3
import tempfile
from datetime import datetime
from typing import Any

Expand Down Expand Up @@ -60,7 +61,10 @@ def _get_cookies_connection(self) -> sqlite3.Connection | None:
)
return None

temp_db_path = self.cookies_db_path + ".tmp"
fd, temp_db_path = tempfile.mkstemp(
suffix=".tmp", dir=os.path.dirname(self.cookies_db_path)
)
os.close(fd)
conn = None
try:
shutil.copy2(self.cookies_db_path, temp_db_path)
Expand Down
5 changes: 5 additions & 0 deletions backend/app/utils/oauth_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def update_status(
state.completed_at = datetime.now()
logger.info(f"Updated {provider} OAuth status to {status}")

def remove_state(self, provider: str) -> None:
"""Remove the state for a provider under lock"""
with self._lock:
self._states.pop(provider, None)


# Global instance
oauth_state_manager = OAuthStateManager()
8 changes: 5 additions & 3 deletions backend/tests/app/component/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ def test_valid_relative_path():
assert result.endswith("project1.env")


def test_valid_absolute_path_within_base_dir():
"""Test that absolute path within base directory is accepted."""
def test_absolute_path_within_base_dir_rejected():
"""Test that absolute paths are always rejected for security."""
valid_path = os.path.join(env_base_dir, "valid.env")
result = sanitize_env_path(valid_path)
assert result == os.path.abspath(valid_path)
assert result is None, (
"Absolute paths should be rejected — callers must use relative paths"
)


def test_path_traversal_attack_rejected():
Expand Down
78 changes: 78 additions & 0 deletions backend/tests/app/service/test_task_lock_mutex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========

import threading

import pytest

from app.service.task import (
_task_locks_mutex,
create_task_lock,
delete_task_lock,
get_or_create_task_lock,
task_locks,
)


@pytest.mark.unit
class TestTaskLockMutex:
"""Tests verifying the _task_locks_mutex protects task_locks from races."""

def setup_method(self):
"""Clean up task_locks before each test."""
with _task_locks_mutex:
task_locks.clear()

def teardown_method(self):
"""Clean up task_locks after each test."""
with _task_locks_mutex:
task_locks.clear()

def test_create_task_lock_is_thread_safe(self):
"""Concurrent create_task_lock calls should not corrupt task_locks."""
errors = []
barrier = threading.Barrier(10)

def worker(idx):
try:
barrier.wait(timeout=5)
create_task_lock(f"task_{idx}")
except Exception as e:
errors.append(e)

threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=10)

assert not errors, f"Unexpected errors: {errors}"
with _task_locks_mutex:
assert len(task_locks) == 10

def test_get_or_create_is_idempotent(self):
"""get_or_create_task_lock called twice returns the same lock."""
lock1 = get_or_create_task_lock("same_id")
lock2 = get_or_create_task_lock("same_id")
assert lock1 is lock2

def test_create_task_lock_raises_on_duplicate(self):
"""create_task_lock should raise for an existing id."""
create_task_lock("dup_id")
with pytest.raises(Exception):
create_task_lock("dup_id")

def test_mutex_attribute_exists(self):
"""_task_locks_mutex should be a threading.Lock instance."""
assert isinstance(_task_locks_mutex, type(threading.Lock()))
70 changes: 70 additions & 0 deletions backend/tests/app/utils/test_cookie_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========

import os
import tempfile

import pytest

from app.utils.cookie_manager import CookieManager


@pytest.mark.unit
class TestCookieManagerTempFile:
"""Tests verifying cookie_manager uses unique temp files via mkstemp."""

def test_get_cookies_connection_creates_unique_temp(self, tmp_path):
"""_get_cookies_connection should create a uniquely-named temp copy."""
# Create a minimal SQLite database to act as the cookies DB
import sqlite3

cookies_db = tmp_path / "Cookies"
conn = sqlite3.connect(str(cookies_db))
conn.execute(
"CREATE TABLE cookies ("
"host_key TEXT, name TEXT, value TEXT, path TEXT, "
"expires_utc INTEGER, is_secure INTEGER, is_httponly INTEGER, "
"last_access_utc INTEGER)"
)
conn.commit()
conn.close()

manager = CookieManager(str(tmp_path))
result_conn = manager._get_cookies_connection()
assert result_conn is not None

# The temp file should NOT be the predictable ".tmp" suffix
# but a unique mkstemp-generated file
predictable_tmp = str(cookies_db) + ".tmp"
# The actual temp file is in the same directory
temp_files = [
f
for f in os.listdir(str(tmp_path))
if f.endswith(".tmp") and f != "Cookies.tmp"
]
assert len(temp_files) >= 1, (
"mkstemp temp file not found — still using predictable .tmp suffix?"
)

result_conn.close()
# Cleanup temp files
for f in temp_files:
full = os.path.join(str(tmp_path), f)
if os.path.exists(full):
os.remove(full)

def test_missing_cookies_db_returns_none(self, tmp_path):
"""_get_cookies_connection should return None for missing DB."""
manager = CookieManager(str(tmp_path / "nonexistent"))
assert manager._get_cookies_connection() is None
67 changes: 67 additions & 0 deletions backend/tests/app/utils/test_oauth_state_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========

import pytest

from app.utils.oauth_state_manager import OAuthStateManager


@pytest.mark.unit
class TestOAuthStateManager:
"""Tests for OAuthStateManager including the new remove_state method."""

def setup_method(self):
self.manager = OAuthStateManager()

def test_create_and_get_state(self):
"""create_state should return a retrievable state object."""
state = self.manager.create_state("google")
assert state.provider == "google"
assert state.status == "pending"
retrieved = self.manager.get_state("google")
assert retrieved is state

def test_remove_state_cleans_up(self):
"""remove_state should remove the provider's state under lock."""
self.manager.create_state("github")
assert self.manager.get_state("github") is not None
self.manager.remove_state("github")
assert self.manager.get_state("github") is None

def test_remove_state_nonexistent_does_not_raise(self):
"""remove_state on a missing provider should not raise."""
self.manager.remove_state("nonexistent")

def test_update_status_sets_completed_at(self):
"""Updating to a terminal status should set completed_at."""
self.manager.create_state("slack")
self.manager.update_status("slack", "success")
state = self.manager.get_state("slack")
assert state.status == "success"
assert state.completed_at is not None

def test_create_state_cancels_previous_pending(self):
"""Creating a new state for the same provider cancels the old one."""
old = self.manager.create_state("google")
assert old.status == "pending"
_new = self.manager.create_state("google")
assert old.status == "cancelled"

def test_to_dict(self):
"""to_dict should include all expected keys."""
state = self.manager.create_state("test")
d = state.to_dict()
assert set(d.keys()) == {"provider", "status", "error", "started_at", "completed_at"}
assert d["provider"] == "test"
assert d["status"] == "pending"
Loading