chore: format

This commit is contained in:
Anonymous 2024-06-07 23:00:03 +08:00
parent 610cbf769c
commit ff41a826f3
8 changed files with 117 additions and 84 deletions

View File

@ -11,18 +11,20 @@ logger = logging.getLogger(__name__)
async def send_request(method, params=None):
payload = {
'jsonrpc': '2.0',
'id': uuid.uuid4().hex,
'method': method,
'params': [f'token:{RPC_SECRET}'] + (params or [])
"jsonrpc": "2.0",
"id": uuid.uuid4().hex,
"method": method,
"params": [f"token:{RPC_SECRET}"] + (params or []),
}
# specify the internal API call don't use proxy
async with httpx.AsyncClient(mounts={
"all://": httpx.AsyncHTTPTransport()
}) as client:
async with httpx.AsyncClient(
mounts={"all://": httpx.AsyncHTTPTransport()}
) as client:
response = await client.post(ARIA2_RPC_URL, json=payload)
logger.info(f"aria2 request: {method} {params} -> {response.status_code} {response.text}")
logger.info(
f"aria2 request: {method} {params} -> {response.status_code} {response.text}"
)
try:
return response.json()
except json.JSONDecodeError as e:
@ -30,38 +32,35 @@ async def send_request(method, params=None):
raise e
async def add_download(url, save_dir='/app/cache'):
method = 'aria2.addUri'
params = [[url],
{'dir': save_dir,
'header': []
}]
async def add_download(url, save_dir="/app/cache"):
method = "aria2.addUri"
params = [[url], {"dir": save_dir, "header": []}]
response = await send_request(method, params)
return response['result']
return response["result"]
async def pause_download(gid):
method = 'aria2.pause'
method = "aria2.pause"
params = [gid]
response = await send_request(method, params)
return response['result']
return response["result"]
async def resume_download(gid):
method = 'aria2.unpause'
method = "aria2.unpause"
params = [gid]
response = await send_request(method, params)
return response['result']
return response["result"]
async def get_status(gid):
method = 'aria2.tellStatus'
method = "aria2.tellStatus"
params = [gid]
response = await send_request(method, params)
return response['result']
return response["result"]
async def list_downloads():
method = 'aria2.tellActive'
method = "aria2.tellActive"
response = await send_request(method)
return response['result']
return response["result"]

View File

@ -1,8 +1,8 @@
import os
ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc')
RPC_SECRET = os.environ.get("RPC_SECRET", '')
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org')
ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://aria2:6800/jsonrpc")
RPC_SECRET = os.environ.get("RPC_SECRET", "")
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org")
SCHEME = os.environ.get("SCHEME", None)
assert SCHEME in ["http", "https"]

View File

View File

