diff --git a/progress.txt b/progress.txt new file mode 100644 index 00000000..7ebf8671 --- /dev/null +++ b/progress.txt @@ -0,0 +1,58 @@ +# Progress Log +Run: d7a8979f-776d-40d5-9f15-fbf698da7289 +Task: Integrate Prism (prismatoid) for screen reader output in Weather Assistant +Started: 2026-02-11 14:47 UTC + +## Codebase Patterns +- Python 3.12, pyproject.toml based, setuptools build +- Tests use pytest with xdist (parallel), run via `python -m pytest tests/ -x -q` +- Pre-commit hooks: trailing whitespace, end-of-file-fixer, ruff lint+format +- Pre-existing test failure in test_weather_assistant_dialog.py (55°F vs 55.0°F) — not our issue +- tomllib available for parsing pyproject.toml in tests +- Source in src/accessiweather/, tests in tests/ +- PR #285 targets dev from feature/weather-chat — reuse existing PRs when branch already has one + +--- + +## 2026-02-11 14:47 - US-001: Add prismatoid as optional dependency in pyproject.toml +- Added `screenreader` optional-dependencies group with `prismatoid>=0.7.0` +- Files changed: pyproject.toml, tests/test_pyproject_screenreader_dep.py +- **Learnings:** Pre-commit hooks auto-fix files; need to re-add and retry commit +--- + +## 2026-02-11 15:12 - US-002: Create pull request against dev branch +- Ran ruff format (fixed 1 file), ruff check (passed), tests (1209 passed, 2 pre-existing failures) +- Pushed all commits to origin/feature/weather-chat +- Updated existing PR #285 (feature/weather-chat → dev) with new title, body referencing #288 +- PR URL: https://github.com/Orinks/AccessiWeather/pull/285 +- **Learnings:** Check for existing PRs before creating new ones (gh pr create fails if one exists) +--- + +## 2026-02-11 16:52 - US-003: Add location resolution for tool calls +- Added `LocationResolver` class in `ai_tools.py` with default location matching (case-insensitive, whitespace-trimmed) and geocoding fallback +- `resolve()` returns `(lat, lon, display_name)` tuple; raises `ValueError` on failure +- Updated `WeatherToolExecutor.__init__` to accept `default_lat/lon/name` and create a `LocationResolver` +- `WeatherToolExecutor.execute()` now catches `ValueError`/`Exception` from handlers and returns error strings instead of crashing +- Updated existing test for geocoding failure (no longer raises, returns error string) +- Created `tests/test_location_resolver.py` with 14 tests: default matching, case-insensitive, whitespace, geocoding fallback, failure handling, no-default config, executor integration +- **Learnings:** WeatherToolExecutor.execute() should return error strings rather than raise for tool call loop compatibility (AI needs to see the error as tool result) +--- + +## 2026-02-11 17:00 - US-004: Format tool results for readable AI context +- Made formatters public: format_current_weather(), format_forecast(), format_alerts() in ai_tools.py +- format_current_weather: extracts temperature, feels_like/feelsLike, conditions (description/textDescription), humidity, wind/windSpeed, pressure/barometricPressure +- format_forecast: shows up to 7 periods with name, temperature, short/detailed forecast +- format_alerts: shows event, severity, headline, description (truncated to 300 chars); "No active alerts." when empty +- All formatters handle None/missing fields gracefully via _append_field helper +- WeatherToolExecutor.execute() uses the new public formatters +- Created tests/test_ai_tools_formatters.py with 31 tests covering all formatters +- **Learnings:** WeatherService response fields vary by backend (snake_case vs camelCase); formatters check both variants +--- + +## 2026-02-11 19:04 - US-005: Update system prompt and add integration test +- Updated SYSTEM_PROMPT in weather_assistant_dialog.py to mention available tools and guide AI to use them for location-specific queries +- Created tests/test_weather_assistant_integration.py with 8 tests covering full tool call flow (single, multiple, chained), direct response, error recovery, and system prompt validation +- Used AST parsing to read SYSTEM_PROMPT without importing wx/prism (avoids Linux CI failures) +- Files changed: src/accessiweather/ui/dialogs/weather_assistant_dialog.py, tests/test_weather_assistant_integration.py +- **Learnings:** Use ast.parse + ast.literal_eval to extract module-level constants from files that have unresolvable imports +--- diff --git a/pyproject.toml b/pyproject.toml index c047cf27..8d0cb17e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "keyring", "openai>=1.0.0", "playsound3", + "prismatoid>=0.7.0", ] readme = "README.md" classifiers = [ @@ -51,7 +52,6 @@ build = [ "pyinstaller>=6.0.0", "pillow>=10.0.0", ] - [project.scripts] accessiweather = "accessiweather.main:main" @@ -169,4 +169,3 @@ line-ending = "auto" [tool.ruff.lint.mccabe] max-complexity = 15 - diff --git a/src/accessiweather/ai_tools.py b/src/accessiweather/ai_tools.py new file mode 100644 index 00000000..c5ce3a53 --- /dev/null +++ b/src/accessiweather/ai_tools.py @@ -0,0 +1,927 @@ +""" +AI tool schemas and execution registry for weather function calling. + +This module defines OpenAI function-calling tool schemas and a registry +that maps tool names to executor functions using the app's WeatherService. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from accessiweather.geocoding import GeocodingService +from accessiweather.services.weather_service.weather_service import WeatherService + +logger = logging.getLogger(__name__) + + +def get_tools_for_message(message: str) -> list[dict[str, Any]]: + """ + Select which tools to send based on the user's message. + + Always includes core tools (current weather, forecast, alerts). + Adds extended tools only when keywords suggest they're needed, + saving tokens on simple weather questions. + + Args: + message: The user's message text. + + Returns: + List of tool schemas to send to the API. + + """ + tools = list(CORE_TOOLS) + msg_lower = message.lower() + + # Keywords that trigger extended tools + extended_triggers = ( + "hour", + "tonight", + "this afternoon", + "this morning", + "at ", + " pm", + " am", + "soil", + "uv", + "cloud", + "dew", + "snow depth", + "visibility", + "pressure", + "sunrise", + "sunset", + "cape", + "custom", + "add", + "save", + "location", + "list", + "my locations", + "search", + "find", + "where is", + "zip", + ) + if any(trigger in msg_lower for trigger in extended_triggers): + tools.extend(EXTENDED_TOOLS) + + # Keywords that trigger discussion tools + discussion_triggers = ( + "discussion", + "afd", + "forecast discussion", + "wpc", + "spc", + "storm prediction", + "weather prediction center", + "convective", + "outlook", + "severe", + "tornado", + "supercell", + "explain the forecast", + "why is", + "reasoning", + "meteorolog", + "synoptic", + "national", + ) + if any(trigger in msg_lower for trigger in discussion_triggers): + tools.extend(DISCUSSION_TOOLS) + + return tools + + +CORE_TOOLS: list[dict[str, Any]] = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get current weather conditions for a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get weather for, e.g. 'New York, NY' or '10001'.", + } + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_forecast", + "description": "Get the weather forecast for a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the forecast for, e.g. 'New York, NY' or '10001'.", + } + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_alerts", + "description": "Get active weather alerts for a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get weather alerts for, e.g. 'New York, NY' or '10001'.", + } + }, + "required": ["location"], + }, + }, + }, +] + +DISCUSSION_TOOLS: list[dict[str, Any]] = [ + { + "type": "function", + "function": { + "name": "get_area_forecast_discussion", + "description": ( + "Get the Area Forecast Discussion (AFD) for a location. " + "This is a detailed technical forecast discussion written by local NWS forecasters. " + "Great for understanding the reasoning behind the forecast." + ), + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "Location to get the AFD for, e.g. 'New York, NY'.", + } + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_wpc_discussion", + "description": ( + "Get the Weather Prediction Center (WPC) Short Range Forecast Discussion. " + "A nationwide weather discussion covering the next 1-3 days. Covers major " + "weather systems, precipitation patterns, and significant weather events " + "across the US." + ), + "parameters": { + "type": "object", + "properties": {}, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_spc_outlook", + "description": ( + "Get the Storm Prediction Center (SPC) Day 1 Convective Outlook discussion. " + "Covers severe weather risks including tornadoes, large hail, and damaging winds. " + "Explains the meteorological reasoning behind severe weather risk areas." + ), + "parameters": { + "type": "object", + "properties": {}, + }, + }, + }, +] + +EXTENDED_TOOLS: list[dict[str, Any]] = [ + { + "type": "function", + "function": { + "name": "get_hourly_forecast", + "description": "Get an hourly weather forecast for a location. Useful for questions like 'will it rain at 3pm?' or 'what's the temperature tonight?'.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the hourly forecast for, e.g. 'New York, NY' or '10001'.", + } + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_location", + "description": "Search for a location by name or ZIP code to find its full name and coordinates. Useful when the user mentions an ambiguous place name.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Location name or ZIP code to search for, e.g. 'Paris' or '90210'.", + } + }, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "query_open_meteo", + "description": ( + "Query the Open-Meteo API with custom parameters. Use this for weather " + "questions not covered by other tools, such as soil temperature, cloud cover, " + "dew point, snow depth, precipitation probability, UV index, visibility, " + "surface pressure, cape, and more. Open-Meteo has global coverage and is free.\n\n" + "Common hourly variables: temperature_2m, relative_humidity_2m, dew_point_2m, " + "apparent_temperature, precipitation_probability, precipitation, rain, showers, " + "snowfall, snow_depth, weather_code, pressure_msl, surface_pressure, " + "cloud_cover, cloud_cover_low, cloud_cover_mid, cloud_cover_high, visibility, " + "wind_speed_10m, wind_direction_10m, wind_gusts_10m, uv_index, " + "soil_temperature_0cm, soil_temperature_6cm, soil_moisture_0_to_1cm\n\n" + "Common daily variables: temperature_2m_max, temperature_2m_min, " + "apparent_temperature_max, apparent_temperature_min, sunrise, sunset, " + "uv_index_max, precipitation_sum, rain_sum, showers_sum, snowfall_sum, " + "precipitation_hours, precipitation_probability_max, wind_speed_10m_max, " + "wind_gusts_10m_max, wind_direction_10m_dominant\n\n" + "Common current variables: temperature_2m, relative_humidity_2m, " + "apparent_temperature, is_day, precipitation, rain, showers, snowfall, " + "weather_code, cloud_cover, pressure_msl, surface_pressure, " + "wind_speed_10m, wind_direction_10m, wind_gusts_10m" + ), + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "Location name or coordinates, e.g. 'Paris, France'.", + }, + "hourly": { + "type": "array", + "items": {"type": "string"}, + "description": "List of hourly variables to fetch.", + }, + "daily": { + "type": "array", + "items": {"type": "string"}, + "description": "List of daily variables to fetch.", + }, + "current": { + "type": "array", + "items": {"type": "string"}, + "description": "List of current variables to fetch.", + }, + "forecast_days": { + "type": "integer", + "description": "Number of forecast days (1-16, default 7).", + }, + "timezone": { + "type": "string", + "description": "Timezone for results, e.g. 'America/New_York'. Default: auto.", + }, + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "add_location", + "description": "Add a location to the user's saved locations list. Use after confirming with the user which location they want to add.", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Display name for the location, e.g. 'New York, NY' or 'Paris, France'.", + }, + "latitude": { + "type": "number", + "description": "Latitude of the location.", + }, + "longitude": { + "type": "number", + "description": "Longitude of the location.", + }, + }, + "required": ["name", "latitude", "longitude"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "list_locations", + "description": "List all saved locations and show which one is currently selected.", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + }, +] + +# Complete list for backward compatibility +WEATHER_TOOLS = CORE_TOOLS + EXTENDED_TOOLS + DISCUSSION_TOOLS + + +class LocationResolver: + """ + Resolves location strings to (lat, lon) coordinates. + + Uses the app's current location as a shortcut when the query matches + the default location name (case-insensitive), otherwise falls back to + the GeocodingService. + """ + + def __init__( + self, + geocoding_service: GeocodingService, + default_lat: float | None = None, + default_lon: float | None = None, + default_name: str | None = None, + ) -> None: + """ + Initialize the location resolver. + + Args: + geocoding_service: The geocoding service for resolving locations. + default_lat: Latitude of the app's current/default location. + default_lon: Longitude of the app's current/default location. + default_name: Display name of the app's current/default location. + + """ + self.geocoding_service = geocoding_service + self.default_lat = default_lat + self.default_lon = default_lon + self.default_name = default_name + + def _matches_default(self, location_str: str) -> bool: + """Check if a location string matches the default location name.""" + if self.default_name is None or self.default_lat is None or self.default_lon is None: + return False + return location_str.strip().lower() == self.default_name.strip().lower() + + def resolve(self, location_str: str) -> tuple[float, float, str]: + """ + Resolve a location string to (lat, lon, display_name). + + If the string matches the default location name (case-insensitive), + the default coordinates are returned without an API call. Otherwise + the GeocodingService is used. + + Args: + location_str: The location to resolve (e.g. 'Paris' or 'New York, NY'). + + Returns: + A (latitude, longitude, display_name) tuple. + + Raises: + ValueError: If the location cannot be resolved. + + """ + if self._matches_default(location_str): + logger.debug( + "Location '%s' matches default location '%s', using cached coordinates", + location_str, + self.default_name, + ) + return (self.default_lat, self.default_lon, self.default_name) # type: ignore[return-value] + + result = self.geocoding_service.geocode_address(location_str) + if result is None: + raise ValueError(f"Could not resolve location: {location_str}") + return result + + +class WeatherToolExecutor: + """Executes weather tool calls using WeatherService and geocoding.""" + + def __init__( + self, + weather_service: WeatherService, + geocoding_service: GeocodingService, + config_manager: Any = None, + default_lat: float | None = None, + default_lon: float | None = None, + default_name: str | None = None, + ) -> None: + """ + Initialize the weather tool executor. + + Args: + weather_service: The weather service for fetching weather data. + geocoding_service: The geocoding service for resolving locations. + config_manager: The config manager for saving locations (optional). + default_lat: Latitude of the app's current/default location. + default_lon: Longitude of the app's current/default location. + default_name: Display name of the app's current/default location. + + """ + self.weather_service = weather_service + self.geocoding_service = geocoding_service + self.config_manager = config_manager + self.location_resolver = LocationResolver( + geocoding_service=geocoding_service, + default_lat=default_lat, + default_lon=default_lon, + default_name=default_name, + ) + self._tool_handlers = { + "get_current_weather": self._get_current_weather, + "get_forecast": self._get_forecast, + "get_alerts": self._get_alerts, + "get_hourly_forecast": self._get_hourly_forecast, + "search_location": self._search_location, + "add_location": self._add_location, + "list_locations": self._list_locations, + "query_open_meteo": self._query_open_meteo, + "get_area_forecast_discussion": self._get_afd, + "get_wpc_discussion": self._get_wpc_discussion, + "get_spc_outlook": self._get_spc_outlook, + } + + def execute(self, tool_name: str, arguments: dict[str, Any]) -> str: + """ + Execute a weather tool by name. + + Args: + tool_name: The name of the tool to execute. + arguments: The arguments for the tool call. + + Returns: + A formatted string with the tool's results, or an error + message string if location resolution or data fetching fails. + + Raises: + ValueError: If the tool name is not recognized. + + """ + handler = self._tool_handlers.get(tool_name) + if handler is None: + raise ValueError(f"Unknown tool: {tool_name}") + try: + return handler(arguments) + except ValueError as e: + logger.warning("Tool execution failed for %s: %s", tool_name, e) + return f"Error: {e}" + except Exception as e: + logger.error("Unexpected error executing tool %s: %s", tool_name, e) + return f"Error fetching weather data: {e}" + + def _resolve_location(self, location: str) -> tuple[float, float, str]: + """ + Resolve a location string to coordinates. + + Args: + location: Location string to geocode. + + Returns: + Tuple of (lat, lon, display_name). + + Raises: + ValueError: If the location cannot be resolved. + + """ + return self.location_resolver.resolve(location) + + def _get_current_weather(self, arguments: dict[str, Any]) -> str: + """Get current weather conditions.""" + location = arguments["location"] + lat, lon, display_name = self._resolve_location(location) + data = self.weather_service.get_current_conditions(lat, lon) + return format_current_weather(data, display_name) + + def _get_forecast(self, arguments: dict[str, Any]) -> str: + """Get weather forecast.""" + location = arguments["location"] + lat, lon, display_name = self._resolve_location(location) + data = self.weather_service.get_forecast(lat, lon) + return format_forecast(data, display_name) + + def _get_alerts(self, arguments: dict[str, Any]) -> str: + """Get weather alerts.""" + location = arguments["location"] + lat, lon, display_name = self._resolve_location(location) + data = self.weather_service.get_alerts(lat, lon) + return format_alerts(data, display_name) + + def _get_hourly_forecast(self, arguments: dict[str, Any]) -> str: + """Get hourly weather forecast.""" + location = arguments["location"] + lat, lon, display_name = self._resolve_location(location) + data = self.weather_service.get_hourly_forecast(lat, lon) + return format_hourly_forecast(data, display_name) + + def _search_location(self, arguments: dict[str, Any]) -> str: + """Search for a location by name or ZIP code.""" + query = arguments["query"] + suggestions = self.geocoding_service.suggest_locations(query, limit=5) + return format_location_search(suggestions, query) + + def _add_location(self, arguments: dict[str, Any]) -> str: + """Add a location to saved locations.""" + if self.config_manager is None: + return "Error: Cannot save locations (config manager unavailable)." + name = arguments["name"] + lat = arguments["latitude"] + lon = arguments["longitude"] + # Check if already saved + existing = self.config_manager.get_location_names() + if name in existing: + return f"'{name}' is already in your saved locations." + success = self.config_manager.add_location(name, lat, lon) + if success: + self.config_manager.save_config() + return f"Added '{name}' to your saved locations." + return f"Failed to add '{name}'. It may already exist under a similar name." + + def _query_open_meteo(self, arguments: dict[str, Any]) -> str: + """Query Open-Meteo API with custom parameters.""" + location = arguments["location"] + lat, lon, display_name = self._resolve_location(location) + + params: dict[str, Any] = { + "latitude": lat, + "longitude": lon, + "timezone": arguments.get("timezone", "auto"), + } + + if "current" in arguments: + params["current"] = ",".join(arguments["current"]) + if "hourly" in arguments: + params["hourly"] = ",".join(arguments["hourly"]) + if "daily" in arguments: + params["daily"] = ",".join(arguments["daily"]) + if "forecast_days" in arguments: + params["forecast_days"] = arguments["forecast_days"] + + # Need at least one data group + if not any(k in params for k in ("current", "hourly", "daily")): + return "Error: specify at least one of current, hourly, or daily variables." + + try: + from accessiweather.openmeteo_client import OpenMeteoApiClient + + client = OpenMeteoApiClient() + try: + data = client._make_request("forecast", params) + finally: + client.close() + + return format_open_meteo_response(data, display_name) + except Exception as e: + logger.warning("Open-Meteo query failed: %s", e) + return f"Error querying Open-Meteo: {e}" + + def _get_afd(self, arguments: dict[str, Any]) -> str: + """Get Area Forecast Discussion for a location.""" + location = arguments["location"] + lat, lon, display_name = self._resolve_location(location) + try: + text = self.weather_service.get_discussion(lat, lon) + if text: + # Truncate if very long to fit in context + if len(text) > 3000: + text = text[:3000] + "\n\n[Truncated — full discussion is longer]" + return f"Area Forecast Discussion for {display_name}:\n\n{text}" + return f"No Area Forecast Discussion available for {display_name}." + except Exception as e: + logger.warning("AFD fetch failed: %s", e) + return f"Error fetching AFD: {e}" + + def _get_wpc_discussion(self, arguments: dict[str, Any]) -> str: + """Get WPC Short Range Forecast Discussion.""" + try: + from accessiweather.services.national_discussion_scraper import ( + NationalDiscussionScraper, + ) + + scraper = NationalDiscussionScraper(request_delay=1.0, max_retries=2, timeout=15) + result = scraper.fetch_wpc_discussion() + text = result.get("full", "") + if text: + if len(text) > 3000: + text = text[:3000] + "\n\n[Truncated — full discussion is longer]" + return f"WPC Short Range Forecast Discussion:\n\n{text}" + return "WPC discussion unavailable." + except Exception as e: + logger.warning("WPC discussion fetch failed: %s", e) + return f"Error fetching WPC discussion: {e}" + + def _get_spc_outlook(self, arguments: dict[str, Any]) -> str: + """Get SPC Day 1 Convective Outlook.""" + try: + from accessiweather.services.national_discussion_scraper import ( + NationalDiscussionScraper, + ) + + scraper = NationalDiscussionScraper(request_delay=1.0, max_retries=2, timeout=15) + result = scraper.fetch_spc_discussion() + text = result.get("full", "") + if text: + if len(text) > 3000: + text = text[:3000] + "\n\n[Truncated — full discussion is longer]" + return f"SPC Day 1 Convective Outlook:\n\n{text}" + return "SPC outlook unavailable." + except Exception as e: + logger.warning("SPC outlook fetch failed: %s", e) + return f"Error fetching SPC outlook: {e}" + + def _list_locations(self, arguments: dict[str, Any]) -> str: + """List all saved locations.""" + if self.config_manager is None: + return "Error: Cannot access locations (config manager unavailable)." + locations = self.config_manager.get_all_locations() + if not locations: + return "No saved locations." + current = self.config_manager.get_current_location() + current_name = current.name if current else None + lines = ["Your saved locations:"] + for loc in locations: + marker = " (current)" if loc.name == current_name else "" + lines.append(f"- {loc.name} ({loc.latitude:.2f}, {loc.longitude:.2f}){marker}") + return "\n".join(lines) + + +def format_current_weather(data: dict[str, Any], display_name: str = "") -> str: + """ + Format current weather conditions data as readable text. + + Extracts key fields: temperature, feels like, conditions, humidity, + wind, and pressure. Handles missing or None fields gracefully. + + Args: + data: Weather data dict from WeatherService.get_current_conditions(). + display_name: Location display name for the header. + + Returns: + A human-readable text summary of current conditions. + + """ + header = f"Current weather for {display_name}:" if display_name else "Current weather:" + lines = [header] + + _append_field(lines, "Temperature", data.get("temperature")) + _append_field(lines, "Feels Like", data.get("feels_like") or data.get("feelsLike")) + _append_field( + lines, + "Conditions", + data.get("description") or data.get("textDescription") or data.get("conditions"), + ) + _append_field(lines, "Humidity", data.get("humidity")) + _append_field(lines, "Wind", data.get("wind") or data.get("windSpeed")) + _append_field(lines, "Pressure", data.get("pressure") or data.get("barometricPressure")) + + # Fallback: if no known fields matched, dump scalar values + if len(lines) == 1: + for key, value in data.items(): + if isinstance(value, (str, int, float)) and key not in ("lat", "lon"): + lines.append(f"{key}: {value}") + + return "\n".join(lines) + + +def format_forecast(data: dict[str, Any], display_name: str = "") -> str: + """ + Format forecast data as readable text with up to 7 periods. + + Each period includes name, temperature, and short forecast text. + Handles missing or None fields gracefully. + + Args: + data: Forecast data dict from WeatherService.get_forecast(). + display_name: Location display name for the header. + + Returns: + A human-readable text summary of the forecast. + + """ + header = f"Forecast for {display_name}:" if display_name else "Forecast:" + lines = [header] + + periods = data.get("periods", data.get("properties", {}).get("periods", [])) + if isinstance(periods, list): + for period in periods[:7]: + if not isinstance(period, dict): + continue + name = period.get("name") or "Unknown" + temp = period.get("temperature") + temp_unit = period.get("temperatureUnit", "") + short = period.get("shortForecast") or period.get("detailedForecast") or "" + + parts = [name] + if temp is not None: + parts.append(f"{temp}°{temp_unit}" if temp_unit else str(temp)) + if short: + parts.append(short) + + lines.append(" - ".join(parts)) + + if len(lines) == 1: + lines.append(json.dumps(data, indent=2, default=str)[:500]) + + return "\n".join(lines) + + +def format_alerts(data: dict[str, Any], display_name: str = "") -> str: + """ + Format weather alerts data as readable text. + + Shows event name, severity, headline, and description for each alert. + Returns 'No active alerts' when the alert list is empty. + Handles missing or None fields gracefully. + + Args: + data: Alerts data dict from WeatherService.get_alerts(). + display_name: Location display name for the header. + + Returns: + A human-readable text summary of weather alerts. + + """ + header = f"Weather alerts for {display_name}:" if display_name else "Weather alerts:" + lines = [header] + + alerts = data.get("alerts", data.get("features", [])) + if isinstance(alerts, list) and len(alerts) > 0: + for alert in alerts: + if not isinstance(alert, dict): + continue + props = alert.get("properties", alert) + event = props.get("event") or "Unknown Alert" + severity = props.get("severity") + headline = props.get("headline") + description = props.get("description") + + alert_line = f"- {event}" + if severity: + alert_line += f" (Severity: {severity})" + lines.append(alert_line) + if headline: + lines.append(f" {headline}") + if description: + lines.append(f" {description[:300]}") + else: + lines.append("No active alerts.") + + return "\n".join(lines) + + +def format_hourly_forecast(data: dict[str, Any], display_name: str = "") -> str: + """ + Format hourly forecast data as readable text with up to 12 periods. + + Args: + data: Hourly forecast data dict from WeatherService.get_hourly_forecast(). + display_name: Location display name for the header. + + Returns: + A human-readable text summary of the hourly forecast. + + """ + header = f"Hourly forecast for {display_name}:" if display_name else "Hourly forecast:" + lines = [header] + + periods = data.get("periods", data.get("properties", {}).get("periods", [])) + if isinstance(periods, list): + for period in periods[:12]: + if not isinstance(period, dict): + continue + name = period.get("name") or period.get("startTime", "") + temp = period.get("temperature") + temp_unit = period.get("temperatureUnit", "") + short = period.get("shortForecast") or "" + wind = period.get("windSpeed") or "" + + parts = [str(name)] + if temp is not None: + parts.append(f"{temp}°{temp_unit}" if temp_unit else str(temp)) + if short: + parts.append(short) + if wind: + parts.append(f"Wind: {wind}") + + lines.append(" - ".join(parts)) + + if len(lines) == 1: + lines.append("No hourly forecast data available.") + + return "\n".join(lines) + + +def format_open_meteo_response(data: dict[str, Any], display_name: str = "") -> str: + """ + Format an Open-Meteo API response as readable text. + + Handles current, hourly, and daily data sections. Limits output + to avoid overwhelming the AI context window. + + Args: + data: Raw Open-Meteo API response dict. + display_name: Location display name for the header. + + Returns: + A human-readable text summary of the queried data. + + """ + header = f"Open-Meteo data for {display_name}:" if display_name else "Open-Meteo data:" + lines = [header] + + # Current data + current = data.get("current") + if isinstance(current, dict): + units = data.get("current_units", {}) + lines.append("\nCurrent:") + for key, value in current.items(): + if key in ("time", "interval"): + continue + unit = units.get(key, "") + lines.append(f" {key}: {value}{unit}") + + # Hourly data (limit to 24 periods) + hourly = data.get("hourly") + if isinstance(hourly, dict): + units = data.get("hourly_units", {}) + times = hourly.get("time", [])[:24] + lines.append(f"\nHourly ({len(times)} periods):") + for i, t in enumerate(times): + parts = [t] + for key, values in hourly.items(): + if key == "time" or not isinstance(values, list): + continue + if i < len(values): + unit = units.get(key, "") + parts.append(f"{key}: {values[i]}{unit}") + lines.append(" " + " | ".join(parts)) + + # Daily data + daily = data.get("daily") + if isinstance(daily, dict): + units = data.get("daily_units", {}) + times = daily.get("time", []) + lines.append(f"\nDaily ({len(times)} days):") + for i, t in enumerate(times): + parts = [t] + for key, values in daily.items(): + if key == "time" or not isinstance(values, list): + continue + if i < len(values): + unit = units.get(key, "") + parts.append(f"{key}: {values[i]}{unit}") + lines.append(" " + " | ".join(parts)) + + if len(lines) == 1: + lines.append("No data returned.") + + return "\n".join(lines) + + +def format_location_search(suggestions: list[str], query: str = "") -> str: + """ + Format location search results as readable text. + + Args: + suggestions: List of location suggestion strings. + query: Original search query for context. + + Returns: + A human-readable list of matching locations. + + """ + if not suggestions: + return f"No locations found matching '{query}'." + + lines = [f"Locations matching '{query}':"] + for i, suggestion in enumerate(suggestions, 1): + lines.append(f"{i}. {suggestion}") + + return "\n".join(lines) + + +def _append_field(lines: list[str], label: str, value: Any) -> None: + """Append a labeled field to lines if the value is not None/empty.""" + if value is not None and value != "": + lines.append(f"{label}: {value}") + + +# Keep backward-compatible private aliases +_format_current_conditions = format_current_weather +_format_forecast = format_forecast +_format_alerts = format_alerts diff --git a/src/accessiweather/screen_reader.py b/src/accessiweather/screen_reader.py new file mode 100644 index 00000000..6ca96edf --- /dev/null +++ b/src/accessiweather/screen_reader.py @@ -0,0 +1,72 @@ +""" +Thin wrapper around prismatoid for screen reader announcements. + +Provides graceful fallback when prismatoid is not installed or crashes on import. +""" + +import logging + +logger = logging.getLogger(__name__) + + +def _try_import_prism(): + """Lazily import prism, returning the module or None.""" + try: + import prism + + return prism + except Exception: + logger.debug("prismatoid not available", exc_info=True) + return None + + +try: + # Module-level flag for tests to check + PRISM_AVAILABLE = _try_import_prism() is not None +except Exception: + PRISM_AVAILABLE = False + + +class ScreenReaderAnnouncer: + """Announces text via screen reader using prismatoid, with graceful fallback.""" + + def __init__(self) -> None: + """Initialize the announcer, acquiring a screen reader backend if possible.""" + self._backend = None + self._runtime_available = False + prism = _try_import_prism() + if prism is not None: + try: + ctx = prism.Context() + backend = ctx.acquire_best() + features = backend.features + if features.is_supported_at_runtime: + self._backend = backend + self._runtime_available = True + logger.info("Screen reader backend active: %s", backend.name) + else: + logger.debug( + "Screen reader backend found (%s) but not running at runtime", + backend.name, + ) + except Exception: + logger.warning("Failed to acquire screen reader backend", exc_info=True) + else: + logger.debug("prismatoid not available; announcer will be a no-op") + + def announce(self, text: str) -> None: + """Speak text via screen reader. No-op if unavailable.""" + if self._backend is not None: + try: + self._backend.speak(text, interrupt=False) + except Exception: + logger.warning("Failed to announce text", exc_info=True) + + def is_available(self) -> bool: + """Return whether a screen reader is actively running.""" + return self._runtime_available + + def shutdown(self) -> None: + """Clean up resources.""" + self._backend = None + self._runtime_available = False diff --git a/src/accessiweather/ui/dialogs/__init__.py b/src/accessiweather/ui/dialogs/__init__.py index fadd73bd..15855733 100644 --- a/src/accessiweather/ui/dialogs/__init__.py +++ b/src/accessiweather/ui/dialogs/__init__.py @@ -6,10 +6,12 @@ from .discussion_dialog import show_discussion_dialog from .explanation_dialog import show_explanation_dialog from .location_dialog import show_add_location_dialog +from .nationwide_discussion_dialog import show_nationwide_discussion_dialog from .settings_dialog import show_settings_dialog from .soundpack_manager_dialog import show_soundpack_manager_dialog from .soundpack_wizard_dialog import SoundPackWizardDialog from .uv_index_dialog import show_uv_index_dialog +from .weather_assistant_dialog import show_weather_assistant_dialog from .weather_history_dialog import show_weather_history_dialog __all__ = [ @@ -19,9 +21,11 @@ "show_aviation_dialog", "show_discussion_dialog", "show_explanation_dialog", + "show_nationwide_discussion_dialog", "show_settings_dialog", "show_soundpack_manager_dialog", "show_uv_index_dialog", + "show_weather_assistant_dialog", "show_weather_history_dialog", "SoundPackWizardDialog", ] diff --git a/src/accessiweather/ui/dialogs/nationwide_discussion_dialog.py b/src/accessiweather/ui/dialogs/nationwide_discussion_dialog.py new file mode 100644 index 00000000..31a7b86f --- /dev/null +++ b/src/accessiweather/ui/dialogs/nationwide_discussion_dialog.py @@ -0,0 +1,249 @@ +"""Nationwide weather discussions dialog with tabbed layout.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import wx + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +class NationwideDiscussionDialog(wx.Dialog): + """ + Dialog displaying nationwide weather discussions in a tabbed interface. + + Tabs: WPC, SPC, NHC, CPC — each with labeled, read-only text controls + for the relevant discussion products. + """ + + def __init__( + self, + parent: wx.Window | None = None, + title: str = "Nationwide Weather Discussions", + ): + """ + Initialize the nationwide discussion dialog. + + Args: + parent: Parent window + title: Dialog title + + """ + super().__init__( + parent, + title=title, + style=wx.DEFAULT_DIALOG_STYLE | wx.RESIZE_BORDER, + ) + + self._create_widgets() + self._bind_events() + + self.SetSize((800, 600)) + self.CenterOnParent() + + def _create_widgets(self) -> None: + """Create all UI widgets.""" + panel = wx.Panel(self) + main_sizer = wx.BoxSizer(wx.VERTICAL) + + # Notebook (tabs) + self.notebook = wx.Notebook(panel, name="Discussion tabs") + main_sizer.Add(self.notebook, 1, wx.ALL | wx.EXPAND, 5) + + # --- WPC tab --- + self.wpc_panel = wx.Panel(self.notebook) + wpc_sizer = wx.BoxSizer(wx.VERTICAL) + + wpc_sizer.Add( + wx.StaticText(self.wpc_panel, label="Short Range Forecast:"), 0, wx.LEFT | wx.TOP, 5 + ) + self.wpc_short_range = wx.TextCtrl( + self.wpc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="WPC Short Range Forecast Discussion", + ) + wpc_sizer.Add(self.wpc_short_range, 1, wx.ALL | wx.EXPAND, 5) + + wpc_sizer.Add(wx.StaticText(self.wpc_panel, label="Medium Range Forecast:"), 0, wx.LEFT, 5) + self.wpc_medium_range = wx.TextCtrl( + self.wpc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="WPC Medium Range Forecast Discussion", + ) + wpc_sizer.Add(self.wpc_medium_range, 1, wx.ALL | wx.EXPAND, 5) + + wpc_sizer.Add(wx.StaticText(self.wpc_panel, label="Extended Forecast:"), 0, wx.LEFT, 5) + self.wpc_extended = wx.TextCtrl( + self.wpc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="WPC Extended Forecast Discussion", + ) + wpc_sizer.Add(self.wpc_extended, 1, wx.ALL | wx.EXPAND, 5) + + wpc_sizer.Add(wx.StaticText(self.wpc_panel, label="QPF Discussion:"), 0, wx.LEFT, 5) + self.wpc_qpf = wx.TextCtrl( + self.wpc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="WPC QPF Discussion", + ) + wpc_sizer.Add(self.wpc_qpf, 1, wx.ALL | wx.EXPAND, 5) + + self.wpc_panel.SetSizer(wpc_sizer) + self.notebook.AddPage(self.wpc_panel, "WPC") + + # --- SPC tab --- + self.spc_panel = wx.Panel(self.notebook) + spc_sizer = wx.BoxSizer(wx.VERTICAL) + + spc_sizer.Add( + wx.StaticText(self.spc_panel, label="Day 1 Convective Outlook:"), 0, wx.LEFT | wx.TOP, 5 + ) + self.spc_day1 = wx.TextCtrl( + self.spc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="SPC Day 1 Convective Outlook", + ) + spc_sizer.Add(self.spc_day1, 1, wx.ALL | wx.EXPAND, 5) + + spc_sizer.Add( + wx.StaticText(self.spc_panel, label="Day 2 Convective Outlook:"), 0, wx.LEFT, 5 + ) + self.spc_day2 = wx.TextCtrl( + self.spc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="SPC Day 2 Convective Outlook", + ) + spc_sizer.Add(self.spc_day2, 1, wx.ALL | wx.EXPAND, 5) + + spc_sizer.Add( + wx.StaticText(self.spc_panel, label="Day 3 Convective Outlook:"), 0, wx.LEFT, 5 + ) + self.spc_day3 = wx.TextCtrl( + self.spc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="SPC Day 3 Convective Outlook", + ) + spc_sizer.Add(self.spc_day3, 1, wx.ALL | wx.EXPAND, 5) + + self.spc_panel.SetSizer(spc_sizer) + self.notebook.AddPage(self.spc_panel, "SPC") + + # --- NHC tab --- + self.nhc_panel = wx.Panel(self.notebook) + nhc_sizer = wx.BoxSizer(wx.VERTICAL) + + nhc_sizer.Add( + wx.StaticText(self.nhc_panel, label="Atlantic Tropical Weather Outlook:"), + 0, + wx.LEFT | wx.TOP, + 5, + ) + self.nhc_atlantic = wx.TextCtrl( + self.nhc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="NHC Atlantic Tropical Weather Outlook", + ) + nhc_sizer.Add(self.nhc_atlantic, 1, wx.ALL | wx.EXPAND, 5) + + nhc_sizer.Add( + wx.StaticText(self.nhc_panel, label="East Pacific Tropical Weather Outlook:"), + 0, + wx.LEFT, + 5, + ) + self.nhc_east_pacific = wx.TextCtrl( + self.nhc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="NHC East Pacific Tropical Weather Outlook", + ) + nhc_sizer.Add(self.nhc_east_pacific, 1, wx.ALL | wx.EXPAND, 5) + + self.nhc_panel.SetSizer(nhc_sizer) + self.notebook.AddPage(self.nhc_panel, "NHC") + + # --- CPC tab --- + self.cpc_panel = wx.Panel(self.notebook) + cpc_sizer = wx.BoxSizer(wx.VERTICAL) + + cpc_sizer.Add( + wx.StaticText(self.cpc_panel, label="6-10 Day Outlook:"), 0, wx.LEFT | wx.TOP, 5 + ) + self.cpc_6_10_day = wx.TextCtrl( + self.cpc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="CPC 6-10 Day Outlook", + ) + cpc_sizer.Add(self.cpc_6_10_day, 1, wx.ALL | wx.EXPAND, 5) + + cpc_sizer.Add(wx.StaticText(self.cpc_panel, label="8-14 Day Outlook:"), 0, wx.LEFT, 5) + self.cpc_8_14_day = wx.TextCtrl( + self.cpc_panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_RICH2, + name="CPC 8-14 Day Outlook", + ) + cpc_sizer.Add(self.cpc_8_14_day, 1, wx.ALL | wx.EXPAND, 5) + + self.cpc_panel.SetSizer(cpc_sizer) + self.notebook.AddPage(self.cpc_panel, "CPC") + + # Close button + button_sizer = wx.BoxSizer(wx.HORIZONTAL) + self.close_button = wx.Button(panel, wx.ID_CLOSE, label="&Close") + button_sizer.Add(self.close_button, 0) + main_sizer.Add(button_sizer, 0, wx.ALL | wx.ALIGN_RIGHT, 10) + + panel.SetSizer(main_sizer) + + dialog_sizer = wx.BoxSizer(wx.VERTICAL) + dialog_sizer.Add(panel, 1, wx.EXPAND) + self.SetSizer(dialog_sizer) + + def _bind_events(self) -> None: + """Bind event handlers.""" + self.close_button.Bind(wx.EVT_BUTTON, self._on_close) + + def _on_close(self, event) -> None: + """Handle close button press.""" + self.EndModal(wx.ID_CLOSE) + + def set_discussion_text(self, tab: str, field: str, text: str) -> None: + """ + Set text for a specific discussion field. + + Args: + tab: Tab name ('wpc', 'spc', 'nhc', 'cpc') + field: Field name (e.g. 'short_range', 'day1', 'atlantic', '6_10_day') + text: The discussion text to display + + """ + attr_name = f"{tab}_{field}" + ctrl = getattr(self, attr_name, None) + if ctrl and isinstance(ctrl, wx.TextCtrl): + ctrl.SetValue(text) + + +def show_nationwide_discussion_dialog(parent: wx.Window | None = None) -> None: + """ + Show the Nationwide Weather Discussions dialog. + + Args: + parent: Parent window + + """ + try: + parent_ctrl = getattr(parent, "control", parent) + dlg = NationwideDiscussionDialog(parent_ctrl) + dlg.ShowModal() + dlg.Destroy() + except Exception as e: + logger.error(f"Failed to show nationwide discussion dialog: {e}") + wx.MessageBox( + f"Failed to open nationwide discussions: {e}", + "Error", + wx.OK | wx.ICON_ERROR, + ) diff --git a/src/accessiweather/ui/dialogs/weather_assistant_dialog.py b/src/accessiweather/ui/dialogs/weather_assistant_dialog.py new file mode 100644 index 00000000..2a8fb344 --- /dev/null +++ b/src/accessiweather/ui/dialogs/weather_assistant_dialog.py @@ -0,0 +1,665 @@ +""" +Weather Assistant dialog — conversational AI weather assistant. + +Phase 2: Multi-turn chat with function calling for weather lookups. +""" + +from __future__ import annotations + +import json +import logging +import threading +from datetime import datetime +from typing import TYPE_CHECKING + +import wx + +from ...ai_tools import WeatherToolExecutor, get_tools_for_message +from ...screen_reader import ScreenReaderAnnouncer + +if TYPE_CHECKING: + from ...app import AccessiWeatherApp + +logger = logging.getLogger(__name__) + +# Maximum conversation turns to keep in context +MAX_CONTEXT_TURNS = 20 + + +def _build_weather_context(app: AccessiWeatherApp) -> str: + """Build a weather context string from the app's current data.""" + weather = app.current_weather_data + if not weather: + return "No weather data currently loaded." + + parts: list[str] = [] + loc = weather.location + parts.append(f"Location: {loc.name} ({loc.latitude}, {loc.longitude})") + + cur = weather.current + if cur: + if cur.temperature_f is not None: + parts.append(f"Temperature: {cur.temperature_f:.0f}°F") + if cur.feels_like_f is not None: + parts.append(f"Feels like: {cur.feels_like_f:.0f}°F") + if cur.condition: + parts.append(f"Conditions: {cur.condition}") + if cur.humidity is not None: + parts.append(f"Humidity: {cur.humidity}%") + if cur.wind_speed_mph is not None: + wind = f"Wind: {cur.wind_speed_mph:.0f} mph" + if cur.wind_direction: + wind += f" from {cur.wind_direction}" + parts.append(wind) + if cur.pressure_in is not None: + parts.append(f"Pressure: {cur.pressure_in:.2f} inHg") + if cur.visibility_miles is not None: + parts.append(f"Visibility: {cur.visibility_miles:.1f} miles") + if cur.uv_index is not None: + parts.append(f"UV Index: {cur.uv_index}") + + forecast = weather.forecast + if forecast and forecast.periods: + parts.append("\nForecast:") + for period in forecast.periods[:6]: + line = f" {period.name}: {period.temperature}°{period.temperature_unit}" + if period.short_forecast: + line += f", {period.short_forecast}" + parts.append(line) + + if weather.alerts and weather.alerts.has_alerts(): + parts.append("\nActive Alerts:") + for alert in weather.alerts.alerts[:5]: + title = getattr(alert, "event", None) or getattr(alert, "title", "Alert") + severity = getattr(alert, "severity", "Unknown") + parts.append(f" - {title} (Severity: {severity})") + + if weather.trend_insights: + parts.append("\nTrend Insights:") + for insight in weather.trend_insights[:3]: + text = insight.summary or f"{insight.metric}: {insight.direction}" + if insight.change is not None and insight.unit: + text += f" ({insight.change:+.1f}{insight.unit})" + parts.append(f" - {text}") + + return "\n".join(parts) + + +SYSTEM_PROMPT = ( + "You are Weather Assistant, a friendly and knowledgeable weather assistant built into " + "AccessiWeather. You help users understand weather conditions in plain, accessible " + "language optimized for screen reader users.\n\n" + "You have access to live weather tools that can fetch current conditions, forecasts, " + "and active alerts for any location. The available tools are:\n" + "- get_current_weather: Get current weather conditions for a location\n" + "- get_forecast: Get the weather forecast for a location\n" + "- get_hourly_forecast: Get hourly forecast (great for specific time questions)\n" + "- get_alerts: Get active weather alerts for a location\n" + "- search_location: Search for a location by name or ZIP code\n" + "- add_location: Save a location to the user's locations list\n" + "- list_locations: Show all saved locations\n" + "- query_open_meteo: Custom Open-Meteo API query for any weather variable " + "(soil temp, UV, cloud cover, dew point, snow depth, visibility, etc.)\n" + "- get_area_forecast_discussion: Local NWS forecaster's detailed discussion\n" + "- get_wpc_discussion: National WPC short range forecast discussion\n" + "- get_spc_outlook: Storm Prediction Center severe weather outlook\n\n" + "Use the provided tools to fetch weather data when users ask about specific locations " + "or conditions not in the current context. You can call multiple tools if needed to " + "give a complete answer. Use search_location when a place name is ambiguous. " + "When adding locations, first resolve coordinates with search_location, then use " + "add_location with the resolved name and coordinates.\n\n" + "Guidelines:\n" + "- Be conversational and helpful\n" + "- Explain weather in practical terms (what to wear, activity suitability, etc.)\n" + "- Avoid visual-only descriptions\n" + "- When referencing data, use the weather context provided\n" + "- Keep responses concise but thorough\n" + "- Respond in plain text only — no markdown formatting\n" + "- Do not repeat information the user can already see\n\n" + "IMPORTANT: Respond in plain text. No bold, italic, headers, or bullet markers." +) + + +class WeatherAssistantDialog(wx.Dialog): + """Multi-turn conversational weather chat dialog.""" + + def __init__( + self, + parent: wx.Window, + app: AccessiWeatherApp, + title: str = "Weather Assistant", + ): + """ + Initialize the Weather Assistant dialog. + + Args: + parent: Parent window + app: Application instance + title: Dialog title + + """ + super().__init__( + parent, + title=title, + style=wx.DEFAULT_DIALOG_STYLE | wx.RESIZE_BORDER, + ) + self.app = app + self._conversation: list[dict[str, str]] = [] + self._is_generating = False + self._announcer = ScreenReaderAnnouncer() + + self._create_widgets() + self._bind_events() + self._add_welcome_message() + + self.SetSize((650, 500)) + self.CenterOnParent() + self.input_ctrl.SetFocus() + + def _create_widgets(self) -> None: + """Create all UI widgets.""" + panel = wx.Panel(self) + main_sizer = wx.BoxSizer(wx.VERTICAL) + + # Chat history label + history_label = wx.StaticText(panel, label="&Conversation:") + main_sizer.Add(history_label, 0, wx.LEFT | wx.RIGHT | wx.TOP, 10) + + # Chat history display (read-only) + self.history_display = wx.TextCtrl( + panel, + style=wx.TE_MULTILINE | wx.TE_READONLY | wx.TE_WORDWRAP | wx.TE_RICH2, + name="Conversation history", + ) + main_sizer.Add(self.history_display, 1, wx.ALL | wx.EXPAND, 10) + + # Status label + self.status_label = wx.StaticText(panel, label="") + main_sizer.Add(self.status_label, 0, wx.LEFT | wx.RIGHT | wx.EXPAND, 10) + + # Input area + input_sizer = wx.BoxSizer(wx.HORIZONTAL) + + input_label = wx.StaticText(panel, label="&Message:") + input_sizer.Add(input_label, 0, wx.ALIGN_CENTER_VERTICAL | wx.RIGHT, 5) + + self.input_ctrl = wx.TextCtrl( + panel, + style=wx.TE_PROCESS_ENTER, + name="Type your message", + ) + input_sizer.Add(self.input_ctrl, 1, wx.ALIGN_CENTER_VERTICAL | wx.RIGHT, 5) + + self.send_button = wx.Button(panel, label="&Send") + input_sizer.Add(self.send_button, 0, wx.ALIGN_CENTER_VERTICAL) + + main_sizer.Add(input_sizer, 0, wx.LEFT | wx.RIGHT | wx.BOTTOM | wx.EXPAND, 10) + + # Bottom buttons + button_sizer = wx.BoxSizer(wx.HORIZONTAL) + + self.clear_button = wx.Button(panel, label="C&lear Chat") + button_sizer.Add(self.clear_button, 0, wx.RIGHT, 5) + + self.copy_button = wx.Button(panel, label="Cop&y Chat") + button_sizer.Add(self.copy_button, 0, wx.RIGHT, 5) + + button_sizer.AddStretchSpacer() + + close_button = wx.Button(panel, wx.ID_CLOSE, label="&Close") + button_sizer.Add(close_button, 0) + + main_sizer.Add(button_sizer, 0, wx.LEFT | wx.RIGHT | wx.BOTTOM | wx.EXPAND, 10) + + panel.SetSizer(main_sizer) + + def _bind_events(self) -> None: + """Bind event handlers.""" + self.send_button.Bind(wx.EVT_BUTTON, self._on_send) + self.input_ctrl.Bind(wx.EVT_TEXT_ENTER, self._on_send) + self.clear_button.Bind(wx.EVT_BUTTON, self._on_clear) + self.copy_button.Bind(wx.EVT_BUTTON, self._on_copy) + self.Bind(wx.EVT_BUTTON, self._on_close, id=wx.ID_CLOSE) + self.Bind(wx.EVT_CLOSE, self._on_close) + + # Escape to close + self.Bind(wx.EVT_CHAR_HOOK, self._on_key) + + def _on_key(self, event: wx.KeyEvent) -> None: + """Handle key events.""" + if event.GetKeyCode() == wx.WXK_ESCAPE: + self.Close() + else: + event.Skip() + + def _add_welcome_message(self) -> None: + """Add initial welcome message to the chat.""" + location = ( + self.app.config_manager.get_current_location() if self.app.config_manager else None + ) + loc_name = location.name if location else "your area" + + welcome = ( + f"Welcome to Weather Assistant! I can help you understand the weather " + f"conditions for {loc_name}. Ask me anything about the current " + f"weather, forecast, what to wear, or how conditions might affect " + f"your plans." + ) + self._append_to_display("Weather Assistant", welcome) + self._announcer.announce(f"Weather Assistant: {welcome}") + + def _append_to_display(self, speaker: str, text: str) -> None: + """Append a message to the chat display.""" + timestamp = datetime.now().strftime("%I:%M %p") + formatted = f"[{timestamp}] {speaker}:\n{text}\n\n" + self.history_display.AppendText(formatted) + # Scroll to bottom + self.history_display.ShowPosition(self.history_display.GetLastPosition()) + + def _set_status(self, text: str) -> None: + """Update the status label.""" + self.status_label.SetLabel(text) + + def _set_generating(self, generating: bool) -> None: + """Toggle generating state.""" + self._is_generating = generating + self.send_button.Enable(not generating) + # Keep input_ctrl always enabled so screen readers don't lose focus. + # The _is_generating flag prevents sends during generation. + if generating: + self._set_status("Thinking...") + else: + self._set_status("Ready") + self.input_ctrl.SetFocus() + + def _on_send(self, event: wx.Event) -> None: + """Handle send button or Enter key.""" + message = self.input_ctrl.GetValue().strip() + if not message or self._is_generating: + return + + self.input_ctrl.SetValue("") + self._append_to_display("You", message) + self._announcer.announce(f"You: {message}") + + # Add to conversation history + self._conversation.append({"role": "user", "content": message}) + + # Trim conversation if too long + if len(self._conversation) > MAX_CONTEXT_TURNS * 2: + self._conversation = self._conversation[-(MAX_CONTEXT_TURNS * 2) :] + + self._set_generating(True) + self._generate_response() + + def _get_tool_executor(self) -> WeatherToolExecutor | None: + """ + Create a WeatherToolExecutor from the app's services. + + Returns: + A WeatherToolExecutor, or None if required services are unavailable. + + """ + try: + import asyncio + + from ...api_client import NoaaApiClient + from ...geocoding import GeocodingService + from ...models.location import Location + from ...openmeteo_client import OpenMeteoApiClient + from ...visual_crossing_client import VisualCrossingClient + + class _CombinedWeatherClient: + """Bridges NWS, Open-Meteo, and Visual Crossing for tool executor.""" + + def __init__(self, vc_api_key: str = ""): + self.nws = NoaaApiClient() + self.openmeteo = OpenMeteoApiClient() + self.vc: VisualCrossingClient | None = None + if vc_api_key: + self.vc = VisualCrossingClient(api_key=vc_api_key) + + def _run_async(self, coro): + """Run an async coroutine from sync context.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + # We're in a thread; create a new loop + return asyncio.run(coro) + return asyncio.run(coro) + + def _make_location(self, lat, lon): + return Location(name="", latitude=lat, longitude=lon) + + def get_current_conditions(self, lat, lon, **kw): + try: + return self.nws.get_current_conditions(lat, lon, **kw) + except Exception: + pass + # Fall back to Open-Meteo (global, free) + return self.openmeteo.get_current_weather(lat, lon) + + def get_forecast(self, lat, lon, **kw): + try: + return self.nws.get_forecast(lat, lon, **kw) + except Exception: + pass + return self.openmeteo.get_forecast(lat, lon) + + def get_hourly_forecast(self, lat, lon, **kw): + try: + return self.nws.get_hourly_forecast(lat, lon, **kw) + except Exception: + pass + return self.openmeteo.get_hourly_forecast(lat, lon) + + def get_alerts(self, lat, lon, **kw): + try: + return self.nws.get_alerts(lat, lon, **kw) + except Exception: + pass + # Visual Crossing has global alerts + if self.vc: + try: + loc = self._make_location(lat, lon) + result = self._run_async(self.vc.get_alerts(loc)) + if result: + return result.to_dict() if hasattr(result, "to_dict") else result + except Exception: + pass + return {"features": []} + + def get_discussion(self, lat, lon, **kw): + return self.nws.get_discussion(lat, lon, **kw) + + config_manager = getattr(self.app, "config_manager", None) + settings = config_manager.get_settings() if config_manager else None + vc_key = settings.visual_crossing_api_key if settings else "" + weather_client = _CombinedWeatherClient(vc_api_key=vc_key) + geocoding_service = GeocodingService() + return WeatherToolExecutor( + weather_client, geocoding_service, config_manager=config_manager + ) + except Exception: + logger.debug("Could not create WeatherToolExecutor", exc_info=True) + return None + + def _generate_response(self) -> None: + """Generate AI response in a background thread.""" + # Get config + settings = self.app.config_manager.get_settings() if self.app.config_manager else None + api_key = settings.openrouter_api_key if settings else "" + model = settings.ai_model_preference if settings else "" + + if not api_key: + wx.CallAfter( + self._on_response_error, + "No OpenRouter API key configured. Set one in Settings > AI Explanations.", + ) + return + + # Build weather context + weather_context = _build_weather_context(self.app) + + # Build messages for API + system_message = f"{SYSTEM_PROMPT}\n\nCurrent weather data:\n{weather_context}" + + messages: list[dict] = [{"role": "system", "content": system_message}] + messages.extend(self._conversation) + + tool_executor = self._get_tool_executor() + logger.info("Tool executor: %s", "available" if tool_executor else "NONE") + + def do_generate(): + try: + from openai import OpenAI + + client = OpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=api_key, + timeout=30.0, + ) + + effective_model = model if model else "meta-llama/llama-3.3-70b-instruct:free" + + # Models with reliable function calling on free tier (in preference order) + TOOL_CAPABLE_MODELS = [ + "qwen/qwen3-coder:free", + "mistralai/mistral-small-3.1-24b-instruct:free", + "meta-llama/llama-3.3-70b-instruct:free", + ] + + extra_kwargs: dict = {} + use_tool_fallback = False + if tool_executor is not None: + # Get last user message for tool selection + user_msg = "" + for msg in reversed(messages): + if msg.get("role") == "user": + user_msg = msg.get("content", "") + break + tools = get_tools_for_message(user_msg) + if tools: + extra_kwargs["tools"] = tools + # When tools are present, use a model known for + # reliable function calling if user hasn't picked + # a specific model (free router doesn't guarantee + # tool support) + # Use whatever the user selected; only flag + # fallback so rate-limit retry logic can kick in + # for free models + free_routers = { + "openrouter/free", + "meta-llama/llama-3.3-70b-instruct:free", + } + if effective_model in free_routers: + effective_model = TOOL_CAPABLE_MODELS[0] + use_tool_fallback = True + logger.info( + "Tools enabled: %d tools, model: %s", + len(tools), + effective_model, + ) + + max_tool_iterations = 5 + for _iteration in range(max_tool_iterations + 1): + # Try the API call, with fallback models on rate limit + last_error = None + for _model_attempt in range(3): + try: + response = client.chat.completions.create( + model=effective_model, + messages=messages, + max_tokens=2000, + extra_headers={ + "HTTP-Referer": "https://accessiweather.orinks.net", + "X-Title": "AccessiWeather Weather Assistant", + }, + **extra_kwargs, + ) + last_error = None + break + except Exception as api_err: + err_str = str(api_err).lower() + is_retryable = ( + "429" in str(api_err) + or "rate" in err_str + or "400" in str(api_err) + or "tool_calls" in err_str + or "provider returned error" in err_str + ) + if is_retryable and use_tool_fallback: + # Try next model in the fallback chain + try: + idx = TOOL_CAPABLE_MODELS.index(effective_model) + if idx + 1 < len(TOOL_CAPABLE_MODELS): + effective_model = TOOL_CAPABLE_MODELS[idx + 1] + logger.info( + "Provider error, falling back to %s", effective_model + ) + last_error = api_err + continue + except ValueError: + pass + raise + if last_error is not None: + raise last_error + + model_used = response.model or effective_model + + if not response.choices: + if extra_kwargs.get("tools") and use_tool_fallback: + # Model returned empty with tools; retry without + logger.warning( + "Empty response with tools on %s, retrying without tools", + effective_model, + ) + extra_kwargs.pop("tools", None) + continue + wx.CallAfter( + self._on_response_error, + "Received an empty response. Try again or switch models in Settings.", + ) + return + + choice = response.choices[0] + assistant_message = choice.message + + # Check for tool calls + logger.info( + "Response finish_reason=%s, tool_calls=%s", + choice.finish_reason, + bool(assistant_message.tool_calls), + ) + if assistant_message.tool_calls and tool_executor is not None: + # Append assistant message with tool calls to messages + tool_call_msg: dict = { + "role": "assistant", + "content": assistant_message.content or "", + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in assistant_message.tool_calls + ], + } + messages.append(tool_call_msg) + + # Execute each tool call + for tool_call in assistant_message.tool_calls: + tool_name = tool_call.function.name + try: + arguments = json.loads(tool_call.function.arguments) + result = tool_executor.execute(tool_name, arguments) + except Exception as exc: + logger.warning( + "Tool call %s failed: %s", tool_name, exc, exc_info=True + ) + result = f"Error executing {tool_name}: {exc}" + + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + } + ) + + # Continue loop to get next response + continue + + # No tool calls — we have the final text response + content = assistant_message.content or "" + if content.strip(): + wx.CallAfter(self._on_response_received, content.strip(), model_used) + else: + wx.CallAfter( + self._on_response_error, + "Received an empty response. Try again or switch models in Settings.", + ) + return + + # Exhausted max iterations — use whatever content we have + fallback = assistant_message.content or "" if assistant_message else "" + if fallback.strip(): + wx.CallAfter(self._on_response_received, fallback.strip(), model_used) + else: + wx.CallAfter( + self._on_response_error, + "The assistant made too many tool calls. Please try a simpler question.", + ) + + except Exception as e: + error_msg = str(e) + logger.error(f"Weather Assistant generation error: {e}", exc_info=True) + + if "api key" in error_msg.lower() or "401" in error_msg: + friendly = "API key is invalid. Check Settings > AI Explanations." + elif "429" in error_msg or "rate limit" in error_msg.lower(): + friendly = ( + "Rate limited. Wait a moment and try again, or switch to a different model." + ) + elif "timeout" in error_msg.lower() or "timed out" in error_msg.lower(): + friendly = "Request timed out. The AI service may be busy, try again." + else: + friendly = f"Error: {error_msg}" + + wx.CallAfter(self._on_response_error, friendly) + + thread = threading.Thread(target=do_generate, daemon=True) + thread.start() + + def _on_response_received(self, text: str, model_used: str) -> None: + """Handle successful AI response.""" + self._conversation.append({"role": "assistant", "content": text}) + self._append_to_display("Weather Assistant", text) + self._announcer.announce(f"Weather Assistant: {text}") + self._set_status(f"Model: {model_used}") + self._set_generating(False) + + def _on_response_error(self, error: str) -> None: + """Handle AI response error.""" + error_message = f"Sorry, I couldn't respond: {error}" + self._append_to_display("Weather Assistant", error_message) + self._announcer.announce(f"Weather Assistant: {error_message}") + # Remove the last user message from conversation since we failed + if self._conversation and self._conversation[-1]["role"] == "user": + self._conversation.pop() + self._set_generating(False) + + def _on_clear(self, event: wx.Event) -> None: + """Clear chat history.""" + self._conversation.clear() + self.history_display.SetValue("") + self._set_status("") + self._add_welcome_message() + self.input_ctrl.SetFocus() + + def _on_copy(self, event: wx.Event) -> None: + """Copy chat to clipboard.""" + text = self.history_display.GetValue() + if text and wx.TheClipboard.Open(): + wx.TheClipboard.SetData(wx.TextDataObject(text)) + wx.TheClipboard.Close() + self._set_status("Chat copied to clipboard.") + + def _on_close(self, event: wx.Event) -> None: + """Handle dialog close.""" + self._announcer.shutdown() + self.EndModal(wx.ID_CLOSE) + + +def show_weather_assistant_dialog(parent: wx.Window, app: AccessiWeatherApp) -> None: + """Show the Weather Assistant dialog.""" + dlg = WeatherAssistantDialog(parent, app) + try: + dlg.ShowModal() + finally: + dlg.Destroy() diff --git a/src/accessiweather/ui/main_window.py b/src/accessiweather/ui/main_window.py index 59d3562c..097dfcaf 100644 --- a/src/accessiweather/ui/main_window.py +++ b/src/accessiweather/ui/main_window.py @@ -218,6 +218,10 @@ def _create_menu_bar(self) -> None: wx.ID_ANY, "Air &Quality...", "View air quality information" ) uv_index_item = view_menu.Append(wx.ID_ANY, "&UV Index...", "View UV index information") + view_menu.AppendSeparator() + weather_chat_item = view_menu.Append( + wx.ID_ANY, "Weather &Assistant...\tCtrl+T", "Chat with AI weather assistant" + ) menu_bar.Append(view_menu, "&View") # Tools menu @@ -258,6 +262,7 @@ def _create_menu_bar(self) -> None: self.Bind(wx.EVT_MENU, lambda e: self._on_aviation(), aviation_item) self.Bind(wx.EVT_MENU, lambda e: self._on_air_quality(), air_quality_item) self.Bind(wx.EVT_MENU, lambda e: self._on_uv_index(), uv_index_item) + self.Bind(wx.EVT_MENU, lambda e: self._on_weather_chat(), weather_chat_item) self.Bind(wx.EVT_MENU, lambda e: self._on_soundpack_manager(), soundpack_item) self.Bind(wx.EVT_MENU, lambda e: self._on_check_updates(), self._check_updates_item) self.Bind(wx.EVT_MENU, lambda e: self._on_report_issue(), report_issue_item) @@ -396,6 +401,12 @@ def _on_uv_index(self) -> None: show_uv_index_dialog(self, self.app) + def _on_weather_chat(self) -> None: + """Open Weather Assistant dialog.""" + from .dialogs import show_weather_assistant_dialog + + show_weather_assistant_dialog(self, self.app) + def _on_soundpack_manager(self) -> None: """Open the soundpack manager dialog.""" from .dialogs import show_soundpack_manager_dialog diff --git a/tests/test_ai_tools.py b/tests/test_ai_tools.py new file mode 100644 index 00000000..0cf5e212 --- /dev/null +++ b/tests/test_ai_tools.py @@ -0,0 +1,200 @@ +"""Tests for AI tool schemas and WeatherToolExecutor.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from accessiweather.ai_tools import WEATHER_TOOLS, WeatherToolExecutor + + +class TestWeatherToolSchemas: + """Tests for the WEATHER_TOOLS schema definitions.""" + + def test_weather_tools_has_expected_count(self): + assert len(WEATHER_TOOLS) == 11 + + def test_all_tools_have_function_type(self): + for tool in WEATHER_TOOLS: + assert tool["type"] == "function" + + def test_all_tools_have_required_fields(self): + for tool in WEATHER_TOOLS: + func = tool["function"] + assert "name" in func + assert "description" in func + assert "parameters" in func + + def test_all_tools_have_json_schema_parameters(self): + for tool in WEATHER_TOOLS: + params = tool["function"]["parameters"] + assert params["type"] == "object" + assert "properties" in params + + def test_core_tool_names(self): + names = [t["function"]["name"] for t in WEATHER_TOOLS] + assert "get_current_weather" in names + assert "get_forecast" in names + assert "get_alerts" in names + + def test_extended_tool_names(self): + names = [t["function"]["name"] for t in WEATHER_TOOLS] + assert "get_hourly_forecast" in names + assert "search_location" in names + assert "add_location" in names + assert "list_locations" in names + assert "query_open_meteo" in names + + def test_discussion_tool_names(self): + names = [t["function"]["name"] for t in WEATHER_TOOLS] + assert "get_area_forecast_discussion" in names + assert "get_wpc_discussion" in names + assert "get_spc_outlook" in names + + def test_all_tools_have_descriptions(self): + for tool in WEATHER_TOOLS: + assert len(tool["function"]["description"]) > 0 + + +class TestWeatherToolExecutor: + """Tests for WeatherToolExecutor.""" + + @pytest.fixture() + def mock_weather_service(self): + return MagicMock() + + @pytest.fixture() + def mock_geocoding_service(self): + service = MagicMock() + service.geocode_address.return_value = (40.7128, -74.0060, "New York, NY") + return service + + @pytest.fixture() + def executor(self, mock_weather_service, mock_geocoding_service): + return WeatherToolExecutor(mock_weather_service, mock_geocoding_service) + + def test_execute_unknown_tool_raises_value_error(self, executor): + with pytest.raises(ValueError, match="Unknown tool"): + executor.execute("unknown_tool", {"location": "NYC"}) + + def test_execute_get_current_weather( + self, executor, mock_weather_service, mock_geocoding_service + ): + mock_weather_service.get_current_conditions.return_value = { + "temperature": "72°F", + "humidity": "55%", + "wind": "5 mph NW", + "description": "Partly Cloudy", + } + + result = executor.execute("get_current_weather", {"location": "New York, NY"}) + + mock_geocoding_service.geocode_address.assert_called_once_with("New York, NY") + mock_weather_service.get_current_conditions.assert_called_once_with(40.7128, -74.0060) + assert "New York, NY" in result + assert "72°F" in result + assert "55%" in result + + def test_execute_get_forecast(self, executor, mock_weather_service, mock_geocoding_service): + mock_weather_service.get_forecast.return_value = { + "periods": [ + { + "name": "Tonight", + "detailedForecast": "Clear skies with a low of 60°F.", + "temperature": 60, + "temperatureUnit": "F", + }, + { + "name": "Tomorrow", + "detailedForecast": "Sunny with a high of 85°F.", + "temperature": 85, + "temperatureUnit": "F", + }, + ] + } + + result = executor.execute("get_forecast", {"location": "New York, NY"}) + + mock_weather_service.get_forecast.assert_called_once_with(40.7128, -74.0060) + assert "Forecast for New York, NY" in result + assert "Tonight" in result + assert "Tomorrow" in result + + def test_execute_get_alerts_with_alerts( + self, executor, mock_weather_service, mock_geocoding_service + ): + mock_weather_service.get_alerts.return_value = { + "alerts": [ + { + "properties": { + "event": "Heat Advisory", + "headline": "Heat advisory in effect until 8 PM", + } + } + ] + } + + result = executor.execute("get_alerts", {"location": "New York, NY"}) + + mock_weather_service.get_alerts.assert_called_once_with(40.7128, -74.0060) + assert "Heat Advisory" in result + assert "Heat advisory in effect" in result + + def test_execute_get_alerts_no_alerts( + self, executor, mock_weather_service, mock_geocoding_service + ): + mock_weather_service.get_alerts.return_value = {"alerts": []} + + result = executor.execute("get_alerts", {"location": "New York, NY"}) + + assert "No active alerts" in result + + def test_execute_geocoding_failure(self, executor, mock_geocoding_service): + mock_geocoding_service.geocode_address.return_value = None + + result = executor.execute("get_current_weather", {"location": "Nonexistent Place"}) + assert "Error" in result + assert "Could not resolve location" in result + + def test_execute_current_weather_minimal_data(self, executor, mock_weather_service): + mock_weather_service.get_current_conditions.return_value = { + "status": "ok", + } + + result = executor.execute("get_current_weather", {"location": "NYC"}) + assert "New York, NY" in result + assert "status: ok" in result + + def test_execute_forecast_nested_properties(self, executor, mock_weather_service): + mock_weather_service.get_forecast.return_value = { + "properties": { + "periods": [ + { + "name": "Today", + "shortForecast": "Sunny", + "temperature": 80, + "temperatureUnit": "F", + } + ] + } + } + + result = executor.execute("get_forecast", {"location": "NYC"}) + assert "Today" in result + + def test_execute_alerts_features_format(self, executor, mock_weather_service): + """Test alerts with GeoJSON features format.""" + mock_weather_service.get_alerts.return_value = { + "features": [ + { + "properties": { + "event": "Tornado Warning", + "headline": "Tornado warning for the area", + } + } + ] + } + + result = executor.execute("get_alerts", {"location": "NYC"}) + assert "Tornado Warning" in result diff --git a/tests/test_ai_tools_extended.py b/tests/test_ai_tools_extended.py new file mode 100644 index 00000000..19214c6d --- /dev/null +++ b/tests/test_ai_tools_extended.py @@ -0,0 +1,367 @@ +"""Tests for extended AI tools: hourly forecast, location search/management, open-meteo, discussions.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from accessiweather.ai_tools import ( + CORE_TOOLS, + DISCUSSION_TOOLS, + EXTENDED_TOOLS, + WeatherToolExecutor, + format_hourly_forecast, + format_location_search, + format_open_meteo_response, + get_tools_for_message, +) + + +class TestGetToolsForMessage: + """Tests for tiered tool selection.""" + + def test_simple_weather_returns_core_only(self): + tools = get_tools_for_message("What's the weather like?") + assert len(tools) == len(CORE_TOOLS) + + def test_hourly_trigger(self): + tools = get_tools_for_message("Will it rain at 3pm?") + names = [t["function"]["name"] for t in tools] + assert "get_hourly_forecast" in names + + def test_location_trigger(self): + tools = get_tools_for_message("Add Paris to my locations") + names = [t["function"]["name"] for t in tools] + assert "add_location" in names + assert "search_location" in names + + def test_discussion_trigger_severe(self): + tools = get_tools_for_message("Is there any severe weather risk?") + names = [t["function"]["name"] for t in tools] + assert "get_spc_outlook" in names + + def test_discussion_trigger_spc(self): + tools = get_tools_for_message("Show me the SPC outlook") + names = [t["function"]["name"] for t in tools] + assert "get_wpc_discussion" in names + + def test_discussion_trigger_tornado(self): + tools = get_tools_for_message("Any tornado risk today?") + names = [t["function"]["name"] for t in tools] + assert "get_area_forecast_discussion" in names + + def test_soil_triggers_extended(self): + tools = get_tools_for_message("What's the soil temperature?") + names = [t["function"]["name"] for t in tools] + assert "query_open_meteo" in names + + def test_uv_triggers_extended(self): + tools = get_tools_for_message("What's the UV index?") + names = [t["function"]["name"] for t in tools] + assert "query_open_meteo" in names + + def test_list_locations_trigger(self): + tools = get_tools_for_message("Show my locations") + names = [t["function"]["name"] for t in tools] + assert "list_locations" in names + + def test_no_triggers_core_only(self): + tools = get_tools_for_message("Is it cold outside?") + names = [t["function"]["name"] for t in tools] + assert "get_current_weather" in names + assert "query_open_meteo" not in names + assert "get_spc_outlook" not in names + + +class TestFormatHourlyForecast: + """Tests for format_hourly_forecast.""" + + def test_basic_periods(self): + data = { + "periods": [ + { + "name": "3 PM", + "temperature": 72, + "temperatureUnit": "F", + "shortForecast": "Sunny", + "windSpeed": "5 mph", + } + ] + } + result = format_hourly_forecast(data, "NYC") + assert "NYC" in result + assert "3 PM" in result + assert "72°F" in result + assert "Sunny" in result + assert "Wind: 5 mph" in result + + def test_nested_properties_periods(self): + data = { + "properties": { + "periods": [{"name": "Tonight", "temperature": 55, "temperatureUnit": "F"}] + } + } + result = format_hourly_forecast(data) + assert "Tonight" in result + + def test_empty_periods(self): + data = {"periods": []} + result = format_hourly_forecast(data) + assert "No hourly forecast data" in result + + def test_limits_to_12_periods(self): + data = { + "periods": [ + {"name": f"Hour {i}", "temperature": 70 + i, "temperatureUnit": "F"} + for i in range(20) + ] + } + result = format_hourly_forecast(data) + assert "Hour 11" in result + assert "Hour 12" not in result + + def test_missing_fields(self): + data = {"periods": [{"startTime": "2026-02-11T15:00"}]} + result = format_hourly_forecast(data) + assert "2026-02-11T15:00" in result + + +class TestFormatLocationSearch: + """Tests for format_location_search.""" + + def test_with_results(self): + result = format_location_search(["New York, NY", "New York Mills, MN"], "New York") + assert "New York" in result + assert "1. New York, NY" in result + assert "2. New York Mills, MN" in result + + def test_no_results(self): + result = format_location_search([], "Nonexistent") + assert "No locations found" in result + + def test_empty_query(self): + result = format_location_search([], "") + assert "No locations found" in result + + +class TestFormatOpenMeteoResponse: + """Tests for format_open_meteo_response.""" + + def test_current_data(self): + data = { + "current": {"temperature_2m": 22.5, "time": "2026-02-11T15:00", "interval": 900}, + "current_units": {"temperature_2m": "°C"}, + } + result = format_open_meteo_response(data, "Berlin") + assert "Berlin" in result + assert "temperature_2m: 22.5°C" in result + assert "time" not in result.split("Current:")[1] # time should be excluded + + def test_hourly_data(self): + data = { + "hourly": { + "time": ["2026-02-11T12:00", "2026-02-11T13:00"], + "temperature_2m": [20.0, 21.0], + }, + "hourly_units": {"temperature_2m": "°C"}, + } + result = format_open_meteo_response(data) + assert "Hourly (2 periods)" in result + assert "20.0°C" in result + + def test_daily_data(self): + data = { + "daily": { + "time": ["2026-02-11"], + "temperature_2m_max": [25.0], + }, + "daily_units": {"temperature_2m_max": "°C"}, + } + result = format_open_meteo_response(data) + assert "Daily (1 days)" in result + assert "25.0°C" in result + + def test_empty_response(self): + result = format_open_meteo_response({}) + assert "No data returned" in result + + def test_no_display_name(self): + data = {"current": {"temperature_2m": 15}, "current_units": {}} + result = format_open_meteo_response(data) + assert "Open-Meteo data:" in result + + +class TestWeatherToolExecutorExtended: + """Tests for extended WeatherToolExecutor methods.""" + + @pytest.fixture() + def mock_services(self): + weather = MagicMock() + geocoding = MagicMock() + geocoding.geocode_address.return_value = (40.7, -74.0, "New York, NY") + config = MagicMock() + config.get_location_names.return_value = ["Home"] + config.get_all_locations.return_value = [] + config.get_current_location.return_value = None + config.add_location.return_value = True + return weather, geocoding, config + + @pytest.fixture() + def executor(self, mock_services): + weather, geocoding, config = mock_services + return WeatherToolExecutor(weather, geocoding, config_manager=config) + + def test_get_hourly_forecast(self, executor, mock_services): + weather, _, _ = mock_services + weather.get_hourly_forecast.return_value = { + "periods": [ + { + "name": "3 PM", + "temperature": 72, + "temperatureUnit": "F", + "shortForecast": "Clear", + } + ] + } + result = executor.execute("get_hourly_forecast", {"location": "New York"}) + assert "3 PM" in result + weather.get_hourly_forecast.assert_called_once_with(40.7, -74.0) + + def test_search_location(self, executor, mock_services): + _, geocoding, _ = mock_services + geocoding.suggest_locations.return_value = ["Paris, France", "Paris, TX"] + result = executor.execute("search_location", {"query": "Paris"}) + assert "Paris, France" in result + geocoding.suggest_locations.assert_called_once_with("Paris", limit=5) + + def test_add_location_success(self, executor, mock_services): + _, _, config = mock_services + result = executor.execute( + "add_location", {"name": "NYC", "latitude": 40.7, "longitude": -74.0} + ) + assert "Added" in result + config.add_location.assert_called_once_with("NYC", 40.7, -74.0) + + def test_add_location_already_exists(self, executor, mock_services): + _, _, config = mock_services + config.get_location_names.return_value = ["NYC"] + result = executor.execute( + "add_location", {"name": "NYC", "latitude": 40.7, "longitude": -74.0} + ) + assert "already" in result + + def test_add_location_no_config_manager(self, mock_services): + weather, geocoding, _ = mock_services + executor = WeatherToolExecutor(weather, geocoding, config_manager=None) + result = executor.execute( + "add_location", {"name": "NYC", "latitude": 40.7, "longitude": -74.0} + ) + assert "unavailable" in result + + def test_list_locations_empty(self, executor): + result = executor.execute("list_locations", {}) + assert "No saved locations" in result + + def test_list_locations_with_data(self, mock_services): + weather, geocoding, config = mock_services + loc = MagicMock() + loc.name = "Home" + loc.latitude = 40.7 + loc.longitude = -74.0 + config.get_all_locations.return_value = [loc] + current = MagicMock() + current.name = "Home" + config.get_current_location.return_value = current + executor = WeatherToolExecutor(weather, geocoding, config_manager=config) + result = executor.execute("list_locations", {}) + assert "Home" in result + assert "(current)" in result + + def test_list_locations_no_config_manager(self, mock_services): + weather, geocoding, _ = mock_services + executor = WeatherToolExecutor(weather, geocoding, config_manager=None) + result = executor.execute("list_locations", {}) + assert "unavailable" in result + + def test_get_afd(self, executor, mock_services): + weather, _, _ = mock_services + weather.get_discussion.return_value = "This is the AFD text for testing." + result = executor.execute("get_area_forecast_discussion", {"location": "NYC"}) + assert "Area Forecast Discussion" in result + assert "AFD text" in result + + def test_get_afd_none(self, executor, mock_services): + weather, _, _ = mock_services + weather.get_discussion.return_value = None + result = executor.execute("get_area_forecast_discussion", {"location": "NYC"}) + assert "No Area Forecast Discussion" in result + + def test_get_afd_truncates_long_text(self, executor, mock_services): + weather, _, _ = mock_services + weather.get_discussion.return_value = "x" * 5000 + result = executor.execute("get_area_forecast_discussion", {"location": "NYC"}) + assert "[Truncated" in result + + def test_get_wpc_discussion(self, executor): + with patch( + "accessiweather.services.national_discussion_scraper.NationalDiscussionScraper" + ) as MockScraper: + instance = MockScraper.return_value + instance.fetch_wpc_discussion.return_value = {"full": "WPC discussion text here."} + result = executor.execute("get_wpc_discussion", {}) + assert "WPC" in result + + def test_get_spc_outlook(self, executor): + with patch( + "accessiweather.services.national_discussion_scraper.NationalDiscussionScraper" + ) as MockScraper: + instance = MockScraper.return_value + instance.fetch_spc_discussion.return_value = {"full": "SPC outlook text here."} + result = executor.execute("get_spc_outlook", {}) + assert "SPC" in result + + def test_query_open_meteo(self, executor): + with patch("accessiweather.openmeteo_client.OpenMeteoApiClient") as MockClient: + instance = MockClient.return_value + instance._make_request.return_value = { + "current": {"temperature_2m": 22.5}, + "current_units": {"temperature_2m": "°C"}, + } + result = executor.execute( + "query_open_meteo", + {"location": "NYC", "current": ["temperature_2m"]}, + ) + assert "22.5" in result + + def test_query_open_meteo_no_variables(self, executor): + result = executor.execute("query_open_meteo", {"location": "NYC"}) + assert "specify at least one" in result + + def test_query_open_meteo_error(self, executor): + with patch("accessiweather.openmeteo_client.OpenMeteoApiClient") as MockClient: + MockClient.side_effect = Exception("API down") + result = executor.execute( + "query_open_meteo", + {"location": "NYC", "current": ["temperature_2m"]}, + ) + assert "Error" in result + + +class TestToolListIntegrity: + """Tests for tool list structure.""" + + def test_core_tools_count(self): + assert len(CORE_TOOLS) == 3 + + def test_extended_tools_count(self): + assert len(EXTENDED_TOOLS) == 5 + + def test_discussion_tools_count(self): + assert len(DISCUSSION_TOOLS) == 3 + + def test_no_duplicate_names(self): + from accessiweather.ai_tools import WEATHER_TOOLS + + names = [t["function"]["name"] for t in WEATHER_TOOLS] + assert len(names) == len(set(names)) diff --git a/tests/test_ai_tools_formatters.py b/tests/test_ai_tools_formatters.py new file mode 100644 index 00000000..da75854d --- /dev/null +++ b/tests/test_ai_tools_formatters.py @@ -0,0 +1,285 @@ +"""Tests for weather data formatter functions in ai_tools.""" + +from __future__ import annotations + +from accessiweather.ai_tools import format_alerts, format_current_weather, format_forecast + + +class TestFormatCurrentWeather: + """Tests for format_current_weather.""" + + def test_all_fields(self): + data = { + "temperature": "72°F", + "feels_like": "75°F", + "description": "Partly Cloudy", + "humidity": "55%", + "wind": "5 mph NW", + "pressure": "30.12 inHg", + } + result = format_current_weather(data, "New York, NY") + assert "Current weather for New York, NY:" in result + assert "Temperature: 72°F" in result + assert "Feels Like: 75°F" in result + assert "Conditions: Partly Cloudy" in result + assert "Humidity: 55%" in result + assert "Wind: 5 mph NW" in result + assert "Pressure: 30.12 inHg" in result + + def test_text_description_field(self): + data = {"textDescription": "Sunny"} + result = format_current_weather(data, "Boston") + assert "Conditions: Sunny" in result + + def test_feels_like_camel_case(self): + data = {"feelsLike": "80°F"} + result = format_current_weather(data, "Miami") + assert "Feels Like: 80°F" in result + + def test_wind_speed_field(self): + data = {"windSpeed": "10 mph"} + result = format_current_weather(data, "Chicago") + assert "Wind: 10 mph" in result + + def test_barometric_pressure_field(self): + data = {"barometricPressure": "1013 hPa"} + result = format_current_weather(data, "Denver") + assert "Pressure: 1013 hPa" in result + + def test_missing_fields_graceful(self): + data = {} + result = format_current_weather(data, "Nowhere") + assert "Current weather for Nowhere:" in result + + def test_none_values_skipped(self): + data = {"temperature": None, "humidity": None, "wind": "5 mph"} + result = format_current_weather(data, "Test") + assert "Temperature" not in result + assert "Humidity" not in result + assert "Wind: 5 mph" in result + + def test_fallback_scalar_dump(self): + data = {"status": "ok", "code": 200} + result = format_current_weather(data, "Test") + assert "status: ok" in result + assert "code: 200" in result + + def test_no_display_name(self): + data = {"temperature": "70°F"} + result = format_current_weather(data) + assert result.startswith("Current weather:") + + def test_empty_string_values_skipped(self): + data = {"temperature": "", "humidity": "50%"} + result = format_current_weather(data, "Test") + assert "Temperature" not in result + assert "Humidity: 50%" in result + + +class TestFormatForecast: + """Tests for format_forecast.""" + + def test_basic_periods(self): + data = { + "periods": [ + { + "name": "Tonight", + "temperature": 60, + "temperatureUnit": "F", + "shortForecast": "Clear", + }, + { + "name": "Tomorrow", + "temperature": 85, + "temperatureUnit": "F", + "shortForecast": "Sunny", + }, + ] + } + result = format_forecast(data, "New York, NY") + assert "Forecast for New York, NY:" in result + assert "Tonight - 60°F - Clear" in result + assert "Tomorrow - 85°F - Sunny" in result + + def test_up_to_seven_periods(self): + data = { + "periods": [ + { + "name": f"Period {i}", + "temperature": 70 + i, + "temperatureUnit": "F", + "shortForecast": "Fair", + } + for i in range(10) + ] + } + result = format_forecast(data, "Test") + assert "Period 6" in result + assert "Period 7" not in result # 0-indexed, so Period 7 would be the 8th + + def test_nested_properties_periods(self): + data = { + "properties": { + "periods": [ + { + "name": "Today", + "temperature": 80, + "temperatureUnit": "F", + "shortForecast": "Sunny", + } + ] + } + } + result = format_forecast(data, "Test") + assert "Today" in result + + def test_detailed_forecast_fallback(self): + data = { + "periods": [ + { + "name": "Tonight", + "temperature": 60, + "temperatureUnit": "F", + "detailedForecast": "Clear skies with a low of 60.", + } + ] + } + result = format_forecast(data, "Test") + assert "Clear skies" in result + + def test_missing_temperature(self): + data = {"periods": [{"name": "Tonight", "shortForecast": "Cloudy"}]} + result = format_forecast(data, "Test") + assert "Tonight - Cloudy" in result + + def test_missing_forecast_text(self): + data = {"periods": [{"name": "Tonight", "temperature": 55, "temperatureUnit": "F"}]} + result = format_forecast(data, "Test") + assert "Tonight - 55°F" in result + + def test_empty_periods(self): + data = {"periods": []} + result = format_forecast(data, "Test") + # Falls back to JSON dump + assert "Forecast for Test:" in result + + def test_no_periods_key(self): + data = {"something": "else"} + result = format_forecast(data, "Test") + assert "Forecast for Test:" in result + + def test_none_name_defaults(self): + data = {"periods": [{"name": None, "temperature": 70, "temperatureUnit": "F"}]} + result = format_forecast(data, "Test") + assert "Unknown" in result + + def test_no_display_name(self): + data = {"periods": [{"name": "Today", "shortForecast": "Nice"}]} + result = format_forecast(data) + assert result.startswith("Forecast:") + + +class TestFormatAlerts: + """Tests for format_alerts.""" + + def test_alert_with_all_fields(self): + data = { + "alerts": [ + { + "properties": { + "event": "Heat Advisory", + "severity": "Moderate", + "headline": "Heat advisory until 8 PM", + "description": "Dangerously hot conditions expected.", + } + } + ] + } + result = format_alerts(data, "Phoenix") + assert "Weather alerts for Phoenix:" in result + assert "- Heat Advisory (Severity: Moderate)" in result + assert "Heat advisory until 8 PM" in result + assert "Dangerously hot conditions" in result + + def test_no_active_alerts(self): + data = {"alerts": []} + result = format_alerts(data, "Test") + assert "No active alerts." in result + + def test_no_alerts_key(self): + data = {} + result = format_alerts(data, "Test") + assert "No active alerts." in result + + def test_geojson_features_format(self): + data = { + "features": [ + { + "properties": { + "event": "Tornado Warning", + "severity": "Extreme", + "headline": "Tornado warning for the area", + } + } + ] + } + result = format_alerts(data, "Oklahoma City") + assert "Tornado Warning" in result + assert "Severity: Extreme" in result + + def test_alert_missing_severity(self): + data = { + "alerts": [{"properties": {"event": "Flood Watch", "headline": "Flooding possible"}}] + } + result = format_alerts(data, "Test") + assert "- Flood Watch" in result + assert "Severity" not in result + + def test_alert_missing_headline(self): + data = { + "alerts": [ + {"properties": {"event": "Wind Advisory", "description": "Strong winds expected."}} + ] + } + result = format_alerts(data, "Test") + assert "Wind Advisory" in result + assert "Strong winds expected." in result + + def test_alert_flat_dict_format(self): + """Test alerts that are flat dicts (no nested properties).""" + data = { + "alerts": [ + {"event": "Frost Advisory", "severity": "Minor", "headline": "Frost tonight"} + ] + } + result = format_alerts(data, "Test") + assert "Frost Advisory" in result + assert "Severity: Minor" in result + + def test_description_truncated(self): + data = {"alerts": [{"properties": {"event": "Test", "description": "x" * 500}}]} + result = format_alerts(data, "Test") + # Description should be truncated to 300 chars + desc_line = [line for line in result.split("\n") if line.startswith(" x")][0] + assert len(desc_line.strip()) == 300 + + def test_none_fields_graceful(self): + data = {"alerts": [{"properties": {"event": None, "severity": None, "headline": None}}]} + result = format_alerts(data, "Test") + assert "Unknown Alert" in result + + def test_no_display_name(self): + data = {"alerts": []} + result = format_alerts(data) + assert result.startswith("Weather alerts:") + + def test_multiple_alerts(self): + data = { + "alerts": [ + {"properties": {"event": "Heat Advisory", "headline": "Hot"}}, + {"properties": {"event": "Air Quality Alert", "headline": "Poor air"}}, + ] + } + result = format_alerts(data, "LA") + assert "Heat Advisory" in result + assert "Air Quality Alert" in result diff --git a/tests/test_location_resolver.py b/tests/test_location_resolver.py new file mode 100644 index 00000000..ce5303b9 --- /dev/null +++ b/tests/test_location_resolver.py @@ -0,0 +1,154 @@ +"""Tests for LocationResolver and updated WeatherToolExecutor location handling.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from accessiweather.ai_tools import LocationResolver, WeatherToolExecutor + + +class TestLocationResolver: + """Tests for the LocationResolver class.""" + + def _make_resolver( + self, + geocoding_service: MagicMock | None = None, + default_lat: float | None = 40.7, + default_lon: float | None = -74.0, + default_name: str | None = "New York, NY", + ) -> LocationResolver: + if geocoding_service is None: + geocoding_service = MagicMock() + return LocationResolver( + geocoding_service=geocoding_service, + default_lat=default_lat, + default_lon=default_lon, + default_name=default_name, + ) + + def test_resolve_returns_tuple(self): + resolver = self._make_resolver() + result = resolver.resolve("New York, NY") + assert isinstance(result, tuple) + assert len(result) == 3 + + def test_resolve_matches_default_exact(self): + geo = MagicMock() + resolver = self._make_resolver(geocoding_service=geo) + lat, lon, name = resolver.resolve("New York, NY") + assert lat == 40.7 + assert lon == -74.0 + assert name == "New York, NY" + geo.geocode_address.assert_not_called() + + def test_resolve_matches_default_case_insensitive(self): + geo = MagicMock() + resolver = self._make_resolver(geocoding_service=geo) + lat, lon, name = resolver.resolve("new york, ny") + assert lat == 40.7 + assert lon == -74.0 + geo.geocode_address.assert_not_called() + + def test_resolve_matches_default_with_whitespace(self): + geo = MagicMock() + resolver = self._make_resolver(geocoding_service=geo) + lat, lon, _name = resolver.resolve(" New York, NY ") + assert lat == 40.7 + assert lon == -74.0 + geo.geocode_address.assert_not_called() + + def test_resolve_falls_back_to_geocoding(self): + geo = MagicMock() + geo.geocode_address.return_value = (48.8566, 2.3522, "Paris, France") + resolver = self._make_resolver(geocoding_service=geo) + lat, lon, name = resolver.resolve("Paris") + assert lat == 48.8566 + assert lon == 2.3522 + assert name == "Paris, France" + geo.geocode_address.assert_called_once_with("Paris") + + def test_resolve_geocoding_failure_raises_valueerror(self): + geo = MagicMock() + geo.geocode_address.return_value = None + resolver = self._make_resolver(geocoding_service=geo) + with pytest.raises(ValueError, match="Could not resolve location"): + resolver.resolve("Nonexistent Place XYZ") + + def test_resolve_no_default_set(self): + """When no default location is configured, always use geocoding.""" + geo = MagicMock() + geo.geocode_address.return_value = (40.7, -74.0, "New York, NY") + resolver = self._make_resolver( + geocoding_service=geo, + default_lat=None, + default_lon=None, + default_name=None, + ) + resolver.resolve("New York, NY") + geo.geocode_address.assert_called_once() + + def test_resolve_partial_default_does_not_match(self): + """If only name is set but not coords, don't match.""" + geo = MagicMock() + geo.geocode_address.return_value = (40.7, -74.0, "New York, NY") + resolver = self._make_resolver( + geocoding_service=geo, + default_lat=None, + default_lon=None, + default_name="New York, NY", + ) + resolver.resolve("New York, NY") + geo.geocode_address.assert_called_once() + + +class TestWeatherToolExecutorWithLocationResolver: + """Tests for WeatherToolExecutor using LocationResolver.""" + + def _make_executor(self) -> tuple[WeatherToolExecutor, MagicMock, MagicMock]: + weather_svc = MagicMock() + geo_svc = MagicMock() + executor = WeatherToolExecutor( + weather_service=weather_svc, + geocoding_service=geo_svc, + default_lat=40.7, + default_lon=-74.0, + default_name="New York, NY", + ) + return executor, weather_svc, geo_svc + + def test_execute_uses_default_location(self): + executor, weather_svc, geo_svc = self._make_executor() + weather_svc.get_current_conditions.return_value = {"temperature": "72°F"} + result = executor.execute("get_current_weather", {"location": "New York, NY"}) + geo_svc.geocode_address.assert_not_called() + weather_svc.get_current_conditions.assert_called_once_with(40.7, -74.0) + assert "72°F" in result + + def test_execute_geocodes_other_location(self): + executor, weather_svc, geo_svc = self._make_executor() + geo_svc.geocode_address.return_value = (48.85, 2.35, "Paris, France") + weather_svc.get_forecast.return_value = {"periods": []} + executor.execute("get_forecast", {"location": "Paris"}) + geo_svc.geocode_address.assert_called_once_with("Paris") + weather_svc.get_forecast.assert_called_once_with(48.85, 2.35) + + def test_execute_returns_error_on_geocoding_failure(self): + executor, weather_svc, geo_svc = self._make_executor() + geo_svc.geocode_address.return_value = None + result = executor.execute("get_current_weather", {"location": "Nowhere XYZ"}) + assert "Error" in result + assert "Could not resolve location" in result + weather_svc.get_current_conditions.assert_not_called() + + def test_execute_returns_error_on_weather_service_failure(self): + executor, weather_svc, geo_svc = self._make_executor() + weather_svc.get_alerts.side_effect = Exception("API down") + result = executor.execute("get_alerts", {"location": "New York, NY"}) + assert "Error" in result + + def test_unknown_tool_raises(self): + executor, _, _ = self._make_executor() + with pytest.raises(ValueError, match="Unknown tool"): + executor.execute("nonexistent_tool", {}) diff --git a/tests/test_nationwide_discussion_dialog.py b/tests/test_nationwide_discussion_dialog.py new file mode 100644 index 00000000..e73c530d --- /dev/null +++ b/tests/test_nationwide_discussion_dialog.py @@ -0,0 +1,194 @@ +"""Tests for NationwideDiscussionDialog structure and exports.""" + +from __future__ import annotations + +import ast + +import pytest + + +class TestNationwideDiscussionDialogStructure: + """Test the dialog module structure using AST parsing (no wx required).""" + + @pytest.fixture + def module_ast(self): + """Parse the module AST.""" + path = "src/accessiweather/ui/dialogs/nationwide_discussion_dialog.py" + with open(path) as f: + return ast.parse(f.read()) + + def _get_class_node(self, tree: ast.Module, name: str) -> ast.ClassDef | None: + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == name: + return node + return None + + def _get_function_node(self, tree: ast.Module, name: str) -> ast.FunctionDef | None: + for node in ast.iter_child_nodes(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == name: + return node + return None + + def test_class_exists(self, module_ast): + """AC1: NationwideDiscussionDialog class exists.""" + cls = self._get_class_node(module_ast, "NationwideDiscussionDialog") + assert cls is not None, "NationwideDiscussionDialog class not found" + + def test_class_inherits_wx_dialog(self, module_ast): + """Dialog inherits from wx.Dialog.""" + cls = self._get_class_node(module_ast, "NationwideDiscussionDialog") + assert cls is not None + base_names = [] + for base in cls.bases: + if isinstance(base, ast.Attribute): + base_names.append(f"{base.value.id}.{base.attr}") # type: ignore[union-attr] + elif isinstance(base, ast.Name): + base_names.append(base.id) + assert any("Dialog" in b for b in base_names) + + def test_convenience_function_exists(self, module_ast): + """AC9: show_nationwide_discussion_dialog function exists.""" + func = self._get_function_node(module_ast, "show_nationwide_discussion_dialog") + assert func is not None + + def test_has_create_widgets_method(self, module_ast): + """Dialog has _create_widgets method.""" + cls = self._get_class_node(module_ast, "NationwideDiscussionDialog") + assert cls is not None + methods = [n.name for n in ast.walk(cls) if isinstance(n, ast.FunctionDef)] + assert "_create_widgets" in methods + + def test_has_close_handler(self, module_ast): + """AC8: Dialog has close button handler.""" + cls = self._get_class_node(module_ast, "NationwideDiscussionDialog") + assert cls is not None + methods = [n.name for n in ast.walk(cls) if isinstance(n, ast.FunctionDef)] + assert "_on_close" in methods + + def test_has_set_discussion_text_method(self, module_ast): + """Dialog has set_discussion_text helper.""" + cls = self._get_class_node(module_ast, "NationwideDiscussionDialog") + assert cls is not None + methods = [n.name for n in ast.walk(cls) if isinstance(n, ast.FunctionDef)] + assert "set_discussion_text" in methods + + +class TestDialogSourceContent: + """Verify widget names and tab structure by inspecting source text.""" + + @pytest.fixture + def source(self): + with open("src/accessiweather/ui/dialogs/nationwide_discussion_dialog.py") as f: + return f.read() + + def test_notebook_created(self, source): + """AC2: wx.Notebook is used for tabs.""" + assert "wx.Notebook" in source + + @pytest.mark.parametrize( + "tab_name", + ["WPC", "SPC", "NHC", "CPC"], + ) + def test_tab_pages_added(self, source, tab_name): + """AC2: All four tab pages are added.""" + assert f'AddPage(self.{tab_name.lower()}_panel, "{tab_name}")' in source + + @pytest.mark.parametrize( + "attr_name", + [ + "wpc_short_range", + "wpc_medium_range", + "wpc_extended", + "wpc_qpf", + ], + ) + def test_wpc_text_controls(self, source, attr_name): + """AC3: WPC tab has all required text controls.""" + assert f"self.{attr_name}" in source + + @pytest.mark.parametrize( + "attr_name", + ["spc_day1", "spc_day2", "spc_day3"], + ) + def test_spc_text_controls(self, source, attr_name): + """AC4: SPC tab has all required text controls.""" + assert f"self.{attr_name}" in source + + @pytest.mark.parametrize( + "attr_name", + ["nhc_atlantic", "nhc_east_pacific"], + ) + def test_nhc_text_controls(self, source, attr_name): + """AC5: NHC tab has all required text controls.""" + assert f"self.{attr_name}" in source + + @pytest.mark.parametrize( + "attr_name", + ["cpc_6_10_day", "cpc_8_14_day"], + ) + def test_cpc_text_controls(self, source, attr_name): + """AC6: CPC tab has all required text controls.""" + assert f"self.{attr_name}" in source + + @pytest.mark.parametrize( + "name_str", + [ + "WPC Short Range Forecast Discussion", + "WPC Medium Range Forecast Discussion", + "WPC Extended Forecast Discussion", + "WPC QPF Discussion", + "SPC Day 1 Convective Outlook", + "SPC Day 2 Convective Outlook", + "SPC Day 3 Convective Outlook", + "NHC Atlantic Tropical Weather Outlook", + "NHC East Pacific Tropical Weather Outlook", + "CPC 6-10 Day Outlook", + "CPC 8-14 Day Outlook", + ], + ) + def test_accessible_names(self, source, name_str): + """AC7: All text controls have name= parameter for accessibility.""" + assert f'name="{name_str}"' in source + + def test_close_button(self, source): + """AC8: Close button exists.""" + assert "wx.ID_CLOSE" in source + + def test_readonly_style(self, source): + """All text controls are read-only.""" + assert "wx.TE_READONLY" in source + + def test_default_size(self, source): + """Dialog has 800x600 default size.""" + assert "800, 600" in source + + +class TestDialogPackageExport: + """Test that the dialog is properly exported from the package.""" + + def test_exported_in_init(self): + """AC9: show_nationwide_discussion_dialog is in __init__.py exports.""" + init_path = "src/accessiweather/ui/dialogs/__init__.py" + with open(init_path) as f: + content = f.read() + assert "show_nationwide_discussion_dialog" in content + + def test_in_all_list(self): + """show_nationwide_discussion_dialog is in __all__.""" + init_path = "src/accessiweather/ui/dialogs/__init__.py" + with open(init_path) as f: + tree = ast.parse(f.read()) + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if ( + isinstance(target, ast.Name) + and target.id == "__all__" + and isinstance(node.value, ast.List) + ): + elements = [ + elt.value for elt in node.value.elts if isinstance(elt, ast.Constant) + ] + assert "show_nationwide_discussion_dialog" in elements + return + pytest.fail("__all__ not found or doesn't contain the function") diff --git a/tests/test_pyproject_screenreader_dep.py b/tests/test_pyproject_screenreader_dep.py new file mode 100644 index 00000000..97185612 --- /dev/null +++ b/tests/test_pyproject_screenreader_dep.py @@ -0,0 +1,29 @@ +"""Tests for pyproject.toml prismatoid dependency.""" + +from pathlib import Path + +import tomllib + +PYPROJECT = Path(__file__).resolve().parent.parent / "pyproject.toml" + + +def _load_pyproject(): + """Load pyproject.toml.""" + with open(PYPROJECT, "rb") as f: + return tomllib.load(f) + + +class TestPrismatoidDependency: + """Verify prismatoid is listed as a main dependency.""" + + def test_prismatoid_in_main_dependencies(self): + """Prismatoid should be a default dependency.""" + data = _load_pyproject() + main_deps = data["project"]["dependencies"] + assert any("prismatoid" in d for d in main_deps) + + def test_prismatoid_version_constraint(self): + """Prismatoid should require >= 0.7.0.""" + data = _load_pyproject() + main_deps = data["project"]["dependencies"] + assert any("prismatoid>=0.7.0" in d for d in main_deps) diff --git a/tests/test_screen_reader.py b/tests/test_screen_reader.py new file mode 100644 index 00000000..904b4efb --- /dev/null +++ b/tests/test_screen_reader.py @@ -0,0 +1,124 @@ +"""Tests for the screen_reader module.""" + +import importlib +import sys +from contextlib import contextmanager +from unittest import mock + + +def _make_mock_prism(runtime_supported=True): + """Create a mock prism module with configurable runtime support.""" + mock_features = mock.MagicMock() + mock_features.is_supported_at_runtime = runtime_supported + mock_backend = mock.MagicMock() + mock_backend.features = mock_features + mock_backend.name = "MockReader" + mock_ctx = mock.MagicMock() + mock_ctx.acquire_best.return_value = mock_backend + mock_prism = mock.MagicMock() + mock_prism.Context.return_value = mock_ctx + return mock_prism, mock_backend + + +@contextmanager +def _patched_prism(mock_prism): + """Patch sys.modules with a mock prism and reload screen_reader.""" + with mock.patch.dict(sys.modules, {"prism": mock_prism}): + import accessiweather.screen_reader as sr_mod + + reloaded = importlib.reload(sr_mod) + yield reloaded + + +class TestWithoutPrismatoid: + """Test behavior when prismatoid is not installed.""" + + def test_prism_available_false_when_not_installed(self): + """PRISM_AVAILABLE should be False when prismatoid is missing.""" + with mock.patch.dict(sys.modules, {"prism": None}): + import accessiweather.screen_reader as sr_mod + + reloaded = importlib.reload(sr_mod) + assert reloaded.PRISM_AVAILABLE is False + + def test_announcer_instantiates_without_prismatoid(self): + """ScreenReaderAnnouncer should instantiate without error.""" + with mock.patch.dict(sys.modules, {"prism": None}): + import accessiweather.screen_reader as sr_mod + + reloaded = importlib.reload(sr_mod) + announcer = reloaded.ScreenReaderAnnouncer() + assert announcer.is_available() is False + + def test_announce_is_noop_without_prismatoid(self): + """announce() should not raise when prismatoid is missing.""" + with mock.patch.dict(sys.modules, {"prism": None}): + import accessiweather.screen_reader as sr_mod + + reloaded = importlib.reload(sr_mod) + announcer = reloaded.ScreenReaderAnnouncer() + announcer.announce("hello") # should not raise + + def test_shutdown_safe_without_prismatoid(self): + """shutdown() should not raise when prismatoid is missing.""" + with mock.patch.dict(sys.modules, {"prism": None}): + import accessiweather.screen_reader as sr_mod + + reloaded = importlib.reload(sr_mod) + announcer = reloaded.ScreenReaderAnnouncer() + announcer.shutdown() # should not raise + + +class TestWithMockedPrismatoid: + """Test behavior when prismatoid is available (mocked).""" + + def test_prism_available_true(self): + mock_prism, _ = _make_mock_prism() + with _patched_prism(mock_prism) as sr_mod: + assert sr_mod.PRISM_AVAILABLE is True + + def test_announcer_is_available(self): + mock_prism, _ = _make_mock_prism() + with _patched_prism(mock_prism) as sr_mod: + announcer = sr_mod.ScreenReaderAnnouncer() + assert announcer.is_available() is True + + def test_announce_calls_speak(self): + mock_prism, mock_backend = _make_mock_prism() + with _patched_prism(mock_prism) as sr_mod: + announcer = sr_mod.ScreenReaderAnnouncer() + announcer.announce("test message") + mock_backend.speak.assert_called_once_with("test message", interrupt=False) + + def test_shutdown_clears_backend(self): + mock_prism, _ = _make_mock_prism() + with _patched_prism(mock_prism) as sr_mod: + announcer = sr_mod.ScreenReaderAnnouncer() + assert announcer.is_available() is True + announcer.shutdown() + assert announcer.is_available() is False + + def test_graceful_fallback_on_acquire_exception(self): + """If acquire_best() raises, announcer should fall back gracefully.""" + mock_prism = mock.MagicMock() + mock_prism.Context.return_value.acquire_best.side_effect = RuntimeError("no SR") + + with _patched_prism(mock_prism) as sr_mod: + announcer = sr_mod.ScreenReaderAnnouncer() + assert announcer.is_available() is False + announcer.announce("test") # no-op, no raise + + def test_announce_handles_speak_exception(self): + """If speak() raises, announce should not propagate.""" + mock_prism, mock_backend = _make_mock_prism() + with _patched_prism(mock_prism) as sr_mod: + mock_backend.speak.side_effect = RuntimeError("speak failed") + announcer = sr_mod.ScreenReaderAnnouncer() + announcer.announce("test") # should not raise + + def test_runtime_not_supported_returns_unavailable(self): + """Backend exists but is_supported_at_runtime=False means unavailable.""" + mock_prism, _ = _make_mock_prism(runtime_supported=False) + with _patched_prism(mock_prism) as sr_mod: + announcer = sr_mod.ScreenReaderAnnouncer() + assert announcer.is_available() is False diff --git a/tests/test_weather_assistant_dialog.py b/tests/test_weather_assistant_dialog.py new file mode 100644 index 00000000..9aa0ba43 --- /dev/null +++ b/tests/test_weather_assistant_dialog.py @@ -0,0 +1,224 @@ +"""Tests for WeatherChat dialog.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from unittest.mock import MagicMock + +from accessiweather.ui.dialogs.weather_assistant_dialog import ( + MAX_CONTEXT_TURNS, + SYSTEM_PROMPT, + _build_weather_context, + show_weather_assistant_dialog, +) + + +@dataclass +class MockLocation: + """Mock Location for testing.""" + + name: str = "Test City" + latitude: float = 40.0 + longitude: float = -74.0 + timezone: str | None = None + country_code: str | None = None + + +@dataclass +class MockCurrentConditions: + """Mock CurrentConditions for testing.""" + + temperature_f: float | None = 72.0 + feels_like_f: float | None = 70.0 + condition: str | None = "Partly Cloudy" + humidity: int | None = 65 + wind_speed_mph: float | None = 8.0 + wind_direction: str | None = "NW" + pressure_in: float | None = 30.12 + visibility_miles: float | None = 10.0 + uv_index: float | None = 5.0 + + +@dataclass +class MockForecastPeriod: + """Mock ForecastPeriod for testing.""" + + name: str = "Tonight" + temperature: float = 55.0 + temperature_unit: str = "F" + short_forecast: str | None = "Clear" + + +@dataclass +class MockForecast: + """Mock Forecast for testing.""" + + periods: list = field(default_factory=lambda: [MockForecastPeriod()]) + + def has_data(self): + return bool(self.periods) + + +@dataclass +class MockAlert: + """Mock alert for testing.""" + + event: str = "Wind Advisory" + severity: str = "Moderate" + title: str = "Wind Advisory" + + +@dataclass +class MockAlerts: + """Mock WeatherAlerts for testing.""" + + alerts: list = field(default_factory=lambda: [MockAlert()]) + + def has_alerts(self): + return bool(self.alerts) + + +@dataclass +class MockTrendInsight: + """Mock TrendInsight for testing.""" + + metric: str = "temperature" + direction: str = "rising" + change: float | None = 5.0 + unit: str | None = "°F" + summary: str | None = "Temperature rising 5°F over the last 3 hours" + + +@dataclass +class MockWeatherData: + """Mock WeatherData for testing.""" + + location: MockLocation = field(default_factory=MockLocation) + current: MockCurrentConditions | None = field(default_factory=MockCurrentConditions) + forecast: MockForecast | None = field(default_factory=MockForecast) + alerts: MockAlerts | None = None + trend_insights: list = field(default_factory=list) + + +class TestBuildWeatherContext: + """Tests for _build_weather_context.""" + + def test_no_weather_data(self): + """Test with no weather data loaded.""" + app = MagicMock() + app.current_weather_data = None + result = _build_weather_context(app) + assert result == "No weather data currently loaded." + + def test_basic_current_conditions(self): + """Test with basic current conditions.""" + app = MagicMock() + app.current_weather_data = MockWeatherData() + result = _build_weather_context(app) + + assert "Test City" in result + assert "72°F" in result + assert "Partly Cloudy" in result + assert "65%" in result + assert "8 mph" in result + assert "NW" in result + assert "30.12 inHg" in result + assert "10.0 miles" in result + assert "UV Index: 5.0" in result + + def test_with_forecast(self): + """Test with forecast data.""" + app = MagicMock() + app.current_weather_data = MockWeatherData() + result = _build_weather_context(app) + + assert "Forecast:" in result + assert "Tonight" in result + assert "55" in result and "F" in result + assert "Clear" in result + + def test_with_alerts(self): + """Test with active alerts.""" + app = MagicMock() + app.current_weather_data = MockWeatherData(alerts=MockAlerts()) + result = _build_weather_context(app) + + assert "Active Alerts:" in result + assert "Wind Advisory" in result + assert "Moderate" in result + + def test_with_trend_insights(self): + """Test with trend insights.""" + app = MagicMock() + app.current_weather_data = MockWeatherData(trend_insights=[MockTrendInsight()]) + result = _build_weather_context(app) + + assert "Trend Insights:" in result + assert "Temperature rising" in result + + def test_no_current_conditions(self): + """Test with weather data but no current conditions.""" + app = MagicMock() + app.current_weather_data = MockWeatherData(current=None) + result = _build_weather_context(app) + + assert "Test City" in result + # Should not crash with None current + assert "Temperature:" not in result + + def test_partial_current_conditions(self): + """Test with some fields None.""" + app = MagicMock() + conditions = MockCurrentConditions() + conditions.wind_speed_mph = None + conditions.wind_direction = None + conditions.pressure_in = None + app.current_weather_data = MockWeatherData(current=conditions) + result = _build_weather_context(app) + + assert "72°F" in result + assert "Wind:" not in result + assert "Pressure:" not in result + + def test_wind_without_direction(self): + """Test wind with speed but no direction.""" + app = MagicMock() + conditions = MockCurrentConditions() + conditions.wind_direction = None + app.current_weather_data = MockWeatherData(current=conditions) + result = _build_weather_context(app) + + assert "Wind: 8 mph" in result + assert "from" not in result.split("Wind:")[1].split("\n")[0] + + +class TestSystemPrompt: + """Tests for the system prompt.""" + + def test_prompt_exists(self): + """Test system prompt is defined.""" + assert SYSTEM_PROMPT + assert "Weather Assistant" in SYSTEM_PROMPT + assert "screen reader" in SYSTEM_PROMPT + + def test_prompt_no_markdown_instruction(self): + """Test prompt instructs no markdown.""" + assert "plain text" in SYSTEM_PROMPT + assert "No bold" in SYSTEM_PROMPT + + +class TestMaxContextTurns: + """Tests for conversation limits.""" + + def test_max_turns_defined(self): + """Test max turns is reasonable.""" + assert MAX_CONTEXT_TURNS > 0 + assert MAX_CONTEXT_TURNS <= 50 + + +class TestShowWeatherChatDialog: + """Tests for the show function.""" + + def test_function_exists(self): + """Test the show function is importable.""" + assert callable(show_weather_assistant_dialog) diff --git a/tests/test_weather_assistant_integration.py b/tests/test_weather_assistant_integration.py new file mode 100644 index 00000000..f2de5446 --- /dev/null +++ b/tests/test_weather_assistant_integration.py @@ -0,0 +1,440 @@ +""" +Integration tests for the weather assistant tool call flow. + +Tests the full loop: user message -> tool_call response -> tool execution -> +tool result -> final text response. Uses mock OpenAI client and mock services. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from unittest.mock import MagicMock + +from accessiweather.ai_tools import WeatherToolExecutor + +# --- Mock OpenAI response objects --- + + +@dataclass +class MockFunction: + name: str + arguments: str + + +@dataclass +class MockToolCall: + id: str + type: str + function: MockFunction + + +@dataclass +class MockMessage: + content: str | None = None + tool_calls: list[MockToolCall] | None = None + role: str = "assistant" + + +@dataclass +class MockChoice: + message: MockMessage + finish_reason: str = "stop" + + +@dataclass +class MockResponse: + choices: list[MockChoice] + model: str = "test-model" + + +def _make_tool_call_response( + tool_calls: list[tuple[str, str, dict]], + content: str = "", +) -> MockResponse: + """ + Create a mock response with tool calls. + + Args: + tool_calls: List of (call_id, function_name, arguments_dict). + content: Optional text content alongside tool calls. + + """ + return MockResponse( + choices=[ + MockChoice( + message=MockMessage( + content=content, + tool_calls=[ + MockToolCall( + id=call_id, + type="function", + function=MockFunction( + name=name, + arguments=json.dumps(args), + ), + ) + for call_id, name, args in tool_calls + ], + ), + finish_reason="tool_calls", + ) + ] + ) + + +def _make_text_response(text: str) -> MockResponse: + """Create a mock response with only text content.""" + return MockResponse( + choices=[ + MockChoice( + message=MockMessage(content=text, tool_calls=None), + finish_reason="stop", + ) + ] + ) + + +def _create_mock_services(): + """Create mock WeatherService and GeocodingService.""" + weather_service = MagicMock() + geocoding_service = MagicMock() + + # geocode_address returns (lat, lon, display_name) + geocoding_service.geocode_address.return_value = (40.7128, -74.0060, "New York, NY") + + # Mock weather data + weather_service.get_current_conditions.return_value = { + "temperature": 72, + "temperatureUnit": "F", + "shortForecast": "Partly Cloudy", + "windSpeed": "10 mph", + "windDirection": "NW", + "relativeHumidity": {"value": 55}, + } + + weather_service.get_forecast.return_value = { + "periods": [ + { + "name": "Today", + "temperature": 72, + "temperatureUnit": "F", + "shortForecast": "Partly Cloudy", + "detailedForecast": "Partly cloudy with a high near 72.", + "windSpeed": "10 mph", + "windDirection": "NW", + } + ] + } + + weather_service.get_alerts.return_value = {"alerts": []} + + return weather_service, geocoding_service + + +def _simulate_tool_call_loop( + messages: list[dict], + mock_responses: list[MockResponse], + tool_executor: WeatherToolExecutor, + max_iterations: int = 5, +) -> str: + """ + Simulate the tool call loop from weather_assistant_dialog.do_generate. + + This mirrors the exact logic in the dialog's do_generate method. + + Returns: + The final text response from the assistant. + + """ + response_index = 0 + + for _iteration in range(max_iterations + 1): + assert response_index < len(mock_responses), "Ran out of mock responses" + response = mock_responses[response_index] + response_index += 1 + + if not response.choices: + raise RuntimeError("Empty response") + + choice = response.choices[0] + assistant_message = choice.message + + if assistant_message.tool_calls and tool_executor is not None: + # Append assistant message with tool calls + tool_call_msg: dict = { + "role": "assistant", + "content": assistant_message.content or "", + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in assistant_message.tool_calls + ], + } + messages.append(tool_call_msg) + + # Execute each tool call + for tool_call in assistant_message.tool_calls: + tool_name = tool_call.function.name + try: + arguments = json.loads(tool_call.function.arguments) + result = tool_executor.execute(tool_name, arguments) + except Exception as exc: + result = f"Error executing {tool_name}: {exc}" + + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + } + ) + # Continue loop for next API call + continue + + # No tool calls — final text response + final_text = assistant_message.content or "" + messages.append({"role": "assistant", "content": final_text}) + return final_text + + raise RuntimeError("Exceeded max tool iterations") + + +class TestFullToolCallFlow: + """Integration tests for the complete tool call flow.""" + + def test_single_tool_call_flow(self): + """Test: user asks about weather -> tool call -> execute -> final response.""" + weather_service, geocoding_service = _create_mock_services() + executor = WeatherToolExecutor(weather_service, geocoding_service) + + messages: list[dict] = [ + {"role": "system", "content": "You are a weather assistant."}, + {"role": "user", "content": "What's the weather in New York?"}, + ] + + # First response: AI calls get_current_weather + # Second response: AI gives final text + mock_responses = [ + _make_tool_call_response( + [ + ("call_001", "get_current_weather", {"location": "New York"}), + ] + ), + _make_text_response( + "The current weather in New York is 72°F and partly cloudy " + "with northwest winds at 10 mph." + ), + ] + + final = _simulate_tool_call_loop(messages, mock_responses, executor) + + assert "72" in final + assert "New York" in final + + # Verify conversation history structure + roles = [m["role"] for m in messages] + assert roles == ["system", "user", "assistant", "tool", "assistant"] + + # Verify tool call message + tool_call_msg = messages[2] + assert tool_call_msg["role"] == "assistant" + assert len(tool_call_msg["tool_calls"]) == 1 + assert tool_call_msg["tool_calls"][0]["function"]["name"] == "get_current_weather" + + # Verify tool result message + tool_result_msg = messages[3] + assert tool_result_msg["role"] == "tool" + assert tool_result_msg["tool_call_id"] == "call_001" + assert len(tool_result_msg["content"]) > 0 # Has weather data + + # Verify geocoding was called + geocoding_service.geocode_address.assert_called_once_with("New York") + weather_service.get_current_conditions.assert_called_once() + + def test_multiple_tool_calls_flow(self): + """Test: AI calls multiple tools in one response.""" + weather_service, geocoding_service = _create_mock_services() + executor = WeatherToolExecutor(weather_service, geocoding_service) + + messages: list[dict] = [ + {"role": "system", "content": "You are a weather assistant."}, + {"role": "user", "content": "Give me weather and alerts for New York"}, + ] + + mock_responses = [ + _make_tool_call_response( + [ + ("call_001", "get_current_weather", {"location": "New York"}), + ("call_002", "get_alerts", {"location": "New York"}), + ] + ), + _make_text_response("New York is 72°F and partly cloudy. No active weather alerts."), + ] + + final = _simulate_tool_call_loop(messages, mock_responses, executor) + + assert "New York" in final + + # Should have: system, user, assistant(tool_calls), tool, tool, assistant(final) + roles = [m["role"] for m in messages] + assert roles == ["system", "user", "assistant", "tool", "tool", "assistant"] + + # Both tool results present + tool_msgs = [m for m in messages if m["role"] == "tool"] + assert len(tool_msgs) == 2 + assert tool_msgs[0]["tool_call_id"] == "call_001" + assert tool_msgs[1]["tool_call_id"] == "call_002" + + def test_chained_tool_calls_flow(self): + """Test: AI makes a tool call, then makes another tool call based on results.""" + weather_service, geocoding_service = _create_mock_services() + executor = WeatherToolExecutor(weather_service, geocoding_service) + + messages: list[dict] = [ + {"role": "system", "content": "You are a weather assistant."}, + {"role": "user", "content": "What's the weather and forecast for NYC?"}, + ] + + mock_responses = [ + # First: get current weather + _make_tool_call_response( + [ + ("call_001", "get_current_weather", {"location": "NYC"}), + ] + ), + # Second: get forecast + _make_tool_call_response( + [ + ("call_002", "get_forecast", {"location": "NYC"}), + ] + ), + # Third: final text + _make_text_response("NYC is currently 72°F. Today's forecast: partly cloudy."), + ] + + final = _simulate_tool_call_loop(messages, mock_responses, executor) + + assert "NYC" in final + + # system, user, assistant(tc1), tool, assistant(tc2), tool, assistant(final) + roles = [m["role"] for m in messages] + assert roles == ["system", "user", "assistant", "tool", "assistant", "tool", "assistant"] + + # Verify both tool calls and results are in history + tool_call_msgs = [m for m in messages if m["role"] == "assistant" and "tool_calls" in m] + assert len(tool_call_msgs) == 2 + + tool_result_msgs = [m for m in messages if m["role"] == "tool"] + assert len(tool_result_msgs) == 2 + + def test_no_tool_calls_direct_response(self): + """Test: AI responds directly without tool calls (e.g., general question).""" + weather_service, geocoding_service = _create_mock_services() + executor = WeatherToolExecutor(weather_service, geocoding_service) + + messages: list[dict] = [ + {"role": "system", "content": "You are a weather assistant."}, + {"role": "user", "content": "What does humidity mean?"}, + ] + + mock_responses = [ + _make_text_response("Humidity measures the amount of water vapor in the air."), + ] + + final = _simulate_tool_call_loop(messages, mock_responses, executor) + + assert "humidity" in final.lower() + + # No tool messages in history + roles = [m["role"] for m in messages] + assert roles == ["system", "user", "assistant"] + assert not any(m["role"] == "tool" for m in messages) + + def test_tool_execution_error_in_flow(self): + """Test: tool execution fails, error is sent back, AI recovers.""" + weather_service, geocoding_service = _create_mock_services() + # Make geocoding fail + geocoding_service.geocode_address.return_value = None + executor = WeatherToolExecutor(weather_service, geocoding_service) + + messages: list[dict] = [ + {"role": "system", "content": "You are a weather assistant."}, + {"role": "user", "content": "Weather in Nonexistentville?"}, + ] + + mock_responses = [ + _make_tool_call_response( + [ + ("call_001", "get_current_weather", {"location": "Nonexistentville"}), + ] + ), + _make_text_response( + "I couldn't find weather data for Nonexistentville. Could you check the spelling?" + ), + ] + + final = _simulate_tool_call_loop(messages, mock_responses, executor) + + assert "Nonexistentville" in final + + # Tool result should contain error info + tool_result = [m for m in messages if m["role"] == "tool"][0] + assert ( + "could not" in tool_result["content"].lower() + or "error" in tool_result["content"].lower() + or "unable" in tool_result["content"].lower() + ) + + +def _read_system_prompt() -> str: + """ + Read SYSTEM_PROMPT from the source file without importing the module. + + The dialog module imports wx and prism which are unavailable on Linux CI, + so we parse the constant directly from the source. + """ + import ast + from pathlib import Path + + src = Path(__file__).resolve().parent.parent / ( + "src/accessiweather/ui/dialogs/weather_assistant_dialog.py" + ) + tree = ast.parse(src.read_text()) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "SYSTEM_PROMPT": + return ast.literal_eval(node.value) + raise RuntimeError("SYSTEM_PROMPT not found in source") + + +class TestSystemPrompt: + """Tests for the updated system prompt.""" + + def test_system_prompt_mentions_tools(self): + """SYSTEM_PROMPT mentions available weather tools.""" + prompt = _read_system_prompt() + assert "get_current_weather" in prompt + assert "get_forecast" in prompt + assert "get_alerts" in prompt + + def test_system_prompt_guides_tool_usage(self): + """SYSTEM_PROMPT guides AI to use tools for location-specific queries.""" + prompt = _read_system_prompt() + assert "tools" in prompt.lower() + assert "location" in prompt.lower() + assert "fetch" in prompt.lower() + + def test_system_prompt_lists_tool_capabilities(self): + """SYSTEM_PROMPT describes what each tool does.""" + prompt = _read_system_prompt() + assert "current" in prompt.lower() + assert "forecast" in prompt.lower() + assert "alert" in prompt.lower() diff --git a/tests/test_weather_assistant_tool_calls.py b/tests/test_weather_assistant_tool_calls.py new file mode 100644 index 00000000..6e8b02b8 --- /dev/null +++ b/tests/test_weather_assistant_tool_calls.py @@ -0,0 +1,342 @@ +"""Tests for weather assistant tool call loop.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from unittest.mock import MagicMock + +from accessiweather.ai_tools import WEATHER_TOOLS, WeatherToolExecutor + +# --- Mock OpenAI response objects --- + + +@dataclass +class MockFunction: + name: str + arguments: str + + +@dataclass +class MockToolCall: + id: str + type: str + function: MockFunction + + +@dataclass +class MockMessage: + content: str | None = None + tool_calls: list[MockToolCall] | None = None + role: str = "assistant" + + +@dataclass +class MockChoice: + message: MockMessage + finish_reason: str = "stop" + + +@dataclass +class MockResponse: + choices: list[MockChoice] + model: str = "test-model" + + +def _make_tool_call_response( + tool_calls: list[tuple[str, str, dict]], + content: str = "", +) -> MockResponse: + """ + Create a mock response with tool calls. + + Args: + tool_calls: List of (id, function_name, arguments) tuples. + content: Optional content text. + + """ + tc_objs = [ + MockToolCall( + id=tc_id, + type="function", + function=MockFunction(name=name, arguments=json.dumps(args)), + ) + for tc_id, name, args in tool_calls + ] + return MockResponse( + choices=[MockChoice(message=MockMessage(content=content, tool_calls=tc_objs))] + ) + + +def _make_text_response(content: str) -> MockResponse: + """Create a mock response with just text content.""" + return MockResponse(choices=[MockChoice(message=MockMessage(content=content, tool_calls=None))]) + + +class TestToolCallLoop: + """Test the tool call handling logic extracted from do_generate.""" + + def _run_tool_loop( + self, + responses: list[MockResponse], + tool_executor: WeatherToolExecutor | None = None, + ) -> dict: + """ + Simulate the tool call loop logic from do_generate. + + Returns dict with 'content', 'error', 'tool_calls_made', 'iterations'. + """ + messages: list[dict] = [ + {"role": "system", "content": "test"}, + {"role": "user", "content": "What's the weather?"}, + ] + + response_idx = 0 + tool_calls_made: list[tuple[str, dict]] = [] + max_tool_iterations = 5 + content = "" + error = "" + iterations = 0 + + extra_kwargs: dict = {} + if tool_executor is not None: + extra_kwargs["tools"] = WEATHER_TOOLS + + for _iteration in range(max_tool_iterations + 1): + iterations += 1 + # Simulate API call + if response_idx < len(responses): + response = responses[response_idx] + response_idx += 1 + else: + # Return last response again + response = responses[-1] + + if not response.choices: + error = "Empty response" + break + + choice = response.choices[0] + assistant_message = choice.message + + if assistant_message.tool_calls and tool_executor is not None: + tool_call_msg: dict = { + "role": "assistant", + "content": assistant_message.content or "", + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in assistant_message.tool_calls + ], + } + messages.append(tool_call_msg) + + for tool_call in assistant_message.tool_calls: + tool_name = tool_call.function.name + try: + arguments = json.loads(tool_call.function.arguments) + result = tool_executor.execute(tool_name, arguments) + except Exception as exc: + result = f"Error executing {tool_name}: {exc}" + + tool_calls_made.append((tool_name, json.loads(tool_call.function.arguments))) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + } + ) + continue + + content = assistant_message.content or "" + break + else: + # Exhausted iterations + content = assistant_message.content or "" + if not content.strip(): + error = "Too many tool calls" + + return { + "content": content, + "error": error, + "tool_calls_made": tool_calls_made, + "iterations": iterations, + "messages": messages, + } + + def _make_executor(self) -> WeatherToolExecutor: + """Create a mock WeatherToolExecutor.""" + weather_service = MagicMock() + weather_service.get_current_conditions.return_value = { + "temperature": "72°F", + "humidity": "45%", + "description": "Clear", + } + weather_service.get_forecast.return_value = { + "periods": [{"name": "Tonight", "detailedForecast": "Clear skies"}] + } + weather_service.get_alerts.return_value = {"alerts": []} + + geocoding_service = MagicMock() + geocoding_service.geocode_address.return_value = (40.0, -74.0, "New York, NY") + + return WeatherToolExecutor(weather_service, geocoding_service) + + def test_no_tool_calls_returns_text_directly(self): + """When response has no tool_calls, return text immediately.""" + result = self._run_tool_loop( + [_make_text_response("It's sunny today!")], + tool_executor=self._make_executor(), + ) + assert result["content"] == "It's sunny today!" + assert result["tool_calls_made"] == [] + assert result["iterations"] == 1 + + def test_single_tool_call_then_text(self): + """One tool call followed by a text response.""" + responses = [ + _make_tool_call_response([("tc1", "get_current_weather", {"location": "NYC"})]), + _make_text_response("The current temperature in NYC is 72°F with clear skies."), + ] + result = self._run_tool_loop(responses, tool_executor=self._make_executor()) + assert "72°F" in result["content"] + assert len(result["tool_calls_made"]) == 1 + assert result["tool_calls_made"][0] == ("get_current_weather", {"location": "NYC"}) + assert result["iterations"] == 2 + + def test_multiple_tool_calls_in_one_response(self): + """Multiple tool calls in a single response message.""" + responses = [ + _make_tool_call_response( + [ + ("tc1", "get_current_weather", {"location": "NYC"}), + ("tc2", "get_forecast", {"location": "NYC"}), + ] + ), + _make_text_response("Here's a full weather summary."), + ] + result = self._run_tool_loop(responses, tool_executor=self._make_executor()) + assert result["content"] == "Here's a full weather summary." + assert len(result["tool_calls_made"]) == 2 + assert result["iterations"] == 2 + + def test_chained_tool_calls(self): + """Multiple sequential tool call rounds.""" + responses = [ + _make_tool_call_response([("tc1", "get_current_weather", {"location": "NYC"})]), + _make_tool_call_response([("tc2", "get_alerts", {"location": "NYC"})]), + _make_text_response("Weather is fine, no alerts."), + ] + result = self._run_tool_loop(responses, tool_executor=self._make_executor()) + assert result["content"] == "Weather is fine, no alerts." + assert len(result["tool_calls_made"]) == 2 + assert result["iterations"] == 3 + + def test_max_iterations_prevents_infinite_loop(self): + """Loop terminates after max 5 tool call iterations.""" + # All responses have tool calls — should stop after 6 iterations (0..5) + responses = [ + _make_tool_call_response([("tc", "get_current_weather", {"location": "NYC"})]) + for _ in range(10) + ] + result = self._run_tool_loop(responses, tool_executor=self._make_executor()) + assert result["iterations"] == 6 # 0..5 inclusive = 6 + assert len(result["tool_calls_made"]) == 6 + + def test_tool_execution_error_handled_gracefully(self): + """Error during tool execution is sent back as tool result.""" + executor = self._make_executor() + # Make geocoding fail + executor.geocoding_service.geocode_address.return_value = None + + responses = [ + _make_tool_call_response([("tc1", "get_current_weather", {"location": "Nowhere"})]), + _make_text_response("I couldn't look up that location."), + ] + result = self._run_tool_loop(responses, tool_executor=executor) + assert result["content"] == "I couldn't look up that location." + # Check that error message was sent as tool result + tool_msgs = [m for m in result["messages"] if m.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert "Error" in tool_msgs[0]["content"] or "Could not resolve" in tool_msgs[0]["content"] + + def test_no_executor_skips_tools(self): + """When tool_executor is None, tools are not used.""" + result = self._run_tool_loop( + [_make_text_response("Just text.")], + tool_executor=None, + ) + assert result["content"] == "Just text." + assert result["tool_calls_made"] == [] + + def test_tool_results_appended_as_tool_role_messages(self): + """Tool results are added to messages with role=tool.""" + responses = [ + _make_tool_call_response([("tc1", "get_current_weather", {"location": "NYC"})]), + _make_text_response("Done."), + ] + result = self._run_tool_loop(responses, tool_executor=self._make_executor()) + tool_msgs = [m for m in result["messages"] if m.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["tool_call_id"] == "tc1" + assert "New York" in tool_msgs[0]["content"] + + def test_assistant_tool_call_message_preserved(self): + """The assistant message with tool_calls is added to messages.""" + responses = [ + _make_tool_call_response([("tc1", "get_forecast", {"location": "NYC"})]), + _make_text_response("Forecast looks good."), + ] + result = self._run_tool_loop(responses, tool_executor=self._make_executor()) + assistant_with_tools = [ + m for m in result["messages"] if m.get("role") == "assistant" and m.get("tool_calls") + ] + assert len(assistant_with_tools) == 1 + assert assistant_with_tools[0]["tool_calls"][0]["function"]["name"] == "get_forecast" + + def test_tools_parameter_included_when_executor_available(self): + """Verify the WEATHER_TOOLS are passed when executor is available.""" + # This test checks the extra_kwargs logic + executor = self._make_executor() + extra_kwargs: dict = {} + if executor is not None: + extra_kwargs["tools"] = WEATHER_TOOLS + assert "tools" in extra_kwargs + assert len(extra_kwargs["tools"]) == 11 + + def test_get_alerts_tool_call(self): + """Test that get_alerts tool call works in the loop.""" + executor = self._make_executor() + responses = [ + _make_tool_call_response([("tc1", "get_alerts", {"location": "Miami"})]), + _make_text_response("No active alerts for Miami."), + ] + result = self._run_tool_loop(responses, tool_executor=executor) + assert result["tool_calls_made"][0] == ("get_alerts", {"location": "Miami"}) + assert result["content"] == "No active alerts for Miami." + + +class TestGetToolExecutor: + """Test _get_tool_executor method.""" + + def test_returns_executor_when_services_available(self): + """Returns a WeatherToolExecutor when constructed with valid services.""" + weather_service = MagicMock() + geocoding_service = MagicMock() + + executor = WeatherToolExecutor(weather_service, geocoding_service) + assert executor is not None + assert executor.weather_service is weather_service + assert executor.geocoding_service is geocoding_service + + def test_returns_none_for_missing_service(self): + """Returns None when weather_service is not available.""" + app = MagicMock(spec=[]) # No attributes + assert getattr(app, "weather_service", None) is None