feat: add k8s
Some checks failed
Deploy Jekyll with GitHub Pages dependencies preinstalled / build (push) Has been cancelled
Deploy Jekyll with GitHub Pages dependencies preinstalled / deploy (push) Has been cancelled

This commit is contained in:
Anonymous 2024-06-11 23:13:25 +08:00
parent b6a13cf1ff
commit d06c749f02
9 changed files with 208 additions and 116 deletions

View File

@ -4,7 +4,7 @@ ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://aria2:6800/jsonrpc")
RPC_SECRET = os.environ.get("RPC_SECRET", "")
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org")
SCHEME = os.environ.get("SCHEME", None)
SCHEME = os.environ.get("SCHEME", "http").lower()
assert SCHEME in ["http", "https"]
CACHE_DIR = os.environ.get("CACHE_DIR", "/app/cache/")

View File

@ -0,0 +1,68 @@
import base64
import json
import re
import time
from typing import Dict
import httpx
class CachedToken:
token: str
exp: int
def __init__(self, token, exp):
self.token = token
self.exp = exp
cached_tokens: Dict[str, CachedToken] = {}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*"
reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}"
def try_extract_image_name(path):
pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$"
match = re.search(pattern, path)
if match:
assert len(match.groups()) == 3
name, resource, reference = match.groups()
assert re.match(name_regex, name)
assert re.match(reference_regex, reference)
assert resource in ["manifests", "blobs", "tags"]
return name, resource, reference
return None, None, None
def get_docker_token(name):
cached = cached_tokens.get(name, None)
if cached and cached.exp > time.time():
return cached.token
url = "https://auth.docker.io/token"
params = {
"scope": f"repository:{name}:pull",
"service": "registry.docker.io",
}
client = httpx.Client()
response = client.get(url, params=params)
response.raise_for_status()
token_data = response.json()
token = token_data["token"]
payload = token.split(".")[1]
padding = len(payload) % 4
payload += "=" * padding
payload = json.loads(base64.b64decode(payload))
assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0
cached_tokens[name] = CachedToken(exp=payload["exp"], token=token)
return token

View File

