Skip to content

Commit

Permalink
Add repeated tasks to refresh and clean cache
Browse files Browse the repository at this point in the history
  • Loading branch information
SinaKhalili committed Oct 8, 2024
1 parent f334c3e commit aef647e
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 10 deletions.
14 changes: 14 additions & 0 deletions backend/api/snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from backend.state import BackendRequest
from backend.state import BackendState
from fastapi import APIRouter
from fastapi import BackgroundTasks


router = APIRouter()


@router.get("/pickle")
async def pickle(request: BackendRequest, background_tasks: BackgroundTasks):
backend_state: BackendState = request.state.backend_state
background_tasks.add_task(backend_state.take_pickle_snapshot)
return {"result": "background task added"}
50 changes: 40 additions & 10 deletions backend/app.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,56 @@
from contextlib import asynccontextmanager
import glob
import os
import shutil

from backend.api import asset_liability
from backend.api import health
from backend.api import liquidation
from backend.api import metadata
from backend.api import price_shock
from backend.api import snapshot
from backend.middleware.cache_middleware import CacheMiddleware
from backend.middleware.readiness import ReadinessMiddleware
from backend.state import BackendState
from backend.utils.repeat_every import repeat_every
from dotenv import load_dotenv
from fastapi import BackgroundTasks
from fastapi import FastAPI
import pandas as pd
from fastapi.testclient import TestClient


load_dotenv()
state = BackendState()


@repeat_every(seconds=60 * 5) # 5 minutes
async def repeatedly_retake_snapshot(state: BackendState) -> None:
await state.take_pickle_snapshot()


@repeat_every(seconds=60 * 5) # 5 minutes
async def repeatedly_clean_cache(state: BackendState) -> None:
if not os.path.exists("pickles"):
print("pickles folder does not exist")
return

pickles = glob.glob("pickles/*")
if len(pickles) > 3:
print("pickles folder has more than 3 pickles, deleting old ones")
pickles.sort(key=os.path.getmtime)
for pickle in pickles[:-3]:
print(f"deleting {pickle}")
shutil.rmtree(pickle)

cache_files = glob.glob("cache/*")
if len(cache_files) > 20:
print("cache folder has more than 20 files, deleting old ones")
cache_files.sort(key=os.path.getmtime)
for cache_file in cache_files[:-20]:
print(f"deleting {cache_file}")
os.remove(cache_file)


@asynccontextmanager
async def lifespan(app: FastAPI):
url = os.getenv("RPC_URL")
Expand All @@ -33,16 +64,19 @@ async def lifespan(app: FastAPI):
if len(cached_vat_path) > 0:
print("Loading cached vat")
await state.load_pickle_snapshot(cached_vat_path[-1])
await repeatedly_clean_cache(state)
await repeatedly_retake_snapshot(state)
else:
print("No cached vat found, bootstrapping")
await state.bootstrap()

await state.take_pickle_snapshot()
await repeatedly_clean_cache(state)
await repeatedly_retake_snapshot(state)
state.ready = True
print("Starting app")
yield

# Cleanup
state.some_df = pd.DataFrame()
state.ready = False
await state.dc.unsubscribe()
await state.connection.close()
Expand All @@ -59,14 +93,10 @@ async def lifespan(app: FastAPI):
app.include_router(
asset_liability.router, prefix="/api/asset-liability", tags=["asset-liability"]
)
app.include_router(snapshot.router, prefix="/api/snapshot", tags=["snapshot"])


# NOTE: All other routes should be in /api/* within the /api folder. Routes outside of /api are not exposed in k8s
@app.get("/")
async def root():
return {"message": "Hello World"}


