Compare commits

...

2 Commits

Author SHA1 Message Date
Anonymous
ff41a826f3 chore: format 2024-06-09 15:30:22 +08:00
Anonymous
610cbf769c feat: support pre and post processors 2024-06-09 15:30:22 +08:00
10 changed files with 197 additions and 113 deletions

View File

@ -11,18 +11,20 @@ logger = logging.getLogger(__name__)
async def send_request(method, params=None):
payload = {
'jsonrpc': '2.0',
'id': uuid.uuid4().hex,
'method': method,
'params': [f'token:{RPC_SECRET}'] + (params or [])
"jsonrpc": "2.0",
"id": uuid.uuid4().hex,
"method": method,
"params": [f"token:{RPC_SECRET}"] + (params or []),
}
# specify the internal API call don't use proxy
async with httpx.AsyncClient(mounts={
"all://": httpx.AsyncHTTPTransport()
}) as client:
async with httpx.AsyncClient(
mounts={"all://": httpx.AsyncHTTPTransport()}
) as client:
response = await client.post(ARIA2_RPC_URL, json=payload)
logger.info(f"aria2 request: {method} {params} -> {response.status_code} {response.text}")
logger.info(
f"aria2 request: {method} {params} -> {response.status_code} {response.text}"
)
try:
return response.json()
except json.JSONDecodeError as e:
@ -30,38 +32,35 @@ async def send_request(method, params=None):
raise e
async def add_download(url, save_dir='/app/cache'):
method = 'aria2.addUri'
params = [[url],
{'dir': save_dir,
'header': []
}]
async def add_download(url, save_dir="/app/cache"):
method = "aria2.addUri"
params = [[url], {"dir": save_dir, "header": []}]
response = await send_request(method, params)
return response['result']
return response["result"]
async def pause_download(gid):
method = 'aria2.pause'
method = "aria2.pause"
params = [gid]
response = await send_request(method, params)
return response['result']
return response["result"]
async def resume_download(gid):
method = 'aria2.unpause'
method = "aria2.unpause"
params = [gid]
response = await send_request(method, params)
return response['result']
return response["result"]
async def get_status(gid):
method = 'aria2.tellStatus'
method = "aria2.tellStatus"
params = [gid]
response = await send_request(method, params)
return response['result']
return response["result"]
async def list_downloads():
method = 'aria2.tellActive'
method = "aria2.tellActive"
response = await send_request(method)
return response['result']
return response["result"]

View File

@ -1,10 +1,9 @@
import os
ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc')
RPC_SECRET = os.environ.get("RPC_SECRET", '')
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org')
ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://aria2:6800/jsonrpc")
RPC_SECRET = os.environ.get("RPC_SECRET", "")
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org")
PROXY = os.environ.get("PROXY", None)
SCHEME = os.environ.get("SCHEME", None)
assert SCHEME in ["http", "https"]

View File

View File

@ -1,47 +1,73 @@
import typing
from typing import Callable, Coroutine
import httpx
import starlette.requests
import starlette.responses
from httpx import Request as HttpxRequest
from starlette.requests import Request
from starlette.responses import Response
from config import PROXY
SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest]
AsyncPreProcessor = Callable[
[Request, HttpxRequest], Coroutine[Request, HttpxRequest, HttpxRequest]
]
SyncPostProcessor = Callable[[Request, Response], Response]
AsyncPostProcessor = Callable[
[Request, Response], Coroutine[Request, Response, Response]
]
async def direct_proxy(
request: starlette.requests.Request,
target_url: str,
pre_process: typing.Callable[[starlette.requests.Request, httpx.Request], httpx.Request] = None,
post_process: typing.Callable[[starlette.requests.Request, httpx.Response], httpx.Response] = None,
) -> typing.Optional[starlette.responses.Response]:
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client:
headers = request.headers.mutablecopy()
for key in headers.keys():
request: Request,
target_url: str,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
cache_ttl: int = 3600,
) -> Response:
# httpx will use the following environment variables to determine the proxy
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
async with httpx.AsyncClient() as client:
req_headers = request.headers.mutablecopy()
for key in req_headers.keys():
if key not in ["user-agent", "accept"]:
del headers[key]
del req_headers[key]
httpx_req = client.build_request(request.method, target_url, headers=headers, )
httpx_req: HttpxRequest = client.build_request(
request.method,
target_url,
headers=req_headers,
)
if pre_process:
httpx_req = pre_process(request, httpx_req)
new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
upstream_response = await client.send(httpx_req)
# TODO: move to post_process
if upstream_response.status_code == 307:
location = upstream_response.headers["location"]
print("catch redirect", location)
res_headers = upstream_response.headers
headers = upstream_response.headers
cl = headers.pop("content-length", None)
ce = headers.pop("content-encoding", None)
cl = res_headers.pop("content-length", None)
ce = res_headers.pop("content-encoding", None)
# print(target_url, cl, ce)
content = upstream_response.content
response = starlette.responses.Response(
headers=headers,
response = Response(
headers=res_headers,
content=content,
status_code=upstream_response.status_code)
status_code=upstream_response.status_code,
)
if post_process:
response = post_process(request, response)
new_res = post_process(request, response)
if isinstance(new_res, Response):
final_res = new_res
elif isinstance(new_res, Coroutine):
final_res = await new_res
else:
final_res = response
return response
return final_res

