Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions app/backend/database/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker, declarative_base
import os
from pathlib import Path

Expand Down
52 changes: 33 additions & 19 deletions app/backend/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import logging
Expand All @@ -12,26 +13,11 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="AI Hedge Fund API", description="Backend API for AI Hedge Fund", version="0.1.0")

# Initialize database tables (this is safe to run multiple times)
Base.metadata.create_all(bind=engine)

# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"], # Frontend URLs
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Include all routes
app.include_router(api_router)

@app.on_event("startup")
async def startup_event():
"""Startup event to check Ollama availability."""
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
# Startup
try:
logger.info("Checking Ollama availability...")
status = await ollama_service.check_ollama_status()
Expand All @@ -53,3 +39,31 @@ async def startup_event():
except Exception as e:
logger.warning(f"Could not check Ollama status: {e}")
logger.info("ℹ Ollama integration is available if you install it later")

yield

# Shutdown (cleanup if needed)
logger.info("Shutting down AI Hedge Fund API...")


app = FastAPI(
title="AI Hedge Fund API",
description="Backend API for AI Hedge Fund",
version="0.1.0",
lifespan=lifespan
)

# Initialize database tables (this is safe to run multiple times)
Base.metadata.create_all(bind=engine)

# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"], # Frontend URLs
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Include all routes
app.include_router(api_router)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.black]
line-length = 420
line-length = 120
target-version = ['py311']
include = '\.pyi?$'

Expand Down
156 changes: 111 additions & 45 deletions src/data/cache.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
import time
from typing import Any


class CacheEntry:
"""A cache entry with TTL support."""

def __init__(self, data: list[dict[str, Any]], ttl_seconds: int = 3600):
self.data = data
self.created_at = time.time()
self.ttl_seconds = ttl_seconds

def is_expired(self) -> bool:
"""Check if the cache entry has expired."""
return time.time() - self.created_at > self.ttl_seconds


class Cache:
"""In-memory cache for API responses."""
"""In-memory cache for API responses with TTL support."""

# Default TTL: 1 hour for most data, 24 hours for less volatile data
DEFAULT_TTL = 3600 # 1 hour
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TTL Constants Could Be Configurable

METRICS_TTL = 86400 # 24 hours (metrics don't change frequently)
NEWS_TTL = 1800 # 30 minutes (news is more time-sensitive)

def __init__(self):
self._prices_cache: dict[str, list[dict[str, any]]] = {}
self._financial_metrics_cache: dict[str, list[dict[str, any]]] = {}
self._line_items_cache: dict[str, list[dict[str, any]]] = {}
self._insider_trades_cache: dict[str, list[dict[str, any]]] = {}
self._company_news_cache: dict[str, list[dict[str, any]]] = {}
self._prices_cache: dict[str, CacheEntry] = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dictionary operations in Python aren't atomic for compound operations
Consider adding
self._lock = threading.Lock()

self._financial_metrics_cache: dict[str, CacheEntry] = {}
self._line_items_cache: dict[str, CacheEntry] = {}
self._insider_trades_cache: dict[str, CacheEntry] = {}
self._company_news_cache: dict[str, CacheEntry] = {}

def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_field: str) -> list[dict]:
"""Merge existing and new data, avoiding duplicates based on a key field."""
Expand All @@ -21,45 +43,89 @@ def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_fie
merged.extend([item for item in new_data if item[key_field] not in existing_keys])
return merged

def get_prices(self, ticker: str) -> list[dict[str, any]] | None:
"""Get cached price data if available."""
return self._prices_cache.get(ticker)

def set_prices(self, ticker: str, data: list[dict[str, any]]):
"""Append new price data to cache."""
self._prices_cache[ticker] = self._merge_data(self._prices_cache.get(ticker), data, key_field="time")

def get_financial_metrics(self, ticker: str) -> list[dict[str, any]]:
"""Get cached financial metrics if available."""
return self._financial_metrics_cache.get(ticker)

def set_financial_metrics(self, ticker: str, data: list[dict[str, any]]):
"""Append new financial metrics to cache."""
self._financial_metrics_cache[ticker] = self._merge_data(self._financial_metrics_cache.get(ticker), data, key_field="report_period")

def get_line_items(self, ticker: str) -> list[dict[str, any]] | None:
"""Get cached line items if available."""
return self._line_items_cache.get(ticker)

def set_line_items(self, ticker: str, data: list[dict[str, any]]):
"""Append new line items to cache."""
self._line_items_cache[ticker] = self._merge_data(self._line_items_cache.get(ticker), data, key_field="report_period")

def get_insider_trades(self, ticker: str) -> list[dict[str, any]] | None:
"""Get cached insider trades if available."""
return self._insider_trades_cache.get(ticker)

def set_insider_trades(self, ticker: str, data: list[dict[str, any]]):
"""Append new insider trades to cache."""
self._insider_trades_cache[ticker] = self._merge_data(self._insider_trades_cache.get(ticker), data, key_field="filing_date") # Could also use transaction_date if preferred

def get_company_news(self, ticker: str) -> list[dict[str, any]] | None:
"""Get cached company news if available."""
return self._company_news_cache.get(ticker)