@ -9,24 +9,22 @@ from starlette.responses import Response
SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest]
AsyncPreProcessor = Callable[
[Request, HttpxRequest],
Coroutine[Request, HttpxRequest, HttpxRequest]
[Request, HttpxRequest], Coroutine[Request, HttpxRequest, HttpxRequest]
]
SyncPostProcessor = Callable[[Request, Response], Response]
AsyncPostProcessor = Callable[
[Request, Response],
Coroutine[Request, Response, Response]
[Request, Response], Coroutine[Request, Response, Response]
]
async def direct_proxy(
request: Request,
target_url: str,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
cache_ttl: int = 3600,
request: Request,
target_url: str,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
cache_ttl: int = 3600,
) -> Response:
# httpx will use the following environment variables to determine the proxy
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
@ -36,7 +34,11 @@ async def direct_proxy(
if key not in ["user-agent", "accept"]:
del req_headers[key]
httpx_req: HttpxRequest = client.build_request(request.method, target_url, headers=req_headers, )
httpx_req: HttpxRequest = client.build_request(
request.method,
target_url,
headers=req_headers,
)
if pre_process:
new_httpx_req = pre_process(request, httpx_req)
@ -56,7 +58,7 @@ async def direct_proxy(
response = Response(
headers=res_headers,
content=content,
status_code=upstream_response.status_code
status_code=upstream_response.status_code,
)
if post_process:

View File

@ -13,7 +13,7 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI
from aria2_api import add_download
from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY
from config import CACHE_DIR, EXTERNAL_URL_ARIA2
from typing import Optional, Callable
logger = logging.getLogger(__name__)
@ -66,17 +66,17 @@ def make_cached_response(url):
async def get_url_content_length(url):
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client:
async with httpx.AsyncClient() as client:
head_response = await client.head(url)
content_len = (head_response.headers.get("content-length", None))
content_len = head_response.headers.get("content-length", None)
return content_len
async def try_file_based_cache(
request: Request,
target_url: str,
download_wait_time: int = 60,
post_process: Optional[Callable[[Request, Response], Response]] = None,
request: Request,
target_url: str,
download_wait_time: int = 60,
post_process: Optional[Callable[[Request, Response], Response]] = None,
) -> Response:
cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED:
@ -87,21 +87,26 @@ async def try_file_based_cache(
if cache_status == DownloadingStatus.DOWNLOADING:
logger.info(f"Download is not finished, return 503 for {target_url}")
return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT)
return Response(
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)
assert cache_status == DownloadingStatus.NOT_FOUND
cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
print("prepare to download", target_url, cache_file, cache_file_dir)
processed_url = quote(target_url, safe='/:?=&%')
processed_url = quote(target_url, safe="/:?=&%")
try:
await add_download(processed_url, save_dir=cache_file_dir)
except Exception as e:
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e)
return Response(content=f"Failed to add download: {e}", status_code=HTTP_500_INTERNAL_SERVER_ERROR)
return Response(
content=f"Failed to add download: {e}",
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
)
# wait for download finished
for _ in range(download_wait_time):
@ -110,5 +115,7 @@ async def try_file_based_cache(
if cache_status == DownloadingStatus.DOWNLOADED:
return make_cached_response(target_url)
logger.info(f"Download is not finished, return 503 for {target_url}")
return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT)
return Response(
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)

View File

@ -10,7 +10,13 @@ from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from starlette.staticfiles import StaticFiles
from config import BASE_DOMAIN, RPC_SECRET, EXTERNAL_URL_ARIA2, EXTERNAL_HOST_ARIA2, SCHEME
from config import (
BASE_DOMAIN,
RPC_SECRET,
EXTERNAL_URL_ARIA2,
EXTERNAL_HOST_ARIA2,
SCHEME,
)
from sites.docker import docker
from sites.npm import npm
from sites.pypi import pypi
@ -18,7 +24,11 @@ from sites.torch import torch
app = FastAPI()
app.mount("/aria2/", StaticFiles(directory="/wwwroot/"), name="static", )
app.mount(
"/aria2/",
StaticFiles(directory="/wwwroot/"),
name="static",
)
async def aria2(request: Request, call_next):
@ -26,17 +36,24 @@ async def aria2(request: Request, call_next):
return RedirectResponse("/aria2/index.html")
if request.url.path == "/jsonrpc":
# dont use proxy for internal API
async with httpx.AsyncClient(mounts={
"all://": httpx.AsyncHTTPTransport()
}) as client:
data = (await request.body())
response = await client.request(url="http://aria2:6800/jsonrpc",
method=request.method,
headers=request.headers, content=data)
async with httpx.AsyncClient(
mounts={"all://": httpx.AsyncHTTPTransport()}
) as client:
data = await request.body()
response = await client.request(
url="http://aria2:6800/jsonrpc",
method=request.method,
headers=request.headers,
content=data,
)
headers = response.headers
headers.pop("content-length", None)
headers.pop("content-encoding", None)
return Response(content=response.content, status_code=response.status_code, headers=headers)
return Response(
content=response.content,
status_code=response.status_code,
headers=headers,
)
return await call_next(request)
@ -64,7 +81,7 @@ async def capture_request(request: Request, call_next: Callable):
return await call_next(request)
if __name__ == '__main__':
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal.SIG_DFL)
port = 80
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
@ -75,11 +92,11 @@ if __name__ == '__main__':
aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()
params = {
'protocol': SCHEME,
'host': EXTERNAL_HOST_ARIA2,
'port': '443' if SCHEME == 'https' else '80',
'interface': 'jsonrpc',
'secret': aria2_secret
"protocol": SCHEME,
"host": EXTERNAL_HOST_ARIA2,
"port": "443" if SCHEME == "https" else "80",
"interface": "jsonrpc",
"secret": aria2_secret,
}
query_string = urllib.parse.urlencode(params)
@ -88,4 +105,11 @@ if __name__ == '__main__':
print(f"Download manager (Aria2) at {aria2_url_with_auth}")
# FIXME: only proxy headers if SCHEME is https
# reload only in dev mode
uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*")
uvicorn.run(
app="server:app",
host="0.0.0.0",
port=port,
reload=True,
proxy_headers=True,
forwarded_allow_ips="*",
)

View File

@ -13,9 +13,7 @@ from proxy.direct import direct_proxy
BASE_URL = "https://registry-1.docker.io"
cached_token: Dict[str, str] = {
}
cached_token: Dict[str, str] = {}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*"
@ -40,7 +38,6 @@ def try_extract_image_name(path):
def get_docker_token(name):
cached = cached_token.get(name, {})
exp = cached.get("exp", 0)
if exp > time.time():
return cached.get("token", 0)
@ -50,12 +47,13 @@ def get_docker_token(name):
"service": "registry.docker.io",
}
response = httpx.get(url, params=params, verify=False)
client = httpx.Client()
response = client.get(url, params=params)
response.raise_for_status()
token_data = response.json()
token = token_data["token"]
payload = (token.split(".")[1])
payload = token.split(".")[1]
padding = len(payload) % 4
payload += "=" * padding
@ -63,10 +61,7 @@ def get_docker_token(name):
assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0
cached_token[name] = {
"exp": payload["exp"],
"token": token
}
cached_token[name] = {"exp": payload["exp"], "token": token}
return token
@ -100,17 +95,20 @@ async def docker(request: Request):
name, operation, reference = try_extract_image_name(path)
if not name:
return Response(content='404 Not Found', status_code=404)
return Response(content="404 Not Found", status_code=404)
# support docker pull xxx which name without library
if '/' not in name:
if "/" not in name:
name = f"library/{name}"
target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}"
# logger
print('[PARSED]', path, name, operation, reference, target_url)
print("[PARSED]", path, name, operation, reference, target_url)
return await direct_proxy(request, target_url,
pre_process=lambda req, http_req: inject_token(name, req, http_req),
post_process=post_process)
return await direct_proxy(
request,
target_url,
pre_process=lambda req, http_req: inject_token(name, req, http_req),
post_process=post_process,
)

View File

@ -21,4 +21,7 @@ async def torch(request: Request):
if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_file_based_cache(request, target_url)
return await direct_proxy(request, target_url, )
return await direct_proxy(
request,
target_url,
)