diff --git a/src/mirrorsrun/aria2_api.py b/src/mirrorsrun/aria2_api.py index 522ffe7..8b216c1 100644 --- a/src/mirrorsrun/aria2_api.py +++ b/src/mirrorsrun/aria2_api.py @@ -9,6 +9,7 @@ from mirrorsrun.config import RPC_SECRET, ARIA2_RPC_URL logger = logging.getLogger(__name__) +# refer to https://aria2.github.io/manual/en/html/aria2c.html async def send_request(method, params=None): request_id = uuid.uuid4().hex payload = { @@ -32,9 +33,20 @@ async def send_request(method, params=None): raise e -async def add_download(url, save_dir="/app/cache"): +async def add_download(url, save_dir="/app/cache", out_file=None): + logger.info(f"[Aria2] add_download {url=} {save_dir=} {out_file=}") + method = "aria2.addUri" - params = [[url], {"dir": save_dir, "header": []}] + options = { + "dir": save_dir, + "header": [], + "out": out_file, + } + + if out_file: + options["out"] = out_file + + params = [[url], options] response = await send_request(method, params) return response["result"] diff --git a/src/mirrorsrun/docker_utils.py b/src/mirrorsrun/docker_utils.py deleted file mode 100644 index 80bdbd8..0000000 --- a/src/mirrorsrun/docker_utils.py +++ /dev/null @@ -1,68 +0,0 @@ -import base64 -import json -import re -import time -from typing import Dict -import httpx - - -class CachedToken: - token: str - exp: int - - def __init__(self, token, exp): - self.token = token - self.exp = exp - - -cached_tokens: Dict[str, CachedToken] = {} - - -# https://github.com/opencontainers/distribution-spec/blob/main/spec.md -name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*" -reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}" - - -def try_extract_image_name(path): - pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$" - match = re.search(pattern, path) - - if match: - assert len(match.groups()) == 3 - name, resource, reference = match.groups() - assert re.match(name_regex, name) - assert re.match(reference_regex, reference) - assert resource in ["manifests", "blobs", "tags"] - return name, resource, reference - - return None, None, None - - -def get_docker_token(name): - cached = cached_tokens.get(name, None) - if cached and cached.exp > time.time(): - return cached.token - - url = "https://auth.docker.io/token" - params = { - "scope": f"repository:{name}:pull", - "service": "registry.docker.io", - } - - client = httpx.Client() - response = client.get(url, params=params) - response.raise_for_status() - - token_data = response.json() - token = token_data["token"] - payload = token.split(".")[1] - padding = len(payload) % 4 - payload += "=" * padding - - payload = json.loads(base64.b64decode(payload)) - assert payload["iss"] == "auth.docker.io" - assert len(payload["access"]) > 0 - - cached_tokens[name] = CachedToken(exp=payload["exp"], token=token) - - return token diff --git a/src/mirrorsrun/proxy/direct.py b/src/mirrorsrun/proxy/direct.py index d6981cf..f82cc28 100644 --- a/src/mirrorsrun/proxy/direct.py +++ b/src/mirrorsrun/proxy/direct.py @@ -59,13 +59,15 @@ async def direct_proxy( target_url: str, pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None, post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None, + follow_redirects: bool = True, ) -> Response: + # httpx will use the following environment variables to determine the proxy # https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy async with httpx.AsyncClient() as client: req_headers = request.headers.mutablecopy() for key in req_headers.keys(): - if key not in ["user-agent", "accept"]: + if key not in ["user-agent", "accept", "authorization"]: del req_headers[key] httpx_req: HttpxRequest = client.build_request( @@ -76,7 +78,9 @@ async def direct_proxy( httpx_req = await pre_process_request(request, httpx_req, pre_process) - upstream_response = await client.send(httpx_req) + upstream_response = await client.send( + httpx_req, follow_redirects=follow_redirects + ) res_headers = upstream_response.headers diff --git a/src/mirrorsrun/proxy/file_cache.py b/src/mirrorsrun/proxy/file_cache.py index c2f0ba0..aff487d 100644 --- a/src/mirrorsrun/proxy/file_cache.py +++ b/src/mirrorsrun/proxy/file_cache.py @@ -7,12 +7,13 @@ from enum import Enum from urllib.parse import urlparse, quote import httpx -from mirrorsrun.aria2_api import add_download -from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2 from starlette.requests import Request from starlette.responses import Response from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT +from mirrorsrun.aria2_api import add_download +from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2 + logger = logging.getLogger(__name__) @@ -80,7 +81,7 @@ async def try_file_based_cache( return make_cached_response(target_url) if cache_status == DownloadingStatus.DOWNLOADING: - logger.info(f"Download is not finished, return 503 for {target_url}") + logger.info(f"Download is not finished, return 504 for {target_url}") return Response( content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", status_code=HTTP_504_GATEWAY_TIMEOUT, @@ -94,8 +95,12 @@ async def try_file_based_cache( processed_url = quote(target_url, safe="/:?=&%") try: - logger.info(f"Start download {processed_url}") - await add_download(processed_url, save_dir=cache_file_dir) + # resolve redirect via aria2 + await add_download( + processed_url, + save_dir=cache_file_dir, + out_file=os.path.basename(cache_file), + ) except Exception as e: logger.error(f"Download error, return 500 for {target_url}", exc_info=e) return Response( @@ -110,7 +115,10 @@ async def try_file_based_cache( if cache_status == DownloadingStatus.DOWNLOADED: logger.info(f"Cache hit for {target_url}") return make_cached_response(target_url) - logger.info(f"Download is not finished, return 503 for {target_url}") + + assert cache_status != DownloadingStatus.NOT_FOUND + + logger.info(f"Download is not finished, return 504 for {target_url}") return Response( content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", status_code=HTTP_504_GATEWAY_TIMEOUT, diff --git a/src/mirrorsrun/server.py b/src/mirrorsrun/server.py index c04633e..fd965b4 100644 --- a/src/mirrorsrun/server.py +++ b/src/mirrorsrun/server.py @@ -23,21 +23,27 @@ from mirrorsrun.config import ( EXTERNAL_HOST_ARIA2, SCHEME, ) -from mirrorsrun.sites.docker import docker + from mirrorsrun.sites.npm import npm from mirrorsrun.sites.pypi import pypi from mirrorsrun.sites.torch import torch -from mirrorsrun.sites.k8s import k8s +from mirrorsrun.sites.docker import dockerhub, k8s, quay, ghcr +from mirrorsrun.sites.common import common subdomain_mapping = { + "mirrors": common, "pypi": pypi, "torch": torch, - "docker": docker, "npm": npm, + "docker": dockerhub, "k8s": k8s, + "ghcr": ghcr, + "quay": quay, } -logging.basicConfig(level=logging.INFO) +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) @@ -123,7 +129,7 @@ if __name__ == "__main__": app="server:app", host="0.0.0.0", port=port, - reload=True, # TODO: reload only in dev mode + reload=True, # TODO: reload only in dev mode proxy_headers=True, # trust x-forwarded-for etc. forwarded_allow_ips="*", ) diff --git a/src/mirrorsrun/sites/common.py b/src/mirrorsrun/sites/common.py new file mode 100644 index 0000000..89616cb --- /dev/null +++ b/src/mirrorsrun/sites/common.py @@ -0,0 +1,18 @@ +from starlette.requests import Request + +from mirrorsrun.proxy.direct import direct_proxy +from starlette.responses import Response + + +async def common(request: Request): + path = request.url.path + if path == "/": + return + if path.startswith("/alpine"): + return await direct_proxy(request, "https://dl-cdn.alpinelinux.org" + path) + if path.startswith("/ubuntu/"): + return await direct_proxy(request, "http://archive.ubuntu.com" + path) + if path.startswith("/ubuntu-ports/"): + return await direct_proxy(request, "http://ports.ubuntu.com" + path) + + return Response("Not Found", status_code=404) diff --git a/src/mirrorsrun/sites/docker.py b/src/mirrorsrun/sites/docker.py index 787bf6d..d0627ef 100644 --- a/src/mirrorsrun/sites/docker.py +++ b/src/mirrorsrun/sites/docker.py @@ -1,60 +1,143 @@ import logging +import re -import httpx +from mirrorsrun.proxy.direct import direct_proxy +from mirrorsrun.proxy.file_cache import try_file_based_cache from starlette.requests import Request from starlette.responses import Response -from mirrorsrun.docker_utils import get_docker_token -from mirrorsrun.proxy.direct import direct_proxy -from mirrorsrun.proxy.file_cache import try_file_based_cache -from mirrorsrun.sites.k8s import try_extract_image_name - logger = logging.getLogger(__name__) -BASE_URL = "https://registry-1.docker.io" +HEADER_AUTH_KEY = "www-authenticate" + +service_realm_mapping = {} + +# https://github.com/opencontainers/distribution-spec/blob/main/spec.md +name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*" +reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}" -def inject_token(name: str, req: Request, httpx_req: httpx.Request): - docker_token = get_docker_token(f"{name}") - httpx_req.headers["Authorization"] = f"Bearer {docker_token}" - return httpx_req +def try_extract_image_name(path): + pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$" + match = re.search(pattern, path) + + if match: + assert len(match.groups()) == 3 + name, resource, reference = match.groups() + assert re.match(name_regex, name) + assert re.match(reference_regex, reference) + assert resource in ["manifests", "blobs", "tags"] + return name, resource, reference + + return None, None, None -async def post_process(request: Request, response: Response): - if response.status_code == 307: - location = response.headers["location"] - return await try_file_based_cache(request, location) +def patch_auth_realm(request: Request, response: Response): + # https://registry-1.docker.io/v2/ + # < www-authenticate: Bearer realm="https://auth.docker.io/token",service="registry.docker.io" + + auth = response.headers.get(HEADER_AUTH_KEY, "") + if auth.startswith("Bearer "): + parts = auth.removeprefix("Bearer ").split(",") + + auth_values = {} + for value in parts: + key, value = value.split("=") + value = value.strip('"') + auth_values[key] = value + + assert "realm" in auth_values + assert "service" in auth_values + service_realm_mapping[auth_values["service"]] = auth_values["realm"] + + mirror_url = f"{request.url.scheme}://{request.url.netloc}" + new_token_url = mirror_url + "/token" + response.headers[HEADER_AUTH_KEY] = auth.replace( + auth_values["realm"], new_token_url + ) return response -async def docker(request: Request): - path = request.url.path - if not path.startswith("/v2/"): - return Response(content="Not Found", status_code=404) +def build_docker_registry_handler(base_url: str, name_mapper=lambda x: x): + async def handler(request: Request): + path = request.url.path + if path == "/token": - if path == "/v2/": - return Response(content="OK") - # return await direct_proxy(request, BASE_URL + '/v2/') + params = request.query_params + scope = params.get("scope", "") + service = params.get("service", "") + parts = scope.split(":") + assert service + assert len(parts) == 3 + assert parts[0] == "repository" + assert parts[1] # name + assert parts[2] == "pull" + parts[1] = name_mapper(parts[1]) - name, resource, reference = try_extract_image_name(path) + scope = ":".join(parts) - if not name: - return Response(content="404 Not Found", status_code=404) + if not scope or not service: + return Response(content="Bad Request", status_code=400) - # support docker pull xxx which name without library + new_params = { + "scope": scope, + "service": service, + } + query = "&".join([f"{k}={v}" for k, v in new_params.items()]) + + return await direct_proxy( + request, service_realm_mapping[service] + "?" + query + ) + + if path == "/v2/": + return await direct_proxy( + request, base_url + "/v2/", post_process=patch_auth_realm + ) + + if not path.startswith("/v2/"): + return Response(content="Not Found", status_code=404) + + name, resource, reference = try_extract_image_name(path) + + if not name: + return Response(content="404 Not Found", status_code=404) + + name = name_mapper(name) + + target_url = base_url + f"/v2/{name}/{resource}/{reference}" + + logger.info( + f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}" + ) + + if resource == "blobs": + return await try_file_based_cache(request, target_url) + + return await direct_proxy( + request, + target_url, + ) + + return handler + + +def dockerhub_name_mapper(name): + # support docker pull xxx which name without library for dockerhub if "/" not in name: - name = f"library/{name}" + return f"library/{name}" + return name - target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}" - logger.info( - f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}" - ) - - return await direct_proxy( - request, - target_url, - pre_process=lambda req, http_req: inject_token(name, req, http_req), - post_process=post_process, # cache in post_process - ) +k8s = build_docker_registry_handler( + "https://registry.k8s.io", +) +quay = build_docker_registry_handler( + "https://quay.io", +) +ghcr = build_docker_registry_handler( + "https://ghcr.io", +) +dockerhub = build_docker_registry_handler( + "https://registry-1.docker.io", name_mapper=dockerhub_name_mapper +) diff --git a/src/mirrorsrun/sites/k8s.py b/src/mirrorsrun/sites/k8s.py deleted file mode 100644 index 965a4cb..0000000 --- a/src/mirrorsrun/sites/k8s.py +++ /dev/null @@ -1,50 +0,0 @@ -import logging - -from starlette.requests import Request -from starlette.responses import Response - -from mirrorsrun.docker_utils import try_extract_image_name -from mirrorsrun.proxy.direct import direct_proxy -from mirrorsrun.proxy.file_cache import try_file_based_cache - -logger = logging.getLogger(__name__) - -BASE_URL = "https://registry.k8s.io" - - -async def post_process(request: Request, response: Response): - if response.status_code == 307: - location = response.headers["location"] - - if "/blobs/" in request.url.path: - return await try_file_based_cache(request, location) - - return await direct_proxy(request, location) - - return response - - -async def k8s(request: Request): - path = request.url.path - if not path.startswith("/v2/"): - return Response(content="Not Found", status_code=404) - - if path == "/v2/": - return Response(content="OK") - - name, resource, reference = try_extract_image_name(path) - - if not name: - return Response(content="404 Not Found", status_code=404) - - target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}" - - logger.info( - f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}" - ) - - return await direct_proxy( - request, - target_url, - post_process=post_process, - ) diff --git a/src/mirrorsrun/sites/npm.py b/src/mirrorsrun/sites/npm.py index 3fc0beb..76c7213 100644 --- a/src/mirrorsrun/sites/npm.py +++ b/src/mirrorsrun/sites/npm.py @@ -2,7 +2,7 @@ from starlette.requests import Request from mirrorsrun.proxy.direct import direct_proxy -BASE_URL = "https://registry.npmjs.org/" +BASE_URL = "https://registry.npmjs.org" async def npm(request: Request): diff --git a/test/mirrors_test.py b/test/mirrors_test.py index 36c37f6..d472f4d 100644 --- a/test/mirrors_test.py +++ b/test/mirrors_test.py @@ -19,5 +19,11 @@ class TestPypi(unittest.TestCase): def test_dockerhub_pull(self): call(f"docker pull docker.local.homeinfra.org/alpine:3.12") + def test_ghcr_pull(self): + call(f"docker pull ghcr.local.homeinfra.org/linuxcontainers/alpine") + + def test_quay_pull(self): + call(f"docker pull quay.local.homeinfra.org/quay/busybox") + def test_k8s_pull(self): call(f"docker pull k8s.local.homeinfra.org/pause:3.5")