Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/test_model_tools_async_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def test_vision_dispatch_keeps_loop_alive(self, tmp_path):
side_effect=lambda url, dest, **kw: _write_fake_image(dest),
),
patch(
"tools.vision_tools._validate_image_url",
"tools.vision_tools._validate_image_url_async",
new_callable=AsyncMock,
return_value=True,
),
patch(
Expand Down Expand Up @@ -278,7 +279,8 @@ def test_two_consecutive_vision_dispatches(self, tmp_path):
side_effect=lambda url, dest, **kw: _write_fake_image(dest),
),
patch(
"tools.vision_tools._validate_image_url",
"tools.vision_tools._validate_image_url_async",
new_callable=AsyncMock,
return_value=True,
),
patch(
Expand Down
20 changes: 19 additions & 1 deletion tests/tools/test_url_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import socket
from unittest.mock import patch

from tools.url_safety import is_safe_url, _is_blocked_ip
from tools.url_safety import async_is_safe_url, is_safe_url, _is_blocked_ip

import ipaddress
import pytest
Expand Down Expand Up @@ -153,6 +153,24 @@ def test_non_cgnat_100_allowed(self):
assert is_safe_url("http://legit-host.example/") is True


class TestAsyncIsSafeUrl:
"""async_is_safe_url must match is_safe_url (runs DNS in a thread pool)."""

@pytest.mark.asyncio
async def test_public_url_allowed(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert await async_is_safe_url("https://example.com/x") is True

@pytest.mark.asyncio
async def test_localhost_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("127.0.0.1", 0)),
]):
assert await async_is_safe_url("http://localhost:8080/") is False


class TestIsBlockedIp:
"""Direct tests for the _is_blocked_ip helper."""