@app.get("/pickle")
async def pickle(background_tasks: BackgroundTasks):
background_tasks.add_task(state.take_pickle_snapshot)
return {"result": "background task added"}
return {"message": "risk dashboard backend is online"}
148 changes: 148 additions & 0 deletions backend/utils/repeat_every.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Periodic Task Execution Decorator
Provides a `repeat_every` decorator for periodic execution of tasks in asynchronous environments.
Modified from fastapi_utils library to support passing a state object to the repeated function.
Features:
- Configurable execution interval and initial delay
- Exception handling with optional callback
- Completion callback and maximum repetitions limit
- Supports both sync and async functions
Usage:
@repeat_every(seconds=60)
async def my_task(state):
# Task logic here
Note: Designed for use with asynchronous frameworks like FastAPI.
Original Source: https://github.com/dmontagu/fastapi-utils (MIT License)
"""

from __future__ import annotations

import asyncio
from functools import wraps
import logging
from traceback import format_exception
from typing import Any, Callable, Coroutine, TypeVar, Union
import warnings

from starlette.concurrency import run_in_threadpool


T = TypeVar("T")

ArgsReturnFuncT = Callable[[T], Any]
ArgsReturnAsyncFuncT = Callable[[T], Coroutine[Any, Any, Any]]
ExcArgNoReturnFuncT = Callable[[Exception], None]
ExcArgNoReturnAsyncFuncT = Callable[[Exception], Coroutine[Any, Any, None]]
ArgsReturnAnyFuncT = Union[ArgsReturnFuncT, ArgsReturnAsyncFuncT]
ExcArgNoReturnAnyFuncT = Union[ExcArgNoReturnFuncT, ExcArgNoReturnAsyncFuncT]
ArgsReturnDecorator = Callable[
[ArgsReturnAnyFuncT], Callable[[T], Coroutine[Any, Any, None]]
]


async def _handle_func(func: ArgsReturnAnyFuncT, arg: T) -> Any:
if asyncio.iscoroutinefunction(func):
return await func(arg)
else:
return await run_in_threadpool(func, arg)


async def _handle_exc(
exc: Exception, on_exception: ExcArgNoReturnAnyFuncT | None
) -> None:
if on_exception:
if asyncio.iscoroutinefunction(on_exception):
await on_exception(exc)
else:
await run_in_threadpool(on_exception, exc)


def repeat_every(
*,
seconds: float,
wait_first: float | None = None,
logger: logging.Logger | None = None,
raise_exceptions: bool = False,
max_repetitions: int | None = None,
on_complete: ArgsReturnAnyFuncT | None = None,
on_exception: ExcArgNoReturnAnyFuncT | None = None,
) -> ArgsReturnDecorator:
"""
This function returns a decorator that modifies a function so it is periodically re-executed after its first call.
The function it decorates should accept one argument and can return a value.
Parameters
----------
seconds: float
The number of seconds to wait between repeated calls
wait_first: float (default None)
If not None, the function will wait for the given duration before the first call
logger: Optional[logging.Logger] (default None)
Warning: This parameter is deprecated and will be removed in the 1.0 release.
The logger to use to log any exceptions raised by calls to the decorated function.
If not provided, exceptions will not be logged by this function (though they may be handled by the event loop).
raise_exceptions: bool (default False)
Warning: This parameter is deprecated and will be removed in the 1.0 release.
If True, errors raised by the decorated function will be raised to the event loop's exception handler.
Note that if an error is raised, the repeated execution will stop.
Otherwise, exceptions are just logged and the execution continues to repeat.
See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.set_exception_handler for more info.
max_repetitions: Optional[int] (default None)
The maximum number of times to call the repeated function. If `None`, the function is repeated forever.
on_complete: Optional[Callable[[T], Any]] (default None)
A function to call after the final repetition of the decorated function.
on_exception: Optional[Callable[[Exception], None]] (default None)
A function to call when an exception is raised by the decorated function.
"""

def decorator(func: ArgsReturnAnyFuncT) -> Callable[[T], Coroutine[Any, Any, None]]:
"""
Converts the decorated function into a repeated, periodically-called version of itself.
"""

@wraps(func)
async def wrapped(arg: T) -> None:
async def loop() -> None:
if wait_first is not None:
await asyncio.sleep(wait_first)

repetitions = 0
while max_repetitions is None or repetitions < max_repetitions:
try:
await _handle_func(func, arg)

except Exception as exc:
if logger is not None:
warnings.warn(
"'logger' is to be deprecated in favor of 'on_exception' in the 1.0 release.",
DeprecationWarning,
)
formatted_exception = "".join(
format_exception(type(exc), exc, exc.__traceback__)
)
logger.error(formatted_exception)
if raise_exceptions:
warnings.warn(
"'raise_exceptions' is to be deprecated in favor of 'on_exception' in the 1.0 release.",
DeprecationWarning,
)
raise exc
await _handle_exc(exc, on_exception)

repetitions += 1
await asyncio.sleep(seconds)

if on_complete:
await _handle_func(on_complete, arg)

asyncio.ensure_future(loop())

return wrapped

return decorator

0 comments on commit aef647e

Please sign in to comment.