Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite router using starlette #17

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
77 changes: 50 additions & 27 deletions crab/router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from flask import Flask, Response, request, abort
from urllib.parse import urlparse
from werkzeug.routing import Rule
import os

import httpx
import psutil
import requests
import uvicorn
from starlette.applications import Starlette
from starlette.middleware.cors import ALL_METHODS
from starlette.responses import StreamingResponse, Response
from starlette.routing import Route
from uvicorn.config import LOGGING_CONFIG as UVICORN_LOGGING_CONFIG
from uvicorn.logging import AccessFormatter


HEADERS_TO_STRIP = ["server", "date"]


client = httpx.AsyncClient(timeout=None)


def get_routes():
Expand All @@ -18,42 +29,54 @@ def get_routes():
return routes


app = Flask(__name__, static_folder=None)
app.url_map.add(Rule("/", endpoint="proxy", defaults={"path": ""}))
app.url_map.add(Rule("/<path:path>", endpoint="proxy"))
class CustomAccessFormatter(AccessFormatter):
def get_client_addr(self, scope):
"""
_Pretend_ the client address is actually the hostname.
Makes the log messages much nicer!
"""
if "headers" not in scope:
return super().get_client_addr(scope)
return httpx.Headers(scope["headers"])["Host"]


@app.endpoint("proxy")
def proxy(path):
async def proxy(request):
routes = get_routes()
hostname = urlparse(request.base_url).hostname
hostname = request.url.hostname
if hostname not in routes:
app.logger.warn(f"No backend for {hostname}")
abort(502)

path = request.full_path if request.args else request.path
target_url = f"http://localhost:{routes[hostname]}{path}"
app.logger.info(f"Routing request to backend - {request.method} {hostname}{path}")

downstream_response = requests.request(
return Response(status_code=502, content=f"No backend found for {hostname}.")
target_url = f"http://localhost:{routes[hostname]}{request.url.path}"
if request.query_params:
target_url += f"?{request.query_params}"
body = await request.body()
upstream_response = await client.request(
method=request.method,
url=target_url,
headers=request.headers,
data=request.get_data(),
data=body,
headers=request.headers.raw,
allow_redirects=False,
stream=True,
)
return Response(
response=downstream_response.raw.data,
status=downstream_response.status_code,
headers=downstream_response.raw.headers.items(),

# Strip some headers which uvicorn forcefully adds
upstream_headers = upstream_response.headers
for header_name in HEADERS_TO_STRIP:
if header_name in upstream_headers:
del upstream_headers[header_name]

return StreamingResponse(
content=upstream_response.raw(),
status_code=upstream_response.status_code,
headers=upstream_headers,
)


app = Starlette(routes=[Route("/(.*)", endpoint=proxy, methods=ALL_METHODS)])


def start_on_port(port):
app.run(
port=port, debug=True, use_debugger=False, use_reloader=False, load_dotenv=False
)
UVICORN_LOGGING_CONFIG["formatters"]["access"]["()"] = CustomAccessFormatter
uvicorn.run(app, port=port, host="0.0.0.0", headers=[("server", "crab")])


def run():
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Flask==1.0.3
psutil==5.6.3
requests==2.22.0
uvicorn==0.10.8
starlette==0.13.0
httpx==0.7.8