Skip to content

Commit

Permalink
Update caching to use JSON instead of pkl
Browse files Browse the repository at this point in the history
  • Loading branch information
SinaKhalili committed Oct 21, 2024
1 parent a15467d commit 9a0f4d5
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 56 deletions.
11 changes: 7 additions & 4 deletions backend/api/asset_liability.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
router = APIRouter()


async def get_asset_liability_matrix(
snapshot_path: str, vat: Vat, mode: int, perp_market_index: int
async def _get_asset_liability_matrix(
snapshot_path: str,
vat: Vat,
mode: int,
perp_market_index: int,
) -> dict:
print("==> Getting asset liability matrix...")
res, df = await get_matrix(vat, mode, perp_market_index)
Expand All @@ -28,8 +31,8 @@ async def get_asset_liability_matrix(
async def get_asset_liability_matrix(
request: BackendRequest, mode: int, perp_market_index: int
):
return await get_asset_liability_matrix(
request.state.current_pickle_path,
return await _get_asset_liability_matrix(
request.state.backend_state.current_pickle_path,
request.state.backend_state.vat,
mode,
perp_market_index,
Expand Down
10 changes: 8 additions & 2 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,21 @@ def clean_cache(state: BackendState) -> None:
pickles.sort(key=os.path.getmtime)
for pickle in pickles[:-5]:
print(f"deleting {pickle}")
shutil.rmtree(pickle)
try:
shutil.rmtree(pickle)
except Exception as e:
print(f"Error deleting {pickle}: {e}")

cache_files = glob.glob("cache/*")
if len(cache_files) > 35:
print("cache folder has more than 35 files, deleting old ones")
cache_files.sort(key=os.path.getmtime)
for cache_file in cache_files[:-35]:
print(f"deleting {cache_file}")
os.remove(cache_file)
try:
os.remove(cache_file)
except Exception as e:
print(f"Error deleting {cache_file}: {e}")


@repeat_every(seconds=60 * 8, wait_first=True)
Expand Down
122 changes: 74 additions & 48 deletions backend/middleware/cache_middleware.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import glob
import hashlib
import json
import os
import pickle
from typing import Any, Callable, Dict, Optional
from typing import Callable, Dict, Optional

from backend.state import BackendRequest
from backend.state import BackendState
Expand All @@ -23,67 +23,67 @@ def __init__(self, app: ASGIApp, state: BackendState, cache_dir: str = "cache"):
os.makedirs(self.cache_dir)

async def dispatch(self, request: BackendRequest, call_next: Callable):
if request.url.path.startswith("/api/snapshot"):
return await call_next(request)
if request.url.path.startswith("/api/price_shock"):
return await call_next(request)
if not request.url.path.startswith("/api"):
return await call_next(request)
if self.state.current_pickle_path == "bootstrap":
return await call_next(request)

current_pickle = self.state.current_pickle_path
previous_pickle = self._get_previous_pickle()

# Try to serve data from the current (latest) pickle first
current_cache_key = self._generate_cache_key(request, current_pickle)
current_cache_file = os.path.join(self.cache_dir, f"{current_cache_key}.pkl")
current_cache_file = os.path.join(self.cache_dir, f"{current_cache_key}.json")

if os.path.exists(current_cache_file):
print(f"Serving latest data for {request.url.path}")
with open(current_cache_file, "rb") as f:
response_data = pickle.load(f)

return Response(
content=response_data["content"],
status_code=response_data["status_code"],
headers=dict(response_data["headers"], **{"X-Cache-Status": "Fresh"}),
)
return self._serve_cached_response(current_cache_file, "Fresh")

# If no data in current pickle, try the previous pickle
if previous_pickle:
previous_cache_key = self._generate_cache_key(request, previous_pickle)
previous_cache_file = os.path.join(
self.cache_dir, f"{previous_cache_key}.pkl"
self.cache_dir, f"{previous_cache_key}.json"
)

if os.path.exists(previous_cache_file):
print(f"Serving stale data for {request.url.path}")
with open(previous_cache_file, "rb") as f:
response_data = pickle.load(f)

# Prepare background task
background_tasks = BackgroundTasks()
background_tasks.add_task(
self._fetch_and_cache,
return await self._serve_stale_response(
previous_cache_file,
request,
call_next,
current_cache_key,
current_cache_file,
)

response = Response(
content=response_data["content"],
status_code=response_data["status_code"],
headers=dict(
response_data["headers"], **{"X-Cache-Status": "Stale"}
),
)
response.background = background_tasks
return response
return await self._serve_miss_response(
request, call_next, current_cache_key, current_cache_file
)

# If no data available, return an empty response and fetch fresh data in the background
print(f"No data available for {request.url.path}")
def _serve_cached_response(self, cache_file: str, cache_status: str):
print(f"Serving {cache_status.lower()} data")
with open(cache_file, "r") as f:
response_data = json.load(f)

content = json.dumps(response_data["content"]).encode("utf-8")
headers = {
k: v
for k, v in response_data["headers"].items()
if k.lower() != "content-length"
}
headers["Content-Length"] = str(len(content))
headers["X-Cache-Status"] = cache_status

return Response(
content=content,
status_code=response_data["status_code"],
headers=headers,
media_type="application/json",
)

async def _serve_stale_response(
self,
cache_file: str,
request: BackendRequest,
call_next: Callable,
current_cache_key: str,
current_cache_file: str,
):
response = self._serve_cached_response(cache_file, "Stale")
background_tasks = BackgroundTasks()
background_tasks.add_task(
self._fetch_and_cache,
Expand All @@ -92,12 +92,32 @@ async def dispatch(self, request: BackendRequest, call_next: Callable):
current_cache_key,
current_cache_file,
)
response.background = background_tasks
return response

async def _serve_miss_response(
self,
request: BackendRequest,
call_next: Callable,
cache_key: str,
cache_file: str,
):
print(f"No data available for {request.url.path}")
background_tasks = BackgroundTasks()
background_tasks.add_task(
self._fetch_and_cache,
request,
call_next,
cache_key,
cache_file,
)
content = json.dumps({"result": "miss"}).encode("utf-8")

# Return an empty response immediately
response = Response(
content='{"result": "miss"}',
status_code=200, # No Content
headers={"X-Cache-Status": "Miss"},
content=content,
status_code=200,
headers={"X-Cache-Status": "Miss", "Content-Length": str(len(content))},
media_type="application/json",
)
response.background = background_tasks
return response
Expand All @@ -120,15 +140,21 @@ async def _fetch_and_cache(
response_body = b""
async for chunk in response.body_iterator:
response_body += chunk

body_content = json.loads(response_body.decode())
response_data = {
"content": response_body,
"content": body_content,
"status_code": response.status_code,
"headers": dict(response.headers),
"headers": {
k: v
for k, v in response.headers.items()
if k.lower() != "content-length"
},
}

os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, "wb") as f:
pickle.dump(response_data, f)
with open(cache_file, "w") as f:
json.dump(response_data, f)
print(f"Cached fresh data for {request.url.path}")
else:
print(
Expand Down
4 changes: 4 additions & 0 deletions gunicorn_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@
accesslog = "-"
errorlog = "-"
loglevel = "info"

# Restart workers that die unexpectedly
worker_exit_on_restart = True
worker_restart_delay = 2
16 changes: 14 additions & 2 deletions src/page/asset_liability.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json

from driftpy.constants.perp_markets import mainnet_perp_market_configs
from driftpy.constants.spot_markets import mainnet_spot_market_configs
from lib.api import api
import pandas as pd
from requests.exceptions import JSONDecodeError
import streamlit as st


Expand Down Expand Up @@ -50,8 +53,17 @@ def asset_liab_matrix_page():
st.stop()

except Exception as e:
st.write(e)
st.stop()
if type(e) == JSONDecodeError:
print("HIT A JSONDecodeError...", e)
st.write("Fetching data for the first time...")
st.image(
"https://i.gifer.com/origin/8a/8a47f769c400b0b7d81a8f6f8e09a44a_w200.gif"
)
st.write("Check again in one minute!")
st.stop()
else:
st.write(e)
st.stop()

res = pd.DataFrame(result["res"])
df = pd.DataFrame(result["df"])
Expand Down

0 comments on commit 9a0f4d5

Please sign in to comment.