Compare commits

..

No commits in common. "ff41a826f31b1a074bff8757901115649fdd3981" and "14ea13da43a510c4338e95bc97ea36d49fb4bdfe" have entirely different histories.

10 changed files with 114 additions and 198 deletions

View File

@ -11,20 +11,18 @@ logger = logging.getLogger(__name__)
async def send_request(method, params=None): async def send_request(method, params=None):
payload = { payload = {
"jsonrpc": "2.0", 'jsonrpc': '2.0',
"id": uuid.uuid4().hex, 'id': uuid.uuid4().hex,
"method": method, 'method': method,
"params": [f"token:{RPC_SECRET}"] + (params or []), 'params': [f'token:{RPC_SECRET}'] + (params or [])
} }
# specify the internal API call don't use proxy # specify the internal API call don't use proxy
async with httpx.AsyncClient( async with httpx.AsyncClient(mounts={
mounts={"all://": httpx.AsyncHTTPTransport()} "all://": httpx.AsyncHTTPTransport()
) as client: }) as client:
response = await client.post(ARIA2_RPC_URL, json=payload) response = await client.post(ARIA2_RPC_URL, json=payload)
logger.info( logger.info(f"aria2 request: {method} {params} -> {response.status_code} {response.text}")
f"aria2 request: {method} {params} -> {response.status_code} {response.text}"
)
try: try:
return response.json() return response.json()
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
@ -32,35 +30,38 @@ async def send_request(method, params=None):
raise e raise e
async def add_download(url, save_dir="/app/cache"): async def add_download(url, save_dir='/app/cache'):
method = "aria2.addUri" method = 'aria2.addUri'
params = [[url], {"dir": save_dir, "header": []}] params = [[url],
{'dir': save_dir,
'header': []
}]
response = await send_request(method, params) response = await send_request(method, params)
return response["result"] return response['result']
async def pause_download(gid): async def pause_download(gid):
method = "aria2.pause" method = 'aria2.pause'
params = [gid] params = [gid]
response = await send_request(method, params) response = await send_request(method, params)
return response["result"] return response['result']
async def resume_download(gid): async def resume_download(gid):
method = "aria2.unpause" method = 'aria2.unpause'
params = [gid] params = [gid]
response = await send_request(method, params) response = await send_request(method, params)
return response["result"] return response['result']
async def get_status(gid): async def get_status(gid):
method = "aria2.tellStatus" method = 'aria2.tellStatus'
params = [gid] params = [gid]
response = await send_request(method, params) response = await send_request(method, params)
return response["result"] return response['result']
async def list_downloads(): async def list_downloads():
method = "aria2.tellActive" method = 'aria2.tellActive'
response = await send_request(method) response = await send_request(method)
return response["result"] return response['result']

View File

@ -1,9 +1,10 @@
import os import os
ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://aria2:6800/jsonrpc") ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", 'http://aria2:6800/jsonrpc')
RPC_SECRET = os.environ.get("RPC_SECRET", "") RPC_SECRET = os.environ.get("RPC_SECRET", '')
BASE_DOMAIN = os.environ.get("BASE_DOMAIN", "local.homeinfra.org") BASE_DOMAIN = os.environ.get("BASE_DOMAIN", 'local.homeinfra.org')
PROXY = os.environ.get("PROXY", None)
SCHEME = os.environ.get("SCHEME", None) SCHEME = os.environ.get("SCHEME", None)
assert SCHEME in ["http", "https"] assert SCHEME in ["http", "https"]

View File

