diff --git a/src/mirrorsrun/server.py b/src/mirrorsrun/server.py index fd965b4..ff9e01f 100644 --- a/src/mirrorsrun/server.py +++ b/src/mirrorsrun/server.py @@ -27,7 +27,7 @@ from mirrorsrun.config import ( from mirrorsrun.sites.npm import npm from mirrorsrun.sites.pypi import pypi from mirrorsrun.sites.torch import torch -from mirrorsrun.sites.docker import dockerhub, k8s, quay, ghcr +from mirrorsrun.sites.docker import dockerhub, k8s, quay, ghcr, nvcr from mirrorsrun.sites.common import common subdomain_mapping = { @@ -39,6 +39,7 @@ subdomain_mapping = { "k8s": k8s, "ghcr": ghcr, "quay": quay, + "nvcr": nvcr, } logging.basicConfig( diff --git a/src/mirrorsrun/sites/docker.py b/src/mirrorsrun/sites/docker.py index d0627ef..ba3c6d3 100644 --- a/src/mirrorsrun/sites/docker.py +++ b/src/mirrorsrun/sites/docker.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) HEADER_AUTH_KEY = "www-authenticate" -service_realm_mapping = {} +mirror_root_realm_mapping = {} # https://github.com/opencontainers/distribution-spec/blob/main/spec.md name_regex = "[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*(/[a-z0-9]+((.|_|__|-+)[a-z0-9]+)*)*" @@ -46,15 +46,14 @@ def patch_auth_realm(request: Request, response: Response): value = value.strip('"') auth_values[key] = value - assert "realm" in auth_values - assert "service" in auth_values - service_realm_mapping[auth_values["service"]] = auth_values["realm"] + realm = auth_values.get("realm", "") + assert realm, f"realm not found in {auth}" - mirror_url = f"{request.url.scheme}://{request.url.netloc}" - new_token_url = mirror_url + "/token" - response.headers[HEADER_AUTH_KEY] = auth.replace( - auth_values["realm"], new_token_url - ) + mirror_root = f"{request.url.scheme}://{request.url.netloc}" + mirror_root_realm_mapping[mirror_root] = realm + + new_token_url = mirror_root + "/token" + response.headers[HEADER_AUTH_KEY] = auth.replace(realm, new_token_url) return response @@ -68,7 +67,6 @@ def build_docker_registry_handler(base_url: str, name_mapper=lambda x: x): scope = params.get("scope", "") service = params.get("service", "") parts = scope.split(":") - assert service assert len(parts) == 3 assert parts[0] == "repository" assert parts[1] # name @@ -77,18 +75,21 @@ def build_docker_registry_handler(base_url: str, name_mapper=lambda x: x): scope = ":".join(parts) - if not scope or not service: + if not scope: return Response(content="Bad Request", status_code=400) new_params = { "scope": scope, - "service": service, } + if service: + new_params["service"] = service + query = "&".join([f"{k}={v}" for k, v in new_params.items()]) - return await direct_proxy( - request, service_realm_mapping[service] + "?" + query - ) + mirror_root = f"{request.url.scheme}://{request.url.netloc}" + realm = mirror_root_realm_mapping[mirror_root] + + return await direct_proxy(request, realm + "?" + query) if path == "/v2/": return await direct_proxy( @@ -138,6 +139,7 @@ quay = build_docker_registry_handler( ghcr = build_docker_registry_handler( "https://ghcr.io", ) +nvcr = build_docker_registry_handler("https://nvcr.io") dockerhub = build_docker_registry_handler( "https://registry-1.docker.io", name_mapper=dockerhub_name_mapper ) diff --git a/test/mirrors_test.py b/test/mirrors_test.py index d472f4d..82397a9 100644 --- a/test/mirrors_test.py +++ b/test/mirrors_test.py @@ -27,3 +27,6 @@ class TestPypi(unittest.TestCase): def test_k8s_pull(self): call(f"docker pull k8s.local.homeinfra.org/pause:3.5") + + def test_nvcr_pull(self): + call(f"docker pull nvcr.local.homeinfra.org/nvidia/cuda")