mirror of
https://github.com/NoCLin/LightMirrors
synced 2025-06-17 01:19:58 +08:00
feat: support pre and post processors
This commit is contained in:
parent
14ea13da43
commit
610cbf769c
@ -4,7 +4,6 @@ ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc')
|
||||
RPC_SECRET = os.environ.get("RPC_SECRET", '')
|
||||
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org')
|
||||
|
||||
PROXY = os.environ.get("PROXY", None)
|
||||
SCHEME = os.environ.get("SCHEME", None)
|
||||
assert SCHEME in ["http", "https"]
|
||||
|
||||
|
@ -1,47 +1,71 @@
|
||||
import typing
|
||||
from typing import Callable, Coroutine
|
||||
|
||||
import httpx
|
||||
import starlette.requests
|
||||
import starlette.responses
|
||||
from httpx import Request as HttpxRequest
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from config import PROXY
|
||||
SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest]
|
||||
|
||||
AsyncPreProcessor = Callable[
|
||||
[Request, HttpxRequest],
|
||||
Coroutine[Request, HttpxRequest, HttpxRequest]
|
||||
]
|
||||
|
||||
SyncPostProcessor = Callable[[Request, Response], Response]
|
||||
|
||||
AsyncPostProcessor = Callable[
|
||||
[Request, Response],
|
||||
Coroutine[Request, Response, Response]
|
||||
]
|
||||
|
||||
|
||||
async def direct_proxy(
|
||||
request: starlette.requests.Request,
|
||||
request: Request,
|
||||
target_url: str,
|
||||
pre_process: typing.Callable[[starlette.requests.Request, httpx.Request], httpx.Request] = None,
|
||||
post_process: typing.Callable[[starlette.requests.Request, httpx.Response], httpx.Response] = None,
|
||||
) -> typing.Optional[starlette.responses.Response]:
|
||||
async with httpx.AsyncClient(proxy=PROXY, verify=False) as client:
|
||||
|
||||
headers = request.headers.mutablecopy()
|
||||
for key in headers.keys():
|
||||
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
|
||||
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
|
||||
cache_ttl: int = 3600,
|
||||
) -> Response:
|
||||
# httpx will use the following environment variables to determine the proxy
|
||||
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
|
||||
async with httpx.AsyncClient() as client:
|
||||
req_headers = request.headers.mutablecopy()
|
||||
for key in req_headers.keys():
|
||||
if key not in ["user-agent", "accept"]:
|
||||
del headers[key]
|
||||
del req_headers[key]
|
||||
|
||||
httpx_req = client.build_request(request.method, target_url, headers=headers, )
|
||||
httpx_req: HttpxRequest = client.build_request(request.method, target_url, headers=req_headers, )
|
||||
|
||||
if pre_process:
|
||||
httpx_req = pre_process(request, httpx_req)
|
||||
new_httpx_req = pre_process(request, httpx_req)
|
||||
if isinstance(new_httpx_req, HttpxRequest):
|
||||
httpx_req = new_httpx_req
|
||||
else:
|
||||
httpx_req = await new_httpx_req
|
||||
|
||||
upstream_response = await client.send(httpx_req)
|
||||
|
||||
# TODO: move to post_process
|
||||
if upstream_response.status_code == 307:
|
||||
location = upstream_response.headers["location"]
|
||||
print("catch redirect", location)
|
||||
res_headers = upstream_response.headers
|
||||
|
||||
headers = upstream_response.headers
|
||||
cl = headers.pop("content-length", None)
|
||||
ce = headers.pop("content-encoding", None)
|
||||
cl = res_headers.pop("content-length", None)
|
||||
ce = res_headers.pop("content-encoding", None)
|
||||
# print(target_url, cl, ce)
|
||||
content = upstream_response.content
|
||||
response = starlette.responses.Response(
|
||||
headers=headers,
|
||||
response = Response(
|
||||
headers=res_headers,
|
||||
content=content,
|
||||
status_code=upstream_response.status_code)
|
||||
status_code=upstream_response.status_code
|
||||
)
|
||||
|
||||
if post_process:
|
||||
response = post_process(request, response)
|
||||
new_res = post_process(request, response)
|
||||
if isinstance(new_res, Response):
|
||||
final_res = new_res
|
||||
elif isinstance(new_res, Coroutine):
|
||||
final_res = await new_res
|
||||
else:
|
||||
final_res = response
|
||||
|
||||
return response
|
||||
return final_res
|
||||
|
@ -14,16 +14,23 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI
|
||||
from aria2_api import add_download
|
||||
|
||||
from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY
|
||||
from typing import Optional, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]:
|
||||
parsed_url = urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
path = parsed_url.path
|
||||
assert hostname
|
||||
assert path
|
||||
|
||||
base_dir = pathlib.Path(CACHE_DIR)
|
||||
assert parsed_url.path[0] == "/"
|
||||
assert parsed_url.path[-1] != "/"
|
||||
cache_file = (base_dir / parsed_url.hostname / parsed_url.path[1:]).resolve()
|
||||
cache_file = (base_dir / hostname / path[1:]).resolve()
|
||||
|
||||
assert cache_file.is_relative_to(base_dir)
|
||||
|
||||
return str(cache_file), os.path.dirname(cache_file)
|
||||
@ -65,12 +72,12 @@ async def get_url_content_length(url):
|
||||
return content_len
|
||||
|
||||
|
||||
async def try_get_cache(
|
||||
async def try_file_based_cache(
|
||||
request: Request,
|
||||
target_url: str,
|
||||
download_wait_time: int = 60,
|
||||
post_process: typing.Callable[[Request, Response], Response] = None,
|
||||
) -> typing.Optional[Response]:
|
||||
post_process: Optional[Callable[[Request, Response], Response]] = None,
|
||||
) -> Response:
|
||||
cache_status = lookup_cache(target_url)
|
||||
if cache_status == DownloadingStatus.DOWNLOADED:
|
||||
resp = make_cached_response(target_url)
|
||||
@ -88,7 +95,7 @@ async def try_get_cache(
|
||||
cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
|
||||
print("prepare to download", target_url, cache_file, cache_file_dir)
|
||||
|
||||
processed_url = quote(target_url, safe='/:?=&')
|
||||
processed_url = quote(target_url, safe='/:?=&%')
|
||||
|
||||
try:
|
||||
await add_download(processed_url, save_dir=cache_file_dir)
|
@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import signal
|
||||
import urllib.parse
|
||||
from typing import Callable
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
@ -24,6 +25,7 @@ async def aria2(request: Request, call_next):
|
||||
if request.url.path == "/":
|
||||
return RedirectResponse("/aria2/index.html")
|
||||
if request.url.path == "/jsonrpc":
|
||||
# dont use proxy for internal API
|
||||
async with httpx.AsyncClient(mounts={
|
||||
"all://": httpx.AsyncHTTPTransport()
|
||||
}) as client:
|
||||
@ -39,8 +41,11 @@ async def aria2(request: Request, call_next):
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def capture_request(request: Request, call_next: callable):
|
||||
async def capture_request(request: Request, call_next: Callable):
|
||||
hostname = request.url.hostname
|
||||
if not hostname:
|
||||
return Response(content="Bad Request", status_code=400)
|
||||
|
||||
if not hostname.endswith(f".{BASE_DOMAIN}"):
|
||||
return await call_next(request)
|
||||
|
||||
@ -81,6 +86,6 @@ if __name__ == '__main__':
|
||||
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
|
||||
|
||||
print(f"Download manager (Aria2) at {aria2_url_with_auth}")
|
||||
# FIXME: only proxy headers if SCHME is https
|
||||
# FIXME: only proxy headers if SCHEME is https
|
||||
# reload only in dev mode
|
||||
uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*")
|
||||
|
0
mirrors/sites/__init__.py
Normal file
0
mirrors/sites/__init__.py
Normal file
@ -2,16 +2,18 @@ import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import httpx
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from proxy.file_cache import try_file_based_cache
|
||||
from proxy.direct import direct_proxy
|
||||
|
||||
BASE_URL = "https://registry-1.docker.io"
|
||||
|
||||
cached_token = {
|
||||
cached_token: Dict[str, str] = {
|
||||
|
||||
}
|
||||
|
||||
@ -69,6 +71,22 @@ def get_docker_token(name):
|
||||
return token
|
||||
|
||||
|
||||
def inject_token(name: str, req: Request, httpx_req: httpx.Request):
|
||||
docker_token = get_docker_token(f"{name}")
|
||||
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
|
||||
return httpx_req
|
||||
|
||||
|
||||
async def post_process(request: Request, response: Response):
|
||||
if response.status_code == 307:
|
||||
location = response.headers["location"]
|
||||
# TODO: logger
|
||||
print("[redirect]", location)
|
||||
return await try_file_based_cache(request, location)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def docker(request: Request):
|
||||
path = request.url.path
|
||||
print("[request]", request.method, request.url)
|
||||
@ -90,11 +108,9 @@ async def docker(request: Request):
|
||||
|
||||
target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}"
|
||||
|
||||
# logger
|
||||
print('[PARSED]', path, name, operation, reference, target_url)
|
||||
|
||||
def inject_token(req, httpx_req):
|
||||
docker_token = get_docker_token(f"{name}")
|
||||
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
|
||||
return httpx_req
|
||||
|
||||
return await direct_proxy(request, target_url, pre_process=inject_token)
|
||||
return await direct_proxy(request, target_url,
|
||||
pre_process=lambda req, http_req: inject_token(name, req, http_req),
|
||||
post_process=post_process)
|
||||
|
@ -4,7 +4,7 @@ from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from proxy.direct import direct_proxy
|
||||
from proxy.cached import try_get_cache
|
||||
from proxy.file_cache import try_file_based_cache
|
||||
|
||||
pypi_file_base_url = "https://files.pythonhosted.org"
|
||||
pypi_base_url = "https://pypi.org"
|
||||
@ -39,6 +39,6 @@ async def pypi(request: Request) -> Response:
|
||||
return Response(content="Not Found", status_code=404)
|
||||
|
||||
if path.endswith(".whl") or path.endswith(".tar.gz"):
|
||||
return await try_get_cache(request, target_url)
|
||||
return await try_file_based_cache(request, target_url)
|
||||
|
||||
return await direct_proxy(request, target_url, post_process=pypi_replace)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from proxy.cached import try_get_cache
|
||||
from proxy.file_cache import try_file_based_cache
|
||||
from proxy.direct import direct_proxy
|
||||
|
||||
BASE_URL = "https://download.pytorch.org"
|
||||
@ -19,6 +19,6 @@ async def torch(request: Request):
|
||||
target_url = BASE_URL + path
|
||||
|
||||
if path.endswith(".whl") or path.endswith(".tar.gz"):
|
||||
return await try_get_cache(request, target_url)
|
||||
return await try_file_based_cache(request, target_url)
|
||||
|
||||
return await direct_proxy(request, target_url, )
|
||||
|
Loading…
x
Reference in New Issue
Block a user