@ -13,24 +13,17 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TI
from aria2_api import add_download from aria2_api import add_download
from config import CACHE_DIR, EXTERNAL_URL_ARIA2 from config import CACHE_DIR, EXTERNAL_URL_ARIA2, PROXY
from typing import Optional, Callable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]: def get_cache_file_and_folder(url: str) -> typing.Tuple[str, str]:
parsed_url = urlparse(url) parsed_url = urlparse(url)
hostname = parsed_url.hostname
path = parsed_url.path
assert hostname
assert path
base_dir = pathlib.Path(CACHE_DIR) base_dir = pathlib.Path(CACHE_DIR)
assert parsed_url.path[0] == "/" assert parsed_url.path[0] == "/"
assert parsed_url.path[-1] != "/" assert parsed_url.path[-1] != "/"
cache_file = (base_dir / hostname / path[1:]).resolve() cache_file = (base_dir / parsed_url.hostname / parsed_url.path[1:]).resolve()
assert cache_file.is_relative_to(base_dir) assert cache_file.is_relative_to(base_dir)
return str(cache_file), os.path.dirname(cache_file) return str(cache_file), os.path.dirname(cache_file)
@ -66,18 +59,18 @@ def make_cached_response(url):
async def get_url_content_length(url): async def get_url_content_length(url):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(proxy=PROXY, verify=False) as client:
head_response = await client.head(url) head_response = await client.head(url)
content_len = head_response.headers.get("content-length", None) content_len = (head_response.headers.get("content-length", None))
return content_len return content_len
async def try_file_based_cache( async def try_get_cache(
request: Request, request: Request,
target_url: str, target_url: str,
download_wait_time: int = 60, download_wait_time: int = 60,
post_process: Optional[Callable[[Request, Response], Response]] = None, post_process: typing.Callable[[Request, Response], Response] = None,
) -> Response: ) -> typing.Optional[Response]:
cache_status = lookup_cache(target_url) cache_status = lookup_cache(target_url)
if cache_status == DownloadingStatus.DOWNLOADED: if cache_status == DownloadingStatus.DOWNLOADED:
resp = make_cached_response(target_url) resp = make_cached_response(target_url)
@ -87,26 +80,21 @@ async def try_file_based_cache(
if cache_status == DownloadingStatus.DOWNLOADING: if cache_status == DownloadingStatus.DOWNLOADING:
logger.info(f"Download is not finished, return 503 for {target_url}") logger.info(f"Download is not finished, return 503 for {target_url}")
return Response( return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", status_code=HTTP_504_GATEWAY_TIMEOUT)
status_code=HTTP_504_GATEWAY_TIMEOUT,
)
assert cache_status == DownloadingStatus.NOT_FOUND assert cache_status == DownloadingStatus.NOT_FOUND
cache_file, cache_file_dir = get_cache_file_and_folder(target_url) cache_file, cache_file_dir = get_cache_file_and_folder(target_url)
print("prepare to download", target_url, cache_file, cache_file_dir) print("prepare to download", target_url, cache_file, cache_file_dir)
processed_url = quote(target_url, safe="/:?=&%") processed_url = quote(target_url, safe='/:?=&')
try: try:
await add_download(processed_url, save_dir=cache_file_dir) await add_download(processed_url, save_dir=cache_file_dir)
except Exception as e: except Exception as e:
logger.error(f"Download error, return 503500 for {target_url}", exc_info=e) logger.error(f"Download error, return 503500 for {target_url}", exc_info=e)
return Response( return Response(content=f"Failed to add download: {e}", status_code=HTTP_500_INTERNAL_SERVER_ERROR)
content=f"Failed to add download: {e}",
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
)
# wait for download finished # wait for download finished
for _ in range(download_wait_time): for _ in range(download_wait_time):
@ -115,7 +103,5 @@ async def try_file_based_cache(
if cache_status == DownloadingStatus.DOWNLOADED: if cache_status == DownloadingStatus.DOWNLOADED:
return make_cached_response(target_url) return make_cached_response(target_url)
logger.info(f"Download is not finished, return 503 for {target_url}") logger.info(f"Download is not finished, return 503 for {target_url}")
return Response( return Response(content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}", status_code=HTTP_504_GATEWAY_TIMEOUT)
status_code=HTTP_504_GATEWAY_TIMEOUT,
)

View File

@ -1,73 +1,47 @@
import typing import typing
from typing import Callable, Coroutine
import httpx import httpx
from httpx import Request as HttpxRequest import starlette.requests
from starlette.requests import Request import starlette.responses
from starlette.responses import Response
SyncPreProcessor = Callable[[Request, HttpxRequest], HttpxRequest] from config import PROXY
AsyncPreProcessor = Callable[
[Request, HttpxRequest], Coroutine[Request, HttpxRequest, HttpxRequest]
]
SyncPostProcessor = Callable[[Request, Response], Response]
AsyncPostProcessor = Callable[
[Request, Response], Coroutine[Request, Response, Response]
]
async def direct_proxy( async def direct_proxy(
request: Request, request: starlette.requests.Request,
target_url: str, target_url: str,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None, pre_process: typing.Callable[[starlette.requests.Request, httpx.Request], httpx.Request] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None, post_process: typing.Callable[[starlette.requests.Request, httpx.Response], httpx.Response] = None,
cache_ttl: int = 3600, ) -> typing.Optional[starlette.responses.Response]:
) -> Response: async with httpx.AsyncClient(proxy=PROXY, verify=False) as client:
# httpx will use the following environment variables to determine the proxy
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
async with httpx.AsyncClient() as client:
req_headers = request.headers.mutablecopy()
for key in req_headers.keys():
if key not in ["user-agent", "accept"]:
del req_headers[key]
httpx_req: HttpxRequest = client.build_request( headers = request.headers.mutablecopy()
request.method, for key in headers.keys():
target_url, if key not in ["user-agent", "accept"]:
headers=req_headers, del headers[key]
)
httpx_req = client.build_request(request.method, target_url, headers=headers, )
if pre_process: if pre_process:
new_httpx_req = pre_process(request, httpx_req) httpx_req = pre_process(request, httpx_req)
if isinstance(new_httpx_req, HttpxRequest):
httpx_req = new_httpx_req
else:
httpx_req = await new_httpx_req
upstream_response = await client.send(httpx_req) upstream_response = await client.send(httpx_req)
res_headers = upstream_response.headers # TODO: move to post_process
if upstream_response.status_code == 307:
location = upstream_response.headers["location"]
print("catch redirect", location)
cl = res_headers.pop("content-length", None) headers = upstream_response.headers
ce = res_headers.pop("content-encoding", None) cl = headers.pop("content-length", None)
ce = headers.pop("content-encoding", None)
# print(target_url, cl, ce) # print(target_url, cl, ce)
content = upstream_response.content content = upstream_response.content
response = Response( response = starlette.responses.Response(
headers=res_headers, headers=headers,
content=content, content=content,
status_code=upstream_response.status_code, status_code=upstream_response.status_code)
)
if post_process: if post_process:
new_res = post_process(request, response) response = post_process(request, response)
if isinstance(new_res, Response):
final_res = new_res
elif isinstance(new_res, Coroutine):
final_res = await new_res
else:
final_res = response
return final_res return response

View File

@ -1,7 +1,6 @@
import base64 import base64
import signal import signal
import urllib.parse import urllib.parse
from typing import Callable
import httpx import httpx
import uvicorn import uvicorn
@ -10,13 +9,7 @@ from starlette.requests import Request
from starlette.responses import RedirectResponse, Response from starlette.responses import RedirectResponse, Response
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
from config import ( from config import BASE_DOMAIN, RPC_SECRET, EXTERNAL_URL_ARIA2, EXTERNAL_HOST_ARIA2, SCHEME
BASE_DOMAIN,
RPC_SECRET,
EXTERNAL_URL_ARIA2,
EXTERNAL_HOST_ARIA2,
SCHEME,
)
from sites.docker import docker from sites.docker import docker
from sites.npm import npm from sites.npm import npm
from sites.pypi import pypi from sites.pypi import pypi
@ -24,45 +17,30 @@ from sites.torch import torch
app = FastAPI() app = FastAPI()
app.mount( app.mount("/aria2/", StaticFiles(directory="/wwwroot/"), name="static", )
"/aria2/",
StaticFiles(directory="/wwwroot/"),
name="static",
)
async def aria2(request: Request, call_next): async def aria2(request: Request, call_next):
if request.url.path == "/": if request.url.path == "/":
return RedirectResponse("/aria2/index.html") return RedirectResponse("/aria2/index.html")
if request.url.path == "/jsonrpc": if request.url.path == "/jsonrpc":
# dont use proxy for internal API async with httpx.AsyncClient(mounts={
async with httpx.AsyncClient( "all://": httpx.AsyncHTTPTransport()
mounts={"all://": httpx.AsyncHTTPTransport()} }) as client:
) as client: data = (await request.body())
data = await request.body() response = await client.request(url="http://aria2:6800/jsonrpc",
response = await client.request(
url="http://aria2:6800/jsonrpc",
method=request.method, method=request.method,
headers=request.headers, headers=request.headers, content=data)
content=data,
)
headers = response.headers headers = response.headers
headers.pop("content-length", None) headers.pop("content-length", None)
headers.pop("content-encoding", None) headers.pop("content-encoding", None)
return Response( return Response(content=response.content, status_code=response.status_code, headers=headers)
content=response.content,
status_code=response.status_code,
headers=headers,
)
return await call_next(request) return await call_next(request)
@app.middleware("http") @app.middleware("http")
async def capture_request(request: Request, call_next: Callable): async def capture_request(request: Request, call_next: callable):
hostname = request.url.hostname hostname = request.url.hostname
if not hostname:
return Response(content="Bad Request", status_code=400)
if not hostname.endswith(f".{BASE_DOMAIN}"): if not hostname.endswith(f".{BASE_DOMAIN}"):
return await call_next(request) return await call_next(request)
@ -81,7 +59,7 @@ async def capture_request(request: Request, call_next: Callable):
return await call_next(request) return await call_next(request)
if __name__ == "__main__": if __name__ == '__main__':
signal.signal(signal.SIGINT, signal.SIG_DFL) signal.signal(signal.SIGINT, signal.SIG_DFL)
port = 80 port = 80
print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})") print(f"Server started at {SCHEME}://*.{BASE_DOMAIN})")
@ -92,24 +70,17 @@ if __name__ == "__main__":
aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode() aria2_secret = base64.b64encode(RPC_SECRET.encode()).decode()
params = { params = {
"protocol": SCHEME, 'protocol': SCHEME,
"host": EXTERNAL_HOST_ARIA2, 'host': EXTERNAL_HOST_ARIA2,
"port": "443" if SCHEME == "https" else "80", 'port': '443' if SCHEME == 'https' else '80',
"interface": "jsonrpc", 'interface': 'jsonrpc',
"secret": aria2_secret, 'secret': aria2_secret
} }
query_string = urllib.parse.urlencode(params) query_string = urllib.parse.urlencode(params)
aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string aria2_url_with_auth = EXTERNAL_URL_ARIA2 + "#!/settings/rpc/set?" + query_string
print(f"Download manager (Aria2) at {aria2_url_with_auth}") print(f"Download manager (Aria2) at {aria2_url_with_auth}")
# FIXME: only proxy headers if SCHEME is https # FIXME: only proxy headers if SCHME is https
# reload only in dev mode # reload only in dev mode
uvicorn.run( uvicorn.run(app="server:app", host="0.0.0.0", port=port, reload=True, proxy_headers=True, forwarded_allow_ips="*")
app="server:app",
host="0.0.0.0",
port=port,
reload=True,
proxy_headers=True,
forwarded_allow_ips="*",
)

