feat: add k8s
Some checks failed
Deploy Jekyll with GitHub Pages dependencies preinstalled / build (push) Has been cancelled
Deploy Jekyll with GitHub Pages dependencies preinstalled / deploy (push) Has been cancelled

This commit is contained in:
Anonymous 2024-06-11 23:13:25 +08:00
parent b6a13cf1ff
commit d06c749f02
9 changed files with 208 additions and 116 deletions

View File

@ -4,7 +4,7 @@ ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://aria2:6800/jsonrpc")
RPC_SECRET = os.environ.get("RPC_SECRET", "") RPC_SECRET = os.environ.get("RPC_SECRET", "")
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org") BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org")
SCHEME = os.environ.get("SCHEME", None) SCHEME = os.environ.get("SCHEME", "http").lower()
assert SCHEME in ["http", "https"] assert SCHEME in ["http", "https"]
CACHE_DIR = os.environ.get("CACHE_DIR", "/app/cache/") CACHE_DIR = os.environ.get("CACHE_DIR", "/app/cache/")

View File

@ -0,0 +1,68 @@
import base64
import json
import re
import time
from typing import Dict
import httpx
class CachedToken:
token: str
exp: int
def __init__(self, token, exp):
self.token = token
self.exp = exp
cached_tokens: Dict[str, CachedToken] = {}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*"
reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}"
def try_extract_image_name(path):
pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$"
match = re.search(pattern, path)
if match:
assert len(match.groups()) == 3
name, resource, reference = match.groups()
assert re.match(name_regex, name)
assert re.match(reference_regex, reference)
assert resource in ["manifests", "blobs", "tags"]
return name, resource, reference
return None, None, None
def get_docker_token(name):
cached = cached_tokens.get(name, None)
if cached and cached.exp > time.time():
return cached.token
url = "https://auth.docker.io/token"
params = {
"scope": f"repository:{name}:pull",
"service": "registry.docker.io",
}
client = httpx.Client()
response = client.get(url, params=params)
response.raise_for_status()
token_data = response.json()
token = token_data["token"]
payload = token.split(".")[1]
padding = len(payload) % 4
payload += "=" * padding
payload = json.loads(base64.b64decode(payload))
assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0
cached_tokens[name] = CachedToken(exp=payload["exp"], token=token)
return token

View File

