Compare commits

..

2 Commits

Author SHA1 Message Date
Anonymous
ff41a826f3 chore: format 2024-06-09 15:30:22 +08:00
Anonymous
610cbf769c feat: support pre and post processors 2024-06-09 15:30:22 +08:00
10 changed files with 197 additions and 113 deletions

View File

@ -11,18 +11,20 @@ logger = logging.getLogger(__name__)
async def send_request(method, params=None): async def send_request(method, params=None):
payload = { payload = {
'jsonrpc': '2.0', "jsonrpc": "2.0",
'id': uuid.uuid4().hex, "id": uuid.uuid4().hex,
'method': method, "method": method,
'params': [f'token:{RPC_SECRET}'] + (params or []) "params": [f"token:{RPC_SECRET}"] + (params or []),
} }
# specify the internal API call don't use proxy # specify the internal API call don't use proxy
async with httpx.AsyncClient(mounts={ async with httpx.AsyncClient(
"all://": httpx.AsyncHTTPTransport() mounts={"all://": httpx.AsyncHTTPTransport()}
}) as client: ) as client:
response = await client.post(ARIA2_RPC_URL, json=payload) response = await client.post(ARIA2_RPC_URL, json=payload)
logger.info(f"aria2 request: {method} {params} -> {response.status_code} {response.text}") logger.info(
f"aria2 request: {method} {params} -> {response.status_code} {response.text}"
)
try: try:
return response.json() return response.json()
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
@ -30,38 +32,35 @@ async def send_request(method, params=None):
raise e raise e
async def add_download(url, save_dir='/app/cache'): async def add_download(url, save_dir="/app/cache"):
method = 'aria2.addUri' method = "aria2.addUri"
params = [[url], params = [[url], {"dir": save_dir, "header": []}]
{'dir': save_dir,
'header': []
}]
response = await send_request(method, params) response = await send_request(method, params)
return response['result'] return response["result"]
async def pause_download(gid): async def pause_download(gid):
method = 'aria2.pause' method = "aria2.pause"
params = [gid] params = [gid]
response = await send_request(method, params) response = await send_request(method, params)
return response['result'] return response["result"]
async def resume_download(gid): async def resume_download(gid):
method = 'aria2.unpause' method = "aria2.unpause"
params = [gid] params = [gid]
response = await send_request(method, params) response = await send_request(method, params)
return response['result'] return response["result"]
async def get_status(gid): async def get_status(gid):
method = 'aria2.tellStatus' method = "aria2.tellStatus"
params = [gid] params = [gid]
response = await send_request(method, params) response = await send_request(method, params)
return response['result'] return response["result"]
async def list_downloads(): async def list_downloads():
method = 'aria2.tellActive' method = "aria2.tellActive"
response = await send_request(method) response = await send_request(method)
return response['result'] return response["result"]

View File

@ -1,10 +1,9 @@
import os import os
ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc') ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://aria2:6800/jsonrpc")
RPC_SECRET = os.environ.get("RPC_SECRET", '') RPC_SECRET = os.environ.get("RPC_SECRET", "")
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org') BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org")
PROXY = os.environ.get("PROXY", None)
SCHEME = os.environ.get("SCHEME", None) SCHEME = os.environ.get("SCHEME", None)
assert SCHEME in ["http", "https"] assert SCHEME in ["http", "https"]

View File

View File

@ -1,47 +1,73 @@
import typing import typing
from typing import Callable, Coroutine
import httpx import httpx
import starlette.requests from httpx import Request as HttpxRequest
import starlette.responses from starlette.requests import Request
from starlette.responses import Response
from config import PROXY SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest]
AsyncPreProcessor = Callable[
[Request, HttpxRequest], Coroutine[Request, HttpxRequest, HttpxRequest]
]
SyncPostProcessor = Callable[[Request, Response], Response]
AsyncPostProcessor = Callable[
[Request, Response], Coroutine[Request, Response, Response]
]
async def direct_proxy( async def direct_proxy(
request: starlette.requests.Request, request: Request,
target_url: str, target_url: str,
pre_process: typing.Callable[[starlette.requests.Request, httpx.Request], httpx.Request] = None, pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Callable[[starlette.requests.Request, httpx.Response], httpx.Response] = None, post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
) -> typing.Optional[starlette.responses.Response]: cache_ttl: int = 3600,
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client: ) -> Response:
# httpx will use the following environment variables to determine the proxy
headers = request.headers.mutablecopy() # https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
for key in headers.keys(): async with httpx.AsyncClient() as client:
req_headers = request.headers.mutablecopy()
for key in req_headers.keys():
if key not in ["user-agent", "accept"]: if key not in ["user-agent", "accept"]:
del headers[key] del req_headers[key]
httpx_req = client.build_request(request.method, target_url, headers=headers, ) httpx_req: HttpxRequest = client.build_request(
request.method,
target_url,
headers=req_headers,
)
if pre_process: if pre_process:
httpx_req = pre_process(request, httpx_req) new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
upstream_response = await client.send(httpx_req) upstream_response = await client.send(httpx_req)
# TODO: move to post_process res_headers = upstream_response.headers
if upstream_response.status_code == 307:
location = upstream_response.headers["location"]
print("catch redirect", location)
headers = upstream_response.headers cl = res_headers.pop("content-length", None)
cl = headers.pop("content-length", None) ce = res_headers.pop("content-encoding", None)
ce = headers.pop("content-encoding", None)
# print(target_url, cl, ce) # print(target_url, cl, ce)
content = upstream_response.content content = upstream_response.content
response = starlette.responses.Response( response = Response(
headers=headers, headers=res_headers,
content=content, content=content,
status_code=upstream_response.status_code) status_code=upstream_response.status_code,
)
if post_process: if post_process:
response = post_process(request, response) new_res = post_process(request, response)
if isinstance(new_res, Response):
final_res = new_res
elif isinstance(new_res, Coroutine):
final_res = await new_res
else:
final_res = response
return response return final_res