def set_company_news(self, ticker: str, data: list[dict[str, any]]):
"""Append new company news to cache."""
self._company_news_cache[ticker] = self._merge_data(self._company_news_cache.get(ticker), data, key_field="date")
def _cleanup_expired(self, cache_dict: dict[str, CacheEntry]) -> None:
"""Remove expired entries from a cache dictionary."""
expired_keys = [key for key, entry in cache_dict.items() if entry.is_expired()]
for key in expired_keys:
del cache_dict[key]

def get_prices(self, ticker: str) -> list[dict[str, Any]] | None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With ref to above comment

def get_prices(self, ticker: str) -> list[dict[str, Any]] | None:
          with self._lock:
              self._cleanup_expired(self._prices_cache)
              entry = self._prices_cache.get(ticker)
              return entry.data if entry else None

Or use threading.RLock() if methods call each other

"""Get cached price data if available and not expired."""
self._cleanup_expired(self._prices_cache)
entry = self._prices_cache.get(ticker)
if entry and not entry.is_expired():
return entry.data
return None
Comment on lines +52 to +58
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_expired() check on line 56 is redundant since _cleanup_expired() on line 54 already removes all expired entries from the cache. If an entry exists in the cache after cleanup, it's guaranteed to be non-expired. Consider simplifying to:

def get_prices(self, ticker: str) -> list[dict[str, Any]] | None:
    """Get cached price data if available and not expired."""
    self._cleanup_expired(self._prices_cache)
    entry = self._prices_cache.get(ticker)
    return entry.data if entry else None

This pattern applies to all get methods.

Copilot uses AI. Check for mistakes.

def set_prices(self, ticker: str, data: list[dict[str, Any]]):
"""Append new price data to cache with TTL."""
existing_data = self.get_prices(ticker)
merged = self._merge_data(existing_data, data, key_field="time")
self._prices_cache[ticker] = CacheEntry(merged, ttl_seconds=self.DEFAULT_TTL)

def get_financial_metrics(self, ticker: str) -> list[dict[str, Any]] | None:
"""Get cached financial metrics if available and not expired."""
self._cleanup_expired(self._financial_metrics_cache)
entry = self._financial_metrics_cache.get(ticker)
if entry and not entry.is_expired():
return entry.data
return None
Comment on lines +66 to +72
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same redundant is_expired() check issue as in get_prices().

Copilot uses AI. Check for mistakes.

def set_financial_metrics(self, ticker: str, data: list[dict[str, Any]]):
"""Append new financial metrics to cache with TTL."""
existing_data = self.get_financial_metrics(ticker)
merged = self._merge_data(existing_data, data, key_field="report_period")
self._financial_metrics_cache[ticker] = CacheEntry(merged, ttl_seconds=self.METRICS_TTL)

def get_line_items(self, ticker: str) -> list[dict[str, Any]] | None:
"""Get cached line items if available and not expired."""
self._cleanup_expired(self._line_items_cache)
entry = self._line_items_cache.get(ticker)
if entry and not entry.is_expired():
return entry.data
return None
Comment on lines +80 to +86
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same redundant is_expired() check issue as in get_prices().

Copilot uses AI. Check for mistakes.

def set_line_items(self, ticker: str, data: list[dict[str, Any]]):
"""Append new line items to cache with TTL."""
existing_data = self.get_line_items(ticker)
merged = self._merge_data(existing_data, data, key_field="report_period")
self._line_items_cache[ticker] = CacheEntry(merged, ttl_seconds=self.METRICS_TTL)

def get_insider_trades(self, ticker: str) -> list[dict[str, Any]] | None:
"""Get cached insider trades if available and not expired."""
self._cleanup_expired(self._insider_trades_cache)
entry = self._insider_trades_cache.get(ticker)
if entry and not entry.is_expired():
return entry.data
return None
Comment on lines +94 to +100
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same redundant is_expired() check issue as in get_prices().

Copilot uses AI. Check for mistakes.

def set_insider_trades(self, ticker: str, data: list[dict[str, Any]]):
"""Append new insider trades to cache with TTL."""
existing_data = self.get_insider_trades(ticker)
merged = self._merge_data(existing_data, data, key_field="filing_date")
self._insider_trades_cache[ticker] = CacheEntry(merged, ttl_seconds=self.DEFAULT_TTL)

def get_company_news(self, ticker: str) -> list[dict[str, Any]] | None:
"""Get cached company news if available and not expired."""
self._cleanup_expired(self._company_news_cache)
entry = self._company_news_cache.get(ticker)
if entry and not entry.is_expired():
return entry.data
return None
Comment on lines +108 to +114
Copy link

Copilot AI Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same redundant is_expired() check issue as in get_prices().

Copilot uses AI. Check for mistakes.

def set_company_news(self, ticker: str, data: list[dict[str, Any]]):
"""Append new company news to cache with TTL."""
existing_data = self.get_company_news(ticker)
merged = self._merge_data(existing_data, data, key_field="date")
self._company_news_cache[ticker] = CacheEntry(merged, ttl_seconds=self.NEWS_TTL)

def clear(self) -> None:
"""Clear all caches."""
self._prices_cache.clear()
self._financial_metrics_cache.clear()
self._line_items_cache.clear()
self._insider_trades_cache.clear()
self._company_news_cache.clear()


# Global cache instance
Expand Down
Loading