View File

@ -13,17 +13,24 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI
from aria2_api import add_download
from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY
from config import CACHE_DIR, EXTERNAL_URL_ARIA2
from typing import Optional, Callable
logger = logging.getLogger(__name__)
def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]:
parsed_url = urlparse(url)
hostname = parsed_url.hostname
path = parsed_url.path
assert hostname
assert path
base_dir = pathlib.Path(CACHE_DIR)
assert parsed_url.path[0] == "/"
assert parsed_url.path[-1] != "/"
cache_file = (base_dir / parsed_url.hostname / parsed_url.path[1:]).resolve()
cache_file = (base_dir / hostname / path[1:]).resolve()
assert cache_file.is_relative_to(base_dir)
return str(cache_file), os.path.dirname(cache_file)
@ -59,18 +66,18 @@ def make_cached_response(url):
async def get_url_content_length(url):
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client:
async with httpx.AsyncClient() as client:
head_response = await client.head(url)
content_len = (head_response.headers.get("content-length", None))
content_len = head_response.headers.get("content-length", None)
return content_len
async def try_get_cache(
request: Request,
target_url: str,
download_wait_time: int = 60,
post_process: typing.Callable[[Request, Response], Response] = None,
) -> typing.Optional[Response]:
async def try_file_based_cache(
request: Request,
target_url: str,
download_wait_time: int = 60,
post_process: Optional[Callable[[Request, Response], Response]] = None,
) -> Response:
cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED:
resp = make_cached_response(target_url)
@ -80,21 +87,26 @@ async def try_get_cache(
if cache_status == DownloadingStatus.DOWNLOADING:
logger.info(f"Download is not finished, return 503 for {target_url}")
return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT)
return Response(
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)
assert cache_status == DownloadingStatus.NOT_FOUND
cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
print("prepare to download", target_url, cache_file, cache_file_dir)
processed_url = quote(target_url, safe='/:?=&')
processed_url = quote(target_url, safe="/:?=&%")
try:
await add_download(processed_url, save_dir=cache_file_dir)
except Exception as e:
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e)
return Response(content=f"Failed to add download: {e}", status_code=HTTP_500_INTERNAL_SERVER_ERROR)
return Response(
content=f"Failed to add download: {e}",
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
)
# wait for download finished
for _ in range(download_wait_time):
@ -103,5 +115,7 @@ async def try_get_cache(
if cache_status == DownloadingStatus.DOWNLOADED:
return make_cached_response(target_url)
logger.info(f"Download is not finished, return 503 for {target_url}")
return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT)
return Response(
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)

View File

@ -1,6 +1,7 @@
import base64
import signal
import urllib.parse
from typing import Callable
import httpx
import uvicorn
@ -9,7 +10,13 @@ from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from starlette.staticfiles import StaticFiles
from config import BASE_DOMAIN, RPC_SECRET, EXTERNAL_URL_ARIA2, EXTERNAL_HOST_ARIA2, SCHEME
from config import (
BASE_DOMAIN,
RPC_SECRET,
EXTERNAL_URL_ARIA2,
EXTERNAL_HOST_ARIA2,
SCHEME,
)
from sites.docker import docker
from sites.npm import npm
from sites.pypi import pypi
@ -17,30 +24,45 @@ from sites.torch import torch
app = FastAPI()
app.mount("/aria2/", StaticFiles(directory="/wwwroot/"), name="static", )
app.mount(
"/aria2/",
StaticFiles(directory="/wwwroot/"),
name="static",
)
async def aria2(request: Request, call_next):
if request.url.path == "/":
return RedirectResponse("/aria2/index.html")
if request.url.path == "/jsonrpc":
async with httpx.AsyncClient(mounts={
"all://": httpx.AsyncHTTPTransport()
}) as client:
data = (await request.body())
response = await client.request(url="http://aria2:6800/jsonrpc",
method=request.method,
headers=request.headers, content=data)
# dont use proxy for internal API
async with httpx.AsyncClient(
mounts={"all://": httpx.AsyncHTTPTransport()}
) as client:
data = await request.body()
response = await client.request(
url="http://aria2:6800/jsonrpc",
method=request.method,
headers=request.headers,
content=data,
)
headers = response.headers
headers.pop("content-length", None)
headers.pop("content-encoding", None)
return Response(content=response.content, status_code=response.status_code, headers=headers)
return Response(
content=response.content,
status_code=response.status_code,
headers=headers,
)
return await call_next(request)
@app.middleware("http")
async def capture_request(request: Request, call_next: callable):
async def capture_request(request: Request, call_next: Callable):
hostname = request.url.hostname
if not hostname:
return Response(content="Bad Request", status_code=400)
if not hostname.endswith(f".{BASE_DOMAIN}"):
return await call_next(request)
@ -59,7 +81,7 @@ async def capture_request(request: Request, call_next: callable):
return await call_next(request)
if __name__ == '__main__':
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal.SIG_DFL)
port = 80
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
@ -70,17 +92,24 @@ if __name__ == '__main__':
aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()
params = {
'protocol': SCHEME,
'host': EXTERNAL_HOST_ARIA2,
'port': '443' if SCHEME == 'https' else '80',
'interface': 'jsonrpc',
'secret': aria2_secret
"protocol": SCHEME,
"host": EXTERNAL_HOST_ARIA2,
"port": "443" if SCHEME == "https" else "80",
"interface": "jsonrpc",
"secret": aria2_secret,
}
query_string = urllib.parse.urlencode(params)
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
print(f"Download manager (Aria2) at {aria2_url_with_auth}")
# FIXME: only proxy headers if SCHME is https
# FIXME: only proxy headers if SCHEME is https
# reload only in dev mode
uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*")
uvicorn.run(
app="server:app",
host="0.0.0.0",
port=port,
reload=True,
proxy_headers=True,
forwarded_allow_ips="*",
)