View File

@ -13,17 +13,24 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI
from aria2_api import add_download from aria2_api import add_download
from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY from config import CACHE_DIR, EXTERNAL_URL_ARIA2
from typing import Optional, Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]: def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]:
parsed_url = urlparse(url) parsed_url = urlparse(url)
hostname = parsed_url.hostname
path = parsed_url.path
assert hostname
assert path
base_dir = pathlib.Path(CACHE_DIR) base_dir = pathlib.Path(CACHE_DIR)
assert parsed_url.path[0] == "/" assert parsed_url.path[0] == "/"
assert parsed_url.path[-1] != "/" assert parsed_url.path[-1] != "/"
cache_file = (base_dir / parsed_url.hostname / parsed_url.path[1:]).resolve() cache_file = (base_dir / hostname / path[1:]).resolve()
assert cache_file.is_relative_to(base_dir) assert cache_file.is_relative_to(base_dir)
return str(cache_file), os.path.dirname(cache_file) return str(cache_file), os.path.dirname(cache_file)
@ -59,18 +66,18 @@ def make_cached_response(url):
async def get_url_content_length(url): async def get_url_content_length(url):
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client: async with httpx.AsyncClient() as client:
head_response = await client.head(url) head_response = await client.head(url)
content_len = (head_response.headers.get("content-length", None)) content_len = head_response.headers.get("content-length", None)
return content_len return content_len
async def try_get_cache( async def try_file_based_cache(
request: Request, request: Request,
target_url: str, target_url: str,
download_wait_time: int = 60, download_wait_time: int = 60,
post_process: typing.Callable[[Request, Response], Response] = None, post_process: Optional[Callable[[Request, Response], Response]] = None,
) -> typing.Optional[Response]: ) -> Response:
cache_status = lookup_cache(target_url) cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED: if cache_status == DownloadingStatus.DOWNLOADED:
resp = make_cached_response(target_url) resp = make_cached_response(target_url)
@ -80,21 +87,26 @@ async def try_get_cache(
if cache_status == DownloadingStatus.DOWNLOADING: if cache_status == DownloadingStatus.DOWNLOADING:
logger.info(f"Download is not finished, return 503 for {target_url}") logger.info(f"Download is not finished, return 503 for {target_url}")
return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", return Response(
status_code=HTTP_504_GATEWAY_TIMEOUT) content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)
assert cache_status == DownloadingStatus.NOT_FOUND assert cache_status == DownloadingStatus.NOT_FOUND
cache_file, cache_file_dir = get_cache_file_and_folder(target_url) cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
print("prepare to download", target_url, cache_file, cache_file_dir) print("prepare to download", target_url, cache_file, cache_file_dir)
processed_url = quote(target_url, safe='/:?=&') processed_url = quote(target_url, safe="/:?=&%")
try: try:
await add_download(processed_url, save_dir=cache_file_dir) await add_download(processed_url, save_dir=cache_file_dir)
except Exception as e: except Exception as e:
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e) logger.error(f"Download error, return 503500 for {target_url}", exc_info=e)
return Response(content=f"Failed to add download: {e}", status_code=HTTP_500_INTERNAL_SERVER_ERROR) return Response(
content=f"Failed to add download: {e}",
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
)
# wait for download finished # wait for download finished
for _ in range(download_wait_time): for _ in range(download_wait_time):
@ -103,5 +115,7 @@ async def try_get_cache(
if cache_status == DownloadingStatus.DOWNLOADED: if cache_status == DownloadingStatus.DOWNLOADED:
return make_cached_response(target_url) return make_cached_response(target_url)
logger.info(f"Download is not finished, return 503 for {target_url}") logger.info(f"Download is not finished, return 503 for {target_url}")
return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", return Response(
status_code=HTTP_504_GATEWAY_TIMEOUT) content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)

