mirror of
https://github.com/NoCLin/LightMirrors
synced 2025-08-04 09:12:49 +08:00
Compare commits
No commits in common. "ff41a826f31b1a074bff8757901115649fdd3981" and "14ea13da43a510c4338e95bc97ea36d49fb4bdfe" have entirely different histories.
ff41a826f3
...
14ea13da43
@ -11,20 +11,18 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
async def send_request(method, params=None):
|
async def send_request(method, params=None):
|
||||||
payload = {
|
payload = {
|
||||||
"jsonrpc": "2.0",
|
'jsonrpc': '2.0',
|
||||||
"id": uuid.uuid4().hex,
|
'id': uuid.uuid4().hex,
|
||||||
"method": method,
|
'method': method,
|
||||||
"params": [f"token:{RPC_SECRET}"] + (params or []),
|
'params': [f'token:{RPC_SECRET}'] + (params or [])
|
||||||
}
|
}
|
||||||
|
|
||||||
# specify the internal API call don't use proxy
|
# specify the internal API call don't use proxy
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(mounts={
|
||||||
mounts={"all://": httpx.AsyncHTTPTransport()}
|
"all://": httpx.AsyncHTTPTransport()
|
||||||
) as client:
|
}) as client:
|
||||||
response = await client.post(ARIA2_RPC_URL, json=payload)
|
response = await client.post(ARIA2_RPC_URL, json=payload)
|
||||||
logger.info(
|
logger.info(f"aria2 request: {method} {params} -> {response.status_code} {response.text}")
|
||||||
f"aria2 request: {method} {params} -> {response.status_code} {response.text}"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
return response.json()
|
return response.json()
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
@ -32,35 +30,38 @@ async def send_request(method, params=None):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
async def add_download(url, save_dir="/app/cache"):
|
async def add_download(url, save_dir='/app/cache'):
|
||||||
method = "aria2.addUri"
|
method = 'aria2.addUri'
|
||||||
params = [[url], {"dir": save_dir, "header": []}]
|
params = [[url],
|
||||||
|
{'dir': save_dir,
|
||||||
|
'header': []
|
||||||
|
}]
|
||||||
response = await send_request(method, params)
|
response = await send_request(method, params)
|
||||||
return response["result"]
|
return response['result']
|
||||||
|
|
||||||
|
|
||||||
async def pause_download(gid):
|
async def pause_download(gid):
|
||||||
method = "aria2.pause"
|
method = 'aria2.pause'
|
||||||
params = [gid]
|
params = [gid]
|
||||||
response = await send_request(method, params)
|
response = await send_request(method, params)
|
||||||
return response["result"]
|
return response['result']
|
||||||
|
|
||||||
|
|
||||||
async def resume_download(gid):
|
async def resume_download(gid):
|
||||||
method = "aria2.unpause"
|
method = 'aria2.unpause'
|
||||||
params = [gid]
|
params = [gid]
|
||||||
response = await send_request(method, params)
|
response = await send_request(method, params)
|
||||||
return response["result"]
|
return response['result']
|
||||||
|
|
||||||
|
|
||||||
async def get_status(gid):
|
async def get_status(gid):
|
||||||
method = "aria2.tellStatus"
|
method = 'aria2.tellStatus'
|
||||||
params = [gid]
|
params = [gid]
|
||||||
response = await send_request(method, params)
|
response = await send_request(method, params)
|
||||||
return response["result"]
|
return response['result']
|
||||||
|
|
||||||
|
|
||||||
async def list_downloads():
|
async def list_downloads():
|
||||||
method = "aria2.tellActive"
|
method = 'aria2.tellActive'
|
||||||
response = await send_request(method)
|
response = await send_request(method)
|
||||||
return response["result"]
|
return response['result']
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://aria2:6800/jsonrpc")
|
ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc')
|
||||||
RPC_SECRET = os.environ.get("RPC_SECRET", "")
|
RPC_SECRET = os.environ.get("RPC_SECRET", '')
|
||||||
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org")
|
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org')
|
||||||
|
|
||||||
|
PROXY = os.environ.get("PROXY", None)
|
||||||
SCHEME = os.environ.get("SCHEME", None)
|
SCHEME = os.environ.get("SCHEME", None)
|
||||||
assert SCHEME in ["http", "https"]
|
assert SCHEME in ["http", "https"]
|
||||||
|
|
||||||
|
@ -13,24 +13,17 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI
|
|||||||
|
|
||||||
from aria2_api import add_download
|
from aria2_api import add_download
|
||||||
|
|
||||||
from config import CACHE_DIR, EXTERNAL_URL_ARIA2
|
from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY
|
||||||
from typing import Optional, Callable
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]:
|
def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]:
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
hostname = parsed_url.hostname
|
|
||||||
path = parsed_url.path
|
|
||||||
assert hostname
|
|
||||||
assert path
|
|
||||||
|
|
||||||
base_dir = pathlib.Path(CACHE_DIR)
|
base_dir = pathlib.Path(CACHE_DIR)
|
||||||
assert parsed_url.path[0] == "/"
|
assert parsed_url.path[0] == "/"
|
||||||
assert parsed_url.path[-1] != "/"
|
assert parsed_url.path[-1] != "/"
|
||||||
cache_file = (base_dir / hostname / path[1:]).resolve()
|
cache_file = (base_dir / parsed_url.hostname / parsed_url.path[1:]).resolve()
|
||||||
|
|
||||||
assert cache_file.is_relative_to(base_dir)
|
assert cache_file.is_relative_to(base_dir)
|
||||||
|
|
||||||
return str(cache_file), os.path.dirname(cache_file)
|
return str(cache_file), os.path.dirname(cache_file)
|
||||||
@ -66,18 +59,18 @@ def make_cached_response(url):
|
|||||||
|
|
||||||
|
|
||||||
async def get_url_content_length(url):
|
async def get_url_content_length(url):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client:
|
||||||
head_response = await client.head(url)
|
head_response = await client.head(url)
|
||||||
content_len = head_response.headers.get("content-length", None)
|
content_len = (head_response.headers.get("content-length", None))
|
||||||
return content_len
|
return content_len
|
||||||
|
|
||||||
|
|
||||||
async def try_file_based_cache(
|
async def try_get_cache(
|
||||||
request: Request,
|
request: Request,
|
||||||
target_url: str,
|
target_url: str,
|
||||||
download_wait_time: int = 60,
|
download_wait_time: int = 60,
|
||||||
post_process: Optional[Callable[[Request, Response], Response]] = None,
|
post_process: typing.Callable[[Request, Response], Response] = None,
|
||||||
) -> Response:
|
) -> typing.Optional[Response]:
|
||||||
cache_status = lookup_cache(target_url)
|
cache_status = lookup_cache(target_url)
|
||||||
if cache_status == DownloadingStatus.DOWNLOADED:
|
if cache_status == DownloadingStatus.DOWNLOADED:
|
||||||
resp = make_cached_response(target_url)
|
resp = make_cached_response(target_url)
|
||||||
@ -87,26 +80,21 @@ async def try_file_based_cache(
|
|||||||
|
|
||||||
if cache_status == DownloadingStatus.DOWNLOADING:
|
if cache_status == DownloadingStatus.DOWNLOADING:
|
||||||
logger.info(f"Download is not finished, return 503 for {target_url}")
|
logger.info(f"Download is not finished, return 503 for {target_url}")
|
||||||
return Response(
|
return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
|
||||||
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
|
status_code=HTTP_504_GATEWAY_TIMEOUT)
|
||||||
status_code=HTTP_504_GATEWAY_TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert cache_status == DownloadingStatus.NOT_FOUND
|
assert cache_status == DownloadingStatus.NOT_FOUND
|
||||||
|
|
||||||
cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
|
cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
|
||||||
print("prepare to download", target_url, cache_file, cache_file_dir)
|
print("prepare to download", target_url, cache_file, cache_file_dir)
|
||||||
|
|
||||||
processed_url = quote(target_url, safe="/:?=&%")
|
processed_url = quote(target_url, safe='/:?=&')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await add_download(processed_url, save_dir=cache_file_dir)
|
await add_download(processed_url, save_dir=cache_file_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e)
|
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e)
|
||||||
return Response(
|
return Response(content=f"Failed to add download: {e}", status_code=HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
content=f"Failed to add download: {e}",
|
|
||||||
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
|
|
||||||
# wait for download finished
|
# wait for download finished
|
||||||
for _ in range(download_wait_time):
|
for _ in range(download_wait_time):
|
||||||
@ -115,7 +103,5 @@ async def try_file_based_cache(
|
|||||||
if cache_status == DownloadingStatus.DOWNLOADED:
|
if cache_status == DownloadingStatus.DOWNLOADED:
|
||||||
return make_cached_response(target_url)
|
return make_cached_response(target_url)
|
||||||
logger.info(f"Download is not finished, return 503 for {target_url}")
|
logger.info(f"Download is not finished, return 503 for {target_url}")
|
||||||
return Response(
|
return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
|
||||||
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
|
status_code=HTTP_504_GATEWAY_TIMEOUT)
|
||||||
status_code=HTTP_504_GATEWAY_TIMEOUT,
|
|
||||||
)
|
|
@ -1,73 +1,47 @@
|
|||||||
import typing
|
import typing
|
||||||
from typing import Callable, Coroutine
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from httpx import Request as HttpxRequest
|
import starlette.requests
|
||||||
from starlette.requests import Request
|
import starlette.responses
|
||||||
from starlette.responses import Response
|
|
||||||
|
|
||||||
SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest]
|
from config import PROXY
|
||||||
|
|
||||||
AsyncPreProcessor = Callable[
|
|
||||||
[Request, HttpxRequest], Coroutine[Request, HttpxRequest, HttpxRequest]
|
|
||||||
]
|
|
||||||
|
|
||||||
SyncPostProcessor = Callable[[Request, Response], Response]
|
|
||||||
|
|
||||||
AsyncPostProcessor = Callable[
|
|
||||||
[Request, Response], Coroutine[Request, Response, Response]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def direct_proxy(
|
async def direct_proxy(
|
||||||
request: Request,
|
request: starlette.requests.Request,
|
||||||
target_url: str,
|
target_url: str,
|
||||||
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
|
pre_process: typing.Callable[[starlette.requests.Request, httpx.Request], httpx.Request] = None,
|
||||||
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
|
post_process: typing.Callable[[starlette.requests.Request, httpx.Response], httpx.Response] = None,
|
||||||
cache_ttl: int = 3600,
|
) -> typing.Optional[starlette.responses.Response]:
|
||||||
) -> Response:
|
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client:
|
||||||
# httpx will use the following environment variables to determine the proxy
|
|
||||||
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
req_headers = request.headers.mutablecopy()
|
|
||||||
for key in req_headers.keys():
|
|
||||||
if key not in ["user-agent", "accept"]:
|
|
||||||
del req_headers[key]
|
|
||||||
|
|
||||||
httpx_req: HttpxRequest = client.build_request(
|
headers = request.headers.mutablecopy()
|
||||||
request.method,
|
for key in headers.keys():
|
||||||
target_url,
|
if key not in ["user-agent", "accept"]:
|
||||||
headers=req_headers,
|
del headers[key]
|
||||||
)
|
|
||||||
|
httpx_req = client.build_request(request.method, target_url, headers=headers, )
|
||||||
|
|
||||||
if pre_process:
|
if pre_process:
|
||||||
new_httpx_req = pre_process(request, httpx_req)
|
httpx_req = pre_process(request, httpx_req)
|
||||||
if isinstance(new_httpx_req, HttpxRequest):
|
|
||||||
httpx_req = new_httpx_req
|
|
||||||
else:
|
|
||||||
httpx_req = await new_httpx_req
|
|
||||||
|
|
||||||
upstream_response = await client.send(httpx_req)
|
upstream_response = await client.send(httpx_req)
|
||||||
|
|
||||||
res_headers = upstream_response.headers
|
# TODO: move to post_process
|
||||||
|
if upstream_response.status_code == 307:
|
||||||
|
location = upstream_response.headers["location"]
|
||||||
|
print("catch redirect", location)
|
||||||
|
|
||||||
cl = res_headers.pop("content-length", None)
|
headers = upstream_response.headers
|
||||||
ce = res_headers.pop("content-encoding", None)
|
cl = headers.pop("content-length", None)
|
||||||
|
ce = headers.pop("content-encoding", None)
|
||||||
# print(target_url, cl, ce)
|
# print(target_url, cl, ce)
|
||||||
content = upstream_response.content
|
content = upstream_response.content
|
||||||
response = Response(
|
response = starlette.responses.Response(
|
||||||
headers=res_headers,
|
headers=headers,
|
||||||
content=content,
|
content=content,
|
||||||
status_code=upstream_response.status_code,
|
status_code=upstream_response.status_code)
|
||||||
)
|
|
||||||
|
|
||||||
if post_process:
|
if post_process:
|
||||||
new_res = post_process(request, response)
|
response = post_process(request, response)
|
||||||
if isinstance(new_res, Response):
|
|
||||||
final_res = new_res
|
|
||||||
elif isinstance(new_res, Coroutine):
|
|
||||||
final_res = await new_res
|
|
||||||
else:
|
|
||||||
final_res = response
|
|
||||||
|
|
||||||
return final_res
|
return response
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import signal
|
import signal
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -10,13 +9,7 @@ from starlette.requests import Request
|
|||||||
from starlette.responses import RedirectResponse, Response
|
from starlette.responses import RedirectResponse, Response
|
||||||
from starlette.staticfiles import StaticFiles
|
from starlette.staticfiles import StaticFiles
|
||||||
|
|
||||||
from config import (
|
from config import BASE_DOMAIN, RPC_SECRET, EXTERNAL_URL_ARIA2, EXTERNAL_HOST_ARIA2, SCHEME
|
||||||
BASE_DOMAIN,
|
|
||||||
RPC_SECRET,
|
|
||||||
EXTERNAL_URL_ARIA2,
|
|
||||||
EXTERNAL_HOST_ARIA2,
|
|
||||||
SCHEME,
|
|
||||||
)
|
|
||||||
from sites.docker import docker
|
from sites.docker import docker
|
||||||
from sites.npm import npm
|
from sites.npm import npm
|
||||||
from sites.pypi import pypi
|
from sites.pypi import pypi
|
||||||
@ -24,45 +17,30 @@ from sites.torch import torch
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
app.mount(
|
app.mount("/aria2/", StaticFiles(directory="/wwwroot/"), name="static", )
|
||||||
"/aria2/",
|
|
||||||
StaticFiles(directory="/wwwroot/"),
|
|
||||||
name="static",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def aria2(request: Request, call_next):
|
async def aria2(request: Request, call_next):
|
||||||
if request.url.path == "/":
|
if request.url.path == "/":
|
||||||
return RedirectResponse("/aria2/index.html")
|
return RedirectResponse("/aria2/index.html")
|
||||||
if request.url.path == "/jsonrpc":
|
if request.url.path == "/jsonrpc":
|
||||||
# dont use proxy for internal API
|
async with httpx.AsyncClient(mounts={
|
||||||
async with httpx.AsyncClient(
|
"all://": httpx.AsyncHTTPTransport()
|
||||||
mounts={"all://": httpx.AsyncHTTPTransport()}
|
}) as client:
|
||||||
) as client:
|
data = (await request.body())
|
||||||
data = await request.body()
|
response = await client.request(url="http://aria2:6800/jsonrpc",
|
||||||
response = await client.request(
|
method=request.method,
|
||||||
url="http://aria2:6800/jsonrpc",
|
headers=request.headers, content=data)
|
||||||
method=request.method,
|
|
||||||
headers=request.headers,
|
|
||||||
content=data,
|
|
||||||
)
|
|
||||||
headers = response.headers
|
headers = response.headers
|
||||||
headers.pop("content-length", None)
|
headers.pop("content-length", None)
|
||||||
headers.pop("content-encoding", None)
|
headers.pop("content-encoding", None)
|
||||||
return Response(
|
return Response(content=response.content, status_code=response.status_code, headers=headers)
|
||||||
content=response.content,
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def capture_request(request: Request, call_next: Callable):
|
async def capture_request(request: Request, call_next: callable):
|
||||||
hostname = request.url.hostname
|
hostname = request.url.hostname
|
||||||
if not hostname:
|
|
||||||
return Response(content="Bad Request", status_code=400)
|
|
||||||
|
|
||||||
if not hostname.endswith(f".{BASE_DOMAIN}"):
|
if not hostname.endswith(f".{BASE_DOMAIN}"):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
@ -81,7 +59,7 @@ async def capture_request(request: Request, call_next: Callable):
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||||
port = 80
|
port = 80
|
||||||
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
|
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
|
||||||
@ -92,24 +70,17 @@ if __name__ == "__main__":
|
|||||||
aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()
|
aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"protocol": SCHEME,
|
'protocol': SCHEME,
|
||||||
"host": EXTERNAL_HOST_ARIA2,
|
'host': EXTERNAL_HOST_ARIA2,
|
||||||
"port": "443" if SCHEME == "https" else "80",
|
'port': '443' if SCHEME == 'https' else '80',
|
||||||
"interface": "jsonrpc",
|
'interface': 'jsonrpc',
|
||||||
"secret": aria2_secret,
|
'secret': aria2_secret
|
||||||
}
|
}
|
||||||
|
|
||||||
query_string = urllib.parse.urlencode(params)
|
query_string = urllib.parse.urlencode(params)
|
||||||
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
|
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
|
||||||
|
|
||||||
print(f"Download manager (Aria2) at {aria2_url_with_auth}")
|
print(f"Download manager (Aria2) at {aria2_url_with_auth}")
|
||||||
# FIXME: only proxy headers if SCHEME is https
|
# FIXME: only proxy headers if SCHME is https
|
||||||
# reload only in dev mode
|
# reload only in dev mode
|
||||||
uvicorn.run(
|
uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*")
|
||||||
app="server:app",
|
|
||||||
host="0.0.0.0",
|
|
||||||
port=port,
|
|
||||||
reload=True,
|
|
||||||
proxy_headers=True,
|
|
||||||
forwarded_allow_ips="*",
|
|
||||||
)
|
|
||||||
|
@ -2,18 +2,18 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
from proxy.file_cache import try_file_based_cache
|
|
||||||
from proxy.direct import direct_proxy
|
from proxy.direct import direct_proxy
|
||||||
|
|
||||||
BASE_URL = "https://registry-1.docker.io"
|
BASE_URL = "https://registry-1.docker.io"
|
||||||
|
|
||||||
cached_token: Dict[str, str] = {}
|
cached_token = {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
|
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
|
||||||
name_regex = "[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*"
|
name_regex = "[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*"
|
||||||
@ -38,6 +38,7 @@ def try_extract_image_name(path):
|
|||||||
def get_docker_token(name):
|
def get_docker_token(name):
|
||||||
cached = cached_token.get(name, {})
|
cached = cached_token.get(name, {})
|
||||||
exp = cached.get("exp", 0)
|
exp = cached.get("exp", 0)
|
||||||
|
|
||||||
if exp > time.time():
|
if exp > time.time():
|
||||||
return cached.get("token", 0)
|
return cached.get("token", 0)
|
||||||
|
|
||||||
@ -47,13 +48,12 @@ def get_docker_token(name):
|
|||||||
"service": "registry.docker.io",
|
"service": "registry.docker.io",
|
||||||
}
|
}
|
||||||
|
|
||||||
client = httpx.Client()
|
response = httpx.get(url, params=params, verify=False)
|
||||||
response = client.get(url, params=params)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
token_data = response.json()
|
token_data = response.json()
|
||||||
token = token_data["token"]
|
token = token_data["token"]
|
||||||
payload = token.split(".")[1]
|
payload = (token.split(".")[1])
|
||||||
padding = len(payload) % 4
|
padding = len(payload) % 4
|
||||||
payload += "=" * padding
|
payload += "=" * padding
|
||||||
|
|
||||||
@ -61,27 +61,14 @@ def get_docker_token(name):
|
|||||||
assert payload["iss"] == "auth.docker.io"
|
assert payload["iss"] == "auth.docker.io"
|
||||||
assert len(payload["access"]) > 0
|
assert len(payload["access"]) > 0
|
||||||
|
|
||||||
cached_token[name] = {"exp": payload["exp"], "token": token}
|
cached_token[name] = {
|
||||||
|
"exp": payload["exp"],
|
||||||
|
"token": token
|
||||||
|
}
|
||||||
|
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
def inject_token(name: str, req: Request, httpx_req: httpx.Request):
|
|
||||||
docker_token = get_docker_token(f"{name}")
|
|
||||||
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
|
|
||||||
return httpx_req
|
|
||||||
|
|
||||||
|
|
||||||
async def post_process(request: Request, response: Response):
|
|
||||||
if response.status_code == 307:
|
|
||||||
location = response.headers["location"]
|
|
||||||
# TODO: logger
|
|
||||||
print("[redirect]", location)
|
|
||||||
return await try_file_based_cache(request, location)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def docker(request: Request):
|
async def docker(request: Request):
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
print("[request]", request.method, request.url)
|
print("[request]", request.method, request.url)
|
||||||
@ -95,20 +82,19 @@ async def docker(request: Request):
|
|||||||
name, operation, reference = try_extract_image_name(path)
|
name, operation, reference = try_extract_image_name(path)
|
||||||
|
|
||||||
if not name:
|
if not name:
|
||||||
return Response(content="404 Not Found", status_code=404)
|
return Response(content='404 Not Found', status_code=404)
|
||||||
|
|
||||||
# support docker pull xxx which name without library
|
# support docker pull xxx which name without library
|
||||||
if "/" not in name:
|
if '/' not in name:
|
||||||
name = f"library/{name}"
|
name = f"library/{name}"
|
||||||
|
|
||||||
target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}"
|
target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}"
|
||||||
|
|
||||||
# logger
|
print('[PARSED]', path, name, operation, reference, target_url)
|
||||||
print("[PARSED]", path, name, operation, reference, target_url)
|
|
||||||
|
|
||||||
return await direct_proxy(
|
def inject_token(req, httpx_req):
|
||||||
request,
|
docker_token = get_docker_token(f"{name}")
|
||||||
target_url,
|
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
|
||||||
pre_process=lambda req, http_req: inject_token(name, req, http_req),
|
return httpx_req
|
||||||
post_process=post_process,
|
|
||||||
)
|
return await direct_proxy(request, target_url, pre_process=inject_token)
|
||||||
|
@ -4,7 +4,7 @@ from starlette.requests import Request
|
|||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
from proxy.direct import direct_proxy
|
from proxy.direct import direct_proxy
|
||||||
from proxy.file_cache import try_file_based_cache
|
from proxy.cached import try_get_cache
|
||||||
|
|
||||||
pypi_file_base_url = "https://files.pythonhosted.org"
|
pypi_file_base_url = "https://files.pythonhosted.org"
|
||||||
pypi_base_url = "https://pypi.org"
|
pypi_base_url = "https://pypi.org"
|
||||||
@ -39,6 +39,6 @@ async def pypi(request: Request) -> Response:
|
|||||||
return Response(content="Not Found", status_code=404)
|
return Response(content="Not Found", status_code=404)
|
||||||
|
|
||||||
if path.endswith(".whl") or path.endswith(".tar.gz"):
|
if path.endswith(".whl") or path.endswith(".tar.gz"):
|
||||||
return await try_file_based_cache(request, target_url)
|
return await try_get_cache(request, target_url)
|
||||||
|
|
||||||
return await direct_proxy(request, target_url, post_process=pypi_replace)
|
return await direct_proxy(request, target_url, post_process=pypi_replace)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
from proxy.file_cache import try_file_based_cache
|
from proxy.cached import try_get_cache
|
||||||
from proxy.direct import direct_proxy
|
from proxy.direct import direct_proxy
|
||||||
|
|
||||||
BASE_URL = "https://download.pytorch.org"
|
BASE_URL = "https://download.pytorch.org"
|
||||||
@ -19,9 +19,6 @@ async def torch(request: Request):
|
|||||||
target_url = BASE_URL + path
|
target_url = BASE_URL + path
|
||||||
|
|
||||||
if path.endswith(".whl") or path.endswith(".tar.gz"):
|
if path.endswith(".whl") or path.endswith(".tar.gz"):
|
||||||
return await try_file_based_cache(request, target_url)
|
return await try_get_cache(request, target_url)
|
||||||
|
|
||||||
return await direct_proxy(
|
return await direct_proxy(request, target_url, )
|
||||||
request,
|
|
||||||
target_url,
|
|
||||||
)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user