feat: add common docker handler and common mirrors

add `follow_redirects` options
add `out` params for `aria2.addUri`
This commit is contained in:
Anonymous 2024-06-13 20:43:57 +08:00
parent d06c749f02
commit 833a31ae80
10 changed files with 191 additions and 172 deletions

View File

@ -9,6 +9,7 @@ from mirrorsrun.config import RPC_SECRET, ARIA2_RPC_URL
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# refer to https://aria2.github.io/manual/en/html/aria2c.html
async def send_request(method, params=None): async def send_request(method, params=None):
request_id = uuid.uuid4().hex request_id = uuid.uuid4().hex
payload = { payload = {
@ -32,9 +33,20 @@ async def send_request(method, params=None):
raise e raise e
async def add_download(url, save_dir="/app/cache"): async def add_download(url, save_dir="/app/cache", out_file=None):
logger.info(f"[Aria2] add_download {url=} {save_dir=} {out_file=}")
method = "aria2.addUri" method = "aria2.addUri"
params = [[url], {"dir": save_dir, "header": []}] options = {
"dir": save_dir,
"header": [],
"out": out_file,
}
if out_file:
options["out"] = out_file
params = [[url], options]
response = await send_request(method, params) response = await send_request(method, params)
return response["result"] return response["result"]

View File

@ -1,68 +0,0 @@
import base64
import json
import re
import time
from typing import Dict
import httpx
class CachedToken:
token: str
exp: int
def __init__(self, token, exp):
self.token = token
self.exp = exp
cached_tokens: Dict[str, CachedToken] = {}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*"
reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}"
def try_extract_image_name(path):
pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$"
match = re.search(pattern, path)
if match:
assert len(match.groups()) == 3
name, resource, reference = match.groups()
assert re.match(name_regex, name)
assert re.match(reference_regex, reference)
assert resource in ["manifests", "blobs", "tags"]
return name, resource, reference
return None, None, None
def get_docker_token(name):
cached = cached_tokens.get(name, None)
if cached and cached.exp > time.time():
return cached.token
url = "https://auth.docker.io/token"
params = {
"scope": f"repository:{name}:pull",
"service": "registry.docker.io",
}
client = httpx.Client()
response = client.get(url, params=params)
response.raise_for_status()
token_data = response.json()
token = token_data["token"]
payload = token.split(".")[1]
padding = len(payload) % 4
payload += "=" * padding
payload = json.loads(base64.b64decode(payload))
assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0
cached_tokens[name] = CachedToken(exp=payload["exp"], token=token)
return token

View File

@ -59,13 +59,15 @@ async def direct_proxy(
target_url: str, target_url: str,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None, pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None, post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
follow_redirects: bool = True,
) -> Response: ) -> Response:
# httpx will use the following environment variables to determine the proxy # httpx will use the following environment variables to determine the proxy
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy # https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
req_headers = request.headers.mutablecopy() req_headers = request.headers.mutablecopy()
for key in req_headers.keys(): for key in req_headers.keys():
if key not in ["user-agent", "accept"]: if key not in ["user-agent", "accept", "authorization"]:
del req_headers[key] del req_headers[key]
httpx_req: HttpxRequest = client.build_request( httpx_req: HttpxRequest = client.build_request(
@ -76,7 +78,9 @@ async def direct_proxy(
httpx_req = await pre_process_request(request, httpx_req, pre_process) httpx_req = await pre_process_request(request, httpx_req, pre_process)
upstream_response = await client.send(httpx_req) upstream_response = await client.send(
httpx_req, follow_redirects=follow_redirects
)
res_headers = upstream_response.headers res_headers = upstream_response.headers

View File

@ -7,12 +7,13 @@ from enum import Enum
from urllib.parse import urlparse, quote from urllib.parse import urlparse, quote
import httpx import httpx
from mirrorsrun.aria2_api import add_download
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT
from mirrorsrun.aria2_api import add_download
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -80,7 +81,7 @@ async def try_file_based_cache(
return make_cached_response(target_url) return make_cached_response(target_url)
if cache_status == DownloadingStatus.DOWNLOADING: if cache_status == DownloadingStatus.DOWNLOADING:
logger.info(f"Download is not finished, return 503 for {target_url}") logger.info(f"Download is not finished, return 504 for {target_url}")
return Response( return Response(
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT, status_code=HTTP_504_GATEWAY_TIMEOUT,
@ -94,8 +95,12 @@ async def try_file_based_cache(
processed_url = quote(target_url, safe="/:?=&%") processed_url = quote(target_url, safe="/:?=&%")
try: try:
logger.info(f"Start download {processed_url}") # resolve redirect via aria2
await add_download(processed_url, save_dir=cache_file_dir) await add_download(
processed_url,
save_dir=cache_file_dir,
out_file=os.path.basename(cache_file),
)
except Exception as e: except Exception as e:
logger.error(f"Download error, return 500 for {target_url}", exc_info=e) logger.error(f"Download error, return 500 for {target_url}", exc_info=e)
return Response( return Response(
@ -110,7 +115,10 @@ async def try_file_based_cache(
if cache_status == DownloadingStatus.DOWNLOADED: if cache_status == DownloadingStatus.DOWNLOADED:
logger.info(f"Cache hit for {target_url}") logger.info(f"Cache hit for {target_url}")
return make_cached_response(target_url) return make_cached_response(target_url)
logger.info(f"Download is not finished, return 503 for {target_url}")
assert cache_status != DownloadingStatus.NOT_FOUND
logger.info(f"Download is not finished, return 504 for {target_url}")
return Response( return Response(
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT, status_code=HTTP_504_GATEWAY_TIMEOUT,

View File

@ -23,21 +23,27 @@ from mirrorsrun.config import (
EXTERNAL_HOST_ARIA2, EXTERNAL_HOST_ARIA2,
SCHEME, SCHEME,
) )
from mirrorsrun.sites.docker import docker
from mirrorsrun.sites.npm import npm from mirrorsrun.sites.npm import npm
from mirrorsrun.sites.pypi import pypi from mirrorsrun.sites.pypi import pypi
from mirrorsrun.sites.torch import torch from mirrorsrun.sites.torch import torch
from mirrorsrun.sites.k8s import k8s from mirrorsrun.sites.docker import dockerhub, k8s, quay, ghcr
from mirrorsrun.sites.common import common
subdomain_mapping = { subdomain_mapping = {
"mirrors": common,
"pypi": pypi, "pypi": pypi,
"torch": torch, "torch": torch,
"docker": docker,
"npm": npm, "npm": npm,
"docker": dockerhub,
"k8s": k8s, "k8s": k8s,
"ghcr": ghcr,
"quay": quay,
} }
logging.basicConfig(level=logging.INFO) logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -0,0 +1,18 @@
from starlette.requests import Request
from mirrorsrun.proxy.direct import direct_proxy
from starlette.responses import Response
async def common(request: Request):
path = request.url.path
if path == "/":
return
if path.startswith("/alpine"):
return await direct_proxy(request, "https://dl-cdn.alpinelinux.org" + path)
if path.startswith("/ubuntu/"):
return await direct_proxy(request, "http://archive.ubuntu.com" + path)
if path.startswith("/ubuntu-ports/"):
return await direct_proxy(request, "http://ports.ubuntu.com" + path)
return Response("Not Found", status_code=404)

View File

@ -1,60 +1,143 @@
import logging import logging
import re
import httpx from mirrorsrun.proxy.direct import direct_proxy
from mirrorsrun.proxy.file_cache import try_file_based_cache
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from mirrorsrun.docker_utils import get_docker_token
from mirrorsrun.proxy.direct import direct_proxy
from mirrorsrun.proxy.file_cache import try_file_based_cache
from mirrorsrun.sites.k8s import try_extract_image_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BASE_URL = "https://registry-1.docker.io" HEADER_AUTH_KEY = "www-authenticate"
service_realm_mapping = {}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*"
reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}"
def inject_token(name: str, req: Request, httpx_req: httpx.Request): def try_extract_image_name(path):
docker_token = get_docker_token(f"{name}") pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$"
httpx_req.headers["Authorization"] = f"Bearer {docker_token}" match = re.search(pattern, path)
return httpx_req
if match:
assert len(match.groups()) == 3
name, resource, reference = match.groups()
assert re.match(name_regex, name)
assert re.match(reference_regex, reference)
assert resource in ["manifests", "blobs", "tags"]
return name, resource, reference
return None, None, None
async def post_process(request: Request, response: Response): def patch_auth_realm(request: Request, response: Response):
if response.status_code == 307: # https://registry-1.docker.io/v2/
location = response.headers["location"] # < www-authenticate: Bearer realm="https://auth.docker.io/token",service="registry.docker.io"
return await try_file_based_cache(request, location)
auth = response.headers.get(HEADER_AUTH_KEY, "")
if auth.startswith("Bearer "):
parts = auth.removeprefix("Bearer ").split(",")
auth_values = {}
for value in parts:
key, value = value.split("=")
value = value.strip('"')
auth_values[key] = value
assert "realm" in auth_values
assert "service" in auth_values
service_realm_mapping[auth_values["service"]] = auth_values["realm"]
mirror_url = f"{request.url.scheme}://{request.url.netloc}"
new_token_url = mirror_url + "/token"
response.headers[HEADER_AUTH_KEY] = auth.replace(
auth_values["realm"], new_token_url
)
return response return response
async def docker(request: Request): def build_docker_registry_handler(base_url: str, name_mapper=lambda x: x):
async def handler(request: Request):
path = request.url.path path = request.url.path
if not path.startswith("/v2/"): if path == "/token":
return Response(content="Not Found", status_code=404)
params = request.query_params
scope = params.get("scope", "")
service = params.get("service", "")
parts = scope.split(":")
assert service
assert len(parts) == 3
assert parts[0] == "repository"
assert parts[1] # name
assert parts[2] == "pull"
parts[1] = name_mapper(parts[1])
scope = ":".join(parts)
if not scope or not service:
return Response(content="Bad Request", status_code=400)
new_params = {
"scope": scope,
"service": service,
}
query = "&".join([f"{k}={v}" for k, v in new_params.items()])
return await direct_proxy(
request, service_realm_mapping[service] + "?" + query
)
if path == "/v2/": if path == "/v2/":
return Response(content="OK") return await direct_proxy(
# return await direct_proxy(request, BASE_URL + '/v2/') request, base_url + "/v2/", post_process=patch_auth_realm
)
if not path.startswith("/v2/"):
return Response(content="Not Found", status_code=404)
name, resource, reference = try_extract_image_name(path) name, resource, reference = try_extract_image_name(path)
if not name: if not name:
return Response(content="404 Not Found", status_code=404) return Response(content="404 Not Found", status_code=404)
# support docker pull xxx which name without library name = name_mapper(name)
if "/" not in name:
name = f"library/{name}"
target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}" target_url = base_url + f"/v2/{name}/{resource}/{reference}"
logger.info( logger.info(
f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}" f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}"
) )
if resource == "blobs":
return await try_file_based_cache(request, target_url)
return await direct_proxy( return await direct_proxy(
request, request,
target_url, target_url,
pre_process=lambda req, http_req: inject_token(name, req, http_req), )
post_process=post_process, # cache in post_process
return handler
def dockerhub_name_mapper(name):
# support docker pull xxx which name without library for dockerhub
if "/" not in name:
return f"library/{name}"
return name
k8s = build_docker_registry_handler(
"https://registry.k8s.io",
)
quay = build_docker_registry_handler(
"https://quay.io",
)
ghcr = build_docker_registry_handler(
"https://ghcr.io",
)
dockerhub = build_docker_registry_handler(
"https://registry-1.docker.io", name_mapper=dockerhub_name_mapper
) )

View File

@ -1,50 +0,0 @@
import logging
from starlette.requests import Request
from starlette.responses import Response
from mirrorsrun.docker_utils import try_extract_image_name
from mirrorsrun.proxy.direct import direct_proxy
from mirrorsrun.proxy.file_cache import try_file_based_cache
logger = logging.getLogger(__name__)
BASE_URL = "https://registry.k8s.io"
async def post_process(request: Request, response: Response):
if response.status_code == 307:
location = response.headers["location"]
if "/blobs/" in request.url.path:
return await try_file_based_cache(request, location)
return await direct_proxy(request, location)
return response
async def k8s(request: Request):
path = request.url.path
if not path.startswith("/v2/"):
return Response(content="Not Found", status_code=404)
if path == "/v2/":
return Response(content="OK")
name, resource, reference = try_extract_image_name(path)
if not name:
return Response(content="404 Not Found", status_code=404)
target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}"
logger.info(
f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}"
)
return await direct_proxy(
request,
target_url,
post_process=post_process,
)

View File

@ -2,7 +2,7 @@ from starlette.requests import Request
from mirrorsrun.proxy.direct import direct_proxy from mirrorsrun.proxy.direct import direct_proxy
BASE_URL = "https://registry.npmjs.org/" BASE_URL = "https://registry.npmjs.org"
async def npm(request: Request): async def npm(request: Request):

View File

@ -19,5 +19,11 @@ class TestPypi(unittest.TestCase):
def test_dockerhub_pull(self): def test_dockerhub_pull(self):
call(f"docker pull docker.local.homeinfra.org/alpine:3.12") call(f"docker pull docker.local.homeinfra.org/alpine:3.12")
def test_ghcr_pull(self):
call(f"docker pull ghcr.local.homeinfra.org/linuxcontainers/alpine")
def test_quay_pull(self):
call(f"docker pull quay.local.homeinfra.org/quay/busybox")
def test_k8s_pull(self): def test_k8s_pull(self):
call(f"docker pull k8s.local.homeinfra.org/pause:3.5") call(f"docker pull k8s.local.homeinfra.org/pause:3.5")