@ -1,3 +1,4 @@
import logging
import typing import typing
from typing import Callable, Coroutine from typing import Callable, Coroutine
@ -18,13 +19,46 @@ AsyncPostProcessor = Callable[
[Request, Response], Coroutine[Request, Response, Response] [Request, Response], Coroutine[Request, Response, Response]
] ]
PreProcessor = typing.Union[SyncPreProcessor, AsyncPreProcessor, None]
PostProcessor = typing.Union[SyncPostProcessor, AsyncPostProcessor, None]
logger = logging.getLogger(__name__)
async def pre_process_request(
request: Request,
httpx_req: HttpxRequest,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
):
if pre_process:
new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
return httpx_req
async def post_process_response(
request: Request,
response: Response,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
):
if post_process:
new_res = post_process(request, response)
if isinstance(new_res, Response):
return new_res
elif isinstance(new_res, Coroutine):
return await new_res
else:
return response
async def direct_proxy( async def direct_proxy(
request: Request, request: Request,
target_url: str, target_url: str,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None, pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None, post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
cache_ttl: int = 3600,
) -> Response: ) -> Response:
# httpx will use the following environment variables to determine the proxy # httpx will use the following environment variables to determine the proxy
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy # https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
@ -40,12 +74,7 @@ async def direct_proxy(
headers=req_headers, headers=req_headers,
) )
if pre_process: httpx_req = await pre_process_request(request, httpx_req, pre_process)
new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
upstream_response = await client.send(httpx_req) upstream_response = await client.send(httpx_req)
@ -54,6 +83,10 @@ async def direct_proxy(
res_headers.pop("content-length", None) res_headers.pop("content-length", None)
res_headers.pop("content-encoding", None) res_headers.pop("content-encoding", None)
logger.info(
f"proxy {request.url} to {target_url} {upstream_response.status_code}"
)
content = upstream_response.content content = upstream_response.content
response = Response( response = Response(
headers=res_headers, headers=res_headers,
@ -61,13 +94,6 @@ async def direct_proxy(
status_code=upstream_response.status_code, status_code=upstream_response.status_code,
) )
if post_process: response = await post_process_response(request, response, post_process)
new_res = post_process(request, response)
if isinstance(new_res, Response):
final_res = new_res
elif isinstance(new_res, Coroutine):
final_res = await new_res
else:
final_res = response
return final_res return response

View File

@ -7,15 +7,12 @@ from enum import Enum
from urllib.parse import urlparse, quote from urllib.parse import urlparse, quote
import httpx import httpx
from mirrorsrun.aria2_api import add_download
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT
from mirrorsrun.aria2_api import add_download
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
from typing import Optional, Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -76,14 +73,11 @@ async def try_file_based_cache(
request: Request, request: Request,
target_url: str, target_url: str,
download_wait_time: int = 60, download_wait_time: int = 60,
post_process: Optional[Callable[[Request, Response], Response]] = None,
) -> Response: ) -> Response:
cache_status = lookup_cache(target_url) cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED: if cache_status == DownloadingStatus.DOWNLOADED:
resp = make_cached_response(target_url) logger.info(f"Cache hit for {target_url}")
if post_process: return make_cached_response(target_url)
resp = post_process(request, resp)
return resp
if cache_status == DownloadingStatus.DOWNLOADING: if cache_status == DownloadingStatus.DOWNLOADING:
logger.info(f"Download is not finished, return 503 for {target_url}") logger.info(f"Download is not finished, return 503 for {target_url}")
@ -95,14 +89,15 @@ async def try_file_based_cache(
assert cache_status == DownloadingStatus.NOT_FOUND assert cache_status == DownloadingStatus.NOT_FOUND
cache_file, cache_file_dir = get_cache_file_and_folder(target_url) cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
print("prepare to download", target_url, cache_file, cache_file_dir) logger.info(f"prepare to cache, {target_url=} {cache_file=} {cache_file_dir=}")
processed_url = quote(target_url, safe="/:?=&%") processed_url = quote(target_url, safe="/:?=&%")
try: try:
logger.info(f"Start download {processed_url}")
await add_download(processed_url, save_dir=cache_file_dir) await add_download(processed_url, save_dir=cache_file_dir)
except Exception as e: except Exception as e:
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e) logger.error(f"Download error, return 500 for {target_url}", exc_info=e)
return Response( return Response(
content=f"Failed to add download: {e}", content=f"Failed to add download: {e}",
status_code=HTTP_500_INTERNAL_SERVER_ERROR, status_code=HTTP_500_INTERNAL_SERVER_ERROR,
@ -113,6 +108,7 @@ async def try_file_based_cache(
await sleep(1) await sleep(1)
cache_status = lookup_cache(target_url) cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED: if cache_status == DownloadingStatus.DOWNLOADED:
logger.info(f"Cache hit for {target_url}")
return make_cached_response(target_url) return make_cached_response(target_url)
logger.info(f"Download is not finished, return 503 for {target_url}") logger.info(f"Download is not finished, return 503 for {target_url}")
return Response( return Response(

View File

@ -1,11 +1,13 @@
import os import os
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(__file__))) # noqa: E402
import base64 import base64
import signal import signal
import urllib.parse import urllib.parse
from typing import Callable from typing import Callable
import logging
import httpx import httpx
import uvicorn import uvicorn
@ -25,6 +27,19 @@ from mirrorsrun.sites.docker import docker
from mirrorsrun.sites.npm import npm from mirrorsrun.sites.npm import npm
from mirrorsrun.sites.pypi import pypi from mirrorsrun.sites.pypi import pypi
from mirrorsrun.sites.torch import torch from mirrorsrun.sites.torch import torch
from mirrorsrun.sites.k8s import k8s
subdomain_mapping = {
"pypi": pypi,
"torch": torch,
"docker": docker,
"npm": npm,
"k8s": k8s,
}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI() app = FastAPI()
@ -73,14 +88,10 @@ async def capture_request(request: Request, call_next: Callable):
if hostname.startswith("aria2."): if hostname.startswith("aria2."):
return await aria2(request, call_next) return await aria2(request, call_next)
if hostname.startswith("pypi."): subdomain = hostname.split(".")[0]
return await pypi(request)
if hostname.startswith("torch."): if subdomain in subdomain_mapping:
return await torch(request) return await subdomain_mapping[subdomain](request)
if hostname.startswith("docker."):
return await docker(request)
if hostname.startswith("npm."):
return await npm(request)
return await call_next(request) return await call_next(request)
@ -88,10 +99,10 @@ async def capture_request(request: Request, call_next: Callable):
if __name__ == "__main__": if __name__ == "__main__":
signal.signal(signal.SIGINT, signal.SIG_DFL) signal.signal(signal.SIGINT, signal.SIG_DFL)
port = 80 port = 80
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})") logger.info(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
for dn in ["pypi", "torch", "docker", "npm"]: for dn in subdomain_mapping.keys():
print(f" - {SCHEME}://{dn}.{BASE_DOMAIN}") logger.info(f" - {SCHEME}://{dn}.{BASE_DOMAIN}")
aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode() aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()
@ -106,14 +117,13 @@ if __name__ == "__main__":
query_string = urllib.parse.urlencode(params) query_string = urllib.parse.urlencode(params)
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
print(f"Download manager (Aria2) at {aria2_url_with_auth}") logger.info(f"Download manager (Aria2) at {aria2_url_with_auth}")
# FIXME: only proxy headers if SCHEME is https
# reload only in dev mode
uvicorn.run( uvicorn.run(
app="server:app", app="server:app",
host="0.0.0.0", host="0.0.0.0",
port=port, port=port,
reload=True, reload=True, # TODO: reload only in dev mode
proxy_headers=True, proxy_headers=True, # trust x-forwarded-for etc.
forwarded_allow_ips="*", forwarded_allow_ips="*",
) )

View File

@ -1,83 +1,19 @@
import base64
import json
import logging import logging
import re
import time
from typing import Dict
import httpx import httpx
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from mirrorsrun.docker_utils import get_docker_token
from mirrorsrun.proxy.direct import direct_proxy from mirrorsrun.proxy.direct import direct_proxy
from mirrorsrun.proxy.file_cache import try_file_based_cache from mirrorsrun.proxy.file_cache import try_file_based_cache
from mirrorsrun.sites.k8s import try_extract_image_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BASE_URL = "https://registry-1.docker.io" BASE_URL = "https://registry-1.docker.io"
class CachedToken:
token: str
exp: int
def __init__(self, token, exp):
self.token = token
self.exp = exp
cached_tokens: Dict[str, CachedToken] = {}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*"
reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}"
def try_extract_image_name(path):
pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$"
match = re.search(pattern, path)
if match:
assert len(match.groups()) == 3
name, resource, reference = match.groups()
assert re.match(name_regex, name)
assert re.match(reference_regex, reference)
assert resource in ["manifests", "blobs", "tags"]
return name, resource, reference
return None, None, None
def get_docker_token(name):
cached = cached_tokens.get(name, None)
if cached and cached.exp > time.time():
return cached.token
url = "https://auth.docker.io/token"
params = {
"scope": f"repository:{name}:pull",
"service": "registry.docker.io",
}
client = httpx.Client()
response = client.get(url, params=params)
response.raise_for_status()
token_data = response.json()
token = token_data["token"]
payload = token.split(".")[1]
padding = len(payload) % 4
payload += "=" * padding
payload = json.loads(base64.b64decode(payload))
assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0
cached_tokens[name] = CachedToken(exp=payload["exp"], token=token)
return token
def inject_token(name: str, req: Request, httpx_req: httpx.Request): def inject_token(name: str, req: Request, httpx_req: httpx.Request):
docker_token = get_docker_token(f"{name}") docker_token = get_docker_token(f"{name}")
httpx_req.headers["Authorization"] = f"Bearer {docker_token}" httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
@ -112,11 +48,13 @@ async def docker(request: Request):
target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}" target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}"
logger.info(f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}") logger.info(
f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}"
)
return await direct_proxy( return await direct_proxy(
request, request,
target_url, target_url,
pre_process=lambda req, http_req: inject_token(name, req, http_req), pre_process=lambda req, http_req: inject_token(name, req, http_req),
post_process=post_process, post_process=post_process, # cache in post_process
) )

