feat: support pre and post processors

This commit is contained in:
Anonymous 2024-06-07 22:33:23 +08:00
parent 14ea13da43
commit 610cbf769c
8 changed files with 96 additions and 45 deletions

View File

@ -4,7 +4,6 @@ ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc')
RPC_SECRET = os.environ.get("RPC_SECRET", '')
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org')
PROXY = os.environ.get("PROXY", None)
SCHEME = os.environ.get("SCHEME", None)
assert SCHEME in ["http", "https"]

View File

@ -1,47 +1,71 @@
import typing
from typing import Callable, Coroutine
import httpx
import starlette.requests
import starlette.responses
from httpx import Request as HttpxRequest
from starlette.requests import Request
from starlette.responses import Response
from config import PROXY
SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest]
AsyncPreProcessor = Callable[
[Request, HttpxRequest],
Coroutine[Request, HttpxRequest, HttpxRequest]
]
SyncPostProcessor = Callable[[Request, Response], Response]
AsyncPostProcessor = Callable[
[Request, Response],
Coroutine[Request, Response, Response]
]
async def direct_proxy(
request: starlette.requests.Request,
request: Request,
target_url: str,
pre_process: typing.Callable[[starlette.requests.Request, httpx.Request], httpx.Request] = None,
post_process: typing.Callable[[starlette.requests.Request, httpx.Response], httpx.Response] = None,
) -> typing.Optional[starlette.responses.Response]:
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client:
headers = request.headers.mutablecopy()
for key in headers.keys():
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
cache_ttl: int = 3600,
) -> Response:
# httpx will use the following environment variables to determine the proxy
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
async with httpx.AsyncClient() as client:
req_headers = request.headers.mutablecopy()
for key in req_headers.keys():
if key not in ["user-agent", "accept"]:
del headers[key]
del req_headers[key]
httpx_req = client.build_request(request.method, target_url, headers=headers, )
httpx_req: HttpxRequest = client.build_request(request.method, target_url, headers=req_headers, )
if pre_process:
httpx_req = pre_process(request, httpx_req)
new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
upstream_response = await client.send(httpx_req)
# TODO: move to post_process
if upstream_response.status_code == 307:
location = upstream_response.headers["location"]
print("catch redirect", location)
res_headers = upstream_response.headers
headers = upstream_response.headers
cl = headers.pop("content-length", None)
ce = headers.pop("content-encoding", None)
cl = res_headers.pop("content-length", None)
ce = res_headers.pop("content-encoding", None)
# print(target_url, cl, ce)
content = upstream_response.content
response = starlette.responses.Response(
headers=headers,
response = Response(
headers=res_headers,
content=content,
status_code=upstream_response.status_code)
status_code=upstream_response.status_code
)
if post_process:
response = post_process(request, response)
new_res = post_process(request, response)
if isinstance(new_res, Response):
final_res = new_res
elif isinstance(new_res, Coroutine):
final_res = await new_res
else:
final_res = response
return response
return final_res

View File