View File

@ -1,6 +1,7 @@
import base64 import base64
import signal import signal
import urllib.parse import urllib.parse
from typing import Callable
import httpx import httpx
import uvicorn import uvicorn
@ -9,7 +10,13 @@ from starlette.requests import Request
from starlette.responses import RedirectResponse, Response from starlette.responses import RedirectResponse, Response
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
from config import BASE_DOMAIN, RPC_SECRET, EXTERNAL_URL_ARIA2, EXTERNAL_HOST_ARIA2, SCHEME from config import (
BASE_DOMAIN,
RPC_SECRET,
EXTERNAL_URL_ARIA2,
EXTERNAL_HOST_ARIA2,
SCHEME,
)
from sites.docker import docker from sites.docker import docker
from sites.npm import npm from sites.npm import npm
from sites.pypi import pypi from sites.pypi import pypi
@ -17,30 +24,45 @@ from sites.torch import torch
app = FastAPI() app = FastAPI()
app.mount("/aria2/", StaticFiles(directory="/wwwroot/"), name="static", ) app.mount(
"/aria2/",
StaticFiles(directory="/wwwroot/"),
name="static",
)
async def aria2(request: Request, call_next): async def aria2(request: Request, call_next):
if request.url.path == "/": if request.url.path == "/":
return RedirectResponse("/aria2/index.html") return RedirectResponse("/aria2/index.html")
if request.url.path == "/jsonrpc": if request.url.path == "/jsonrpc":
async with httpx.AsyncClient(mounts={ # dont use proxy for internal API
"all://": httpx.AsyncHTTPTransport() async with httpx.AsyncClient(
}) as client: mounts={"all://": httpx.AsyncHTTPTransport()}
data = (await request.body()) ) as client:
response = await client.request(url="http://aria2:6800/jsonrpc", data = await request.body()
response = await client.request(
url="http://aria2:6800/jsonrpc",
method=request.method, method=request.method,
headers=request.headers, content=data) headers=request.headers,
content=data,
)
headers = response.headers headers = response.headers
headers.pop("content-length", None) headers.pop("content-length", None)
headers.pop("content-encoding", None) headers.pop("content-encoding", None)
return Response(content=response.content, status_code=response.status_code, headers=headers) return Response(
content=response.content,
status_code=response.status_code,
headers=headers,
)
return await call_next(request) return await call_next(request)
@app.middleware("http") @app.middleware("http")
async def capture_request(request: Request, call_next: callable): async def capture_request(request: Request, call_next: Callable):
hostname = request.url.hostname hostname = request.url.hostname
if not hostname:
return Response(content="Bad Request", status_code=400)
if not hostname.endswith(f".{BASE_DOMAIN}"): if not hostname.endswith(f".{BASE_DOMAIN}"):
return await call_next(request) return await call_next(request)
@ -59,7 +81,7 @@ async def capture_request(request: Request, call_next: callable):
return await call_next(request) return await call_next(request)
if __name__ == '__main__': if __name__ == "__main__":
signal.signal(signal.SIGINT, signal.SIG_DFL) signal.signal(signal.SIGINT, signal.SIG_DFL)
port = 80 port = 80
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})") print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
@ -70,17 +92,24 @@ if __name__ == '__main__':
aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode() aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()
params = { params = {
'protocol': SCHEME, "protocol": SCHEME,
'host': EXTERNAL_HOST_ARIA2, "host": EXTERNAL_HOST_ARIA2,
'port': '443' if SCHEME == 'https' else '80', "port": "443" if SCHEME == "https" else "80",
'interface': 'jsonrpc', "interface": "jsonrpc",
'secret': aria2_secret "secret": aria2_secret,
} }
query_string = urllib.parse.urlencode(params) query_string = urllib.parse.urlencode(params)
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
print(f"Download manager (Aria2) at {aria2_url_with_auth}") print(f"Download manager (Aria2) at {aria2_url_with_auth}")
# FIXME: only proxy headers if SCHME is https # FIXME: only proxy headers if SCHEME is https
# reload only in dev mode # reload only in dev mode
uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*") uvicorn.run(
app="server:app",
host="0.0.0.0",
port=port,
reload=True,
proxy_headers=True,
forwarded_allow_ips="*",
)

View File

View File

