Skip to content

Commit

Permalink
Update backend cache
Browse files Browse the repository at this point in the history
  • Loading branch information
SinaKhalili committed Oct 2, 2024
1 parent dd19182 commit 877aa7b
Show file tree
Hide file tree
Showing 6 changed files with 568 additions and 113 deletions.
132 changes: 24 additions & 108 deletions backend/app.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,19 @@
from asyncio import create_task
from asyncio import gather
from contextlib import asynccontextmanager
from datetime import datetime
import glob
import os

from anchorpy import Wallet
from backend.api import asset_liability
from backend.api import health
from backend.api import liquidation
from backend.api import metadata
from backend.api import price_shock
from backend.middleware.cache_middleware import CacheMiddleware
from backend.middleware.readiness import ReadinessMiddleware
from backend.state import BackendState
from backend.utils.vat import load_newest_files
from backend.utils.waiting_for import waiting_for
from dotenv import load_dotenv
from driftpy.account_subscription_config import AccountSubscriptionConfig
from driftpy.drift_client import DriftClient
from driftpy.market_map.market_map import MarketMap
from driftpy.market_map.market_map_config import (
WebsocketConfig as MarketMapWebsocketConfig,
)
from driftpy.market_map.market_map_config import MarketMapConfig
from driftpy.pickle.vat import Vat
from driftpy.types import MarketType
from driftpy.user_map.user_map import UserMap
from driftpy.user_map.user_map_config import (
WebsocketConfig as UserMapWebsocketConfig,
)
from driftpy.user_map.user_map_config import UserMapConfig
from driftpy.user_map.user_map_config import UserStatsMapConfig
from driftpy.user_map.userstats_map import UserStatsMap
from fastapi import BackgroundTasks
from fastapi import FastAPI
import pandas as pd
from solana.rpc.async_api import AsyncClient


load_dotenv()
Expand All @@ -38,73 +22,21 @@

@asynccontextmanager
async def lifespan(app: FastAPI):
global state
url = os.getenv("RPC_URL")
if not url:
raise ValueError("RPC_URL environment variable is not set.")

state.connection = AsyncClient(url)
state.dc = DriftClient(
state.connection,
Wallet.dummy(),
"mainnet",
account_subscription=AccountSubscriptionConfig("cached"),
)
state.perp_map = MarketMap(
MarketMapConfig(
state.dc.program,
MarketType.Perp(),
MarketMapWebsocketConfig(),
state.dc.connection,
)
)
state.spot_map = MarketMap(
MarketMapConfig(
state.dc.program,
MarketType.Spot(),
MarketMapWebsocketConfig(),
state.dc.connection,
)
)
state.user_map = UserMap(UserMapConfig(state.dc, UserMapWebsocketConfig()))
state.stats_map = UserStatsMap(UserStatsMapConfig(state.dc))
state.vat = Vat(
state.dc,
state.user_map,
state.stats_map,
state.spot_map,
state.perp_map,
)
global state
state.initialize(url)

print("Checking if cached vat exists")
cached_vat_path = sorted(glob.glob("pickles/*"))
if len(cached_vat_path) > 0:
print("Loading cached vat")
directory = cached_vat_path[-1]
pickle_map = load_newest_files(directory)
with waiting_for("unpickling"):
await state.vat.unpickle(
users_filename=pickle_map["usermap"],
user_stats_filename=pickle_map["userstats"],
spot_markets_filename=pickle_map["spot"],
perp_markets_filename=pickle_map["perp"],
spot_oracles_filename=pickle_map["spotoracles"],
perp_oracles_filename=pickle_map["perporacles"],
)
await state.load_pickle_snapshot(cached_vat_path[-1])
else:
print("No cached vat found")

# with waiting_for("drift client"):
# await state.dc.subscribe()
with waiting_for("subscriptions"):
await gather(
create_task(state.spot_map.subscribe()),
create_task(state.perp_map.subscribe()),
create_task(state.user_map.subscribe()),
create_task(state.stats_map.subscribe()),
)
print("No cached vat found, bootstrapping")
await state.bootstrap()

