diff --git a/.gitignore b/.gitignore index 7ce4267..8be8a4b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__ venv capport.yaml +custom diff --git a/src/capport/api/__init__.py b/src/capport/api/__init__.py index 8a018e1..e69de29 100644 --- a/src/capport/api/__init__.py +++ b/src/capport/api/__init__.py @@ -1,172 +0,0 @@ -from __future__ import annotations - -import ipaddress -import logging -import typing -import uuid - -import capport.database -import capport.comm.hub -import capport.comm.message -import capport.utils.cli -import capport.utils.ipneigh -import quart -import quart_trio -import trio -from capport import cptypes -from capport.config import Config - - -app = quart_trio.QuartTrio(__name__) -_logger = logging.getLogger(__name__) - - -config: typing.Optional[Config] = None -hub: typing.Optional[capport.comm.hub.Hub] = None -hub_app: typing.Optional[ApiHubApp] = None -nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None - - -def get_client_ip() -> cptypes.IPAddress: - try: - addr = ipaddress.ip_address(quart.request.remote_addr) - except ValueError as e: - _logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}') - quart.abort(500, 'Invalid client address') - if addr.is_loopback: - forw_addr_headers = quart.request.headers.getlist('X-Forwarded-For') - if len(forw_addr_headers) == 1: - try: - return ipaddress.ip_address(forw_addr_headers[0]) - except ValueError as e: - _logger.warning(f'Invalid forwarded client address {forw_addr_headers!r} (from {addr}): {e}') - quart.abort(500, 'Invalid client address') - elif forw_addr_headers: - _logger.warning(f'Multiple forwarded client addresses {forw_addr_headers!r} (from {addr})') - quart.abort(500, 'Invalid client address') - return addr - - -async def get_client_mac_if_present(address: typing.Optional[cptypes.IPAddress]=None) -> typing.Optional[cptypes.MacAddress]: - assert nc # for mypy - if not address: - address = get_client_ip() - return await nc.get_neighbor_mac(address) - - -async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cptypes.MacAddress: - mac = await get_client_mac_if_present(address) - if mac is None: - _logger.warning(f"Couldn't find MAC addresss for {address}") - quart.abort(404, 'Unknown client') - return mac - - -class ApiHubApp(capport.comm.hub.HubApplication): - async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: - # TODO: support websocket notification updates to clients? - pass - - -async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None: - assert config # for mypy - assert hub # for mypy - pu = capport.database.PendingUpdates() - try: - hub.database.login(mac, config.session_timeout, pending_updates=pu) - except capport.database.NotReadyYet as e: - quart.abort(500, str(e)) - - if pu.macs: - _logger.info(f'User {mac} (with IP {address}) logged in') - for msg in pu.serialize(): - await hub.broadcast(msg) - - -async def user_logout(mac: cptypes.MacAddress) -> None: - assert hub # for mypy - pu = capport.database.PendingUpdates() - try: - hub.database.logout(mac, pending_updates=pu) - except capport.database.NotReadyYet as e: - quart.abort(500, str(e)) - if pu.macs: - _logger.info(f'User {mac} logged out') - for msg in pu.serialize(): - await hub.broadcast(msg) - - -async def user_lookup() -> cptypes.MacPublicState: - assert hub # for mypy - address = get_client_ip() - mac = await get_client_mac_if_present(address) - if not mac: - return cptypes.MacPublicState.from_missing_mac(address) - else: - return hub.database.lookup(address, mac) - - -async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None: - global hub - global hub_app - global nc - assert config # for mypy - try: - async with capport.utils.ipneigh.connect() as mync: - nc = mync - _logger.info("Running hub for API") - myapp = ApiHubApp() - myhub = capport.comm.hub.Hub(config=config, app=myapp) - hub = myhub - hub_app = myapp - await myhub.run(task_status=task_status) - finally: - hub = None - hub_app = None - nc = None - _logger.info("Done running hub for API") - await app.shutdown() - - -@app.before_serving -async def init(): - global config - config = Config.load_default_once() - app.secret_key = config.cookie_secret - capport.utils.cli.init_logger(config) - await app.nursery.start(_run_hub) - - -# @app.route('/all') -# async def route_all(): -# return hub_app.database.as_json() - - -@app.route('/', methods=['GET']) -async def index(): - state = await user_lookup() - return await quart.render_template('index.html', state=state) - - -@app.route('/login', methods=['POST']) -async def login(): - address = get_client_ip() - mac = await get_client_mac(address) - await user_login(address, mac) - await quart.flash('Logged in') - return quart.redirect('/', code=303) - - -@app.route('/logout', methods=['POST']) -async def logout(): - mac = await get_client_mac() - await user_logout(mac) - await quart.flash('Logged out') - return quart.redirect('/', code=303) - - -@app.route('/api/captive-portal', methods=['GET']) -# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908 -async def captive_api(): - state = await user_lookup() - return state.to_rfc8908(config) diff --git a/src/capport/api/app.py b/src/capport/api/app.py new file mode 100644 index 0000000..b9e144e --- /dev/null +++ b/src/capport/api/app.py @@ -0,0 +1,8 @@ +from .app_cls import MyQuartApp + + +app = MyQuartApp(__name__) + + +__import__('capport.api.setup') +__import__('capport.api.views') diff --git a/src/capport/api/app_cls.py b/src/capport/api/app_cls.py new file mode 100644 index 0000000..5f9087d --- /dev/null +++ b/src/capport/api/app_cls.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import os +import os.path +import typing + +import capport.comm.hub +import capport.config +import capport.utils.ipneigh +import jinja2 +import quart.templating +import quart_trio + + +class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader): + app: MyQuartApp + + def __init__(self, app: MyQuartApp) -> None: + super().__init__(app) + + def _loaders(self) -> typing.Generator[jinja2.BaseLoader, None, None]: + if self.app.custom_loader: + yield self.app.custom_loader + for loader in super()._loaders(): + yield loader + + +class MyQuartApp(quart_trio.QuartTrio): + my_nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None + my_hub: typing.Optional[capport.comm.hub.Hub] = None + my_config: capport.config.Config + custom_loader: typing.Optional[jinja2.FileSystemLoader] = None + + def __init__(self, import_name: str, **kwargs) -> None: + self.my_config = capport.config.Config.load_default_once() + kwargs.setdefault('template_folder', os.path.join(os.path.dirname(__file__), 'templates')) + cust_templ = os.path.join('custom', 'templates') + if os.path.exists(cust_templ): + self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ)) + super().__init__(import_name, **kwargs) + self.debug = self.my_config.debug + self.secret_key = self.my_config.cookie_secret + + def create_global_jinja_loader(self) -> DispatchingJinjaLoader: + """Create and return a global (not blueprint specific) Jinja loader.""" + return DispatchingJinjaLoader(self) diff --git a/src/capport/api/setup.py b/src/capport/api/setup.py new file mode 100644 index 0000000..7437e1a --- /dev/null +++ b/src/capport/api/setup.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import logging +import uuid + +import capport.database +import capport.comm.hub +import capport.comm.message +import capport.utils.cli +import capport.utils.ipneigh +import trio + +from .app import app + + +_logger = logging.getLogger(__name__) + + +class ApiHubApp(capport.comm.hub.HubApplication): + async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: + # TODO: support websocket notification updates to clients? + pass + + +async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None: + try: + async with capport.utils.ipneigh.connect() as mync: + app.my_nc = mync + _logger.info("Running hub for API") + myapp = ApiHubApp() + myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp) + app.my_hub = myhub + await myhub.run(task_status=task_status) + finally: + app.my_hub = None + app.my_nc = None + _logger.info("Done running hub for API") + await app.shutdown() + + +@app.before_serving +async def init(): + app.debug = app.my_config.debug + app.secret_key = app.my_config.cookie_secret + capport.utils.cli.init_logger(app.my_config) + await app.nursery.start(_run_hub) diff --git a/src/capport/api/templates/base.html b/src/capport/api/templates/base.html index 88b314c..a0b0b67 100644 --- a/src/capport/api/templates/base.html +++ b/src/capport/api/templates/base.html @@ -2,7 +2,7 @@ - {% block title %}Captive Portal Universität Stuttgart{% endblock %} + {% block title %}Captive Portal{% endblock %} diff --git a/src/capport/api/views.py b/src/capport/api/views.py new file mode 100644 index 0000000..97786c9 --- /dev/null +++ b/src/capport/api/views.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import ipaddress +import logging +import typing + +import capport.database +import capport.comm.hub +import capport.comm.message +import capport.utils.cli +import capport.utils.ipneigh +import quart +from capport import cptypes + +from .app import app + + +_logger = logging.getLogger(__name__) + + +def get_client_ip() -> cptypes.IPAddress: + try: + addr = ipaddress.ip_address(quart.request.remote_addr) + except ValueError as e: + _logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}') + quart.abort(500, 'Invalid client address') + if addr.is_loopback: + forw_addr_headers = quart.request.headers.getlist('X-Forwarded-For') + if len(forw_addr_headers) == 1: + try: + return ipaddress.ip_address(forw_addr_headers[0]) + except ValueError as e: + _logger.warning(f'Invalid forwarded client address {forw_addr_headers!r} (from {addr}): {e}') + quart.abort(500, 'Invalid client address') + elif forw_addr_headers: + _logger.warning(f'Multiple forwarded client addresses {forw_addr_headers!r} (from {addr})') + quart.abort(500, 'Invalid client address') + return addr + + +async def get_client_mac_if_present(address: typing.Optional[cptypes.IPAddress]=None) -> typing.Optional[cptypes.MacAddress]: + assert app.my_nc # for mypy + if not address: + address = get_client_ip() + return await app.my_nc.get_neighbor_mac(address) + + +async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cptypes.MacAddress: + mac = await get_client_mac_if_present(address) + if mac is None: + _logger.warning(f"Couldn't find MAC addresss for {address}") + quart.abort(404, 'Unknown client') + return mac + + +async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None: + assert app.my_hub # for mypy + pu = capport.database.PendingUpdates() + try: + app.my_hub.database.login(mac, app.my_config.session_timeout, pending_updates=pu) + except capport.database.NotReadyYet as e: + quart.abort(500, str(e)) + + if pu.macs: + _logger.info(f'User {mac} (with IP {address}) logged in') + for msg in pu.serialize(): + await app.my_hub.broadcast(msg) + + +async def user_logout(mac: cptypes.MacAddress) -> None: + assert app.my_hub # for mypy + pu = capport.database.PendingUpdates() + try: + app.my_hub.database.logout(mac, pending_updates=pu) + except capport.database.NotReadyYet as e: + quart.abort(500, str(e)) + if pu.macs: + _logger.info(f'User {mac} logged out') + for msg in pu.serialize(): + await app.my_hub.broadcast(msg) + + +async def user_lookup() -> cptypes.MacPublicState: + assert app.my_hub # for mypy + address = get_client_ip() + mac = await get_client_mac_if_present(address) + if not mac: + return cptypes.MacPublicState.from_missing_mac(address) + else: + return app.my_hub.database.lookup(address, mac) + + +# @app.route('/all') +# async def route_all(): +# return app.my_hub.database.as_json() + + +@app.route('/', methods=['GET']) +async def index(): + state = await user_lookup() + return await quart.render_template('index.html', state=state) + + +@app.route('/login', methods=['POST']) +async def login(): + address = get_client_ip() + mac = await get_client_mac(address) + await user_login(address, mac) + await quart.flash('Logged in') + return quart.redirect('/', code=303) + + +@app.route('/logout', methods=['POST']) +async def logout(): + mac = await get_client_mac() + await user_logout(mac) + await quart.flash('Logged out') + return quart.redirect('/', code=303) + + +@app.route('/api/captive-portal', methods=['GET']) +# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908 +async def captive_api(): + state = await user_lookup() + return state.to_rfc8908(app.my_config) diff --git a/start-api.sh b/start-api.sh index 1d9d365..5bb91de 100755 --- a/start-api.sh +++ b/start-api.sh @@ -5,4 +5,4 @@ set -e base=$(dirname "$(readlink -f "$0")") cd "${base}" -exec ./venv/bin/hypercorn --config python:capport.api.hypercorn_conf capport.api "$@" +exec ./venv/bin/hypercorn --config python:capport.api.hypercorn_conf capport.api.app "$@"