From 610cbf769c4bd5d4dc37a749385705df43ff7531 Mon Sep 17 00:00:00 2001 From: Anonymous <> Date: Fri, 7 Jun 2024 22:33:23 +0800 Subject: [PATCH] feat: support pre and post processors --- mirrors/config.py | 1 - mirrors/proxy/direct.py | 76 ++++++++++++++-------- mirrors/proxy/{cached.py => file_cache.py} | 17 +++-- mirrors/server.py | 9 ++- mirrors/sites/__init__.py | 0 mirrors/sites/docker.py | 30 +++++++-- mirrors/sites/pypi.py | 4 +- mirrors/sites/torch.py | 4 +- 8 files changed, 96 insertions(+), 45 deletions(-) rename mirrors/proxy/{cached.py => file_cache.py} (89%) create mode 100644 mirrors/sites/__init__.py diff --git a/mirrors/config.py b/mirrors/config.py index 1fb5e9b..af1339f 100644 --- a/mirrors/config.py +++ b/mirrors/config.py @@ -4,7 +4,6 @@ ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc') RPC_SECRET = os.environ.get("RPC_SECRET", '') BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org') -PROXY = os.environ.get("PROXY", None) SCHEME = os.environ.get("SCHEME", None) assert SCHEME in ["http", "https"] diff --git a/mirrors/proxy/direct.py b/mirrors/proxy/direct.py index aa42145..f865e11 100644 --- a/mirrors/proxy/direct.py +++ b/mirrors/proxy/direct.py @@ -1,47 +1,71 @@ import typing +from typing import Callable, Coroutine import httpx -import starlette.requests -import starlette.responses +from httpx import Request as HttpxRequest +from starlette.requests import Request +from starlette.responses import Response -from config import PROXY +SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest] + +AsyncPreProcessor = Callable[ + [Request, HttpxRequest], + Coroutine[Request, HttpxRequest, HttpxRequest] +] + +SyncPostProcessor = Callable[[Request, Response], Response] + +AsyncPostProcessor = Callable[ + [Request, Response], + Coroutine[Request, Response, Response] +] async def direct_proxy( - request: starlette.requests.Request, + request: Request, target_url: str, - pre_process: typing.Callable[[starlette.requests.Request, httpx.Request], httpx.Request] = None, - post_process: typing.Callable[[starlette.requests.Request, httpx.Response], httpx.Response] = None, -) -> typing.Optional[starlette.responses.Response]: - async with httpx.AsyncClient(proxy=PROXY, verify=False) as client: - - headers = request.headers.mutablecopy() - for key in headers.keys(): + pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None, + post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None, + cache_ttl: int = 3600, +) -> Response: + # httpx will use the following environment variables to determine the proxy + # https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy + async with httpx.AsyncClient() as client: + req_headers = request.headers.mutablecopy() + for key in req_headers.keys(): if key not in ["user-agent", "accept"]: - del headers[key] + del req_headers[key] - httpx_req = client.build_request(request.method, target_url, headers=headers, ) + httpx_req: HttpxRequest = client.build_request(request.method, target_url, headers=req_headers, ) if pre_process: - httpx_req = pre_process(request, httpx_req) + new_httpx_req = pre_process(request, httpx_req) + if isinstance(new_httpx_req, HttpxRequest): + httpx_req = new_httpx_req + else: + httpx_req = await new_httpx_req + upstream_response = await client.send(httpx_req) - # TODO: move to post_process - if upstream_response.status_code == 307: - location = upstream_response.headers["location"] - print("catch redirect", location) + res_headers = upstream_response.headers - headers = upstream_response.headers - cl = headers.pop("content-length", None) - ce = headers.pop("content-encoding", None) + cl = res_headers.pop("content-length", None) + ce = res_headers.pop("content-encoding", None) # print(target_url, cl, ce) content = upstream_response.content - response = starlette.responses.Response( - headers=headers, + response = Response( + headers=res_headers, content=content, - status_code=upstream_response.status_code) + status_code=upstream_response.status_code + ) if post_process: - response = post_process(request, response) + new_res = post_process(request, response) + if isinstance(new_res, Response): + final_res = new_res + elif isinstance(new_res, Coroutine): + final_res = await new_res + else: + final_res = response - return response + return final_res diff --git a/mirrors/proxy/cached.py b/mirrors/proxy/file_cache.py similarity index 89% rename from mirrors/proxy/cached.py rename to mirrors/proxy/file_cache.py index b527ba4..dd3a393 100644 --- a/mirrors/proxy/cached.py +++ b/mirrors/proxy/file_cache.py @@ -14,16 +14,23 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI from aria2_api import add_download from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY +from typing import Optional, Callable logger = logging.getLogger(__name__) def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]: parsed_url = urlparse(url) + hostname = parsed_url.hostname + path = parsed_url.path + assert hostname + assert path + base_dir = pathlib.Path(CACHE_DIR) assert parsed_url.path[0] == "/" assert parsed_url.path[-1] != "/" - cache_file = (base_dir / parsed_url.hostname / parsed_url.path[1:]).resolve() + cache_file = (base_dir / hostname / path[1:]).resolve() + assert cache_file.is_relative_to(base_dir) return str(cache_file), os.path.dirname(cache_file) @@ -65,12 +72,12 @@ async def get_url_content_length(url): return content_len -async def try_get_cache( +async def try_file_based_cache( request: Request, target_url: str, download_wait_time: int = 60, - post_process: typing.Callable[[Request, Response], Response] = None, -) -> typing.Optional[Response]: + post_process: Optional[Callable[[Request, Response], Response]] = None, +) -> Response: cache_status = lookup_cache(target_url) if cache_status == DownloadingStatus.DOWNLOADED: resp = make_cached_response(target_url) @@ -88,7 +95,7 @@ async def try_get_cache( cache_file, cache_file_dir = get_cache_file_and_folder(target_url) print("prepare to download", target_url, cache_file, cache_file_dir) - processed_url = quote(target_url, safe='/:?=&') + processed_url = quote(target_url, safe='/:?=&%') try: await add_download(processed_url, save_dir=cache_file_dir) diff --git a/mirrors/server.py b/mirrors/server.py index 6383224..6a3f83b 100644 --- a/mirrors/server.py +++ b/mirrors/server.py @@ -1,6 +1,7 @@ import base64 import signal import urllib.parse +from typing import Callable import httpx import uvicorn @@ -24,6 +25,7 @@ async def aria2(request: Request, call_next): if request.url.path == "/": return RedirectResponse("/aria2/index.html") if request.url.path == "/jsonrpc": + # dont use proxy for internal API async with httpx.AsyncClient(mounts={ "all://": httpx.AsyncHTTPTransport() }) as client: @@ -39,8 +41,11 @@ async def aria2(request: Request, call_next): @app.middleware("http") -async def capture_request(request: Request, call_next: callable): +async def capture_request(request: Request, call_next: Callable): hostname = request.url.hostname + if not hostname: + return Response(content="Bad Request", status_code=400) + if not hostname.endswith(f".{BASE_DOMAIN}"): return await call_next(request) @@ -81,6 +86,6 @@ if __name__ == '__main__': aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string print(f"Download manager (Aria2) at {aria2_url_with_auth}") - # FIXME: only proxy headers if SCHME is https + # FIXME: only proxy headers if SCHEME is https # reload only in dev mode uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*") diff --git a/mirrors/sites/__init__.py b/mirrors/sites/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mirrors/sites/docker.py b/mirrors/sites/docker.py index 4ab4c2c..90e4dd1 100644 --- a/mirrors/sites/docker.py +++ b/mirrors/sites/docker.py @@ -2,16 +2,18 @@ import base64 import json import re import time +from typing import Dict import httpx from starlette.requests import Request from starlette.responses import Response +from proxy.file_cache import try_file_based_cache from proxy.direct import direct_proxy BASE_URL = "https://registry-1.docker.io" -cached_token = { +cached_token: Dict[str, str] = { } @@ -69,6 +71,22 @@ def get_docker_token(name): return token +def inject_token(name: str, req: Request, httpx_req: httpx.Request): + docker_token = get_docker_token(f"{name}") + httpx_req.headers["Authorization"] = f"Bearer {docker_token}" + return httpx_req + + +async def post_process(request: Request, response: Response): + if response.status_code == 307: + location = response.headers["location"] + # TODO: logger + print("[redirect]", location) + return await try_file_based_cache(request, location) + + return response + + async def docker(request: Request): path = request.url.path print("[request]", request.method, request.url) @@ -90,11 +108,9 @@ async def docker(request: Request): target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}" + # logger print('[PARSED]', path, name, operation, reference, target_url) - def inject_token(req, httpx_req): - docker_token = get_docker_token(f"{name}") - httpx_req.headers["Authorization"] = f"Bearer {docker_token}" - return httpx_req - - return await direct_proxy(request, target_url, pre_process=inject_token) + return await direct_proxy(request, target_url, + pre_process=lambda req, http_req: inject_token(name, req, http_req), + post_process=post_process) diff --git a/mirrors/sites/pypi.py b/mirrors/sites/pypi.py index a85f3e5..712e928 100644 --- a/mirrors/sites/pypi.py +++ b/mirrors/sites/pypi.py @@ -4,7 +4,7 @@ from starlette.requests import Request from starlette.responses import Response from proxy.direct import direct_proxy -from proxy.cached import try_get_cache +from proxy.file_cache import try_file_based_cache pypi_file_base_url = "https://files.pythonhosted.org" pypi_base_url = "https://pypi.org" @@ -39,6 +39,6 @@ async def pypi(request: Request) -> Response: return Response(content="Not Found", status_code=404) if path.endswith(".whl") or path.endswith(".tar.gz"): - return await try_get_cache(request, target_url) + return await try_file_based_cache(request, target_url) return await direct_proxy(request, target_url, post_process=pypi_replace) diff --git a/mirrors/sites/torch.py b/mirrors/sites/torch.py index 17a2c04..1921380 100644 --- a/mirrors/sites/torch.py +++ b/mirrors/sites/torch.py @@ -1,7 +1,7 @@ from starlette.requests import Request from starlette.responses import Response -from proxy.cached import try_get_cache +from proxy.file_cache import try_file_based_cache from proxy.direct import direct_proxy BASE_URL = "https://download.pytorch.org" @@ -19,6 +19,6 @@ async def torch(request: Request): target_url = BASE_URL + path if path.endswith(".whl") or path.endswith(".tar.gz"): - return await try_get_cache(request, target_url) + return await try_file_based_cache(request, target_url) return await direct_proxy(request, target_url, )