View File

@ -2,18 +2,18 @@ import base64
import json import json
import re import re
import time import time
from typing import Dict
import httpx import httpx
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from proxy.file_cache import try_file_based_cache
from proxy.direct import direct_proxy from proxy.direct import direct_proxy
BASE_URL = "https://registry-1.docker.io" BASE_URL = "https://registry-1.docker.io"
cached_token: Dict[str, str] = {} cached_token = {
}
# https://github.com/opencontainers/distribution-spec/blob/main/spec.md # https://github.com/opencontainers/distribution-spec/blob/main/spec.md
name_regex = "[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*" name_regex = "[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*"
@ -38,6 +38,7 @@ def try_extract_image_name(path):
def get_docker_token(name): def get_docker_token(name):
cached = cached_token.get(name, {}) cached = cached_token.get(name, {})
exp = cached.get("exp", 0) exp = cached.get("exp", 0)
if exp > time.time(): if exp > time.time():
return cached.get("token", 0) return cached.get("token", 0)
@ -47,13 +48,12 @@ def get_docker_token(name):
"service": "registry.docker.io", "service": "registry.docker.io",
} }
client = httpx.Client() response = httpx.get(url, params=params, verify=False)
response = client.get(url, params=params)
response.raise_for_status() response.raise_for_status()
token_data = response.json() token_data = response.json()
token = token_data["token"] token = token_data["token"]
payload = token.split(".")[1] payload = (token.split(".")[1])
padding = len(payload) % 4 padding = len(payload) % 4
payload += "=" * padding payload += "=" * padding
@ -61,27 +61,14 @@ def get_docker_token(name):
assert payload["iss"] == "auth.docker.io" assert payload["iss"] == "auth.docker.io"
assert len(payload["access"]) > 0 assert len(payload["access"]) > 0
cached_token[name] = {"exp": payload["exp"], "token": token} cached_token[name] = {
"exp": payload["exp"],
"token": token
}
return token return token
def inject_token(name: str, req: Request, httpx_req: httpx.Request):
docker_token = get_docker_token(f"{name}")
httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
return httpx_req
async def post_process(request: Request, response: Response):
if response.status_code == 307:
location = response.headers["location"]
# TODO: logger
print("[redirect]", location)
return await try_file_based_cache(request, location)
return response
async def docker(request: Request): async def docker(request: Request):
path = request.url.path path = request.url.path
print("[request]", request.method, request.url) print("[request]", request.method, request.url)
@ -95,20 +82,19 @@ async def docker(request: Request):
name, operation, reference = try_extract_image_name(path) name, operation, reference = try_extract_image_name(path)
if not name: if not name:
return Response(content="404 Not Found", status_code=404) return Response(content='404 Not Found', status_code=404)
# support docker pull xxx which name without library # support docker pull xxx which name without library
if "/" not in name: if '/' not in name:
name = f"library/{name}" name = f"library/{name}"
target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}" target_url = BASE_URL + f"/v2/{name}/{operation}/{reference}"
# logger print('[PARSED]', path, name, operation, reference, target_url)
print("[PARSED]", path, name, operation, reference, target_url)
return await direct_proxy( def inject_token(req, httpx_req):
request, docker_token = get_docker_token(f"{name}")
target_url, httpx_req.headers["Authorization"] = f"Bearer {docker_token}"
pre_process=lambda req, http_req: inject_token(name, req, http_req), return httpx_req
post_process=post_process,
) return await direct_proxy(request, target_url, pre_process=inject_token)