state.some_df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
state.ready = True
print("Starting app")
yield
Expand All @@ -118,39 +50,23 @@ async def lifespan(app: FastAPI):

app = FastAPI(lifespan=lifespan)
app.add_middleware(ReadinessMiddleware, state=state)
app.add_middleware(CacheMiddleware, state=state, cache_dir="cache")

app.include_router(health.router, prefix="/api/health", tags=["health"])
app.include_router(metadata.router, prefix="/api/metadata", tags=["metadata"])
app.include_router(liquidation.router, prefix="/api/liquidation", tags=["liquidation"])
app.include_router(price_shock.router, prefix="/api/price-shock", tags=["price-shock"])
app.include_router(
asset_liability.router, prefix="/api/asset-liability", tags=["asset-liability"]
)


@app.get("/")
async def root():
return {"message": "Hello World"}


@app.get("/df")
async def get_df():
return state.some_df.to_dict(orient="records")


@app.get("/users")
async def get_users():
users = [user.user_public_key for user in state.user_map.values()]
return users


@app.get("/pickle")
async def pickle():
now = datetime.now()
folder_name = now.strftime("vat-%Y-%m-%d-%H-%M-%S")
if not os.path.exists("pickles"):
os.makedirs("pickles")
path = os.path.join("pickles", folder_name, "")

os.makedirs(path, exist_ok=True)
with waiting_for("pickling"):
result = await state.vat.pickle(path)

return {"result": result}


@app.get("/health")
async def health_check():
return {"status": "healthy" if state.ready else "initializing"}
async def pickle(background_tasks: BackgroundTasks):
background_tasks.add_task(state.take_pickle_snapshot)
return {"result": "background task added"}
69 changes: 69 additions & 0 deletions backend/middleware/cache_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import hashlib
import os
import pickle

from backend.state import BackendRequest
from backend.state import BackendState
from fastapi import HTTPException
from fastapi import Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp


class CacheMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, state: BackendState, cache_dir: str = "cache"):
super().__init__(app)
self.state = state
self.cache_dir = cache_dir
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)

async def dispatch(self, request: BackendRequest, call_next):
if not request.url.path.startswith("/api"):
return await call_next(request)
if self.state.current_pickle_path == "bootstrap":
return await call_next(request)

cache_key = self._generate_cache_key(request)
cache_file = os.path.join(self.cache_dir, f"{cache_key}.pkl")

if os.path.exists(cache_file):
print(f"Cache hit for {request.url.path}")
with open(cache_file, "rb") as f:
response_data = pickle.load(f)
return Response(
content=response_data["content"],
status_code=response_data["status_code"],
headers=response_data["headers"],
)

print(f"Cache miss for {request.url.path}")
response = await call_next(request)

if response.status_code == 200:
response_body = b""
async for chunk in response.body_iterator:
response_body += chunk
response_data = {
"content": response_body,
"status_code": response.status_code,
"headers": dict(response.headers),
}

os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, "wb") as f:
pickle.dump(response_data, f)

return Response(
content=response_body,
status_code=response.status_code,
headers=dict(response.headers),
)

return response

def _generate_cache_key(self, request: BackendRequest) -> str:
current_pickle_path = self.state.current_pickle_path
hash_input = f"{current_pickle_path}:{request.method}:{request.url.path}:{request.url.query}"
print("Hash input: ", hash_input)
return hashlib.md5(hash_input.encode()).hexdigest()
9 changes: 6 additions & 3 deletions backend/middleware/readiness.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from backend.state import BackendRequest
from backend.state import BackendState
from fastapi import HTTPException
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp


class ReadinessMiddleware(BaseHTTPMiddleware):
def __init__(self, app, state: BackendState):
def __init__(self, app: ASGIApp, state: BackendState):
super().__init__(app)
self.state = state

async def dispatch(self, request: Request, call_next):
async def dispatch(self, request: BackendRequest, call_next):
if not self.state.ready and request.url.path != "/health":
raise HTTPException(status_code=503, detail="Service is not ready")

