mirror of
https://github.com/NoCLin/LightMirrors
synced 2025-06-17 09:25:25 +08:00
feat: add k8s
This commit is contained in:
parent
b6a13cf1ff
commit
d06c749f02
@ -4,7 +4,7 @@ ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://aria2:6800/jsonrpc")
|
||||
RPC_SECRET = os.environ.get("RPC_SECRET", "")
|
||||
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org")
|
||||
|
||||
SCHEME = os.environ.get("SCHEME", None)
|
||||
SCHEME = os.environ.get("SCHEME", "http").lower()
|
||||
assert SCHEME in ["http", "https"]
|
||||
|
||||
CACHE_DIR = os.environ.get("CACHE_DIR", "/app/cache/")
|
||||
|
68
src/mirrorsrun/docker_utils.py
Normal file
68
src/mirrorsrun/docker_utils.py
Normal file
@ -0,0 +1,68 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Dict
|
||||
import httpx
|
||||
|
||||
|
||||
class CachedToken:
|
||||
token: str
|
||||
exp: int
|
||||
|
||||
def __init__(self, token, exp):
|
||||
self.token = token
|
||||
self.exp = exp
|
||||
|
||||
|
||||
cached_tokens: Dict[str, CachedToken] = {}
|
||||
|
||||
|
||||
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
|
||||
name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*"
|
||||
reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}"
|
||||
|
||||
|
||||
def try_extract_image_name(path):
|
||||
pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$"
|
||||
match = re.search(pattern, path)
|
||||
|
||||
if match:
|
||||
assert len(match.groups()) == 3
|
||||
name, resource, reference = match.groups()
|
||||
assert re.match(name_regex, name)
|
||||
assert re.match(reference_regex, reference)
|
||||
assert resource in ["manifests", "blobs", "tags"]
|
||||
return name, resource, reference
|
||||
|
||||
return None, None, None
|
||||
|
||||
|
||||
def get_docker_token(name):
|
||||
cached = cached_tokens.get(name, None)
|
||||
if cached and cached.exp > time.time():
|
||||
return cached.token
|
||||
|
||||
url = "https://auth.docker.io/token"
|
||||
params = {
|
||||
"scope": f"repository:{name}:pull",
|
||||
"service": "registry.docker.io",
|
||||
}
|
||||
|
||||
client = httpx.Client()
|
||||
response = client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
token = token_data["token"]
|
||||
payload = token.split(".")[1]
|
||||
padding = len(payload) % 4
|
||||
payload += "=" * padding
|
||||
|
||||
payload = json.loads(base64.b64decode(payload))
|
||||
assert payload["iss"] == "auth.docker.io"
|
||||
assert len(payload["access"]) > 0
|
||||
|
||||
cached_tokens[name] = CachedToken(exp=payload["exp"], token=token)
|
||||
|
||||
return token
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import typing
|
||||
from typing import Callable, Coroutine
|
||||
|
||||
@ -18,13 +19,46 @@ AsyncPostProcessor = Callable[
|
||||
[Request, Response], Coroutine[Request, Response, Response]
|
||||
]
|
||||
|
||||
PreProcessor = typing.Union[SyncPreProcessor, AsyncPreProcessor, None]
|
||||
PostProcessor = typing.Union[SyncPostProcessor, AsyncPostProcessor, None]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def pre_process_request(
|
||||
request: Request,
|
||||
httpx_req: HttpxRequest,
|
||||
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
|
||||
):
|
||||
if pre_process:
|
||||
new_httpx_req = pre_process(request, httpx_req)
|
||||
if isinstance(new_httpx_req, HttpxRequest):
|
||||
httpx_req = new_httpx_req
|
||||
else:
|
||||
httpx_req = await new_httpx_req
|
||||
return httpx_req
|
||||
|
||||
|
||||
async def post_process_response(
|
||||
request: Request,
|
||||
response: Response,
|
||||
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
|
||||
):
|
||||
if post_process:
|
||||
new_res = post_process(request, response)
|
||||
if isinstance(new_res, Response):
|
||||
return new_res
|
||||
elif isinstance(new_res, Coroutine):
|
||||
return await new_res
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
async def direct_proxy(
|
||||
request: Request,
|
||||
target_url: str,
|
||||
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
|
||||
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
|
||||
cache_ttl: int = 3600,
|
||||
) -> Response:
|
||||
# httpx will use the following environment variables to determine the proxy
|
||||
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
|
||||
@ -40,12 +74,7 @@ async def direct_proxy(
|
||||
headers=req_headers,
|
||||
)
|
||||
|
||||
if pre_process:
|
||||
new_httpx_req = pre_process(request, httpx_req)
|
||||
if isinstance(new_httpx_req, HttpxRequest):
|
||||
httpx_req = new_httpx_req
|
||||
else:
|
||||
httpx_req = await new_httpx_req
|
||||
httpx_req = await pre_process_request(request, httpx_req, pre_process)
|
||||
|
||||
upstream_response = await client.send(httpx_req)
|
||||
|
||||
@ -54,6 +83,10 @@ async def direct_proxy(
|
||||
res_headers.pop("content-length", None)
|
||||
res_headers.pop("content-encoding", None)
|
||||
|
||||
logger.info(
|
||||
f"proxy {request.url} to {target_url} {upstream_response.status_code}"
|
||||
)
|
||||
|
||||
content = upstream_response.content
|
||||
response = Response(
|
||||
headers=res_headers,
|
||||
@ -61,13 +94,6 @@ async def direct_proxy(
|
||||
status_code=upstream_response.status_code,
|
||||
)
|
||||
|
||||
if post_process:
|
||||
new_res = post_process(request, response)
|
||||
if isinstance(new_res, Response):
|
||||
final_res = new_res
|
||||
elif isinstance(new_res, Coroutine):
|
||||
final_res = await new_res
|
||||
else:
|
||||
final_res = response
|
||||
response = await post_process_response(request, response, post_process)
|
||||
|
||||
return final_res
|
||||
return response
|
||||
|
@ -7,15 +7,12 @@ from enum import Enum
|
||||
from urllib.parse import urlparse, quote
|
||||
|
||||
import httpx
|
||||
from mirrorsrun.aria2_api import add_download
|
||||
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT
|
||||
|
||||
from mirrorsrun.aria2_api import add_download
|
||||
|
||||
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
|
||||
from typing import Optional, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -76,14 +73,11 @@ async def try_file_based_cache(
|
||||
request: Request,
|
||||
target_url: str,
|
||||
download_wait_time: int = 60,
|
||||
post_process: Optional[Callable[[Request, Response], Response]] = None,
|
||||
) -> Response:
|
||||
cache_status = lookup_cache(target_url)
|
||||
if cache_status == DownloadingStatus.DOWNLOADED:
|
||||
resp = make_cached_response(target_url)
|
||||
if post_process:
|
||||
resp = post_process(request, resp)
|
||||
return resp
|
||||
logger.info(f"Cache hit for {target_url}")
|
||||
return make_cached_response(target_url)
|
||||
|
||||
if cache_status == DownloadingStatus.DOWNLOADING:
|
||||
logger.info(f"Download is not finished, return 503 for {target_url}")
|
||||
@ -95,14 +89,15 @@ async def try_file_based_cache(
|
||||
assert cache_status == DownloadingStatus.NOT_FOUND
|
||||
|
||||
cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
|
||||
print("prepare to download", target_url, cache_file, cache_file_dir)
|
||||
logger.info(f"prepare to cache, {target_url=} {cache_file=} {cache_file_dir=}")
|
||||
|
||||
processed_url = quote(target_url, safe="/:?=&%")
|
||||
|
||||
try:
|
||||
logger.info(f"Start download {processed_url}")
|
||||
await add_download(processed_url, save_dir=cache_file_dir)
|
||||
except Exception as e:
|
||||
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e)
|
||||
logger.error(f"Download error, return 500 for {target_url}", exc_info=e)
|
||||
return Response(
|
||||
content=f"Failed to add download: {e}",
|
||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@ -113,6 +108,7 @@ async def try_file_based_cache(
|
||||
await sleep(1)
|
||||
cache_status = lookup_cache(target_url)
|
||||
if cache_status == DownloadingStatus.DOWNLOADED:
|
||||
logger.info(f"Cache hit for {target_url}")
|
||||
return make_cached_response(target_url)
|
||||
logger.info(f"Download is not finished, return 503 for {target_url}")
|
||||
return Response(
|
||||
|
@ -1,11 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__))) # noqa: E402
|
||||
|
||||
import base64
|
||||
import signal
|
||||
import urllib.parse
|
||||
from typing import Callable
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
@ -25,6 +27,19 @@ from mirrorsrun.sites.docker import docker
|
||||
from mirrorsrun.sites.npm import npm
|
||||
from mirrorsrun.sites.pypi import pypi
|
||||
from mirrorsrun.sites.torch import torch
|
||||
from mirrorsrun.sites.k8s import k8s
|
||||
|
||||
subdomain_mapping = {
|
||||
"pypi": pypi,
|
||||
"torch": torch,
|
||||
"docker": docker,
|
||||
"npm": npm,
|
||||
"k8s": k8s,
|
||||
}
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@ -73,14 +88,10 @@ async def capture_request(request: Request, call_next: Callable):
|
||||
if hostname.startswith("aria2."):
|
||||
return await aria2(request, call_next)
|
||||
|
||||
if hostname.startswith("pypi."):
|
||||
return await pypi(request)
|
||||
if hostname.startswith("torch."):
|
||||
return await torch(request)
|
||||
if hostname.startswith("docker."):
|
||||
return await docker(request)
|
||||
if hostname.startswith("npm."):
|
||||
return await npm(request)
|
||||
subdomain = hostname.split(".")[0]
|
||||
|
||||
if subdomain in subdomain_mapping:
|
||||
return await subdomain_mapping[subdomain](request)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@ -88,10 +99,10 @@ async def capture_request(request: Request, call_next: Callable):
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
port = 80
|
||||
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
|
||||
logger.info(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
|
||||
|
||||
for dn in ["pypi", "torch", "docker", "npm"]:
|
||||
print(f" - {SCHEME}://{dn}.{BASE_DOMAIN}")
|
||||
for dn in subdomain_mapping.keys():
|
||||
logger.info(f" - {SCHEME}://{dn}.{BASE_DOMAIN}")
|
||||
|
||||
aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()
|
||||
|
||||
@ -106,14 +117,13 @@ if __name__ == "__main__":
|
||||
query_string = urllib.parse.urlencode(params)
|
||||
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
|
||||
|
||||
print(f"Download manager (Aria2) at {aria2_url_with_auth}")
|
||||
# FIXME: only proxy headers if SCHEME is https
|
||||
# reload only in dev mode
|
||||
logger.info(f"Download manager (Aria2) at {aria2_url_with_auth}")
|
||||
|
||||
uvicorn.run(
|
||||
app="server:app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
proxy_headers=True,
|
||||
reload=True, # TODO: reload only in dev mode
|
||||
proxy_headers=True, # trust x-forwarded-for etc.
|
||||
forwarded_allow_ips="*",
|
||||
)
|
||||
|
@ -1,83 +1,19 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import httpx
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mirrorsrun.docker_utils import get_docker_token
|
||||
from mirrorsrun.proxy.direct import direct_proxy
|
||||
from mirrorsrun.proxy.file_cache import try_file_based_cache
|
||||
from mirrorsrun.sites.k8s import try_extract_image_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_URL = "https://registry-1.docker.io"
|
||||
|
||||
|
||||
class CachedToken:
|
||||
token: str
|
||||
exp: int
|
||||
|
||||
def __init__(self, token, exp):
|
||||
self.token = token
|
||||
self.exp = exp
|
||||
|
||||
|
||||
cached_tokens: Dict[str, CachedToken] = {}
|
||||
|
||||
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
|
||||
name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*"
|
||||
reference_regex = "[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}"
|
||||
|
||||
|
||||
def try_extract_image_name(path):
|
||||
pattern = r"^/v2/(.*)/([a-zA-Z]+)/(.*)$"
|
||||
match = re.search(pattern, path)
|
||||
|
||||
if match:
|
||||
assert len(match.groups()) == 3
|
||||
name, resource, reference = match.groups()
|
||||
assert re.match(name_regex, name)
|
||||
assert re.match(reference_regex, reference)
|
||||
assert resource in ["manifests", "blobs", "tags"]
|
||||
return name, resource, reference
|
||||
|
||||
return None, None, None
|
||||
|
||||
|
||||
def get_docker_token(name):
|
||||
cached = cached_tokens.get(name, None)
|
||||
if cached and cached.exp > time.time():
|
||||
return cached.token
|
||||
|
||||
url = "https://auth.docker.io/token"
|
||||
params = {
|
||||
"scope": f"repository:{name}:pull",
|
||||
"service": "registry.docker.io",
|
||||
}
|
||||
|
||||
client = httpx.Client()
|
||||
response = client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
token = token_data["token"]
|
||||
payload = token.split(".")[1]
|
||||
padding = len(payload) % 4
|
||||
payload += "=" * padding
|
||||
|
||||
payload = json.loads(base64.b64decode(payload))
|
||||
assert payload["iss"] == "auth.docker.io"
|
||||
assert len(payload["access"]) > 0
|
||||
|
||||
cached_tokens[name] = CachedToken(exp=payload["exp"], token=token)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def inject_token(name: str, req: Request, httpx_req: httpx.Request):
|
||||
docker_token = get_docker_token(f"{name}")
|
||||
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
|
||||
@ -112,11 +48,13 @@ async def docker(request: Request):
|
||||
|
||||
target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}"
|
||||
|
||||
logger.info(f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}")
|
||||
logger.info(
|
||||
f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}"
|
||||
)
|
||||
|
||||
return await direct_proxy(
|
||||
request,
|
||||
target_url,
|
||||
pre_process=lambda req, http_req: inject_token(name, req, http_req),
|
||||
post_process=post_process,
|
||||
post_process=post_process, # cache in post_process
|
||||
)
|
||||
|
50
src/mirrorsrun/sites/k8s.py
Normal file
50
src/mirrorsrun/sites/k8s.py
Normal file
@ -0,0 +1,50 @@
|
||||
import logging
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mirrorsrun.docker_utils import try_extract_image_name
|
||||
from mirrorsrun.proxy.direct import direct_proxy
|
||||
from mirrorsrun.proxy.file_cache import try_file_based_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_URL = "https://registry.k8s.io"
|
||||
|
||||
|
||||
async def post_process(request: Request, response: Response):
|
||||
if response.status_code == 307:
|
||||
location = response.headers["location"]
|
||||
|
||||
if "/blobs/" in request.url.path:
|
||||
return await try_file_based_cache(request, location)
|
||||
|
||||
return await direct_proxy(request, location)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def k8s(request: Request):
|
||||
path = request.url.path
|
||||
if not path.startswith("/v2/"):
|
||||
return Response(content="Not Found", status_code=404)
|
||||
|
||||
if path == "/v2/":
|
||||
return Response(content="OK")
|
||||
|
||||
name, resource, reference = try_extract_image_name(path)
|
||||
|
||||
if not name:
|
||||
return Response(content="404 Not Found", status_code=404)
|
||||
|
||||
target_url = BASE_URL + f"/v2/{name}/{resource}/{reference}"
|
||||
|
||||
logger.info(
|
||||
f"got docker request, {path=} {name=} {resource=} {reference=} {target_url=}"
|
||||
)
|
||||
|
||||
return await direct_proxy(
|
||||
request,
|
||||
target_url,
|
||||
post_process=post_process,
|
||||
)
|
@ -1,2 +1,3 @@
|
||||
[flake8]
|
||||
max-line-length = 99
|
||||
max-line-length = 99
|
||||
ignore = E402
|
@ -16,5 +16,8 @@ class TestPypi(unittest.TestCase):
|
||||
def test_torch_http(self):
|
||||
call(f"pip download -i {TORCH_INDEX} tqdm --trusted-host {TORCH_HOST} --dest /tmp/torch/")
|
||||
|
||||
def test_docker_pull(self):
|
||||
def test_dockerhub_pull(self):
|
||||
call(f"docker pull docker.local.homeinfra.org/alpine:3.12")
|
||||
|
||||
def test_k8s_pull(self):
|
||||
call(f"docker pull k8s.local.homeinfra.org/pause:3.5")
|
||||
|
Loading…
x
Reference in New Issue
Block a user