@ -1,3 +1,4 @@
import logging
import typing
from typing import Callable, Coroutine
@ -18,13 +19,46 @@ AsyncPostProcessor = Callable[
[Request, Response], Coroutine[Request, Response, Response]
]
PreProcessor = typing.Union[SyncPreProcessor, AsyncPreProcessor, None]
PostProcessor = typing.Union[SyncPostProcessor, AsyncPostProcessor, None]
logger = logging.getLogger(__name__)
async def pre_process_request(
request: Request,
httpx_req: HttpxRequest,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
):
if pre_process:
new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
return httpx_req
async def post_process_response(
request: Request,
response: Response,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
):
if post_process:
new_res = post_process(request, response)
if isinstance(new_res, Response):
return new_res
elif isinstance(new_res, Coroutine):
return await new_res
else:
return response
async def direct_proxy(
request: Request,
target_url: str,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
cache_ttl: int = 3600,
) -> Response:
# httpx will use the following environment variables to determine the proxy
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
@ -40,12 +74,7 @@ async def direct_proxy(
headers=req_headers,
)
if pre_process:
new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
httpx_req = await pre_process_request(request, httpx_req, pre_process)
upstream_response = await client.send(httpx_req)
@ -54,6 +83,10 @@ async def direct_proxy(
res_headers.pop("content-length", None)
res_headers.pop("content-encoding", None)
logger.info(
f"proxy {request.url} to {target_url} {upstream_response.status_code}"
)
content = upstream_response.content
response = Response(
headers=res_headers,
@ -61,13 +94,6 @@ async def direct_proxy(
status_code=upstream_response.status_code,
)
if post_process:
new_res = post_process(request, response)
if isinstance(new_res, Response):
final_res = new_res
elif isinstance(new_res, Coroutine):
final_res = await new_res
else:
final_res = response
response = await post_process_response(request, response, post_process)
return final_res
return response

View File

@ -7,15 +7,12 @@ from enum import Enum
from urllib.parse import urlparse, quote
import httpx
from mirrorsrun.aria2_api import add_download
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT
from mirrorsrun.aria2_api import add_download
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
from typing import Optional, Callable
logger = logging.getLogger(__name__)
@ -76,14 +73,11 @@ async def try_file_based_cache(
request: Request,
target_url: str,
download_wait_time: int = 60,
post_process: Optional[Callable[[Request, Response], Response]] = None,
) -> Response:
cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED:
resp = make_cached_response(target_url)
if post_process:
resp = post_process(request, resp)
return resp
logger.info(f"Cache hit for {target_url}")
return make_cached_response(target_url)
if cache_status == DownloadingStatus.DOWNLOADING:
logger.info(f"Download is not finished, return 503 for {target_url}")
@ -95,14 +89,15 @@ async def try_file_based_cache(
assert cache_status == DownloadingStatus.NOT_FOUND
cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
print("prepare to download", target_url, cache_file, cache_file_dir)
logger.info(f"prepare to cache, {target_url=} {cache_file=} {cache_file_dir=}")
processed_url = quote(target_url, safe="/:?=&%")
try:
logger.info(f"Start download {processed_url}")
await add_download(processed_url, save_dir=cache_file_dir)
except Exception as e:
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e)
logger.error(f"Download error, return 500 for {target_url}", exc_info=e)
return Response(
content=f"Failed to add download: {e}",
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
@ -113,6 +108,7 @@ async def try_file_based_cache(
await sleep(1)
cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED:
logger.info(f"Cache hit for {target_url}")
return make_cached_response(target_url)
logger.info(f"Download is not finished, return 503 for {target_url}")
return Response(

View File

@ -1,11 +1,13 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(__file__))) # noqa: E402
import base64
import signal
import urllib.parse
from typing import Callable
import logging
import httpx
import uvicorn
@ -25,6 +27,19 @@ from mirrorsrun.sites.docker import docker
from mirrorsrun.sites.npm import npm
from mirrorsrun.sites.pypi import pypi
from mirrorsrun.sites.torch import torch
from mirrorsrun.sites.k8s import k8s
subdomain_mapping = {
"pypi": pypi,
"torch": torch,
"docker": docker,
"npm": npm,
"k8s": k8s,
}
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
@ -73,14 +88,10 @@ async def capture_request(request: Request, call_next: Callable):
if hostname.startswith("aria2."):
return await aria2(request, call_next)
if hostname.startswith("pypi."):
return await pypi(request)
if hostname.startswith("torch."):
return await torch(request)
if hostname.startswith("docker."):
return await docker(request)
if hostname.startswith("npm."):
return await npm(request)
subdomain = hostname.split(".")[0]
if subdomain in subdomain_mapping:
return await subdomain_mapping[subdomain](request)
return await call_next(request)
@ -88,10 +99,10 @@ async def capture_request(request: Request, call_next: Callable):
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal.SIG_DFL)
port = 80
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
logger.info(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
for dn in ["pypi", "torch", "docker", "npm"]:
print(f" - {SCHEME}://{dn}.{BASE_DOMAIN}")
for dn in subdomain_mapping.keys():
logger.info(f" - {SCHEME}://{dn}.{BASE_DOMAIN}")
aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()
@ -106,14 +117,13 @@ if __name__ == "__main__":
query_string = urllib.parse.urlencode(params)
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
print(f"Download manager (Aria2) at {aria2_url_with_auth}")
# FIXME: only proxy headers if SCHEME is https
# reload only in dev mode
logger.info(f"Download manager (Aria2) at {aria2_url_with_auth}")
uvicorn.run(
app="server:app",
host="0.0.0.0",
port=port,
reload=True,
proxy_headers=True,
reload=True, # TODO: reload only in dev mode
proxy_headers=True, # trust x-forwarded-for etc.
forwarded_allow_ips="*",
)

View File

@ -1,83 +1,19 @@
import base64
import json
import logging
import re
import time
from typing import Dict
import httpx
from starlette.requests import Request
from starlette.responses import Response
from mirrorsrun.docker_utils import get_docker_token
from mirrorsrun.proxy.direct import direct_proxy
from mirrorsrun.proxy.file_cache import try_file_based_cache
from mirrorsrun.sites.k8s import try_extract_image_name
logger = logging.getLogger(__name__)
BASE_URL = "https://registry-1.docker.io"
class CachedToken:
token: str
exp: int
def __init__(self, token, exp):
self.token = token
self.exp = exp
cached_tokens: Dict[str, CachedToken] = {}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*"
reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}"
def try_extract_image_name(path):
pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$"
match = re.search(pattern, path)
if match:
assert len(match.groups()) == 3
name, resource, reference = match.groups()
assert re.match(name_regex, name)
assert re.match(reference_regex, reference)
assert resource in ["manifests", "blobs", "tags"]
return name, resource, reference
return None, None, None
def get_docker_token(name):
cached = cached_tokens.get(name, None)
if cached and cached.exp > time.time():
return cached.token
url = "https://auth.docker.io/token"
params = {
"scope": f"repository:{name}:pull",
"service": "registry.docker.io",
}
client = httpx.Client()
response = client.get(url, params=params)
response.raise_for_status()
token_data = response.json()
token = token_data["token"]
payload = token.split(".")[1]
padding = len(payload) % 4
payload += "=" * padding
payload = json.loads(base64.b64decode(payload))
assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0
cached_tokens[name] = CachedToken(exp=payload["exp"], token=token)
return token
def inject_token(name: str, req: Request, httpx_req: httpx.Request):
docker_token = get_docker_token(f"{name}")
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
@ -112,11 +48,13 @@ async def docker(request: Request):
target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}"
logger.info(f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}")
logger.info(
f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}"
)
return await direct_proxy(
request,
target_url,
pre_process=lambda req, http_req: inject_token(name, req, http_req),
post_process=post_process,
post_process=post_process, # cache in post_process
)