request.state.backend_state = self.state
response = await call_next(request)
return response
109 changes: 107 additions & 2 deletions backend/state.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,127 @@
from typing import Optional
from asyncio import create_task
from asyncio import gather
from datetime import datetime
import os
from typing import TypedDict

from anchorpy import Wallet
from backend.utils.vat import load_newest_files
from backend.utils.waiting_for import waiting_for
from driftpy.account_subscription_config import AccountSubscriptionConfig
from driftpy.drift_client import DriftClient
from driftpy.market_map.market_map import MarketMap
from driftpy.market_map.market_map_config import (
WebsocketConfig as MarketMapWebsocketConfig,
)
from driftpy.market_map.market_map_config import MarketMapConfig
from driftpy.pickle.vat import Vat
from driftpy.types import MarketType
from driftpy.user_map.user_map import UserMap
from driftpy.user_map.user_map_config import (
WebsocketConfig as UserMapWebsocketConfig,
)
from driftpy.user_map.user_map_config import UserMapConfig
from driftpy.user_map.user_map_config import UserStatsMapConfig
from driftpy.user_map.userstats_map import UserStatsMap
from fastapi import Request
import pandas as pd
from solana.rpc.async_api import AsyncClient


class BackendState:
some_df: pd.DataFrame
connection: AsyncClient
dc: DriftClient
spot_map: MarketMap
perp_map: MarketMap
user_map: UserMap
stats_map: UserStatsMap

current_pickle_path: str
vat: Vat
ready: bool

def initialize(
self, url: str
): # Not using __init__ because we need the rpc url to be passed in
self.connection = AsyncClient(url)
self.dc = DriftClient(
self.connection,
Wallet.dummy(),
"mainnet",
account_subscription=AccountSubscriptionConfig("cached"),
)
self.perp_map = MarketMap(
MarketMapConfig(
self.dc.program,
MarketType.Perp(),
MarketMapWebsocketConfig(),
self.dc.connection,
)
)
self.spot_map = MarketMap(
MarketMapConfig(
self.dc.program,
MarketType.Spot(),
MarketMapWebsocketConfig(),
self.dc.connection,
)
)
self.user_map = UserMap(UserMapConfig(self.dc, UserMapWebsocketConfig()))
self.stats_map = UserStatsMap(UserStatsMapConfig(self.dc))
self.vat = Vat(
self.dc,
self.user_map,
self.stats_map,
self.spot_map,
self.perp_map,
)

async def bootstrap(self):
with waiting_for("drift client"):
await self.dc.subscribe()
with waiting_for("subscriptions"):
await gather(
create_task(self.spot_map.subscribe()),
create_task(self.perp_map.subscribe()),
create_task(self.user_map.subscribe()),
create_task(self.stats_map.subscribe()),
)
self.current_pickle_path = "bootstrap"

async def take_pickle_snapshot(self):
now = datetime.now()
folder_name = now.strftime("vat-%Y-%m-%d-%H-%M-%S")
if not os.path.exists("pickles"):
os.makedirs("pickles")
path = os.path.join("pickles", folder_name, "")

os.makedirs(path, exist_ok=True)
with waiting_for("pickling"):
result = await self.vat.pickle(path)
with waiting_for("unpickling"):
await self.load_pickle_snapshot(path)
return result

async def load_pickle_snapshot(self, directory: str):
pickle_map = load_newest_files(directory)
self.current_pickle_path = directory
with waiting_for("unpickling"):
await self.vat.unpickle(
users_filename=pickle_map["usermap"],
user_stats_filename=pickle_map["userstats"],
spot_markets_filename=pickle_map["spot"],
perp_markets_filename=pickle_map["perp"],
spot_oracles_filename=pickle_map["spotoracles"],
perp_oracles_filename=pickle_map["perporacles"],
)
return pickle_map


class BackendRequest(Request):
@property
def backend_state(self) -> BackendState:
return self.state.get("backend_state")

@backend_state.setter
def backend_state(self, value: BackendState):
self.state["backend_state"] = value
Loading

0 comments on commit 877aa7b

Please sign in to comment.