-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dd19182
commit 877aa7b
Showing
6 changed files
with
568 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import hashlib | ||
import os | ||
import pickle | ||
|
||
from backend.state import BackendRequest | ||
from backend.state import BackendState | ||
from fastapi import HTTPException | ||
from fastapi import Response | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from starlette.types import ASGIApp | ||
|
||
|
||
class CacheMiddleware(BaseHTTPMiddleware): | ||
def __init__(self, app: ASGIApp, state: BackendState, cache_dir: str = "cache"): | ||
super().__init__(app) | ||
self.state = state | ||
self.cache_dir = cache_dir | ||
if not os.path.exists(self.cache_dir): | ||
os.makedirs(self.cache_dir) | ||
|
||
async def dispatch(self, request: BackendRequest, call_next): | ||
if not request.url.path.startswith("/api"): | ||
return await call_next(request) | ||
if self.state.current_pickle_path == "bootstrap": | ||
return await call_next(request) | ||
|
||
cache_key = self._generate_cache_key(request) | ||
cache_file = os.path.join(self.cache_dir, f"{cache_key}.pkl") | ||
|
||
if os.path.exists(cache_file): | ||
print(f"Cache hit for {request.url.path}") | ||
with open(cache_file, "rb") as f: | ||
response_data = pickle.load(f) | ||
return Response( | ||
content=response_data["content"], | ||
status_code=response_data["status_code"], | ||
headers=response_data["headers"], | ||
) | ||
|
||
print(f"Cache miss for {request.url.path}") | ||
response = await call_next(request) | ||
|
||
if response.status_code == 200: | ||
response_body = b"" | ||
async for chunk in response.body_iterator: | ||
response_body += chunk | ||
response_data = { | ||
"content": response_body, | ||
"status_code": response.status_code, | ||
"headers": dict(response.headers), | ||
} | ||
|
||
os.makedirs(os.path.dirname(cache_file), exist_ok=True) | ||
with open(cache_file, "wb") as f: | ||
pickle.dump(response_data, f) | ||
|
||
return Response( | ||
content=response_body, | ||
status_code=response.status_code, | ||
headers=dict(response.headers), | ||
) | ||
|
||
return response | ||
|
||
def _generate_cache_key(self, request: BackendRequest) -> str: | ||
current_pickle_path = self.state.current_pickle_path | ||
hash_input = f"{current_pickle_path}:{request.method}:{request.url.path}:{request.url.query}" | ||
print("Hash input: ", hash_input) | ||
return hashlib.md5(hash_input.encode()).hexdigest() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,19 @@ | ||
from backend.state import BackendRequest | ||
from backend.state import BackendState | ||
from fastapi import HTTPException | ||
from fastapi import Request | ||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from starlette.types import ASGIApp | ||
|
||
|
||
class ReadinessMiddleware(BaseHTTPMiddleware): | ||
def __init__(self, app, state: BackendState): | ||
def __init__(self, app: ASGIApp, state: BackendState): | ||
super().__init__(app) | ||
self.state = state | ||
|
||
async def dispatch(self, request: Request, call_next): | ||
async def dispatch(self, request: BackendRequest, call_next): | ||
if not self.state.ready and request.url.path != "/health": | ||
raise HTTPException(status_code=503, detail="Service is not ready") | ||
|
||
request.state.backend_state = self.state | ||
response = await call_next(request) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,127 @@ | ||
from typing import Optional | ||
from asyncio import create_task | ||
from asyncio import gather | ||
from datetime import datetime | ||
import os | ||
from typing import TypedDict | ||
|
||
from anchorpy import Wallet | ||
from backend.utils.vat import load_newest_files | ||
from backend.utils.waiting_for import waiting_for | ||
from driftpy.account_subscription_config import AccountSubscriptionConfig | ||
from driftpy.drift_client import DriftClient | ||
from driftpy.market_map.market_map import MarketMap | ||
from driftpy.market_map.market_map_config import ( | ||
WebsocketConfig as MarketMapWebsocketConfig, | ||
) | ||
from driftpy.market_map.market_map_config import MarketMapConfig | ||
from driftpy.pickle.vat import Vat | ||
from driftpy.types import MarketType | ||
from driftpy.user_map.user_map import UserMap | ||
from driftpy.user_map.user_map_config import ( | ||
WebsocketConfig as UserMapWebsocketConfig, | ||
) | ||
from driftpy.user_map.user_map_config import UserMapConfig | ||
from driftpy.user_map.user_map_config import UserStatsMapConfig | ||
from driftpy.user_map.userstats_map import UserStatsMap | ||
from fastapi import Request | ||
import pandas as pd | ||
from solana.rpc.async_api import AsyncClient | ||
|
||
|
||
class BackendState: | ||
some_df: pd.DataFrame | ||
connection: AsyncClient | ||
dc: DriftClient | ||
spot_map: MarketMap | ||
perp_map: MarketMap | ||
user_map: UserMap | ||
stats_map: UserStatsMap | ||
|
||
current_pickle_path: str | ||
vat: Vat | ||
ready: bool | ||
|
||
def initialize( | ||
self, url: str | ||
): # Not using __init__ because we need the rpc url to be passed in | ||
self.connection = AsyncClient(url) | ||
self.dc = DriftClient( | ||
self.connection, | ||
Wallet.dummy(), | ||
"mainnet", | ||
account_subscription=AccountSubscriptionConfig("cached"), | ||
) | ||
self.perp_map = MarketMap( | ||
MarketMapConfig( | ||
self.dc.program, | ||
MarketType.Perp(), | ||
MarketMapWebsocketConfig(), | ||
self.dc.connection, | ||
) | ||
) | ||
self.spot_map = MarketMap( | ||
MarketMapConfig( | ||
self.dc.program, | ||
MarketType.Spot(), | ||
MarketMapWebsocketConfig(), | ||
self.dc.connection, | ||
) | ||
) | ||
self.user_map = UserMap(UserMapConfig(self.dc, UserMapWebsocketConfig())) | ||
self.stats_map = UserStatsMap(UserStatsMapConfig(self.dc)) | ||
self.vat = Vat( | ||
self.dc, | ||
self.user_map, | ||
self.stats_map, | ||
self.spot_map, | ||
self.perp_map, | ||
) | ||
|
||
async def bootstrap(self): | ||
with waiting_for("drift client"): | ||
await self.dc.subscribe() | ||
with waiting_for("subscriptions"): | ||
await gather( | ||
create_task(self.spot_map.subscribe()), | ||
create_task(self.perp_map.subscribe()), | ||
create_task(self.user_map.subscribe()), | ||
create_task(self.stats_map.subscribe()), | ||
) | ||
self.current_pickle_path = "bootstrap" | ||
|
||
async def take_pickle_snapshot(self): | ||
now = datetime.now() | ||
folder_name = now.strftime("vat-%Y-%m-%d-%H-%M-%S") | ||
if not os.path.exists("pickles"): | ||
os.makedirs("pickles") | ||
path = os.path.join("pickles", folder_name, "") | ||
|
||
os.makedirs(path, exist_ok=True) | ||
with waiting_for("pickling"): | ||
result = await self.vat.pickle(path) | ||
with waiting_for("unpickling"): | ||
await self.load_pickle_snapshot(path) | ||
return result | ||
|
||
async def load_pickle_snapshot(self, directory: str): | ||
pickle_map = load_newest_files(directory) | ||
self.current_pickle_path = directory | ||
with waiting_for("unpickling"): | ||
await self.vat.unpickle( | ||
users_filename=pickle_map["usermap"], | ||
user_stats_filename=pickle_map["userstats"], | ||
spot_markets_filename=pickle_map["spot"], | ||
perp_markets_filename=pickle_map["perp"], | ||
spot_oracles_filename=pickle_map["spotoracles"], | ||
perp_oracles_filename=pickle_map["perporacles"], | ||
) | ||
return pickle_map | ||
|
||
|
||
class BackendRequest(Request): | ||
@property | ||
def backend_state(self) -> BackendState: | ||
return self.state.get("backend_state") | ||
|
||
@backend_state.setter | ||
def backend_state(self, value: BackendState): | ||
self.state["backend_state"] = value |
Oops, something went wrong.