@ -2,18 +2,18 @@ import base64
import json import json
import re import re
import time import time
from typing import Dict
import httpx import httpx
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from proxy.file_cache import try_file_based_cache
from proxy.direct import direct_proxy from proxy.direct import direct_proxy
BASE_URL = "https://registry-1.docker.io" BASE_URL = "https://registry-1.docker.io"
cached_token = { cached_token: Dict[str, str] = {}
}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md # https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*" name_regex = "[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*"
@ -38,7 +38,6 @@ def try_extract_image_name(path):
def get_docker_token(name): def get_docker_token(name):
cached = cached_token.get(name, {}) cached = cached_token.get(name, {})
exp = cached.get("exp", 0) exp = cached.get("exp", 0)
if exp > time.time(): if exp > time.time():
return cached.get("token", 0) return cached.get("token", 0)
@ -48,12 +47,13 @@ def get_docker_token(name):
"service": "registry.docker.io", "service": "registry.docker.io",
} }
response = httpx.get(url, params=params, verify=False) client = httpx.Client()
response = client.get(url, params=params)
response.raise_for_status() response.raise_for_status()
token_data = response.json() token_data = response.json()
token = token_data["token"] token = token_data["token"]
payload = (token.split(".")[1]) payload = token.split(".")[1]
padding = len(payload) % 4 padding = len(payload) % 4
payload += "=" * padding payload += "=" * padding
@ -61,14 +61,27 @@ def get_docker_token(name):
assert payload["iss"] == "auth.docker.io" assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0 assert len(payload["access"]) > 0
cached_token[name] = { cached_token[name] = {"exp": payload["exp"], "token": token}
"exp": payload["exp"],
"token": token
}
return token return token
def inject_token(name: str, req: Request, httpx_req: httpx.Request):
docker_token = get_docker_token(f"{name}")
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
return httpx_req
async def post_process(request: Request, response: Response):
if response.status_code == 307:
location = response.headers["location"]
# TODO: logger
print("[redirect]", location)
return await try_file_based_cache(request, location)
return response
async def docker(request: Request): async def docker(request: Request):
path = request.url.path path = request.url.path
print("[request]", request.method, request.url) print("[request]", request.method, request.url)
@ -82,19 +95,20 @@ async def docker(request: Request):
name, operation, reference = try_extract_image_name(path) name, operation, reference = try_extract_image_name(path)
if not name: if not name:
return Response(content='404 Not Found', status_code=404) return Response(content="404 Not Found", status_code=404)
# support docker pull xxx which name without library # support docker pull xxx which name without library
if '/' not in name: if "/" not in name:
name = f"library/{name}" name = f"library/{name}"
target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}" target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}"
print('[PARSED]', path, name, operation, reference, target_url) # logger
print("[PARSED]", path, name, operation, reference, target_url)
def inject_token(req, httpx_req): return await direct_proxy(
docker_token = get_docker_token(f"{name}") request,
httpx_req.headers["Authorization"] = f"Bearer {docker_token}" target_url,
return httpx_req pre_process=lambda req, http_req: inject_token(name, req, http_req),
post_process=post_process,
return await direct_proxy(request, target_url, pre_process=inject_token) )

View File

@ -4,7 +4,7 @@ from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from proxy.direct import direct_proxy from proxy.direct import direct_proxy
from proxy.cached import try_get_cache from proxy.file_cache import try_file_based_cache
pypi_file_base_url = "https://files.pythonhosted.org" pypi_file_base_url = "https://files.pythonhosted.org"
pypi_base_url = "https://pypi.org" pypi_base_url = "https://pypi.org"
@ -39,6 +39,6 @@ async def pypi(request: Request) -> Response:
return Response(content="Not Found", status_code=404) return Response(content="Not Found", status_code=404)
if path.endswith(".whl") or path.endswith(".tar.gz"): if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_get_cache(request, target_url) return await try_file_based_cache(request, target_url)
return await direct_proxy(request, target_url, post_process=pypi_replace) return await direct_proxy(request, target_url, post_process=pypi_replace)

View File

@ -1,7 +1,7 @@
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from proxy.cached import try_get_cache from proxy.file_cache import try_file_based_cache
from proxy.direct import direct_proxy from proxy.direct import direct_proxy
BASE_URL = "https://download.pytorch.org" BASE_URL = "https://download.pytorch.org"
@ -19,6 +19,9 @@ async def torch(request: Request):
target_url = BASE_URL + path target_url = BASE_URL + path
if path.endswith(".whl") or path.endswith(".tar.gz"): if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_get_cache(request, target_url) return await try_file_based_cache(request, target_url)
return await direct_proxy(request, target_url, ) return await direct_proxy(
request,
target_url,
)