diff --git a/src/capport/api/app.py b/src/capport/api/app.py index b6bc3d4..cca81b7 100644 --- a/src/capport/api/app.py +++ b/src/capport/api/app.py @@ -5,5 +5,6 @@ from .app_cls import MyQuartApp app = MyQuartApp(__name__) __import__('capport.api.setup') +__import__('capport.api.proxy_fix') __import__('capport.api.lang') __import__('capport.api.views') diff --git a/src/capport/api/proxy_fix.py b/src/capport/api/proxy_fix.py new file mode 100644 index 0000000..c574127 --- /dev/null +++ b/src/capport/api/proxy_fix.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import ipaddress +import typing + +import quart +from werkzeug.http import parse_list_header +import werkzeug + +from .app import app + + +def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str]=()) -> typing.Optional[str]: + if not value_list: + return None + values = parse_list_header(value_list) + if values and values[0]: + if not allowed or values[0] in allowed: + return values[0] + return None + + +def local_proxy_fix(request: quart.Request): + try: + addr = ipaddress.ip_address(request.remote_addr) + except ValueError: + # TODO: accept unix sockets somehow too? + return + if not addr.is_loopback: + return + client = _get_first_in_list(request.headers.get('X-Forwarded-For')) + if not client: + # assume this is always set behind reverse proxies supporting any of the headers + return + request.remote_addr = client + scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https')) + port: typing.Optional[int] = None + if scheme: + port = 443 if scheme == 'https' else 80 + request.scheme = scheme + host = _get_first_in_list(request.headers.get('X-Forwarded-Host')) + port_s: typing.Optional[str] + if host: + request.host = host + if ':' in host and not host.endswith(']'): + try: + _, port_s = host.rsplit(':', maxsplit=1) + port = int(port_s) + except ValueError: + # ignore invalid port in host header + pass + port_s = _get_first_in_list(request.headers.get('X-Forwarded-Port')) + if port_s: + try: + port = int(port_s) + except ValueError: + # ignore invalid port in header + pass + if port: + if request.server and len(request.server) == 2: + request.server = (request.server[0], port) + root_path = _get_first_in_list(request.headers.get('X-Forwarded-Prefix')) + if root_path: + request.root_path = root_path + + +class LocalProxyFixRequestHandler: + def __init__(self, orig_handle_request): + self._orig_handle_request = orig_handle_request + + async def __call__(self, request: quart.Request) -> typing.Union[quart.Response, werkzeug.Response]: + # need to patch request before url_adapter is built + local_proxy_fix(request) + return await self._orig_handle_request(request) + + +app.handle_request = LocalProxyFixRequestHandler(app.handle_request) # type: ignore diff --git a/src/capport/api/views.py b/src/capport/api/views.py index d4c9787..2368aab 100644 --- a/src/capport/api/views.py +++ b/src/capport/api/views.py @@ -25,17 +25,6 @@ def get_client_ip() -> cptypes.IPAddress: except ValueError as e: _logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}') quart.abort(500, 'Invalid client address') - if addr.is_loopback: - forw_addr_headers = quart.request.headers.getlist('X-Forwarded-For') - if len(forw_addr_headers) == 1: - try: - return ipaddress.ip_address(forw_addr_headers[0]) - except ValueError as e: - _logger.warning(f'Invalid forwarded client address {forw_addr_headers!r} (from {addr}): {e}') - quart.abort(500, 'Invalid client address') - elif forw_addr_headers: - _logger.warning(f'Multiple forwarded client addresses {forw_addr_headers!r} (from {addr})') - quart.abort(500, 'Invalid client address') return addr