@ -14,16 +14,23 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI
from aria2_api import add_download
from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY
from typing import Optional, Callable
logger = logging.getLogger(__name__)
def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]:
parsed_url = urlparse(url)
hostname = parsed_url.hostname
path = parsed_url.path
assert hostname
assert path
base_dir = pathlib.Path(CACHE_DIR)
assert parsed_url.path[0] == "/"
assert parsed_url.path[-1] != "/"
cache_file = (base_dir / parsed_url.hostname / parsed_url.path[1:]).resolve()
cache_file = (base_dir / hostname / path[1:]).resolve()
assert cache_file.is_relative_to(base_dir)
return str(cache_file), os.path.dirname(cache_file)
@ -65,12 +72,12 @@ async def get_url_content_length(url):
return content_len
async def try_get_cache(
async def try_file_based_cache(
request: Request,
target_url: str,
download_wait_time: int = 60,
post_process: typing.Callable[[Request, Response], Response] = None,
) -> typing.Optional[Response]:
post_process: Optional[Callable[[Request, Response], Response]] = None,
) -> Response:
cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED:
resp = make_cached_response(target_url)
@ -88,7 +95,7 @@ async def try_get_cache(
cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
print("prepare to download", target_url, cache_file, cache_file_dir)
processed_url = quote(target_url, safe='/:?=&')
processed_url = quote(target_url, safe='/:?=&%')
try:
await add_download(processed_url, save_dir=cache_file_dir)

View File

@ -1,6 +1,7 @@
import base64
import signal
import urllib.parse
from typing import Callable
import httpx
import uvicorn
@ -24,6 +25,7 @@ async def aria2(request: Request, call_next):
if request.url.path == "/":
return RedirectResponse("/aria2/index.html")
if request.url.path == "/jsonrpc":
# dont use proxy for internal API
async with httpx.AsyncClient(mounts={
"all://": httpx.AsyncHTTPTransport()
}) as client:
@ -39,8 +41,11 @@ async def aria2(request: Request, call_next):
@app.middleware("http")
async def capture_request(request: Request, call_next: callable):
async def capture_request(request: Request, call_next: Callable):
hostname = request.url.hostname
if not hostname:
return Response(content="Bad Request", status_code=400)
if not hostname.endswith(f".{BASE_DOMAIN}"):
return await call_next(request)
@ -81,6 +86,6 @@ if __name__ == '__main__':
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
print(f"Download manager (Aria2) at {aria2_url_with_auth}")
# FIXME: only proxy headers if SCHME is https
# FIXME: only proxy headers if SCHEME is https
# reload only in dev mode
uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*")

View File

View File

@ -2,16 +2,18 @@ import base64
import json
import re
import time
from typing import Dict
import httpx
from starlette.requests import Request
from starlette.responses import Response
from proxy.file_cache import try_file_based_cache
from proxy.direct import direct_proxy
BASE_URL = "https://registry-1.docker.io"
cached_token = {
cached_token: Dict[str, str] = {
}
@ -69,6 +71,22 @@ def get_docker_token(name):
return token
def inject_token(name: str, req: Request, httpx_req: httpx.Request):
docker_token = get_docker_token(f"{name}")
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
return httpx_req
async def post_process(request: Request, response: Response):
if response.status_code == 307:
location = response.headers["location"]
# TODO: logger
print("[redirect]", location)
return await try_file_based_cache(request, location)
return response
async def docker(request: Request):
path = request.url.path
print("[request]", request.method, request.url)
@ -90,11 +108,9 @@ async def docker(request: Request):
target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}"
# logger
print('[PARSED]', path, name, operation, reference, target_url)
def inject_token(req, httpx_req):
docker_token = get_docker_token(f"{name}")
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
return httpx_req
return await direct_proxy(request, target_url, pre_process=inject_token)
return await direct_proxy(request, target_url,
pre_process=lambda req, http_req: inject_token(name, req, http_req),
post_process=post_process)

View File

@ -4,7 +4,7 @@ from starlette.requests import Request
from starlette.responses import Response
from proxy.direct import direct_proxy
from proxy.cached import try_get_cache
from proxy.file_cache import try_file_based_cache
pypi_file_base_url = "https://files.pythonhosted.org"
pypi_base_url = "https://pypi.org"
@ -39,6 +39,6 @@ async def pypi(request: Request) -> Response:
return Response(content="Not Found", status_code=404)
if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_get_cache(request, target_url)
return await try_file_based_cache(request, target_url)
return await direct_proxy(request, target_url, post_process=pypi_replace)

View File

@ -1,7 +1,7 @@
from starlette.requests import Request
from starlette.responses import Response
from proxy.cached import try_get_cache
from proxy.file_cache import try_file_based_cache
from proxy.direct import direct_proxy
BASE_URL = "https://download.pytorch.org"
@ -19,6 +19,6 @@ async def torch(request: Request):
target_url = BASE_URL + path
if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_get_cache(request, target_url)
return await try_file_based_cache(request, target_url)
return await direct_proxy(request, target_url, )