Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ pandas
drf-spectacular
qdrant-client==1.10.0
presidio-analyzer
presidio-anonymizer
presidio-anonymizer
func_timeout
74 changes: 68 additions & 6 deletions swirl/middleware.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from rest_framework.authtoken.models import Token
from django.http import HttpResponseForbidden, HttpResponse
from swirl.authenticators import *
import json
import yaml
import jwt
import logging

import jwt
import yaml
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseForbidden
from func_timeout import FunctionTimedOut, func_timeout
from rest_framework.authtoken.models import Token

from swirl.authenticators import *

logger = logging.getLogger(__name__)


SWIRL_API_SEARCH_URLS = ["/api/swirl/search/", "/swirl/search/"]
SWIRL_API_RAG_URLS = ["/api/swirl/rag-search/", "/api/swirl/sapi/detail-search-rag/"]


class TokenMiddleware:
def __init__(self, get_response):
self.get_response = get_response
Expand Down Expand Up @@ -68,4 +76,58 @@ def __call__(self, request):
response = HttpResponse(yaml_content, content_type='text/yaml')
return response
return self.get_response(request)
return self.get_response(request)
return self.get_response(request)

class TimeoutMiddleware:
def __init__(self, get_response):
self.get_response = get_response

def __call__(self, request):
min_timeout = 1
max_timeout = 180
timeout_param = request.GET.get("rag_timeout")
is_rag_url = request.path in SWIRL_API_RAG_URLS
is_search_url = request.path in SWIRL_API_SEARCH_URLS
has_search_rag_tag = (
request.GET.get("rag", False)
or request.GET.get("do_rag", "").lower() == "true"
)

logger.info(
f"TimeoutMiddleware - init path {request.path} rag_timeout {timeout_param} rag {has_search_rag_tag} (rag:{request.GET.get('rag','<unset>')} or do_rag:{request.GET.get('do_rag','<unset>')})"
)

if timeout_param and ((is_search_url and has_search_rag_tag) or is_rag_url):
logger.debug(
f"Enabling RAG timeout for {request.path} and {timeout_param} seconds"
)

## little method to wrap the request execution
def execute_request_with_timeout():
return self.get_response(request)

## parse the timeout value or fail the request
try:
timeout_duration = int(timeout_param)
except ValueError:
return HttpResponseBadRequest("Invalid timeout value provided")

## validate the timeout value
if timeout_duration < min_timeout or timeout_duration > max_timeout:
return HttpResponseBadRequest(
f"Timeout value must be between {min_timeout} and {max_timeout} seconds"
)

try:
logger.info(f"Request timeout set to {timeout_duration} seconds")
response = func_timeout(timeout_duration, execute_request_with_timeout)
except FunctionTimedOut:
logger.debug(
f"Raise timeout for {request.path} after {timeout_duration} seconds"
)
response = HttpResponse("Rag timed out", status=504)
else:
logger.debug(f"Disabling RAG timeout for {request.path}")
response = self.get_response(request)

return response
1 change: 1 addition & 0 deletions swirl_server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
'swirl.middleware.TokenMiddleware',
'swirl.middleware.SpyglassAuthenticatorsMiddleware',
'swirl.middleware.SwaggerMiddleware',
'swirl.middleware.TimeoutMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
Expand Down