From 45fc94b7c715e22011491cdd77d1224ff323770a Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Tue, 26 Nov 2019 19:44:51 +0000 Subject: [PATCH 1/9] =?UTF-8?q?Rewrite=20the=20router=20using=20starlette?= =?UTF-8?q?=20=F0=9F=8C=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 5 ++-- router.py | 69 +++++++++++++++++++++++++----------------------- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/requirements.txt b/requirements.txt index b888a07..d247598 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -Flask==1.0.3 psutil==5.6.3 -requests==2.22.0 +uvicorn==0.10.8 +starlette==0.13.0 +httpx==0.7.8 diff --git a/router.py b/router.py index 663076f..5b0a124 100644 --- a/router.py +++ b/router.py @@ -1,9 +1,12 @@ -from flask import Flask, Response, request, abort -from urllib.parse import urlparse -from werkzeug.routing import Rule import os + +import httpx import psutil -import requests +import uvicorn +from starlette.applications import Starlette +from starlette.middleware.cors import ALL_METHODS +from starlette.responses import Response, StreamingResponse +from starlette.routing import Route def get_routes(): @@ -18,42 +21,42 @@ def get_routes(): return routes -app = Flask(__name__, static_folder=None) -app.url_map.add(Rule("/", endpoint="proxy", defaults={"path": ""})) -app.url_map.add(Rule("/", endpoint="proxy")) - - -@app.endpoint("proxy") -def proxy(path): +async def proxy(request): routes = get_routes() - hostname = urlparse(request.base_url).hostname + hostname = request.url.hostname if hostname not in routes: - app.logger.warn(f"No backend for {hostname}") - abort(502) - - path = request.full_path if request.args else request.path + return Response(status_code=502) + path = request.url.path target_url = f"http://localhost:{routes[hostname]}{path}" - app.logger.info(f"Routing request to backend - {request.method} {hostname}{path}") + body = await request.body() + async with httpx.AsyncClient() as client: + upstream_response = await client.request( + method=request.method, + url=target_url, + data=body, + headers=request.headers.raw, + allow_redirects=False + ) + if not upstream_response.is_stream_consumed: + return StreamingResponse( + content=upstream_response.stream(), + status_code=upstream_response.status_code, + headers=upstream_response.headers, + ) + + upstream_response_body = await upstream_response.read() + return Response( + content=upstream_response_body, + status_code=upstream_response.status_code, + headers=upstream_response.headers, + ) + - downstream_response = requests.request( - method=request.method, - url=target_url, - headers=request.headers, - data=request.get_data(), - allow_redirects=False, - stream=True, - ) - return Response( - response=downstream_response.raw.data, - status=downstream_response.status_code, - headers=downstream_response.raw.headers.items(), - ) +app = Starlette(routes=[Route("/(.*)", endpoint=proxy, methods=ALL_METHODS)]) def start_on_port(port): - app.run( - port=port, debug=True, use_debugger=False, use_reloader=False, load_dotenv=False - ) + uvicorn.run(app, port=port, host="0.0.0.0") def run(): From 78c2d391ca0129a39c23ae731db87ea974c3e5ad Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Tue, 26 Nov 2019 20:58:28 +0000 Subject: [PATCH 2/9] Add better logging, and always stream responses --- router.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/router.py b/router.py index 5b0a124..2f24b1f 100644 --- a/router.py +++ b/router.py @@ -5,8 +5,10 @@ import uvicorn from starlette.applications import Starlette from starlette.middleware.cors import ALL_METHODS -from starlette.responses import Response, StreamingResponse +from starlette.responses import StreamingResponse, Response from starlette.routing import Route +from uvicorn.config import LOGGING_CONFIG as UVICORN_LOGGING_CONFIG +from uvicorn.logging import AccessFormatter def get_routes(): @@ -21,6 +23,17 @@ def get_routes(): return routes +class CustomAccessFormatter(AccessFormatter): + def get_client_addr(self, scope): + """ + _Pretend_ the client address is actually the hostname. + Makes the log messages much nicer! + """ + if 'headers' not in scope: + return super().get_client_addr(scope) + return httpx.Headers(scope['headers'])['Host'] + + async def proxy(request): routes = get_routes() hostname = request.url.hostname @@ -35,18 +48,11 @@ async def proxy(request): url=target_url, data=body, headers=request.headers.raw, - allow_redirects=False + allow_redirects=False, + stream=True ) - if not upstream_response.is_stream_consumed: - return StreamingResponse( - content=upstream_response.stream(), - status_code=upstream_response.status_code, - headers=upstream_response.headers, - ) - - upstream_response_body = await upstream_response.read() - return Response( - content=upstream_response_body, + return StreamingResponse( + content=upstream_response.raw(), status_code=upstream_response.status_code, headers=upstream_response.headers, ) @@ -56,6 +62,7 @@ async def proxy(request): def start_on_port(port): + UVICORN_LOGGING_CONFIG['formatters']['access']['()'] = CustomAccessFormatter uvicorn.run(app, port=port, host="0.0.0.0") From f52986a9f77cff6e03575f9df318775266ef189b Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Tue, 26 Nov 2019 21:05:51 +0000 Subject: [PATCH 3/9] Manny strip out some headers uvicorn forcefully adds It's not possible to configure this. I might submit a PR for it --- router.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/router.py b/router.py index 2f24b1f..fe5949d 100644 --- a/router.py +++ b/router.py @@ -11,6 +11,9 @@ from uvicorn.logging import AccessFormatter +HEADERS_TO_STRIP = ['server', 'date'] + + def get_routes(): routes = {} for process in psutil.process_iter(attrs=["environ"]): @@ -51,10 +54,17 @@ async def proxy(request): allow_redirects=False, stream=True ) + + # Strip some headers which uvicorn forcefully adds + upstream_headers = upstream_response.headers + for header_name in HEADERS_TO_STRIP: + if header_name in upstream_headers: + del upstream_headers[header_name] + return StreamingResponse( content=upstream_response.raw(), status_code=upstream_response.status_code, - headers=upstream_response.headers, + headers=upstream_headers, ) From 8021503a118353f9dacabc5232772849a5320823 Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Tue, 26 Nov 2019 21:06:16 +0000 Subject: [PATCH 4/9] Change the server name It's a nice touch --- router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router.py b/router.py index fe5949d..664d4f6 100644 --- a/router.py +++ b/router.py @@ -73,7 +73,7 @@ async def proxy(request): def start_on_port(port): UVICORN_LOGGING_CONFIG['formatters']['access']['()'] = CustomAccessFormatter - uvicorn.run(app, port=port, host="0.0.0.0") + uvicorn.run(app, port=port, host="0.0.0.0", headers=[('server', 'crab')]) def run(): From daac186332c58321071bba61133f49ef448dbc5d Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Thu, 28 Nov 2019 12:02:03 +0000 Subject: [PATCH 5/9] Return _helpful_ error when there's no process to route to --- router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router.py b/router.py index 664d4f6..261d2ad 100644 --- a/router.py +++ b/router.py @@ -41,7 +41,7 @@ async def proxy(request): routes = get_routes() hostname = request.url.hostname if hostname not in routes: - return Response(status_code=502) + return Response(status_code=502, content=f"No backend found for {hostname}.") path = request.url.path target_url = f"http://localhost:{routes[hostname]}{path}" body = await request.body() From c7e3775d447d6b56f54e8afffb2eef0358d3157d Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Thu, 28 Nov 2019 12:19:13 +0000 Subject: [PATCH 6/9] Quotes and commas --- router.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/router.py b/router.py index 261d2ad..35c1441 100644 --- a/router.py +++ b/router.py @@ -11,7 +11,7 @@ from uvicorn.logging import AccessFormatter -HEADERS_TO_STRIP = ['server', 'date'] +HEADERS_TO_STRIP = ["server", "date"] def get_routes(): @@ -32,9 +32,9 @@ def get_client_addr(self, scope): _Pretend_ the client address is actually the hostname. Makes the log messages much nicer! """ - if 'headers' not in scope: + if "headers" not in scope: return super().get_client_addr(scope) - return httpx.Headers(scope['headers'])['Host'] + return httpx.Headers(scope["headers"])["Host"] async def proxy(request): @@ -52,7 +52,7 @@ async def proxy(request): data=body, headers=request.headers.raw, allow_redirects=False, - stream=True + stream=True, ) # Strip some headers which uvicorn forcefully adds @@ -72,8 +72,8 @@ async def proxy(request): def start_on_port(port): - UVICORN_LOGGING_CONFIG['formatters']['access']['()'] = CustomAccessFormatter - uvicorn.run(app, port=port, host="0.0.0.0", headers=[('server', 'crab')]) + UVICORN_LOGGING_CONFIG["formatters"]["access"]["()"] = CustomAccessFormatter + uvicorn.run(app, port=port, host="0.0.0.0", headers=[("server", "crab")]) def run(): From 45db656a308d61adb3ac70b3825c2260d1a3f767 Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Thu, 5 Dec 2019 10:07:26 +0000 Subject: [PATCH 7/9] Ensure querystring params are also passed upstream --- router.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/router.py b/router.py index 35c1441..205b213 100644 --- a/router.py +++ b/router.py @@ -42,8 +42,9 @@ async def proxy(request): hostname = request.url.hostname if hostname not in routes: return Response(status_code=502, content=f"No backend found for {hostname}.") - path = request.url.path - target_url = f"http://localhost:{routes[hostname]}{path}" + target_url = f"http://localhost:{routes[hostname]}{request.url.path}" + if request.query_params: + target_url += f"?{request.query_params}" body = await request.body() async with httpx.AsyncClient() as client: upstream_response = await client.request( From 8d0620e58a23f9529c26365e218ff0d331d4b4ad Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Fri, 6 Dec 2019 14:58:16 +0000 Subject: [PATCH 8/9] Maintain client between connections This saves TCP connection times, and allows connection pooling. --- router.py | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/router.py b/router.py index 205b213..58b9016 100644 --- a/router.py +++ b/router.py @@ -14,6 +14,9 @@ HEADERS_TO_STRIP = ["server", "date"] +client = httpx.AsyncClient() + + def get_routes(): routes = {} for process in psutil.process_iter(attrs=["environ"]): @@ -46,27 +49,26 @@ async def proxy(request): if request.query_params: target_url += f"?{request.query_params}" body = await request.body() - async with httpx.AsyncClient() as client: - upstream_response = await client.request( - method=request.method, - url=target_url, - data=body, - headers=request.headers.raw, - allow_redirects=False, - stream=True, - ) - - # Strip some headers which uvicorn forcefully adds - upstream_headers = upstream_response.headers - for header_name in HEADERS_TO_STRIP: - if header_name in upstream_headers: - del upstream_headers[header_name] - - return StreamingResponse( - content=upstream_response.raw(), - status_code=upstream_response.status_code, - headers=upstream_headers, - ) + upstream_response = await client.request( + method=request.method, + url=target_url, + data=body, + headers=request.headers.raw, + allow_redirects=False, + stream=True, + ) + + # Strip some headers which uvicorn forcefully adds + upstream_headers = upstream_response.headers + for header_name in HEADERS_TO_STRIP: + if header_name in upstream_headers: + del upstream_headers[header_name] + + return StreamingResponse( + content=upstream_response.raw(), + status_code=upstream_response.status_code, + headers=upstream_headers, + ) app = Starlette(routes=[Route("/(.*)", endpoint=proxy, methods=ALL_METHODS)]) From 9dfa3da1a50ebf2ab22a39d511da045f8a8e3c2b Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Tue, 17 Dec 2019 17:21:25 +0000 Subject: [PATCH 9/9] Set no timeout --- router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router.py b/router.py index 58b9016..7341fb8 100644 --- a/router.py +++ b/router.py @@ -14,7 +14,7 @@ HEADERS_TO_STRIP = ["server", "date"] -client = httpx.AsyncClient() +client = httpx.AsyncClient(timeout=None) def get_routes():