View File

@ -0,0 +1,50 @@
import logging
from starlette.requests import Request
from starlette.responses import Response
from mirrorsrun.docker_utils import try_extract_image_name
from mirrorsrun.proxy.direct import direct_proxy
from mirrorsrun.proxy.file_cache import try_file_based_cache
logger = logging.getLogger(__name__)
BASE_URL = "https://registry.k8s.io"
async def post_process(request: Request, response: Response):
if response.status_code == 307:
location = response.headers["location"]
if "/blobs/" in request.url.path:
return await try_file_based_cache(request, location)
return await direct_proxy(request, location)
return response
async def k8s(request: Request):
path = request.url.path
if not path.startswith("/v2/"):
return Response(content="Not Found", status_code=404)
if path == "/v2/":
return Response(content="OK")
name, resource, reference = try_extract_image_name(path)
if not name:
return Response(content="404 Not Found", status_code=404)
target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}"
logger.info(
f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}"
)
return await direct_proxy(
request,
target_url,
post_process=post_process,
)

View File

@ -1,2 +1,3 @@
[flake8] [flake8]
max-line-length = 99 max-line-length = 99
ignore = E402

View File

@ -16,5 +16,8 @@ class TestPypi(unittest.TestCase):
def test_torch_http(self): def test_torch_http(self):
call(f"pip download -i {TORCH_INDEX} tqdm --trusted-host {TORCH_HOST} --dest /tmp/torch/") call(f"pip download -i {TORCH_INDEX} tqdm --trusted-host {TORCH_HOST} --dest /tmp/torch/")
def test_docker_pull(self): def test_dockerhub_pull(self):
call(f"docker pull docker.local.homeinfra.org/alpine:3.12") call(f"docker pull docker.local.homeinfra.org/alpine:3.12")
def test_k8s_pull(self):
call(f"docker pull k8s.local.homeinfra.org/pause:3.5")