diff --git a/backend/app/api/escrow.py b/backend/app/api/escrow.py index c3fdacaa..4d3d616a 100644 --- a/backend/app/api/escrow.py +++ b/backend/app/api/escrow.py @@ -10,6 +10,8 @@ from __future__ import annotations +import logging + from fastapi import APIRouter, HTTPException, status from app.exceptions import ( @@ -34,6 +36,9 @@ refund_escrow, release_escrow, ) +from app.services.onchain_cache import cache_get, cache_invalidate, cache_set + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/escrow", tags=["escrow"]) @@ -67,6 +72,7 @@ async def fund_escrow(body: EscrowFundRequest) -> EscrowResponse: ) # Auto-activate after successful funding escrow = await activate_escrow(body.bounty_id) + await cache_invalidate("escrow", body.bounty_id) return escrow except EscrowAlreadyExistsError as exc: raise HTTPException(status_code=409, detail=str(exc)) from exc @@ -96,10 +102,12 @@ async def release_escrow_endpoint(body: EscrowReleaseRequest) -> EscrowResponse: moves the escrow to COMPLETED state. """ try: - return await release_escrow( + result = await release_escrow( bounty_id=body.bounty_id, winner_wallet=body.winner_wallet, ) + await cache_invalidate("escrow", body.bounty_id) + return result except EscrowNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except InvalidEscrowTransitionError as exc: @@ -123,7 +131,9 @@ async def release_escrow_endpoint(body: EscrowReleaseRequest) -> EscrowResponse: async def refund_escrow_endpoint(body: EscrowRefundRequest) -> EscrowResponse: """Return escrowed $FNDRY to the bounty creator on timeout or cancellation.""" try: - return await refund_escrow(bounty_id=body.bounty_id) + result = await refund_escrow(bounty_id=body.bounty_id) + await cache_invalidate("escrow", body.bounty_id) + return result except EscrowNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except InvalidEscrowTransitionError as exc: @@ -141,8 +151,18 @@ async def refund_escrow_endpoint(body: EscrowRefundRequest) -> EscrowResponse: }, ) async def get_escrow(bounty_id: str) -> EscrowStatusResponse: - """Return the current escrow state, locked balance, and full audit trail.""" + """Return the current escrow state, locked balance, and full audit trail. + + Results are cached in Redis for 30 seconds to reduce database load. + """ + cached = await cache_get("escrow", bounty_id) + if cached is not None: + return EscrowStatusResponse.model_validate(cached) + try: - return await get_escrow_status(bounty_id=bounty_id) + result = await get_escrow_status(bounty_id=bounty_id) except EscrowNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc + + await cache_set("escrow", bounty_id, result.model_dump(mode="json")) + return result diff --git a/backend/app/api/onchain.py b/backend/app/api/onchain.py new file mode 100644 index 00000000..54debb85 --- /dev/null +++ b/backend/app/api/onchain.py @@ -0,0 +1,260 @@ +"""On-chain data REST API endpoints with Redis caching. + +Provides read-only endpoints that aggregate on-chain Solana state with a +30-second Redis TTL cache to limit RPC calls. + +Endpoints: +- ``GET /reputation/{wallet}`` -- Reputation summary for a wallet address. +- ``GET /staking/{wallet}`` -- Staking info (SOL + FNDRY balance) for wallet. +- ``GET /treasury/stats`` -- Treasury SOL/FNDRY balance and aggregate stats. +- ``POST /webhooks/helius`` -- Cache invalidation webhook for Helius/Shyft. +""" + +from __future__ import annotations + +import hashlib +import hmac +import logging +import os +from typing import Annotated, Any + +from fastapi import APIRouter, Header, HTTPException, Query, status +from pydantic import BaseModel, Field + +from app.models.payout import TreasuryStats +from app.models.reputation import ReputationSummary +from app.services import reputation_service +from app.services.onchain_cache import ( + cache_get, + cache_invalidate, + cache_invalidate_prefix, + cache_set, +) +from app.services.solana_client import ( + SolanaRPCError, + get_sol_balance, + get_token_balance, +) +from app.services.treasury_service import get_treasury_stats + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["onchain"]) + +_HELIUS_WEBHOOK_SECRET = os.getenv("HELIUS_WEBHOOK_SECRET", "") + + +# --------------------------------------------------------------------------- +# Response models +# --------------------------------------------------------------------------- + + +class StakingInfo(BaseModel): + """On-chain balances for a given wallet address.""" + + wallet: str + sol_balance: float = Field(..., description="Native SOL balance") + fndry_balance: float = Field(..., description="$FNDRY SPL token balance") + cached: bool = Field(False, description="True when served from cache") + + +class HeliusWebhookPayload(BaseModel): + """Minimal Helius / Shyft webhook payload.""" + + type: str = Field("", description="Transaction type or event name") + accounts: list[str] = Field(default_factory=list) + + +class CacheInvalidationResponse(BaseModel): + keys_removed: int + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _get_reputation_by_wallet(wallet: str) -> ReputationSummary | None: + """Look up reputation for the contributor whose wallet matches *wallet*. + + Returns ``None`` when no contributor has verified the given wallet. + """ + from app.database import async_session_factory + from app.models.user import User + from sqlalchemy import select + + try: + async with async_session_factory() as session: + result = await session.execute( + select(User).where( + User.wallet_address == wallet, + User.wallet_verified.is_(True), + ) + ) + user = result.scalars().first() + if user is None: + return None + # contributor_id == username for the reputation store + return await reputation_service.get_reputation(user.username) + except Exception as exc: + logger.warning("wallet→reputation lookup failed: %s", exc) + return None + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@router.get( + "/reputation/{wallet}", + response_model=ReputationSummary, + summary="Get reputation for a wallet address", + responses={ + 404: {"description": "No verified contributor found for this wallet"}, + 503: {"description": "Upstream RPC or DB unavailable"}, + }, +) +async def get_reputation_by_wallet( + wallet: str, + skip: Annotated[int, Query(ge=0)] = 0, + limit: Annotated[int, Query(ge=1, le=100)] = 10, +) -> ReputationSummary: + """Return the reputation profile of the contributor who owns *wallet*. + + Results are cached for 30 seconds in Redis. The ``skip``/``limit`` + parameters paginate the embedded history entries. + """ + cached: Any = await cache_get("reputation", wallet) + if cached is not None: + summary = ReputationSummary.model_validate(cached) + summary.history = summary.history[skip : skip + limit] + return summary + + summary = await _get_reputation_by_wallet(wallet) + if summary is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No verified contributor found for wallet {wallet}", + ) + + await cache_set("reputation", wallet, summary.model_dump(mode="json")) + + summary.history = summary.history[skip : skip + limit] + return summary + + +@router.get( + "/staking/{wallet}", + response_model=StakingInfo, + summary="Get on-chain staking balances for a wallet", + responses={ + 502: {"description": "Solana RPC request failed"}, + }, +) +async def get_staking_info(wallet: str) -> StakingInfo: + """Return native SOL and $FNDRY balances held by *wallet*. + + Results are cached for 30 seconds in Redis. + """ + cached: Any = await cache_get("staking", wallet) + if cached is not None: + return StakingInfo(**cached, cached=True) + + try: + sol = await get_sol_balance(wallet) + fndry = await get_token_balance(wallet) + except SolanaRPCError as exc: + logger.error("Solana RPC error for staking/%s: %s", wallet, exc) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Solana RPC error: {exc}", + ) from exc + + payload = {"wallet": wallet, "sol_balance": sol, "fndry_balance": fndry} + await cache_set("staking", wallet, payload) + return StakingInfo(**payload) + + +@router.get( + "/treasury/stats", + response_model=TreasuryStats, + summary="Get live treasury statistics", + responses={ + 503: {"description": "Treasury data unavailable"}, + }, +) +async def get_treasury_stats_endpoint() -> TreasuryStats: + """Return treasury SOL/FNDRY balances and aggregate payout totals. + + Results are cached for 30 seconds in Redis (the treasury service also + maintains its own 60-second in-memory cache as a secondary layer). + """ + cached: Any = await cache_get("treasury", "stats") + if cached is not None: + return TreasuryStats.model_validate(cached) + + try: + stats = await get_treasury_stats() + except Exception as exc: + logger.error("Failed to fetch treasury stats: %s", exc) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Treasury data temporarily unavailable", + ) from exc + + await cache_set("treasury", "stats", stats.model_dump(mode="json")) + return stats + + +@router.post( + "/webhooks/helius", + response_model=CacheInvalidationResponse, + summary="Helius / Shyft webhook for cache invalidation", + status_code=status.HTTP_200_OK, +) +async def helius_webhook( + payload: HeliusWebhookPayload, + x_helius_signature: Annotated[str | None, Header()] = None, +) -> CacheInvalidationResponse: + """Invalidate on-chain cache entries when Helius reports new transactions. + + If ``HELIUS_WEBHOOK_SECRET`` is set, the ``X-Helius-Signature`` header + is verified with HMAC-SHA256. Requests with invalid signatures are + rejected with 401. + + The affected cache namespaces are derived from the ``accounts`` list in + the payload: staking entries for each account are purged, and the + treasury stats key is always cleared. + """ + if _HELIUS_WEBHOOK_SECRET: + if not x_helius_signature: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing signature" + ) + expected = hmac.new( + _HELIUS_WEBHOOK_SECRET.encode(), + msg=payload.model_dump_json().encode(), + digestmod=hashlib.sha256, + ).hexdigest() + if not hmac.compare_digest(expected, x_helius_signature): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid signature" + ) + + removed = 0 + for account in payload.accounts: + await cache_invalidate("staking", account) + await cache_invalidate("reputation", account) + removed += 1 + + # Always bust the treasury cache on any relevant transaction + removed += await cache_invalidate_prefix("treasury") + + logger.info( + "Helius webhook processed: type=%s accounts=%d removed=%d", + payload.type, + len(payload.accounts), + removed, + ) + return CacheInvalidationResponse(keys_removed=removed) diff --git a/backend/app/main.py b/backend/app/main.py index 086458ad..94d84453 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -53,6 +53,7 @@ from app.database import init_db, close_db from app.api.og import router as og_router from app.api.contributor_webhooks import router as contributor_webhooks_router +from app.api.onchain import router as onchain_router from app.middleware.security import SecurityHeadersMiddleware from app.middleware.sanitization import InputSanitizationMiddleware from app.services.config_validator import install_log_filter, validate_secrets @@ -399,6 +400,9 @@ async def value_error_handler(request: Request, exc: ValueError): # Admin Dashboard: /api/admin/* (protected by ADMIN_API_KEY) app.include_router(admin_router) +# On-chain data: /api/reputation/*, /api/staking/*, /api/treasury/*, /api/webhooks/helius +app.include_router(onchain_router, prefix="/api") + @app.post("/api/sync", tags=["admin"]) async def trigger_sync(): diff --git a/backend/app/services/onchain_cache.py b/backend/app/services/onchain_cache.py new file mode 100644 index 00000000..48cb9660 --- /dev/null +++ b/backend/app/services/onchain_cache.py @@ -0,0 +1,76 @@ +"""Redis-backed cache for on-chain data with 30-second TTL. + +Provides a thin wrapper around the shared Redis client with graceful +degradation: on any Redis error the cache functions log a warning and +return ``None`` so callers can fall back to a live RPC query. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from app.core.redis import get_redis + +logger = logging.getLogger(__name__) + +CACHE_TTL: int = 30 # seconds +KEY_PREFIX = "onchain:" + + +def _key(namespace: str, identifier: str) -> str: + return f"{KEY_PREFIX}{namespace}:{identifier}" + + +async def cache_get(namespace: str, identifier: str) -> Any | None: + """Return a cached value or ``None`` on miss / Redis error.""" + try: + redis = await get_redis() + raw = await redis.get(_key(namespace, identifier)) + if raw is None: + return None + return json.loads(raw) + except Exception as exc: + logger.warning( + "onchain_cache get failed (%s/%s): %s", namespace, identifier, exc + ) + return None + + +async def cache_set(namespace: str, identifier: str, value: Any) -> None: + """Persist *value* with a 30-second TTL. Silently ignores Redis errors.""" + try: + redis = await get_redis() + await redis.setex(_key(namespace, identifier), CACHE_TTL, json.dumps(value)) + except Exception as exc: + logger.warning( + "onchain_cache set failed (%s/%s): %s", namespace, identifier, exc + ) + + +async def cache_invalidate(namespace: str, identifier: str) -> None: + """Delete a single cache entry. Silently ignores Redis errors.""" + try: + redis = await get_redis() + await redis.delete(_key(namespace, identifier)) + except Exception as exc: + logger.warning( + "onchain_cache invalidate failed (%s/%s): %s", namespace, identifier, exc + ) + + +async def cache_invalidate_prefix(namespace: str) -> int: + """Delete all keys under *namespace*. Returns the number of keys removed.""" + try: + redis = await get_redis() + pattern = _key(namespace, "*") + keys = await redis.keys(pattern) + if keys: + return await redis.delete(*keys) + return 0 + except Exception as exc: + logger.warning( + "onchain_cache invalidate_prefix failed (%s): %s", namespace, exc + ) + return 0 diff --git a/backend/tests/test_onchain_api.py b/backend/tests/test_onchain_api.py new file mode 100644 index 00000000..e6f6fd5c --- /dev/null +++ b/backend/tests/test_onchain_api.py @@ -0,0 +1,500 @@ +"""Tests for the on-chain data REST API endpoints. + +All tests use the full FastAPI app over httpx.ASGITransport with: +- Mocked Redis (cache always misses by default, verifies writes) +- Mocked Solana RPC (get_sol_balance, get_token_balance) +- Mocked treasury_service.get_treasury_stats +- Mocked reputation_service.get_reputation / wallet lookup +""" + +from __future__ import annotations + +import os + +os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:") +os.environ.setdefault("SECRET_KEY", "test-secret-key-for-ci") + +import asyncio +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient + +from app.main import app + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest_asyncio.fixture +async def client(): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as c: + yield c + + +def _make_reputation_summary() -> dict: + return { + "contributor_id": "alice", + "username": "alice", + "display_name": "Alice", + "reputation_score": 42.5, + "badge": None, + "tier_progression": { + "current_tier": "T1", + "t1_completions": 5, + "t2_completions": 0, + "t3_completions": 0, + "t1_required": 3, + "t2_required": 5, + "t3_required": 3, + "next_tier": "T2", + "progress_pct": 100.0, + }, + "is_veteran": False, + "total_bounties_completed": 5, + "average_review_score": 8.1, + "history": [], + } + + +def _make_treasury_stats() -> dict: + return { + "sol_balance": 12.5, + "fndry_balance": 500_000.0, + "treasury_wallet": "AqqW7hFLau8oH8nDuZp5jPjM3EXUrD7q3SxbcNE8YTN1", + "total_paid_out_fndry": 10_000.0, + "total_paid_out_sol": 0.5, + "total_payouts": 20, + "total_buyback_amount": 1.0, + "total_buybacks": 3, + "last_updated": datetime.now(timezone.utc).isoformat(), + } + + +# --------------------------------------------------------------------------- +# Cache helpers used across tests +# --------------------------------------------------------------------------- + + +def _miss_cache() -> tuple[AsyncMock, AsyncMock]: + """Return (cache_get, cache_set) mocks where get always misses.""" + get_mock = AsyncMock(return_value=None) + set_mock = AsyncMock() + return get_mock, set_mock + + +def _hit_cache(value) -> tuple[AsyncMock, AsyncMock]: + """Return (cache_get, cache_set) mocks where get returns *value*.""" + get_mock = AsyncMock(return_value=value) + set_mock = AsyncMock() + return get_mock, set_mock + + +# --------------------------------------------------------------------------- +# GET /api/reputation/{wallet} +# --------------------------------------------------------------------------- + + +class TestReputationEndpoint: + @pytest.mark.asyncio + async def test_returns_404_when_no_contributor_for_wallet(self, client): + with ( + patch("app.api.onchain.cache_get", AsyncMock(return_value=None)), + patch("app.api.onchain.cache_set", AsyncMock()), + patch( + "app.api.onchain._get_reputation_by_wallet", + AsyncMock(return_value=None), + ), + ): + resp = await client.get("/api/reputation/UnknownWallet123") + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_returns_reputation_from_rpc_on_cache_miss(self, client): + from app.models.reputation import ReputationSummary + + summary = ReputationSummary.model_validate(_make_reputation_summary()) + get_mock, set_mock = _miss_cache() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch( + "app.api.onchain._get_reputation_by_wallet", + AsyncMock(return_value=summary), + ), + ): + resp = await client.get("/api/reputation/ValidWallet1234") + assert resp.status_code == 200 + data = resp.json() + assert data["username"] == "alice" + assert data["reputation_score"] == 42.5 + + @pytest.mark.asyncio + async def test_writes_to_cache_on_miss(self, client): + from app.models.reputation import ReputationSummary + + summary = ReputationSummary.model_validate(_make_reputation_summary()) + get_mock, set_mock = _miss_cache() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch( + "app.api.onchain._get_reputation_by_wallet", + AsyncMock(return_value=summary), + ), + ): + await client.get("/api/reputation/ValidWallet1234") + set_mock.assert_awaited_once() + args = set_mock.call_args[0] + assert args[0] == "reputation" + assert args[1] == "ValidWallet1234" + + @pytest.mark.asyncio + async def test_returns_cached_value_without_rpc(self, client): + cached = _make_reputation_summary() + get_mock, set_mock = _hit_cache(cached) + rpc_mock = AsyncMock() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch("app.api.onchain._get_reputation_by_wallet", rpc_mock), + ): + resp = await client.get("/api/reputation/CachedWallet1234") + assert resp.status_code == 200 + rpc_mock.assert_not_awaited() + set_mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_pagination_skip_limit(self, client): + summary_data = _make_reputation_summary() + summary_data["history"] = [ + { + "id": str(i), + "contributor_id": "alice", + "bounty_id": f"b{i}", + "bounty_title": f"Bounty {i}", + "bounty_tier": 1, + "review_score": 8.0, + "earned_reputation": 5.0, + "is_veteran_penalty": False, + "created_at": datetime.now(timezone.utc).isoformat(), + } + for i in range(5) + ] + + get_mock, set_mock = _hit_cache(summary_data) + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + ): + resp = await client.get("/api/reputation/SomeWallet123?skip=2&limit=2") + assert resp.status_code == 200 + assert len(resp.json()["history"]) == 2 + + +# --------------------------------------------------------------------------- +# GET /api/staking/{wallet} +# --------------------------------------------------------------------------- + + +class TestStakingEndpoint: + @pytest.mark.asyncio + async def test_returns_balances_from_rpc(self, client): + get_mock, set_mock = _miss_cache() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch("app.api.onchain.get_sol_balance", AsyncMock(return_value=1.5)), + patch( + "app.api.onchain.get_token_balance", AsyncMock(return_value=25_000.0) + ), + ): + resp = await client.get("/api/staking/DevWallet12345678") + assert resp.status_code == 200 + data = resp.json() + assert data["sol_balance"] == 1.5 + assert data["fndry_balance"] == 25_000.0 + assert data["cached"] is False + + @pytest.mark.asyncio + async def test_returns_cached_balances(self, client): + cached = { + "wallet": "DevWallet12345678", + "sol_balance": 2.0, + "fndry_balance": 5_000.0, + } + get_mock, set_mock = _hit_cache(cached) + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + ): + resp = await client.get("/api/staking/DevWallet12345678") + assert resp.status_code == 200 + assert resp.json()["cached"] is True + assert resp.json()["sol_balance"] == 2.0 + + @pytest.mark.asyncio + async def test_writes_cache_on_rpc_hit(self, client): + get_mock, set_mock = _miss_cache() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch("app.api.onchain.get_sol_balance", AsyncMock(return_value=0.0)), + patch("app.api.onchain.get_token_balance", AsyncMock(return_value=0.0)), + ): + await client.get("/api/staking/ZeroWallet12345678") + set_mock.assert_awaited_once() + assert set_mock.call_args[0][0] == "staking" + + @pytest.mark.asyncio + async def test_returns_502_on_rpc_error(self, client): + from app.services.solana_client import SolanaRPCError + + get_mock, set_mock = _miss_cache() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch( + "app.api.onchain.get_sol_balance", + AsyncMock(side_effect=SolanaRPCError("node unavailable")), + ), + ): + resp = await client.get("/api/staking/BadWallet123456") + assert resp.status_code == 502 + + +# --------------------------------------------------------------------------- +# GET /api/treasury/stats +# --------------------------------------------------------------------------- + + +class TestTreasuryStatsEndpoint: + @pytest.mark.asyncio + async def test_returns_stats_from_service(self, client): + from app.models.payout import TreasuryStats + + stats = TreasuryStats.model_validate(_make_treasury_stats()) + get_mock, set_mock = _miss_cache() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch("app.api.onchain.get_treasury_stats", AsyncMock(return_value=stats)), + ): + resp = await client.get("/api/treasury/stats") + assert resp.status_code == 200 + data = resp.json() + assert data["sol_balance"] == 12.5 + assert data["total_payouts"] == 20 + + @pytest.mark.asyncio + async def test_serves_from_cache(self, client): + cached = _make_treasury_stats() + get_mock, set_mock = _hit_cache(cached) + service_mock = AsyncMock() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch("app.api.onchain.get_treasury_stats", service_mock), + ): + resp = await client.get("/api/treasury/stats") + assert resp.status_code == 200 + service_mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_returns_503_on_service_error(self, client): + get_mock, set_mock = _miss_cache() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch( + "app.api.onchain.get_treasury_stats", + AsyncMock(side_effect=RuntimeError("RPC down")), + ), + ): + resp = await client.get("/api/treasury/stats") + assert resp.status_code == 503 + + @pytest.mark.asyncio + async def test_writes_cache_on_service_hit(self, client): + from app.models.payout import TreasuryStats + + stats = TreasuryStats.model_validate(_make_treasury_stats()) + get_mock, set_mock = _miss_cache() + + with ( + patch("app.api.onchain.cache_get", get_mock), + patch("app.api.onchain.cache_set", set_mock), + patch("app.api.onchain.get_treasury_stats", AsyncMock(return_value=stats)), + ): + await client.get("/api/treasury/stats") + set_mock.assert_awaited_once() + assert set_mock.call_args[0] == ("treasury", "stats") + + +# --------------------------------------------------------------------------- +# POST /api/webhooks/helius +# --------------------------------------------------------------------------- + + +class TestHeliusWebhook: + @pytest.mark.asyncio + async def test_invalidates_staking_and_reputation_per_account(self, client): + invalidate_mock = AsyncMock() + prefix_mock = AsyncMock(return_value=1) + + with ( + patch("app.api.onchain.cache_invalidate", invalidate_mock), + patch("app.api.onchain.cache_invalidate_prefix", prefix_mock), + patch("app.api.onchain._HELIUS_WEBHOOK_SECRET", ""), + ): + resp = await client.post( + "/api/webhooks/helius", + json={"type": "TRANSFER", "accounts": ["walletA", "walletB"]}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["keys_removed"] == 3 # 2 accounts + 1 treasury prefix + + @pytest.mark.asyncio + async def test_always_busts_treasury_cache(self, client): + prefix_mock = AsyncMock(return_value=2) + + with ( + patch("app.api.onchain.cache_invalidate", AsyncMock()), + patch("app.api.onchain.cache_invalidate_prefix", prefix_mock), + patch("app.api.onchain._HELIUS_WEBHOOK_SECRET", ""), + ): + await client.post( + "/api/webhooks/helius", json={"type": "SWAP", "accounts": []} + ) + prefix_mock.assert_awaited_once_with("treasury") + + @pytest.mark.asyncio + async def test_rejects_missing_signature_when_secret_set(self, client): + with patch("app.api.onchain._HELIUS_WEBHOOK_SECRET", "supersecret"): + resp = await client.post( + "/api/webhooks/helius", + json={"type": "TRANSFER", "accounts": []}, + ) + assert resp.status_code == 401 + + @pytest.mark.asyncio + async def test_rejects_invalid_signature(self, client): + with patch("app.api.onchain._HELIUS_WEBHOOK_SECRET", "supersecret"): + resp = await client.post( + "/api/webhooks/helius", + headers={"X-Helius-Signature": "badsig"}, + json={"type": "TRANSFER", "accounts": []}, + ) + assert resp.status_code == 401 + + @pytest.mark.asyncio + async def test_accepts_valid_signature(self, client): + import hashlib + import hmac + + secret = "supersecret" + with ( + patch("app.api.onchain._HELIUS_WEBHOOK_SECRET", secret), + patch("app.api.onchain.cache_invalidate", AsyncMock()), + patch("app.api.onchain.cache_invalidate_prefix", AsyncMock(return_value=0)), + ): + # Build exact JSON that Pydantic will serialize + from app.api.onchain import HeliusWebhookPayload + + body = HeliusWebhookPayload(type="TRANSFER", accounts=["wallet1"]) + correct_sig = hmac.new( + secret.encode(), + msg=body.model_dump_json().encode(), + digestmod=hashlib.sha256, + ).hexdigest() + + resp = await client.post( + "/api/webhooks/helius", + headers={"X-Helius-Signature": correct_sig}, + json={"type": "TRANSFER", "accounts": ["wallet1"]}, + ) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Escrow cache integration +# --------------------------------------------------------------------------- + + +class TestEscrowCaching: + @pytest.mark.asyncio + async def test_get_escrow_writes_cache_on_miss(self, client): + from app.models.escrow import EscrowState, EscrowStatusResponse + + mock_status = EscrowStatusResponse( + bounty_id="b1", + state=EscrowState.ACTIVE, + amount=1000.0, + creator_wallet="creator", + winner_wallet=None, + expires_at=None, + ledger=[], + ) + + set_mock = AsyncMock() + + with ( + patch("app.api.escrow.cache_get", AsyncMock(return_value=None)), + patch("app.api.escrow.cache_set", set_mock), + patch( + "app.api.escrow.get_escrow_status", + AsyncMock(return_value=mock_status), + ), + ): + resp = await client.get("/api/escrow/b1") + assert resp.status_code == 200 + set_mock.assert_awaited_once() + assert set_mock.call_args[0][:2] == ("escrow", "b1") + + @pytest.mark.asyncio + async def test_get_escrow_serves_from_cache(self, client): + from app.models.escrow import EscrowState + + cached = { + "bounty_id": "b1", + "state": EscrowState.ACTIVE, + "amount": 1000.0, + "creator_wallet": "creator", + "winner_wallet": None, + "expires_at": None, + "ledger": [], + } + service_mock = AsyncMock() + + with ( + patch("app.api.escrow.cache_get", AsyncMock(return_value=cached)), + patch("app.api.escrow.get_escrow_status", service_mock), + ): + resp = await client.get("/api/escrow/b1") + assert resp.status_code == 200 + service_mock.assert_not_awaited()