View File

@ -4,7 +4,7 @@ from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from proxy.direct import direct_proxy from proxy.direct import direct_proxy
from proxy.file_cache import try_file_based_cache from proxy.cached import try_get_cache
pypi_file_base_url = "https://files.pythonhosted.org" pypi_file_base_url = "https://files.pythonhosted.org"
pypi_base_url = "https://pypi.org" pypi_base_url = "https://pypi.org"
@ -39,6 +39,6 @@ async def pypi(request: Request) -> Response:
return Response(content="Not Found", status_code=404) return Response(content="Not Found", status_code=404)
if path.endswith(".whl") or path.endswith(".tar.gz"): if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_file_based_cache(request, target_url) return await try_get_cache(request, target_url)
return await direct_proxy(request, target_url, post_process=pypi_replace) return await direct_proxy(request, target_url, post_process=pypi_replace)

View File

@ -1,7 +1,7 @@
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from proxy.file_cache import try_file_based_cache from proxy.cached import try_get_cache
from proxy.direct import direct_proxy from proxy.direct import direct_proxy
BASE_URL = "https://download.pytorch.org" BASE_URL = "https://download.pytorch.org"
@ -19,9 +19,6 @@ async def torch(request: Request):
target_url = BASE_URL + path target_url = BASE_URL + path
if path.endswith(".whl") or path.endswith(".tar.gz"): if path.endswith(".whl") or path.endswith(".tar.gz"):
return await try_file_based_cache(request, target_url) return await try_get_cache(request, target_url)
return await direct_proxy( return await direct_proxy(request, target_url, )
request,
target_url,
)