Expand Down
15 changes: 13 additions & 2 deletions tests/tools/test_vision_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,11 @@ async def test_download_failure_logs_exc_info(self, tmp_path, caplog):
async def test_analysis_error_logs_exc_info(self, caplog):
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
with (
patch("tools.vision_tools._validate_image_url", return_value=True),
patch(
"tools.vision_tools._validate_image_url_async",
new_callable=AsyncMock,
return_value=True,
),
patch(
"tools.vision_tools._download_image",
new_callable=AsyncMock,
Expand Down Expand Up @@ -316,7 +320,11 @@ async def fake_download(url, dest, max_retries=3):
return dest

with (
patch("tools.vision_tools._validate_image_url", return_value=True),
patch(
"tools.vision_tools._validate_image_url_async",
new_callable=AsyncMock,
return_value=True,
),
patch("tools.vision_tools._download_image", side_effect=fake_download),
patch(
"tools.vision_tools._image_to_base64_data_url",
Expand Down Expand Up @@ -408,7 +416,9 @@ async def test_tilde_path_expanded_to_local_file(self, tmp_path, monkeypatch):
img = fake_home / "test_image.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)

# Windows expanduser() prefers USERPROFILE over HOME; POSIX uses HOME.
monkeypatch.setenv("HOME", str(fake_home))
monkeypatch.setenv("USERPROFILE", str(fake_home))

mock_response = MagicMock()
mock_choice = MagicMock()
Expand Down Expand Up @@ -439,6 +449,7 @@ async def test_tilde_path_nonexistent_file_gives_error(self, tmp_path, monkeypat
fake_home = tmp_path / "fakehome"
fake_home.mkdir()
monkeypatch.setenv("HOME", str(fake_home))
monkeypatch.setenv("USERPROFILE", str(fake_home))

result = await vision_analyze_tool(
"~/nonexistent.png", "describe this", "test/model"
Expand Down
20 changes: 16 additions & 4 deletions tests/tools/test_website_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ async def test_web_extract_short_circuits_blocked_url(monkeypatch):
from tools import web_tools

# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
async def _allow_ssrf(_url: str) -> bool:
return True

monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf)
monkeypatch.setattr(
web_tools,
"check_website_access",
Expand Down Expand Up @@ -394,7 +397,10 @@ async def test_web_extract_blocks_redirected_final_url(monkeypatch):
from tools import web_tools

# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
async def _allow_ssrf(_url: str) -> bool:
return True

monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf)

def fake_check(url):
if url == "https://allowed.test":
Expand Down Expand Up @@ -436,7 +442,10 @@ async def test_web_crawl_short_circuits_blocked_url(monkeypatch):
# web_crawl_tool checks for Firecrawl env before website policy
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
async def _allow_ssrf(_url: str) -> bool:
return True

monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf)
monkeypatch.setattr(
web_tools,
"check_website_access",
Expand Down Expand Up @@ -467,7 +476,10 @@ async def test_web_crawl_blocks_redirected_final_url(monkeypatch):
# web_crawl_tool checks for Firecrawl env before website policy
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
async def _allow_ssrf(_url: str) -> bool:
return True

monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf)

def fake_check(url):
if url == "https://allowed.test":
Expand Down
10 changes: 10 additions & 0 deletions tools/url_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SDKs (Firecrawl/Tavily) where redirect handling is on their servers.
"""

import asyncio
import ipaddress
import logging
import socket
Expand Down Expand Up @@ -94,3 +95,12 @@ def is_safe_url(url: str) -> bool:
# become SSRF bypass vectors
logger.warning("Blocked request — URL safety check error for %s: %s", url, exc)
return False


async def async_is_safe_url(url: str) -> bool:
"""Same rules as :func:`is_safe_url`, but run the DNS work off the event loop.
``socket.getaddrinfo`` can block; call this from async code paths (gateway,
``web_extract_tool``, vision download hooks) instead of ``is_safe_url``.
"""
return await asyncio.to_thread(is_safe_url, url)
41 changes: 19 additions & 22 deletions tools/vision_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,35 +45,32 @@
_debug = DebugSession("vision_tools", env_var="VISION_TOOLS_DEBUG")


def _validate_image_url(url: str) -> bool:
"""
Basic validation of image URL format.

Args:
url (str): The URL to validate

Returns:
bool: True if URL appears to be valid, False otherwise
"""
def _image_url_shape_ok(url: str) -> bool:
"""HTTP(S) shape check only (scheme, netloc). No DNS."""
if not url or not isinstance(url, str):
return False

# Basic HTTP/HTTPS URL check
if not (url.startswith("http://") or url.startswith("https://")):
return False

# Parse to ensure we at least have a network location; still allow URLs
# without file extensions (e.g. CDN endpoints that redirect to images).
parsed = urlparse(url)
if not parsed.netloc:
return False
return True

# Block private/internal addresses to prevent SSRF
from tools.url_safety import is_safe_url
if not is_safe_url(url):

def _validate_image_url(url: str) -> bool:
"""Validate image URL for sync callers and tests (SSRF via sync DNS check)."""
if not _image_url_shape_ok(url):
return False
from tools.url_safety import is_safe_url
return is_safe_url(url)

return True

async def _validate_image_url_async(url: str) -> bool:
"""Validate remote image URL without blocking the event loop on DNS."""
if not _image_url_shape_ok(url):
return False
from tools.url_safety import async_is_safe_url
return await async_is_safe_url(url)


async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
Expand Down Expand Up @@ -106,8 +103,8 @@ async def _ssrf_redirect_guard(response):
"""
if response.is_redirect and response.next_request:
redirect_url = str(response.next_request.url)
from tools.url_safety import is_safe_url
if not is_safe_url(redirect_url):
from tools.url_safety import async_is_safe_url
if not await async_is_safe_url(redirect_url):
raise ValueError(
f"Blocked redirect to private/internal address: {redirect_url}"
)
Expand Down Expand Up @@ -273,7 +270,7 @@ async def vision_analyze_tool(
logger.info("Using local image file: %s", image_url)
temp_image_path = local_path
should_cleanup = False # Don't delete cached/local files
elif _validate_image_url(image_url):
elif await _validate_image_url_async(image_url):
# Remote URL -- download to a temporary location
logger.info("Downloading image from URL...")
temp_dir = Path("./temp_vision_images")
Expand Down
8 changes: 4 additions & 4 deletions tools/web_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from firecrawl import Firecrawl
from agent.auxiliary_client import async_call_llm, extract_content_or_reasoning
from tools.debug_helpers import DebugSession
from tools.url_safety import is_safe_url
from tools.url_safety import async_is_safe_url
from tools.website_policy import check_website_access

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -881,7 +881,7 @@ async def web_extract_tool(
safe_urls = []
ssrf_blocked: List[Dict[str, Any]] = []
for url in urls:
if not is_safe_url(url):
if not await async_is_safe_url(url):
ssrf_blocked.append({
"url": url, "title": "", "content": "",
"error": "Blocked: URL targets a private or internal network address",
Expand Down Expand Up @@ -1209,7 +1209,7 @@ async def web_crawl_tool(
url = f'https://{url}'

# SSRF protection — block private/internal addresses
if not is_safe_url(url):
if not await async_is_safe_url(url):
return json.dumps({"results": [{"url": url, "title": "", "content": "",
"error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False)

Expand Down Expand Up @@ -1299,7 +1299,7 @@ async def _process_tavily_crawl(result):
logger.info("Crawling %s%s", url, instructions_text)

# SSRF protection — block private/internal addresses
if not is_safe_url(url):
if not await async_is_safe_url(url):
return json.dumps({"results": [{"url": url, "title": "", "content": "",
"error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False)

Expand Down
Loading