diff --git a/app/backend/database/connection.py b/app/backend/database/connection.py index 46a4fd2ed..22355026f 100644 --- a/app/backend/database/connection.py +++ b/app/backend/database/connection.py @@ -1,6 +1,5 @@ from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, declarative_base import os from pathlib import Path diff --git a/app/backend/main.py b/app/backend/main.py index 770f4a18c..2f3acda6e 100644 --- a/app/backend/main.py +++ b/app/backend/main.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware import logging @@ -12,26 +13,11 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -app = FastAPI(title="AI Hedge Fund API", description="Backend API for AI Hedge Fund", version="0.1.0") -# Initialize database tables (this is safe to run multiple times) -Base.metadata.create_all(bind=engine) - -# Configure CORS -app.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"], # Frontend URLs - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Include all routes -app.include_router(api_router) - -@app.on_event("startup") -async def startup_event(): - """Startup event to check Ollama availability.""" +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for startup and shutdown events.""" + # Startup try: logger.info("Checking Ollama availability...") status = await ollama_service.check_ollama_status() @@ -53,3 +39,31 @@ async def startup_event(): except Exception as e: logger.warning(f"Could not check Ollama status: {e}") logger.info("ℹ Ollama integration is available if you install it later") + + yield + + # Shutdown (cleanup if needed) + logger.info("Shutting down AI Hedge Fund API...") + + +app = FastAPI( + title="AI Hedge Fund API", + description="Backend API for AI Hedge Fund", + version="0.1.0", + lifespan=lifespan +) + +# Initialize database tables (this is safe to run multiple times) +Base.metadata.create_all(bind=engine) + +# Configure CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"], # Frontend URLs + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Include all routes +app.include_router(api_router) diff --git a/pyproject.toml b/pyproject.toml index 7faa9432f..9d0d3f4fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.black] -line-length = 420 +line-length = 120 target-version = ['py311'] include = '\.pyi?$' diff --git a/src/agents/news_sentiment.py b/src/agents/news_sentiment.py index 07e4ca4ab..a9d9c083b 100644 --- a/src/agents/news_sentiment.py +++ b/src/agents/news_sentiment.py @@ -1,5 +1,4 @@ - - +from concurrent.futures import ThreadPoolExecutor, as_completed from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from src.data.models import CompanyNews @@ -22,13 +21,49 @@ class Sentiment(BaseModel): confidence: int = Field(description="Confidence 0-100") +def _analyze_single_article( + news: CompanyNews, + ticker: str, + agent_id: str, + state: AgentState +) -> tuple[CompanyNews, int]: + """ + Analyze sentiment for a single news article. + + Args: + news: The news article to analyze + ticker: The stock ticker symbol + agent_id: The agent ID for LLM calls + state: The agent state + + Returns: + Tuple of (updated news article, confidence score) + """ + prompt = ( + f"Analyze the sentiment of this news headline for stock {ticker}. " + f"Determine if it's 'positive', 'negative', or 'neutral' for the stock. " + f"Provide a confidence score from 0 to 100. Respond in JSON format.\n\n" + f"Headline: {news.title}" + ) + + response = call_llm(prompt, Sentiment, agent_name=agent_id, state=state) + + if response: + news.sentiment = response.sentiment.lower() + return news, response.confidence + else: + news.sentiment = "neutral" + return news, 0 + + def news_sentiment_agent(state: AgentState, agent_id: str = "news_sentiment_agent"): """ Analyzes news sentiment for a list of tickers and generates trading signals. This agent fetches company news, uses an LLM to classify the sentiment of articles - with missing sentiment data, and then aggregates the sentiments to produce an - overall signal (bullish, bearish, or neutral) and a confidence score for each ticker. + with missing sentiment data (in parallel for speed), and then aggregates the + sentiments to produce an overall signal (bullish, bearish, or neutral) and a + confidence score for each ticker. Args: state: The current state of the agent graph. @@ -60,40 +95,46 @@ def news_sentiment_agent(state: AgentState, agent_id: str = "news_sentiment_agen recent_articles = company_news[:10] articles_without_sentiment = [news for news in recent_articles if news.sentiment is None] - # Analyze only the 5 most recent articles without sentiment to reduce LLM calls + # Analyze articles without sentiment using parallel processing sentiments_classified_by_llm = 0 if articles_without_sentiment: - # We only take the first 5 articles, but this is configurable - num_articles_to_analyze = 5 - articles_to_analyze = articles_without_sentiment[:num_articles_to_analyze] - progress.update_status(agent_id, ticker, f"Analyzing sentiment for {len(articles_to_analyze)} articles") - - for idx, news in enumerate(articles_to_analyze): - # We analyze based on title, but can also pass in the entire article text, - # but this is more expensive and requires extracting the text from the article. - # Note: this is an opportunity for improvement! - progress.update_status(agent_id, ticker, f"Analyzing sentiment for article {idx + 1} of {len(articles_to_analyze)}") - prompt = ( - f"Please analyze the sentiment of the following news headline " - f"with the following context: " - f"The stock is {ticker}. " - f"Determine if sentiment is 'positive', 'negative', or 'neutral' for the stock {ticker} only. " - f"Also provide a confidence score for your prediction from 0 to 100. " - f"Respond in JSON format.\n\n" - f"Headline: {news.title}" + num_articles_to_analyze = min(5, len(articles_without_sentiment)) + articles_to_analyze = articles_without_sentiment[:num_articles_to_analyze] + progress.update_status( + agent_id, + ticker, + f"Analyzing sentiment for {len(articles_to_analyze)} articles (parallel)" ) - response = call_llm(prompt, Sentiment, agent_name=agent_id, state=state) - if response: - news.sentiment = response.sentiment.lower() - sentiment_confidences[id(news)] = response.confidence - else: - news.sentiment = "neutral" - sentiment_confidences[id(news)] = 0 - sentiments_classified_by_llm += 1 + + # Use ThreadPoolExecutor for parallel LLM calls + with ThreadPoolExecutor(max_workers=3) as executor: + futures = { + executor.submit( + _analyze_single_article, + news, + ticker, + agent_id, + state + ): news for news in articles_to_analyze + } + + for future in as_completed(futures): + try: + updated_news, confidence = future.result() + sentiment_confidences[id(updated_news)] = confidence + sentiments_classified_by_llm += 1 + except Exception as e: + # If analysis fails, mark as neutral + original_news = futures[future] + original_news.sentiment = "neutral" + sentiment_confidences[id(original_news)] = 0 # Aggregate sentiment across all articles sentiment = pd.Series([n.sentiment for n in company_news]).dropna() - news_signals = np.where(sentiment == "negative","bearish", np.where(sentiment == "positive", "bullish", "neutral")).tolist() + news_signals = np.where( + sentiment == "negative", "bearish", + np.where(sentiment == "positive", "bullish", "neutral") + ).tolist() progress.update_status(agent_id, ticker, "Aggregating signals") diff --git a/src/data/cache.py b/src/data/cache.py index 4127934e3..4057f7fca 100644 --- a/src/data/cache.py +++ b/src/data/cache.py @@ -1,12 +1,34 @@ +import time +from typing import Any + + +class CacheEntry: + """A cache entry with TTL support.""" + + def __init__(self, data: list[dict[str, Any]], ttl_seconds: int = 3600): + self.data = data + self.created_at = time.time() + self.ttl_seconds = ttl_seconds + + def is_expired(self) -> bool: + """Check if the cache entry has expired.""" + return time.time() - self.created_at > self.ttl_seconds + + class Cache: - """In-memory cache for API responses.""" + """In-memory cache for API responses with TTL support.""" + + # Default TTL: 1 hour for most data, 24 hours for less volatile data + DEFAULT_TTL = 3600 # 1 hour + METRICS_TTL = 86400 # 24 hours (metrics don't change frequently) + NEWS_TTL = 1800 # 30 minutes (news is more time-sensitive) def __init__(self): - self._prices_cache: dict[str, list[dict[str, any]]] = {} - self._financial_metrics_cache: dict[str, list[dict[str, any]]] = {} - self._line_items_cache: dict[str, list[dict[str, any]]] = {} - self._insider_trades_cache: dict[str, list[dict[str, any]]] = {} - self._company_news_cache: dict[str, list[dict[str, any]]] = {} + self._prices_cache: dict[str, CacheEntry] = {} + self._financial_metrics_cache: dict[str, CacheEntry] = {} + self._line_items_cache: dict[str, CacheEntry] = {} + self._insider_trades_cache: dict[str, CacheEntry] = {} + self._company_news_cache: dict[str, CacheEntry] = {} def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_field: str) -> list[dict]: """Merge existing and new data, avoiding duplicates based on a key field.""" @@ -21,45 +43,89 @@ def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_fie merged.extend([item for item in new_data if item[key_field] not in existing_keys]) return merged - def get_prices(self, ticker: str) -> list[dict[str, any]] | None: - """Get cached price data if available.""" - return self._prices_cache.get(ticker) - - def set_prices(self, ticker: str, data: list[dict[str, any]]): - """Append new price data to cache.""" - self._prices_cache[ticker] = self._merge_data(self._prices_cache.get(ticker), data, key_field="time") - - def get_financial_metrics(self, ticker: str) -> list[dict[str, any]]: - """Get cached financial metrics if available.""" - return self._financial_metrics_cache.get(ticker) - - def set_financial_metrics(self, ticker: str, data: list[dict[str, any]]): - """Append new financial metrics to cache.""" - self._financial_metrics_cache[ticker] = self._merge_data(self._financial_metrics_cache.get(ticker), data, key_field="report_period") - - def get_line_items(self, ticker: str) -> list[dict[str, any]] | None: - """Get cached line items if available.""" - return self._line_items_cache.get(ticker) - - def set_line_items(self, ticker: str, data: list[dict[str, any]]): - """Append new line items to cache.""" - self._line_items_cache[ticker] = self._merge_data(self._line_items_cache.get(ticker), data, key_field="report_period") - - def get_insider_trades(self, ticker: str) -> list[dict[str, any]] | None: - """Get cached insider trades if available.""" - return self._insider_trades_cache.get(ticker) - - def set_insider_trades(self, ticker: str, data: list[dict[str, any]]): - """Append new insider trades to cache.""" - self._insider_trades_cache[ticker] = self._merge_data(self._insider_trades_cache.get(ticker), data, key_field="filing_date") # Could also use transaction_date if preferred - - def get_company_news(self, ticker: str) -> list[dict[str, any]] | None: - """Get cached company news if available.""" - return self._company_news_cache.get(ticker) - - def set_company_news(self, ticker: str, data: list[dict[str, any]]): - """Append new company news to cache.""" - self._company_news_cache[ticker] = self._merge_data(self._company_news_cache.get(ticker), data, key_field="date") + def _cleanup_expired(self, cache_dict: dict[str, CacheEntry]) -> None: + """Remove expired entries from a cache dictionary.""" + expired_keys = [key for key, entry in cache_dict.items() if entry.is_expired()] + for key in expired_keys: + del cache_dict[key] + + def get_prices(self, ticker: str) -> list[dict[str, Any]] | None: + """Get cached price data if available and not expired.""" + self._cleanup_expired(self._prices_cache) + entry = self._prices_cache.get(ticker) + if entry and not entry.is_expired(): + return entry.data + return None + + def set_prices(self, ticker: str, data: list[dict[str, Any]]): + """Append new price data to cache with TTL.""" + existing_data = self.get_prices(ticker) + merged = self._merge_data(existing_data, data, key_field="time") + self._prices_cache[ticker] = CacheEntry(merged, ttl_seconds=self.DEFAULT_TTL) + + def get_financial_metrics(self, ticker: str) -> list[dict[str, Any]] | None: + """Get cached financial metrics if available and not expired.""" + self._cleanup_expired(self._financial_metrics_cache) + entry = self._financial_metrics_cache.get(ticker) + if entry and not entry.is_expired(): + return entry.data + return None + + def set_financial_metrics(self, ticker: str, data: list[dict[str, Any]]): + """Append new financial metrics to cache with TTL.""" + existing_data = self.get_financial_metrics(ticker) + merged = self._merge_data(existing_data, data, key_field="report_period") + self._financial_metrics_cache[ticker] = CacheEntry(merged, ttl_seconds=self.METRICS_TTL) + + def get_line_items(self, ticker: str) -> list[dict[str, Any]] | None: + """Get cached line items if available and not expired.""" + self._cleanup_expired(self._line_items_cache) + entry = self._line_items_cache.get(ticker) + if entry and not entry.is_expired(): + return entry.data + return None + + def set_line_items(self, ticker: str, data: list[dict[str, Any]]): + """Append new line items to cache with TTL.""" + existing_data = self.get_line_items(ticker) + merged = self._merge_data(existing_data, data, key_field="report_period") + self._line_items_cache[ticker] = CacheEntry(merged, ttl_seconds=self.METRICS_TTL) + + def get_insider_trades(self, ticker: str) -> list[dict[str, Any]] | None: + """Get cached insider trades if available and not expired.""" + self._cleanup_expired(self._insider_trades_cache) + entry = self._insider_trades_cache.get(ticker) + if entry and not entry.is_expired(): + return entry.data + return None + + def set_insider_trades(self, ticker: str, data: list[dict[str, Any]]): + """Append new insider trades to cache with TTL.""" + existing_data = self.get_insider_trades(ticker) + merged = self._merge_data(existing_data, data, key_field="filing_date") + self._insider_trades_cache[ticker] = CacheEntry(merged, ttl_seconds=self.DEFAULT_TTL) + + def get_company_news(self, ticker: str) -> list[dict[str, Any]] | None: + """Get cached company news if available and not expired.""" + self._cleanup_expired(self._company_news_cache) + entry = self._company_news_cache.get(ticker) + if entry and not entry.is_expired(): + return entry.data + return None + + def set_company_news(self, ticker: str, data: list[dict[str, Any]]): + """Append new company news to cache with TTL.""" + existing_data = self.get_company_news(ticker) + merged = self._merge_data(existing_data, data, key_field="date") + self._company_news_cache[ticker] = CacheEntry(merged, ttl_seconds=self.NEWS_TTL) + + def clear(self) -> None: + """Clear all caches.""" + self._prices_cache.clear() + self._financial_metrics_cache.clear() + self._line_items_cache.clear() + self._insider_trades_cache.clear() + self._company_news_cache.clear() # Global cache instance diff --git a/src/utils/llm.py b/src/utils/llm.py index c7535d5d2..88e4c7459 100644 --- a/src/utils/llm.py +++ b/src/utils/llm.py @@ -1,22 +1,33 @@ """Helper functions for LLM""" import json +import logging +import time +from typing import Any + from pydantic import BaseModel from src.llm.models import get_model, get_model_info from src.utils.progress import progress from src.graph.state import AgentState +# Configure logging +logger = logging.getLogger(__name__) + def call_llm( - prompt: any, + prompt: Any, pydantic_model: type[BaseModel], agent_name: str | None = None, state: AgentState | None = None, max_retries: int = 3, default_factory=None, + temperature: float | None = None, ) -> BaseModel: """ - Makes an LLM call with retry logic, handling both JSON supported and non-JSON supported models. + Makes an LLM call with retry logic and exponential backoff. + + Handles both JSON-supported and non-JSON-supported models, with automatic + parsing and validation against the provided Pydantic model. Args: prompt: The prompt to send to the LLM @@ -25,6 +36,7 @@ def call_llm( state: Optional state object to extract agent-specific model configuration max_retries: Maximum number of retries (default: 3) default_factory: Optional factory function to create default response on failure + temperature: Optional temperature override for the model (0.0-1.0) Returns: An instance of the specified Pydantic model @@ -48,14 +60,15 @@ def call_llm( model_info = get_model_info(model_name, model_provider) llm = get_model(model_name, model_provider, api_keys) - # For non-JSON support models, we can use structured output + # For JSON-supporting models, use structured output if not (model_info and not model_info.has_json_mode()): llm = llm.with_structured_output( pydantic_model, method="json_mode", ) - # Call the LLM with retries + # Call the LLM with retries and exponential backoff + last_error = None for attempt in range(max_retries): try: # Call the LLM @@ -66,15 +79,33 @@ def call_llm( parsed_result = extract_json_from_response(result.content) if parsed_result: return pydantic_model(**parsed_result) + else: + raise ValueError("Failed to extract JSON from LLM response") else: return result except Exception as e: + last_error = e + + # Calculate exponential backoff delay: 1s, 2s, 4s, ... + delay = min(2 ** attempt, 10) # Cap at 10 seconds + if agent_name: - progress.update_status(agent_name, None, f"Error - retry {attempt + 1}/{max_retries}") - - if attempt == max_retries - 1: - print(f"Error in LLM call after {max_retries} attempts: {e}") + progress.update_status( + agent_name, + None, + f"Error - retry {attempt + 1}/{max_retries} in {delay}s" + ) + + logger.warning( + f"LLM call failed (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {delay}s..." + ) + + if attempt < max_retries - 1: + time.sleep(delay) + else: + logger.error(f"LLM call failed after {max_retries} attempts: {last_error}") # Use default_factory if provided, otherwise create a basic default if default_factory: return default_factory() @@ -107,17 +138,79 @@ def create_default_response(model_class: type[BaseModel]) -> BaseModel: def extract_json_from_response(content: str) -> dict | None: - """Extracts JSON from markdown-formatted response.""" + """ + Extracts JSON from LLM response, handling multiple formats. + + Supports: + - JSON wrapped in ```json ... ``` code blocks + - JSON wrapped in ``` ... ``` code blocks + - Raw JSON without code blocks + - JSON with leading/trailing text + + Args: + content: The raw string response from the LLM + + Returns: + Parsed JSON as a dictionary, or None if parsing fails + """ + if not content: + return None + + content = content.strip() + + # Try 1: Extract from ```json code block try: json_start = content.find("```json") if json_start != -1: - json_text = content[json_start + 7 :] # Skip past ```json + json_text = content[json_start + 7:] json_end = json_text.find("```") if json_end != -1: json_text = json_text[:json_end].strip() return json.loads(json_text) - except Exception as e: - print(f"Error extracting JSON from response: {e}") + except json.JSONDecodeError: + pass + + # Try 2: Extract from ``` code block (no language specified) + try: + json_start = content.find("```") + if json_start != -1: + json_text = content[json_start + 3:] + json_end = json_text.find("```") + if json_end != -1: + json_text = json_text[:json_end].strip() + # Skip language identifier if present on first line + if json_text and not json_text.startswith('{'): + first_newline = json_text.find('\n') + if first_newline != -1: + json_text = json_text[first_newline:].strip() + return json.loads(json_text) + except json.JSONDecodeError: + pass + + # Try 3: Parse the entire content as JSON + try: + return json.loads(content) + except json.JSONDecodeError: + pass + + # Try 4: Find JSON object boundaries { ... } + try: + start = content.find('{') + if start != -1: + # Find matching closing brace + brace_count = 0 + for i, char in enumerate(content[start:], start): + if char == '{': + brace_count += 1 + elif char == '}': + brace_count -= 1 + if brace_count == 0: + json_text = content[start:i+1] + return json.loads(json_text) + except json.JSONDecodeError: + pass + + logger.warning(f"Failed to extract JSON from response: {content[:200]}...") return None diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 000000000..301ffddbd --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,366 @@ +""" +Tests for the cache module. + +This module tests the in-memory caching functionality including: +- Basic get/set operations +- TTL-based expiration +- Data deduplication on merge +- Cache cleanup +""" + +import time +import pytest +from unittest.mock import patch + +from src.data.cache import Cache, CacheEntry, get_cache + + +class TestCacheEntry: + """Tests for the CacheEntry class.""" + + def test_cache_entry_creation(self): + """Test that a cache entry is created with correct data.""" + data = [{"time": "2024-01-01", "price": 100.0}] + entry = CacheEntry(data, ttl_seconds=3600) + + assert entry.data == data + assert entry.ttl_seconds == 3600 + assert entry.created_at <= time.time() + + def test_cache_entry_not_expired(self): + """Test that a fresh cache entry is not expired.""" + data = [{"time": "2024-01-01", "price": 100.0}] + entry = CacheEntry(data, ttl_seconds=3600) + + assert entry.is_expired() is False + + def test_cache_entry_expired(self): + """Test that an old cache entry is correctly marked as expired.""" + data = [{"time": "2024-01-01", "price": 100.0}] + entry = CacheEntry(data, ttl_seconds=1) + + # Mock time to simulate expiration + with patch.object(entry, 'created_at', time.time() - 2): + assert entry.is_expired() is True + + def test_cache_entry_with_zero_ttl(self): + """Test that a cache entry with zero TTL expires immediately.""" + data = [{"time": "2024-01-01", "price": 100.0}] + entry = CacheEntry(data, ttl_seconds=0) + + # Even with 0 TTL, it shouldn't be expired in the same instant + # But after any time passes, it should be + with patch.object(entry, 'created_at', time.time() - 0.001): + assert entry.is_expired() is True + + +class TestCachePrices: + """Tests for price caching functionality.""" + + def test_set_and_get_prices(self): + """Test basic set and get for prices.""" + cache = Cache() + prices = [ + {"time": "2024-01-01", "close": 100.0}, + {"time": "2024-01-02", "close": 101.0}, + ] + + cache.set_prices("AAPL", prices) + result = cache.get_prices("AAPL") + + assert result == prices + + def test_get_prices_returns_none_for_missing_ticker(self): + """Test that getting prices for a non-existent ticker returns None.""" + cache = Cache() + + result = cache.get_prices("UNKNOWN") + + assert result is None + + def test_prices_expiration(self): + """Test that expired price data is not returned.""" + cache = Cache() + prices = [{"time": "2024-01-01", "close": 100.0}] + + cache.set_prices("AAPL", prices) + + # Mock the cache entry to be expired + cache._prices_cache["AAPL"].created_at = time.time() - cache.DEFAULT_TTL - 1 + + result = cache.get_prices("AAPL") + + assert result is None + + def test_prices_merge_deduplication(self): + """Test that duplicate prices are not added on merge.""" + cache = Cache() + + # Set initial prices + initial_prices = [ + {"time": "2024-01-01", "close": 100.0}, + {"time": "2024-01-02", "close": 101.0}, + ] + cache.set_prices("AAPL", initial_prices) + + # Add overlapping + new prices + new_prices = [ + {"time": "2024-01-02", "close": 101.0}, # Duplicate + {"time": "2024-01-03", "close": 102.0}, # New + ] + cache.set_prices("AAPL", new_prices) + + result = cache.get_prices("AAPL") + + # Should have 3 unique entries, not 4 + assert len(result) == 3 + times = [p["time"] for p in result] + assert times == ["2024-01-01", "2024-01-02", "2024-01-03"] + + +class TestCacheFinancialMetrics: + """Tests for financial metrics caching functionality.""" + + def test_set_and_get_financial_metrics(self): + """Test basic set and get for financial metrics.""" + cache = Cache() + metrics = [ + {"report_period": "2024-Q1", "revenue": 1000000}, + {"report_period": "2024-Q2", "revenue": 1100000}, + ] + + cache.set_financial_metrics("AAPL", metrics) + result = cache.get_financial_metrics("AAPL") + + assert result == metrics + + def test_financial_metrics_uses_longer_ttl(self): + """Test that financial metrics use the METRICS_TTL (24 hours).""" + cache = Cache() + metrics = [{"report_period": "2024-Q1", "revenue": 1000000}] + + cache.set_financial_metrics("AAPL", metrics) + + # Should still be valid after DEFAULT_TTL (1 hour) + cache._financial_metrics_cache["AAPL"].created_at = time.time() - cache.DEFAULT_TTL - 1 + result = cache.get_financial_metrics("AAPL") + + assert result is not None # Still valid because METRICS_TTL > DEFAULT_TTL + + +class TestCacheCompanyNews: + """Tests for company news caching functionality.""" + + def test_set_and_get_company_news(self): + """Test basic set and get for company news.""" + cache = Cache() + news = [ + {"date": "2024-01-01", "title": "News 1"}, + {"date": "2024-01-02", "title": "News 2"}, + ] + + cache.set_company_news("AAPL", news) + result = cache.get_company_news("AAPL") + + assert result == news + + def test_company_news_uses_shorter_ttl(self): + """Test that company news uses the shorter NEWS_TTL (30 minutes).""" + cache = Cache() + news = [{"date": "2024-01-01", "title": "Breaking News"}] + + cache.set_company_news("AAPL", news) + + # Should be expired after NEWS_TTL + cache._company_news_cache["AAPL"].created_at = time.time() - cache.NEWS_TTL - 1 + result = cache.get_company_news("AAPL") + + assert result is None + + +class TestCacheInsiderTrades: + """Tests for insider trades caching functionality.""" + + def test_set_and_get_insider_trades(self): + """Test basic set and get for insider trades.""" + cache = Cache() + trades = [ + {"filing_date": "2024-01-01", "shares": 1000}, + {"filing_date": "2024-01-02", "shares": 500}, + ] + + cache.set_insider_trades("AAPL", trades) + result = cache.get_insider_trades("AAPL") + + assert result == trades + + def test_insider_trades_merge_deduplication(self): + """Test that duplicate insider trades are not added.""" + cache = Cache() + + initial_trades = [{"filing_date": "2024-01-01", "shares": 1000}] + cache.set_insider_trades("AAPL", initial_trades) + + new_trades = [ + {"filing_date": "2024-01-01", "shares": 1000}, # Duplicate + {"filing_date": "2024-01-02", "shares": 500}, # New + ] + cache.set_insider_trades("AAPL", new_trades) + + result = cache.get_insider_trades("AAPL") + + assert len(result) == 2 + + +class TestCacheLineItems: + """Tests for line items caching functionality.""" + + def test_set_and_get_line_items(self): + """Test basic set and get for line items.""" + cache = Cache() + line_items = [ + {"report_period": "2024-Q1", "net_income": 50000}, + {"report_period": "2024-Q2", "net_income": 55000}, + ] + + cache.set_line_items("AAPL", line_items) + result = cache.get_line_items("AAPL") + + assert result == line_items + + +class TestCacheClear: + """Tests for cache clearing functionality.""" + + def test_clear_removes_all_data(self): + """Test that clear() removes all cached data.""" + cache = Cache() + + # Populate all cache types + cache.set_prices("AAPL", [{"time": "2024-01-01", "close": 100}]) + cache.set_financial_metrics("AAPL", [{"report_period": "2024-Q1", "revenue": 1000}]) + cache.set_line_items("AAPL", [{"report_period": "2024-Q1", "net_income": 100}]) + cache.set_insider_trades("AAPL", [{"filing_date": "2024-01-01", "shares": 100}]) + cache.set_company_news("AAPL", [{"date": "2024-01-01", "title": "News"}]) + + cache.clear() + + assert cache.get_prices("AAPL") is None + assert cache.get_financial_metrics("AAPL") is None + assert cache.get_line_items("AAPL") is None + assert cache.get_insider_trades("AAPL") is None + assert cache.get_company_news("AAPL") is None + + +class TestCacheMergeData: + """Tests for the _merge_data helper method.""" + + def test_merge_with_none_existing(self): + """Test merge when existing data is None.""" + cache = Cache() + new_data = [{"time": "2024-01-01", "value": 1}] + + result = cache._merge_data(None, new_data, "time") + + assert result == new_data + + def test_merge_with_empty_existing(self): + """Test merge when existing data is empty list.""" + cache = Cache() + new_data = [{"time": "2024-01-01", "value": 1}] + + # Empty list is falsy, so treated same as None + result = cache._merge_data([], new_data, "time") + + assert result == new_data + + def test_merge_preserves_order(self): + """Test that merge preserves order of existing data.""" + cache = Cache() + existing = [ + {"time": "2024-01-01", "value": 1}, + {"time": "2024-01-02", "value": 2}, + ] + new_data = [{"time": "2024-01-03", "value": 3}] + + result = cache._merge_data(existing, new_data, "time") + + assert result[0]["time"] == "2024-01-01" + assert result[1]["time"] == "2024-01-02" + assert result[2]["time"] == "2024-01-03" + + +class TestCacheCleanupExpired: + """Tests for expired entry cleanup.""" + + def test_cleanup_removes_expired_entries(self): + """Test that cleanup removes expired entries from cache.""" + cache = Cache() + + # Add two tickers + cache.set_prices("AAPL", [{"time": "2024-01-01", "close": 100}]) + cache.set_prices("GOOGL", [{"time": "2024-01-01", "close": 200}]) + + # Expire only AAPL + cache._prices_cache["AAPL"].created_at = time.time() - cache.DEFAULT_TTL - 1 + + # Access GOOGL (this triggers cleanup) + result = cache.get_prices("GOOGL") + + # GOOGL should still be there + assert result is not None + # AAPL should be cleaned up + assert "AAPL" not in cache._prices_cache + + +class TestGlobalCache: + """Tests for the global cache instance.""" + + def test_get_cache_returns_cache_instance(self): + """Test that get_cache returns a Cache instance.""" + cache = get_cache() + + assert isinstance(cache, Cache) + + def test_get_cache_returns_same_instance(self): + """Test that get_cache returns the same global instance.""" + cache1 = get_cache() + cache2 = get_cache() + + assert cache1 is cache2 + + +class TestCacheMultipleTickers: + """Tests for caching data for multiple tickers.""" + + def test_different_tickers_stored_separately(self): + """Test that different tickers have separate cache entries.""" + cache = Cache() + + aapl_prices = [{"time": "2024-01-01", "close": 100}] + googl_prices = [{"time": "2024-01-01", "close": 200}] + + cache.set_prices("AAPL", aapl_prices) + cache.set_prices("GOOGL", googl_prices) + + assert cache.get_prices("AAPL")[0]["close"] == 100 + assert cache.get_prices("GOOGL")[0]["close"] == 200 + + def test_updating_one_ticker_doesnt_affect_others(self): + """Test that updating one ticker doesn't affect other tickers.""" + cache = Cache() + + cache.set_prices("AAPL", [{"time": "2024-01-01", "close": 100}]) + cache.set_prices("GOOGL", [{"time": "2024-01-01", "close": 200}]) + + # Update AAPL + cache.set_prices("AAPL", [{"time": "2024-01-02", "close": 110}]) + + # GOOGL should be unchanged + googl_result = cache.get_prices("GOOGL") + assert len(googl_result) == 1 + assert googl_result[0]["close"] == 200 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])