View File

@ -0,0 +1,50 @@
import logging
from starlette.requests import Request
from starlette.responses import Response
from mirrorsrun.docker_utils import try_extract_image_name
from mirrorsrun.proxy.direct import direct_proxy
from mirrorsrun.proxy.file_cache import try_file_based_cache
logger = logging.getLogger(__name__)
BASE_URL = "https://registry.k8s.io"
async def post_process(request: Request, response: Response):
if response.status_code == 307:
location = response.headers["location"]
if "/blobs/" in request.url.path:
return await try_file_based_cache(request, location)
return await direct_proxy(request, location)
return response
async def k8s(request: Request):
path = request.url.path
if not path.startswith("/v2/"):
return Response(content="Not Found", status_code=404)
if path == "/v2/":
return Response(content="OK")
name, resource, reference = try_extract_image_name(path)
if not name:
return Response(content="404 Not Found", status_code=404)
target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}"
logger.info(
f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}"
)
return await direct_proxy(
request,
target_url,
post_process=post_process,
)

View File

@ -1,2 +1,3 @@
[flake8]
max-line-length = 99
max-line-length = 99
ignore = E402

View File

@ -16,5 +16,8 @@ class TestPypi(unittest.TestCase):
def test_torch_http(self):
call(f"pip download -i {TORCH_INDEX} tqdm --trusted-host {TORCH_HOST} --dest /tmp/torch/")
def test_docker_pull(self):
def test_dockerhub_pull(self):
call(f"docker pull docker.local.homeinfra.org/alpine:3.12")
def test_k8s_pull(self):
call(f"docker pull k8s.local.homeinfra.org/pause:3.5")