feat: support pre and post processors

This commit is contained in:
Anonymous 2024-06-07 22:33:23 +08:00
parent 14ea13da43
commit 610cbf769c
8 changed files with 96 additions and 45 deletions

View File

@ -4,7 +4,6 @@ ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc')
RPC_SECRET = os.environ.get("RPC_SECRET", '') RPC_SECRET = os.environ.get("RPC_SECRET", '')
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org') BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org')
PROXY = os.environ.get("PROXY", None)
SCHEME = os.environ.get("SCHEME", None) SCHEME = os.environ.get("SCHEME", None)
assert SCHEME in ["http", "https"] assert SCHEME in ["http", "https"]

View File

@ -1,47 +1,71 @@
import typing import typing
from typing import Callable, Coroutine
import httpx import httpx
import starlette.requests from httpx import Request as HttpxRequest
import starlette.responses from starlette.requests import Request
from starlette.responses import Response
from config import PROXY SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest]
AsyncPreProcessor = Callable[
[Request, HttpxRequest],
Coroutine[Request, HttpxRequest, HttpxRequest]
]
SyncPostProcessor = Callable[[Request, Response], Response]
AsyncPostProcessor = Callable[
[Request, Response],
Coroutine[Request, Response, Response]
]
async def direct_proxy( async def direct_proxy(
request: starlette.requests.Request, request: Request,
target_url: str, target_url: str,
pre_process: typing.Callable[[starlette.requests.Request, httpx.Request], httpx.Request] = None, pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Callable[[starlette.requests.Request, httpx.Response], httpx.Response] = None, post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
) -> typing.Optional[starlette.responses.Response]: cache_ttl: int = 3600,
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client: ) -> Response:
# httpx will use the following environment variables to determine the proxy
headers = request.headers.mutablecopy() # https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
for key in headers.keys(): async with httpx.AsyncClient() as client:
req_headers = request.headers.mutablecopy()
for key in req_headers.keys():
if key not in ["user-agent", "accept"]: if key not in ["user-agent", "accept"]:
del headers[key] del req_headers[key]
httpx_req = client.build_request(request.method, target_url, headers=headers, ) httpx_req: HttpxRequest = client.build_request(request.method, target_url, headers=req_headers, )
if pre_process: if pre_process:
httpx_req = pre_process(request, httpx_req) new_httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
upstream_response = await client.send(httpx_req) upstream_response = await client.send(httpx_req)
# TODO: move to post_process res_headers = upstream_response.headers
if upstream_response.status_code == 307:
location = upstream_response.headers["location"]
print("catch redirect", location)
headers = upstream_response.headers cl = res_headers.pop("content-length", None)
cl = headers.pop("content-length", None) ce = res_headers.pop("content-encoding", None)
ce = headers.pop("content-encoding", None)
# print(target_url, cl, ce) # print(target_url, cl, ce)
content = upstream_response.content content = upstream_response.content
response = starlette.responses.Response( response = Response(
headers=headers, headers=res_headers,
content=content, content=content,
status_code=upstream_response.status_code) status_code=upstream_response.status_code
)
if post_process: if post_process:
response = post_process(request, response) new_res = post_process(request, response)
if isinstance(new_res, Response):
final_res = new_res
elif isinstance(new_res, Coroutine):
final_res = await new_res
else:
final_res = response
return response return final_res

View File

