diff --git a/mirrors/aria2_api.py b/mirrors/aria2_api.py index 9b6ede4..83b426e 100644 --- a/mirrors/aria2_api.py +++ b/mirrors/aria2_api.py @@ -11,18 +11,20 @@ logger = logging.getLogger(__name__) async def send_request(method, params=None): payload = { - 'jsonrpc': '2.0', - 'id': uuid.uuid4().hex, - 'method': method, - 'params': [f'token:{RPC_SECRET}'] + (params or []) + "jsonrpc": "2.0", + "id": uuid.uuid4().hex, + "method": method, + "params": [f"token:{RPC_SECRET}"] + (params or []), } # specify the internal API call don't use proxy - async with httpx.AsyncClient(mounts={ - "all://": httpx.AsyncHTTPTransport() - }) as client: + async with httpx.AsyncClient( + mounts={"all://": httpx.AsyncHTTPTransport()} + ) as client: response = await client.post(ARIA2_RPC_URL, json=payload) - logger.info(f"aria2 request: {method} {params} -> {response.status_code} {response.text}") + logger.info( + f"aria2 request: {method} {params} -> {response.status_code} {response.text}" + ) try: return response.json() except json.JSONDecodeError as e: @@ -30,38 +32,35 @@ async def send_request(method, params=None): raise e -async def add_download(url, save_dir='/app/cache'): - method = 'aria2.addUri' - params = [[url], - {'dir': save_dir, - 'header': [] - }] +async def add_download(url, save_dir="/app/cache"): + method = "aria2.addUri" + params = [[url], {"dir": save_dir, "header": []}] response = await send_request(method, params) - return response['result'] + return response["result"] async def pause_download(gid): - method = 'aria2.pause' + method = "aria2.pause" params = [gid] response = await send_request(method, params) - return response['result'] + return response["result"] async def resume_download(gid): - method = 'aria2.unpause' + method = "aria2.unpause" params = [gid] response = await send_request(method, params) - return response['result'] + return response["result"] async def get_status(gid): - method = 'aria2.tellStatus' + method = "aria2.tellStatus" params = [gid] response = await send_request(method, params) - return response['result'] + return response["result"] async def list_downloads(): - method = 'aria2.tellActive' + method = "aria2.tellActive" response = await send_request(method) - return response['result'] + return response["result"] diff --git a/mirrors/config.py b/mirrors/config.py index af1339f..7d7e1a0 100644 --- a/mirrors/config.py +++ b/mirrors/config.py @@ -1,8 +1,8 @@ import os -ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc') -RPC_SECRET = os.environ.get("RPC_SECRET", '') -BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org') +ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://aria2:6800/jsonrpc") +RPC_SECRET = os.environ.get("RPC_SECRET", "") +BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org") SCHEME = os.environ.get("SCHEME", None) assert SCHEME in ["http", "https"] diff --git a/mirrors/proxy/__init__.py b/mirrors/proxy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mirrors/proxy/direct.py b/mirrors/proxy/direct.py index f865e11..c518362 100644 --- a/mirrors/proxy/direct.py +++ b/mirrors/proxy/direct.py @@ -9,24 +9,22 @@ from starlette.responses import Response SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest] AsyncPreProcessor = Callable[ - [Request, HttpxRequest], - Coroutine[Request, HttpxRequest, HttpxRequest] + [Request, HttpxRequest], Coroutine[Request, HttpxRequest, HttpxRequest] ] SyncPostProcessor = Callable[[Request, Response], Response] AsyncPostProcessor = Callable[ - [Request, Response], - Coroutine[Request, Response, Response] + [Request, Response], Coroutine[Request, Response, Response] ] async def direct_proxy( - request: Request, - target_url: str, - pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None, - post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None, - cache_ttl: int = 3600, + request: Request, + target_url: str, + pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None, + post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None, + cache_ttl: int = 3600, ) -> Response: # httpx will use the following environment variables to determine the proxy # https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy @@ -36,7 +34,11 @@ async def direct_proxy( if key not in ["user-agent", "accept"]: del req_headers[key] - httpx_req: HttpxRequest = client.build_request(request.method, target_url, headers=req_headers, ) + httpx_req: HttpxRequest = client.build_request( + request.method, + target_url, + headers=req_headers, + ) if pre_process: new_httpx_req = pre_process(request, httpx_req) @@ -56,7 +58,7 @@ async def direct_proxy( response = Response( headers=res_headers, content=content, - status_code=upstream_response.status_code + status_code=upstream_response.status_code, ) if post_process: diff --git a/mirrors/proxy/file_cache.py b/mirrors/proxy/file_cache.py index dd3a393..727befc 100644 --- a/mirrors/proxy/file_cache.py +++ b/mirrors/proxy/file_cache.py @@ -13,7 +13,7 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI from aria2_api import add_download -from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY +from config import CACHE_DIR, EXTERNAL_URL_ARIA2 from typing import Optional, Callable logger = logging.getLogger(__name__) @@ -66,17 +66,17 @@ def make_cached_response(url): async def get_url_content_length(url): - async with httpx.AsyncClient(proxy=PROXY, verify=False) as client: + async with httpx.AsyncClient() as client: head_response = await client.head(url) - content_len = (head_response.headers.get("content-length", None)) + content_len = head_response.headers.get("content-length", None) return content_len async def try_file_based_cache( - request: Request, - target_url: str, - download_wait_time: int = 60, - post_process: Optional[Callable[[Request, Response], Response]] = None, + request: Request, + target_url: str, + download_wait_time: int = 60, + post_process: Optional[Callable[[Request, Response], Response]] = None, ) -> Response: cache_status = lookup_cache(target_url) if cache_status == DownloadingStatus.DOWNLOADED: @@ -87,21 +87,26 @@ async def try_file_based_cache( if cache_status == DownloadingStatus.DOWNLOADING: logger.info(f"Download is not finished, return 503 for {target_url}") - return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", - status_code=HTTP_504_GATEWAY_TIMEOUT) + return Response( + content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", + status_code=HTTP_504_GATEWAY_TIMEOUT, + ) assert cache_status == DownloadingStatus.NOT_FOUND cache_file, cache_file_dir = get_cache_file_and_folder(target_url) print("prepare to download", target_url, cache_file, cache_file_dir) - processed_url = quote(target_url, safe='/:?=&%') + processed_url = quote(target_url, safe="/:?=&%") try: await add_download(processed_url, save_dir=cache_file_dir) except Exception as e: logger.error(f"Download error, return 503500 for {target_url}", exc_info=e) - return Response(content=f"Failed to add download: {e}", status_code=HTTP_500_INTERNAL_SERVER_ERROR) + return Response( + content=f"Failed to add download: {e}", + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + ) # wait for download finished for _ in range(download_wait_time): @@ -110,5 +115,7 @@ async def try_file_based_cache( if cache_status == DownloadingStatus.DOWNLOADED: return make_cached_response(target_url) logger.info(f"Download is not finished, return 503 for {target_url}") - return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", - status_code=HTTP_504_GATEWAY_TIMEOUT) + return Response( + content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", + status_code=HTTP_504_GATEWAY_TIMEOUT, + ) diff --git a/mirrors/server.py b/mirrors/server.py index 6a3f83b..56d7c32 100644 --- a/mirrors/server.py +++ b/mirrors/server.py @@ -10,7 +10,13 @@ from starlette.requests import Request from starlette.responses import RedirectResponse, Response from starlette.staticfiles import StaticFiles -from config import BASE_DOMAIN, RPC_SECRET, EXTERNAL_URL_ARIA2, EXTERNAL_HOST_ARIA2, SCHEME +from config import ( + BASE_DOMAIN, + RPC_SECRET, + EXTERNAL_URL_ARIA2, + EXTERNAL_HOST_ARIA2, + SCHEME, +) from sites.docker import docker from sites.npm import npm from sites.pypi import pypi @@ -18,7 +24,11 @@ from sites.torch import torch app = FastAPI() -app.mount("/aria2/", StaticFiles(directory="/wwwroot/"), name="static", ) +app.mount( + "/aria2/", + StaticFiles(directory="/wwwroot/"), + name="static", +) async def aria2(request: Request, call_next): @@ -26,17 +36,24 @@ async def aria2(request: Request, call_next): return RedirectResponse("/aria2/index.html") if request.url.path == "/jsonrpc": # dont use proxy for internal API - async with httpx.AsyncClient(mounts={ - "all://": httpx.AsyncHTTPTransport() - }) as client: - data = (await request.body()) - response = await client.request(url="http://aria2:6800/jsonrpc", - method=request.method, - headers=request.headers, content=data) + async with httpx.AsyncClient( + mounts={"all://": httpx.AsyncHTTPTransport()} + ) as client: + data = await request.body() + response = await client.request( + url="http://aria2:6800/jsonrpc", + method=request.method, + headers=request.headers, + content=data, + ) headers = response.headers headers.pop("content-length", None) headers.pop("content-encoding", None) - return Response(content=response.content, status_code=response.status_code, headers=headers) + return Response( + content=response.content, + status_code=response.status_code, + headers=headers, + ) return await call_next(request) @@ -64,7 +81,7 @@ async def capture_request(request: Request, call_next: Callable): return await call_next(request) -if __name__ == '__main__': +if __name__ == "__main__": signal.signal(signal.SIGINT, signal.SIG_DFL) port = 80 print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})") @@ -75,11 +92,11 @@ if __name__ == '__main__': aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode() params = { - 'protocol': SCHEME, - 'host': EXTERNAL_HOST_ARIA2, - 'port': '443' if SCHEME == 'https' else '80', - 'interface': 'jsonrpc', - 'secret': aria2_secret + "protocol": SCHEME, + "host": EXTERNAL_HOST_ARIA2, + "port": "443" if SCHEME == "https" else "80", + "interface": "jsonrpc", + "secret": aria2_secret, } query_string = urllib.parse.urlencode(params) @@ -88,4 +105,11 @@ if __name__ == '__main__': print(f"Download manager (Aria2) at {aria2_url_with_auth}") # FIXME: only proxy headers if SCHEME is https # reload only in dev mode - uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*") + uvicorn.run( + app="server:app", + host="0.0.0.0", + port=port, + reload=True, + proxy_headers=True, + forwarded_allow_ips="*", + ) diff --git a/mirrors/sites/docker.py b/mirrors/sites/docker.py index 90e4dd1..17ec6b8 100644 --- a/mirrors/sites/docker.py +++ b/mirrors/sites/docker.py @@ -13,9 +13,7 @@ from proxy.direct import direct_proxy BASE_URL = "https://registry-1.docker.io" -cached_token: Dict[str, str] = { - -} +cached_token: Dict[str, str] = {} # https://github.com/opencontainers/distribution-spec/blob/main/spec.md name_regex = "[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*" @@ -40,7 +38,6 @@ def try_extract_image_name(path): def get_docker_token(name): cached = cached_token.get(name, {}) exp = cached.get("exp", 0) - if exp > time.time(): return cached.get("token", 0) @@ -50,12 +47,13 @@ def get_docker_token(name): "service": "registry.docker.io", } - response = httpx.get(url, params=params, verify=False) + client = httpx.Client() + response = client.get(url, params=params) response.raise_for_status() token_data = response.json() token = token_data["token"] - payload = (token.split(".")[1]) + payload = token.split(".")[1] padding = len(payload) % 4 payload += "=" * padding @@ -63,10 +61,7 @@ def get_docker_token(name): assert payload["iss"] == "auth.docker.io" assert len(payload["access"]) > 0 - cached_token[name] = { - "exp": payload["exp"], - "token": token - } + cached_token[name] = {"exp": payload["exp"], "token": token} return token @@ -100,17 +95,20 @@ async def docker(request: Request): name, operation, reference = try_extract_image_name(path) if not name: - return Response(content='404 Not Found', status_code=404) + return Response(content="404 Not Found", status_code=404) # support docker pull xxx which name without library - if '/' not in name: + if "/" not in name: name = f"library/{name}" target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}" # logger - print('[PARSED]', path, name, operation, reference, target_url) + print("[PARSED]", path, name, operation, reference, target_url) - return await direct_proxy(request, target_url, - pre_process=lambda req, http_req: inject_token(name, req, http_req), - post_process=post_process) + return await direct_proxy( + request, + target_url, + pre_process=lambda req, http_req: inject_token(name, req, http_req), + post_process=post_process, + ) diff --git a/mirrors/sites/torch.py b/mirrors/sites/torch.py index 1921380..d806966 100644 --- a/mirrors/sites/torch.py +++ b/mirrors/sites/torch.py @@ -21,4 +21,7 @@ async def torch(request: Request): if path.endswith(".whl") or path.endswith(".tar.gz"): return await try_file_based_cache(request, target_url) - return await direct_proxy(request, target_url, ) + return await direct_proxy( + request, + target_url, + )