View File

View File

@ -2,18 +2,18 @@ import base64
import json
import re
import time
from typing import Dict
import httpx
from starlette.requests import Request
from starlette.responses import Response
from proxy.file_cache import try_file_based_cache
from proxy.direct import direct_proxy
BASE_URL = "https://registry-1.docker.io"
cached_token = {
}
cached_token: Dict[str, str] = {}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*"
@ -38,7 +38,6 @@ def try_extract_image_name(path):
def get_docker_token(name):
cached = cached_token.get(name, {})
exp = cached.get("exp", 0)
if exp > time.time():
return cached.get("token", 0)
@ -48,12 +47,13 @@ def get_docker_token(name):
"service": "registry.docker.io",
}
response = httpx.get(url, params=params, verify=False)
client = httpx.Client()
response = client.get(url, params=params)
response.raise_for_status()
token_data = response.json()
token = token_data["token"]
payload = (token.split(".")[1])
payload = token.split(".")[1]
padding = len(payload) % 4
payload += "=" * padding
@ -61,14 +61,27 @@ def get_docker_token(name):
assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0
cached_token[name] = {
"exp": payload["exp"],
"token": token
}
cached_token[name] = {"exp": payload["exp"], "token": token}
return token
def inject_token(name: str, req: Request, httpx_req: httpx.Request):
docker_token = get_docker_token(f"{name}")
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
return httpx_req
async def post_process(request: Request, response: Response):
if response.status_code == 307:
location = response.headers["location"]
# TODO: logger
print("[redirect]", location)
return await try_file_based_cache(request, location)
return response
async def docker(request: Request):
path = request.url.path
print("[request]", request.method, request.url)
@ -82,19 +95,20 @@ async def docker(request: Request):
name, operation, reference = try_extract_image_name(path)
if not name:
return Response(content='404 Not Found', status_code=404)
return Response(content="404 Not Found", status_code=404)
# support docker pull xxx which name without library
if '/' not in name:
if "/" not in name:
name = f"library/{name}"
target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}"
print('[PARSED]', path, name, operation, reference, target_url)
# logger
print("[PARSED]", path, name, operation, reference, target_url)
def inject_token(req, httpx_req):
docker_token = get_docker_token(f"{name}")
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
return httpx_req
return await direct_proxy(request, target_url, pre_process=inject_token)
return await direct_proxy(
request,
target_url,
pre_process=lambda req, http_req: inject_token(name, req, http_req),
post_process=post_process,
)

View File

@ -4,7 +4,7 @@ from starlette.requests import Request
from starlette.responses import Response
from proxy.direct import direct_proxy
from proxy.cached import try_get_cache
from proxy.file_cache import try_file_based_cache
pypi_file_base_url = "https://files.pythonhosted.org"
pypi_base_url = "https://pypi.org"
@ -39,6 +39,6 @@ async def pypi(request: Request) -> Response:
return Response(content="Not Found", status_code=404)
if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_get_cache(request, target_url)
return await try_file_based_cache(request, target_url)
return await direct_proxy(request, target_url, post_process=pypi_replace)

View File

@ -1,7 +1,7 @@
from starlette.requests import Request
from starlette.responses import Response
from proxy.cached import try_get_cache
from proxy.file_cache import try_file_based_cache
from proxy.direct import direct_proxy
BASE_URL = "https://download.pytorch.org"
@ -19,6 +19,9 @@ async def torch(request: Request):
target_url = BASE_URL + path
if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_get_cache(request, target_url)
return await try_file_based_cache(request, target_url)
return await direct_proxy(request, target_url, )
return await direct_proxy(
request,
target_url,
)