@ -14,16 +14,23 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI
from aria2_api import add_download from aria2_api import add_download
from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY
from typing import Optional, Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]: def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]:
parsed_url = urlparse(url) parsed_url = urlparse(url)
hostname = parsed_url.hostname
path = parsed_url.path
assert hostname
assert path
base_dir = pathlib.Path(CACHE_DIR) base_dir = pathlib.Path(CACHE_DIR)
assert parsed_url.path[0] == "/" assert parsed_url.path[0] == "/"
assert parsed_url.path[-1] != "/" assert parsed_url.path[-1] != "/"
cache_file = (base_dir / parsed_url.hostname / parsed_url.path[1:]).resolve() cache_file = (base_dir / hostname / path[1:]).resolve()
assert cache_file.is_relative_to(base_dir) assert cache_file.is_relative_to(base_dir)
return str(cache_file), os.path.dirname(cache_file) return str(cache_file), os.path.dirname(cache_file)
@ -65,12 +72,12 @@ async def get_url_content_length(url):
return content_len return content_len
async def try_get_cache( async def try_file_based_cache(
request: Request, request: Request,
target_url: str, target_url: str,
download_wait_time: int = 60, download_wait_time: int = 60,
post_process: typing.Callable[[Request, Response], Response] = None, post_process: Optional[Callable[[Request, Response], Response]] = None,
) -> typing.Optional[Response]: ) -> Response:
cache_status = lookup_cache(target_url) cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED: if cache_status == DownloadingStatus.DOWNLOADED:
resp = make_cached_response(target_url) resp = make_cached_response(target_url)
@ -88,7 +95,7 @@ async def try_get_cache(
cache_file, cache_file_dir = get_cache_file_and_folder(target_url) cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
print("prepare to download", target_url, cache_file, cache_file_dir) print("prepare to download", target_url, cache_file, cache_file_dir)
processed_url = quote(target_url, safe='/:?=&') processed_url = quote(target_url, safe='/:?=&%')
try: try:
await add_download(processed_url, save_dir=cache_file_dir) await add_download(processed_url, save_dir=cache_file_dir)

View File

@ -1,6 +1,7 @@
import base64 import base64
import signal import signal
import urllib.parse import urllib.parse
from typing import Callable
import httpx import httpx
import uvicorn import uvicorn
@ -24,6 +25,7 @@ async def aria2(request: Request, call_next):
if request.url.path == "/": if request.url.path == "/":
return RedirectResponse("/aria2/index.html") return RedirectResponse("/aria2/index.html")
if request.url.path == "/jsonrpc": if request.url.path == "/jsonrpc":
# dont use proxy for internal API
async with httpx.AsyncClient(mounts={ async with httpx.AsyncClient(mounts={
"all://": httpx.AsyncHTTPTransport() "all://": httpx.AsyncHTTPTransport()
}) as client: }) as client:
@ -39,8 +41,11 @@ async def aria2(request: Request, call_next):
@app.middleware("http") @app.middleware("http")
async def capture_request(request: Request, call_next: callable): async def capture_request(request: Request, call_next: Callable):
hostname = request.url.hostname hostname = request.url.hostname
if not hostname:
return Response(content="Bad Request", status_code=400)
if not hostname.endswith(f".{BASE_DOMAIN}"): if not hostname.endswith(f".{BASE_DOMAIN}"):
return await call_next(request) return await call_next(request)
@ -81,6 +86,6 @@ if __name__ == '__main__':
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
print(f"Download manager (Aria2) at {aria2_url_with_auth}") print(f"Download manager (Aria2) at {aria2_url_with_auth}")
# FIXME: only proxy headers if SCHME is https # FIXME: only proxy headers if SCHEME is https
# reload only in dev mode # reload only in dev mode
uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*") uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*")

View File

View File

@ -2,16 +2,18 @@ import base64
import json import json
import re import re
import time import time
from typing import Dict
import httpx import httpx
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from proxy.file_cache import try_file_based_cache
from proxy.direct import direct_proxy from proxy.direct import direct_proxy
BASE_URL = "https://registry-1.docker.io" BASE_URL = "https://registry-1.docker.io"
cached_token = { cached_token: Dict[str, str] = {
} }
@ -69,6 +71,22 @@ def get_docker_token(name):
return token return token
def inject_token(name: str, req: Request, httpx_req: httpx.Request):
docker_token = get_docker_token(f"{name}")
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
return httpx_req
async def post_process(request: Request, response: Response):
if response.status_code == 307:
location = response.headers["location"]
# TODO: logger
print("[redirect]", location)
return await try_file_based_cache(request, location)
return response
async def docker(request: Request): async def docker(request: Request):
path = request.url.path path = request.url.path
print("[request]", request.method, request.url) print("[request]", request.method, request.url)
@ -90,11 +108,9 @@ async def docker(request: Request):
target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}" target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}"
# logger
print('[PARSED]', path, name, operation, reference, target_url) print('[PARSED]', path, name, operation, reference, target_url)
def inject_token(req, httpx_req): return await direct_proxy(request, target_url,
docker_token = get_docker_token(f"{name}") pre_process=lambda req, http_req: inject_token(name, req, http_req),
httpx_req.headers["Authorization"] = f"Bearer {docker_token}" post_process=post_process)
return httpx_req
return await direct_proxy(request, target_url, pre_process=inject_token)

View File

@ -4,7 +4,7 @@ from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from proxy.direct import direct_proxy from proxy.direct import direct_proxy
from proxy.cached import try_get_cache from proxy.file_cache import try_file_based_cache
pypi_file_base_url = "https://files.pythonhosted.org" pypi_file_base_url = "https://files.pythonhosted.org"
pypi_base_url = "https://pypi.org" pypi_base_url = "https://pypi.org"
@ -39,6 +39,6 @@ async def pypi(request: Request) -> Response:
return Response(content="Not Found", status_code=404) return Response(content="Not Found", status_code=404)
if path.endswith(".whl") or path.endswith(".tar.gz"): if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_get_cache(request, target_url) return await try_file_based_cache(request, target_url)
return await direct_proxy(request, target_url, post_process=pypi_replace) return await direct_proxy(request, target_url, post_process=pypi_replace)

View File

@ -1,7 +1,7 @@
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from proxy.cached import try_get_cache from proxy.file_cache import try_file_based_cache
from proxy.direct import direct_proxy from proxy.direct import direct_proxy
BASE_URL = "https://download.pytorch.org" BASE_URL = "https://download.pytorch.org"
@ -19,6 +19,6 @@ async def torch(request: Request):
target_url = BASE_URL + path target_url = BASE_URL + path
if path.endswith(".whl") or path.endswith(".tar.gz"): if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_get_cache(request, target_url) return await try_file_based_cache(request, target_url)
return await direct_proxy(request, target_url, ) return await direct_proxy(request, target_url, )