From d1050d2ee4734144f0c8225c7ace970ce138af50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Mon, 4 Apr 2022 19:21:51 +0200 Subject: [PATCH] initial commit --- .gitignore | 6 + .pycodestyle | 11 + .pylintrc | 20 + LICENSE | 19 + README.md | 10 + capport-example.yaml | 7 + mypy | 24 ++ protobuf/compile.sh | 10 + protobuf/message.proto | 42 +++ pylint | 20 + pyproject.toml | 6 + setup-venv.sh | 11 + setup.cfg | 36 ++ setup.py | 6 + src/capport/__init__.py | 0 src/capport/api/__init__.py | 169 +++++++++ src/capport/api/templates/index.html | 22 ++ src/capport/comm/__init__.py | 0 src/capport/comm/hub.py | 441 +++++++++++++++++++++++ src/capport/comm/message.py | 37 ++ src/capport/comm/message.pyi | 93 +++++ src/capport/comm/protobuf/message_pb2.py | 355 ++++++++++++++++++ src/capport/config.py | 34 ++ src/capport/control/__init__.py | 0 src/capport/control/run.py | 56 +++ src/capport/cptypes.py | 94 +++++ src/capport/database.py | 200 ++++++++++ src/capport/utils/__init__.py | 0 src/capport/utils/cli.py | 17 + src/capport/utils/ipneigh.py | 63 ++++ src/capport/utils/nft_set.py | 181 ++++++++++ src/capport/utils/nft_socket.py | 217 +++++++++++ start-api.sh | 8 + start-control.sh | 8 + 34 files changed, 2223 insertions(+) create mode 100644 .gitignore create mode 100644 .pycodestyle create mode 100644 .pylintrc create mode 100644 LICENSE create mode 100644 README.md create mode 100644 capport-example.yaml create mode 100755 mypy create mode 100755 protobuf/compile.sh create mode 100644 protobuf/message.proto create mode 100755 pylint create mode 100644 pyproject.toml create mode 100755 setup-venv.sh create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 src/capport/__init__.py create mode 100644 src/capport/api/__init__.py create mode 100644 src/capport/api/templates/index.html create mode 100644 src/capport/comm/__init__.py create mode 100644 src/capport/comm/hub.py create mode 100644 src/capport/comm/message.py create mode 100644 src/capport/comm/message.pyi create mode 100644 src/capport/comm/protobuf/message_pb2.py create mode 100644 src/capport/config.py create mode 100644 src/capport/control/__init__.py create mode 100644 src/capport/control/run.py create mode 100644 src/capport/cptypes.py create mode 100644 src/capport/database.py create mode 100644 src/capport/utils/__init__.py create mode 100644 src/capport/utils/cli.py create mode 100644 src/capport/utils/ipneigh.py create mode 100644 src/capport/utils/nft_set.py create mode 100644 src/capport/utils/nft_socket.py create mode 100755 start-api.sh create mode 100755 start-control.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7ce4267 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.vscode +*.pyc +*.egg-info +__pycache__ +venv +capport.yaml diff --git a/.pycodestyle b/.pycodestyle new file mode 100644 index 0000000..8f3dbe0 --- /dev/null +++ b/.pycodestyle @@ -0,0 +1,11 @@ +[pycodestyle] +# E241 multiple spaces after ':' [ want to align stuff ] +# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ] +# E501 line too long [ temporary? can't disable it in certain places.. ] +# E701 multiple statements on one line (colon) [ perfectly readable ] +# E713 test for membership should be ‘not in’ [ disagree: want `not a in x` ] +# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ] +# W503 Line break occurred before a binary operator [ pep8 flipped on this (also contradicts W504) ] +ignore = E241,E266,E501,E701,E713,E714,W503 +max-line-length = 120 +exclude = 00*.py,generated.py diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..f05a15a --- /dev/null +++ b/.pylintrc @@ -0,0 +1,20 @@ +[MESSAGES CONTROL] + +disable=logging-fstring-interpolation + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=120 + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +[DESIGN] + +# Maximum number of locals for function / method body. +max-locals=20 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=0 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..ab84ef7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2022 Universität Stuttgart + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..de72aa8 --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ +# python Captive Portal + +### Installation + +Either clone repository (and install dependencies either through distribution or as virtualenv with `./setup-venv.sh`) or install as package. + +[`pipx`](https://pypa.github.io/pipx/) (available in debian as package) can be used to install in separate virtual environment: + + pipx install https://github.tik.uni-stuttgart.de/NKS/python-capport + diff --git a/capport-example.yaml b/capport-example.yaml new file mode 100644 index 0000000..7e124f2 --- /dev/null +++ b/capport-example.yaml @@ -0,0 +1,7 @@ +--- +secret: mysecret +controllers: +- capport-controller1.example.com +- capport-controller2.example.com +session-timeout: 3600 # in seconds +venue-info-url: 'https://example.com' diff --git a/mypy b/mypy new file mode 100755 index 0000000..095418b --- /dev/null +++ b/mypy @@ -0,0 +1,24 @@ +#!/bin/sh + +### check type annotations with mypy + +set -e + +base=$(dirname "$(readlink -f "$0")") +cd "${base}" + +if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then + echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!" + exit 1 +fi + +if [ ! -x ./venv/bin/mypy ]; then + ./venv/bin/pip install mypy +fi + +site_pkgs=$(./venv/bin/python -c 'import site; print(site.getsitepackages()[0])') +if [ ! -d "${site_pkgs}/trio_typing" ]; then + ./venv/bin/pip install trio-typing[mypy] +fi + +./venv/bin/mypy --install-types src diff --git a/protobuf/compile.sh b/protobuf/compile.sh new file mode 100755 index 0000000..031bcdf --- /dev/null +++ b/protobuf/compile.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -e + +cd "$(dirname "$(readlink -f "$0")")" + +rm -rf ../src/capport/comm/protobuf/message_pb2.py +mkdir -p ../src/capport/comm/protobuf + +protoc --python_out=../src/capport/comm/protobuf message.proto diff --git a/protobuf/message.proto b/protobuf/message.proto new file mode 100644 index 0000000..ef9fe1d --- /dev/null +++ b/protobuf/message.proto @@ -0,0 +1,42 @@ +syntax = "proto3"; + +package capport; + +message Message { + oneof oneof { + Hello hello = 1; + AuthenticationResult authentication_result = 2; + Ping ping = 3; + MacStates mac_states = 10; + } +} + +// sent by clients and servers as first message +message Hello { + bytes instance_id = 1; + string hostname = 2; + bool is_controller = 3; + bytes authentication = 4; +} + +// tell peer whether hello authentication was good +message AuthenticationResult { + bool success = 1; +} + +message Ping { + bytes payload = 1; +} + +message MacStates { + repeated MacState states = 1; +} + +message MacState { + bytes mac_address = 1; + // Seconds of UTC time since epoch + int64 last_change = 2; + // Seconds of UTC time since epoch + int64 allow_until = 3; + bool allowed = 4; +} diff --git a/pylint b/pylint new file mode 100755 index 0000000..d88ee0c --- /dev/null +++ b/pylint @@ -0,0 +1,20 @@ +#!/bin/sh + +### check type annotations with mypy + +set -e + +base=$(dirname "$(readlink -f "$0")") +cd "${base}" + +if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then + echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!" + exit 1 +fi + +if [ ! -x ./venv/bin/pylint ]; then + # need current pylint to deal with more recent python features + ./venv/bin/pip install pylint +fi + +./venv/bin/pylint src diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..374b58c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel" +] +build-backend = "setuptools.build_meta" diff --git a/setup-venv.sh b/setup-venv.sh new file mode 100755 index 0000000..4880b73 --- /dev/null +++ b/setup-venv.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +set -e + +self=$(dirname "$(readlink -f "$0")") +cd "${self}" + +python3 -m venv venv + +# install cli extras +./venv/bin/pip install -e '.' diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..00b80b4 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,36 @@ +[metadata] +name = capport-tik-nks +version = 0.0.1 +author = Stefan Bühler +author_email = stefan.buehler@tik.uni-stuttgart.de +description = Captive Portal +long_description = file: README.md +long_description_content_type = text/markdown +url = https://github.tik.uni-stuttgart.de/NKS/python-capport +project_urls = + Bug Tracker = https://github.tik.uni-stuttgart.de/NKS/python-capport/issues +classifiers = + Programming Language :: Python :: 3 + License :: OSI Approved :: MIT License + Operating System :: OS Independent + +[options] +package_dir = + = src +packages = find: +python_requires = >=3.9 +install_requires = + trio + quart-trio + quart + hypercorn[trio] + PyYAML + protobuf + pyroute2 + +[options.packages.find] +where = src + +[options.entry_points] +console_scripts = + capport-control = capport.control.run:main diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..f9c1139 --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +# https://github.com/pypa/setuptools/issues/2816 +# allow editable install on older pip versions +from setuptools import setup + +if __name__ == "__main__": + setup() diff --git a/src/capport/__init__.py b/src/capport/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/capport/api/__init__.py b/src/capport/api/__init__.py new file mode 100644 index 0000000..ac785e2 --- /dev/null +++ b/src/capport/api/__init__.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import ipaddress +import logging +import typing +import uuid + +import capport.database +import capport.comm.hub +import capport.comm.message +import capport.utils.cli +import capport.utils.ipneigh +import quart +import quart_trio +import trio +from capport import cptypes +from capport.config import Config + + +app = quart_trio.QuartTrio(__name__) +_logger = logging.getLogger(__name__) + + +config: typing.Optional[Config] = None +hub: typing.Optional[capport.comm.hub.Hub] = None +hub_app: typing.Optional[ApiHubApp] = None +nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None + + +def get_client_ip() -> cptypes.IPAddress: + try: + addr = ipaddress.ip_address(quart.request.remote_addr) + except ValueError as e: + _logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}') + quart.abort(500, 'Invalid client address') + if addr.is_loopback: + forw_addr_headers = quart.request.headers.getlist('X-Forwarded-For') + if len(forw_addr_headers) == 1: + try: + return ipaddress.ip_address(forw_addr_headers[0]) + except ValueError as e: + _logger.warning(f'Invalid forwarded client address {forw_addr_headers!r} (from {addr}): {e}') + quart.abort(500, 'Invalid client address') + elif forw_addr_headers: + _logger.warning(f'Multiple forwarded client addresses {forw_addr_headers!r} (from {addr})') + quart.abort(500, 'Invalid client address') + return addr + + +async def get_client_mac_if_present(address: typing.Optional[cptypes.IPAddress]=None) -> typing.Optional[cptypes.MacAddress]: + assert nc # for mypy + if not address: + address = get_client_ip() + return await nc.get_neighbor_mac(address) + + +async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cptypes.MacAddress: + mac = await get_client_mac_if_present(address) + if mac is None: + _logger.warning(f"Couldn't find MAC addresss for {address}") + quart.abort(404, 'Unknown client') + return mac + + +class ApiHubApp(capport.comm.hub.HubApplication): + async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: + # TODO: support websocket notification updates to clients? + pass + + +async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None: + assert config # for mypy + assert hub # for mypy + pu = capport.database.PendingUpdates() + try: + hub.database.login(mac, config.session_timeout, pending_updates=pu) + except capport.database.NotReadyYet as e: + quart.abort(500, str(e)) + + if pu.macs: + _logger.info(f'User {mac} (with IP {address}) logged in') + for msg in pu.serialize(): + await hub.broadcast(msg) + + +async def user_logout(mac: cptypes.MacAddress) -> None: + assert hub # for mypy + pu = capport.database.PendingUpdates() + try: + hub.database.logout(mac, pending_updates=pu) + except capport.database.NotReadyYet as e: + quart.abort(500, str(e)) + if pu.macs: + _logger.info(f'User {mac} logged out') + for msg in pu.serialize(): + await hub.broadcast(msg) + + +async def user_lookup() -> cptypes.MacPublicState: + assert hub # for mypy + address = get_client_ip() + mac = await get_client_mac_if_present(address) + if not mac: + return cptypes.MacPublicState.from_missing_mac(address) + else: + return hub.database.lookup(address, mac) + + +async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None: + global hub + global hub_app + global nc + assert config # for mypy + try: + async with capport.utils.ipneigh.connect() as mync: + nc = mync + _logger.info("Running hub for API") + myapp = ApiHubApp() + myhub = capport.comm.hub.Hub(config=config, app=myapp) + hub = myhub + hub_app = myapp + await myhub.run(task_status=task_status) + finally: + hub = None + hub_app = None + nc = None + _logger.info("Done running hub for API") + await app.shutdown() + + +@app.before_serving +async def init(): + global config + config = Config.load() + capport.utils.cli.init_logger(config) + await app.nursery.start(_run_hub) + + +# @app.route('/all') +# async def route_all(): +# return hub_app.database.as_json() + + +@app.route('/', methods=['GET']) +async def index(): + state = await user_lookup() + return await quart.render_template('index.html', state=state) + + +@app.route('/login', methods=['POST']) +async def login(): + address = get_client_ip() + mac = await get_client_mac(address) + await user_login(address, mac) + return quart.redirect('/', code=303) + + +@app.route('/logout', methods=['POST']) +async def logout(): + mac = await get_client_mac() + await user_logout(mac) + return quart.redirect('/', code=303) + + +@app.route('/api/captive-portal', methods=['GET']) +# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908 +async def captive_api(): + state = await user_lookup() + return state.to_rfc8908(config) diff --git a/src/capport/api/templates/index.html b/src/capport/api/templates/index.html new file mode 100644 index 0000000..bc762a7 --- /dev/null +++ b/src/capport/api/templates/index.html @@ -0,0 +1,22 @@ + + + + + Captive Portal Universität Stuttgart + + + + {% if not state.mac %} + It seems you're accessing this site from outside the network this captive portal is running for. + {% elif state.captive %} + To get access to the internet please accept our usage guidelines by clicking this button: +
+ {% else %} + You already accepted out conditions and are currently granted access to the internet: +
+
+
+ Your current session will last for {{ state.allowed_remaining }} seconds. + {% endif %} + + \ No newline at end of file diff --git a/src/capport/comm/__init__.py b/src/capport/comm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/capport/comm/hub.py b/src/capport/comm/hub.py new file mode 100644 index 0000000..2401f4e --- /dev/null +++ b/src/capport/comm/hub.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +import random +import socket +import ssl +import struct +import typing +import uuid + +import capport.database +import capport.comm.message +import trio + +if typing.TYPE_CHECKING: + from ..config import Config + + +_logger = logging.getLogger(__name__) + + +class HubConnectionReadError(ConnectionError): + pass + + +class HubConnectionClosedError(ConnectionError): + pass + + +class LoopbackConnectionError(Exception): + pass + + +class Channel: + def __init__(self, hub: Hub, transport_stream, server_side: bool): + self._hub = hub + self._serverside = server_side + self._ssl = trio.SSLStream(transport_stream, self._hub._anon_context, server_side=server_side) + _logger.debug(f"{self}: created (server_side={server_side})") + + def __repr__(self) -> str: + return f'Channel[0x{id(self):x}]' + + async def do_handshake(self) -> capport.comm.message.Hello: + try: + await self._ssl.do_handshake() + ssl_binding = self._ssl.get_channel_binding() + if not ssl_binding: + # binding mustn't be None after successful handshake + raise ConnectionError("Missing SSL channel binding") + except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: + raise ConnectionError(e) from None + msg = self._hub._make_hello(ssl_binding, server_side=self._serverside).to_message() + await self.send_msg(msg) + peer_hello = (await self.recv_msg()).to_variant() + if not isinstance(peer_hello, capport.comm.message.Hello): + raise HubConnectionReadError("Expected Hello as first message") + auth_succ = (peer_hello.authentication == self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)) + await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message()) + peer_auth = (await self.recv_msg()).to_variant() + if not isinstance(peer_auth, capport.comm.message.AuthenticationResult): + raise HubConnectionReadError("Expected AuthenticationResult as second message") + if not auth_succ or not peer_auth.success: + raise HubConnectionReadError("Authentication failed") + return peer_hello + + async def _read(self, num: int) -> bytes: + assert num > 0 + buf = b'' + # _logger.debug(f"{self}:_read({num})") + while num > 0: + try: + part = await self._ssl.receive_some(num) + except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: + raise ConnectionError(e) from None + # _logger.debug(f"{self}:_read({num}) got part {part!r}") + if len(part) == 0: + if len(buf) == 0: + raise HubConnectionClosedError() + raise HubConnectionReadError("Unexpected end of TLS stream") + buf += part + num -= len(part) + if num < 0: + raise HubConnectionReadError("TLS receive_some returned too much") + return buf + + async def _recv_raw_msg(self) -> bytes: + len_bytes = await self._read(4) + chunk_size, = struct.unpack('!I', len_bytes) + chunk = await self._read(chunk_size) + if chunk is None: + raise HubConnectionReadError("Unexpected end of TLS stream after chunk length") + return chunk + + async def recv_msg(self) -> capport.comm.message.Message: + try: + chunk = await self._recv_raw_msg() + except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: + raise ConnectionError(e) from None + msg = capport.comm.message.Message() + msg.ParseFromString(chunk) + return msg + + async def _send_raw(self, chunk: bytes) -> None: + try: + await self._ssl.send_all(chunk) + except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: + raise ConnectionError(e) from None + + async def send_msg(self, msg: capport.comm.message.Message): + chunk = msg.SerializeToString(deterministic=True) + chunk_size = len(chunk) + len_bytes = struct.pack('!I', chunk_size) + chunk = len_bytes + chunk + await self._send_raw(chunk) + + async def aclose(self): + try: + await self._ssl.aclose() + except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: + raise ConnectionError(e) from None + + +class Connection: + PING_INTERVAL = 10 + RECEIVE_TIMEOUT = 15 + SEND_TIMEOUT = 5 + + def __init__(self, hub: Hub, channel: Channel, peer: capport.comm.message.Hello): + self._channel = channel + self._hub = hub + tx: trio.MemorySendChannel + rx: trio.MemoryReceiveChannel + (tx, rx) = trio.open_memory_channel(64) + self._pending_tx = tx + self._pending_rx = rx + self.peer: capport.comm.message.Hello = peer + self.peer_id: uuid.UUID = uuid.UUID(bytes=peer.instance_id) + self.closed = trio.Event() # set by Hub._lost_peer + _logger.debug(f"{self._channel}: authenticated -> {self.peer_id}") + + async def _sender(self, cancel_scope: trio.CancelScope) -> None: + try: + msg: typing.Optional[capport.comm.message.Message] + while True: + msg = None + # make sure we send something every PING_INTERVAL + with trio.move_on_after(self.PING_INTERVAL): + msg = await self._pending_rx.receive() + # if send blocks too long we're in trouble + with trio.fail_after(self.SEND_TIMEOUT): + if msg: + await self._channel.send_msg(msg) + else: + await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message()) + except trio.TooSlowError: + _logger.warning(f"{self._channel}: send timed out") + except ConnectionError as e: + _logger.warning(f"{self._channel}: failed sending: {e!r}") + except Exception as e: + _logger.exception(f"{self._channel}: failed sending") + finally: + cancel_scope.cancel() + + async def _receive(self, cancel_scope: trio.CancelScope) -> None: + try: + while True: + try: + with trio.fail_after(self.RECEIVE_TIMEOUT): + msg = await self._channel.recv_msg() + except (HubConnectionClosedError, ConnectionResetError): + return + except trio.TooSlowError: + _logger.warning(f"{self._channel}: receive timed out") + return + await self._hub._received_msg(self.peer_id, msg) + except ConnectionError as e: + _logger.warning(f"{self._channel}: failed receiving: {e!r}") + except Exception: + _logger.exception(f"{self._channel}: failed receiving") + finally: + cancel_scope.cancel() + + async def _inner_run(self) -> None: + if self.peer_id == self._hub._instance_id: + # connected to ourself, don't need that + raise LoopbackConnectionError() + async with trio.open_nursery() as nursery: + nursery.start_soon(self._sender, nursery.cancel_scope) + # be nice and wait for new_peer beforce receiving messages + # (won't work on failover to a second connection) + await nursery.start(self._hub._new_peer, self.peer_id, self) + nursery.start_soon(self._receive, nursery.cancel_scope) + + async def send_msg(self, *msgs: capport.comm.message.Message): + try: + for msg in msgs: + await self._pending_tx.send(msg) + except trio.ClosedResourceError: + pass + + async def _run(self) -> None: + try: + await self._inner_run() + finally: + _logger.debug(f"{self._channel}: finished message handling") + # basic (non-async) cleanup + self._hub._lost_peer(self.peer_id, self) + self._pending_tx.close() + self._pending_rx.close() + # allow 3 seconds for proper cleanup + with trio.CancelScope(shield=True, deadline=trio.current_time() + 3): + try: + await self._channel.aclose() + except OSError: + pass + + @staticmethod + async def run(hub: Hub, transport_stream, server_side: bool) -> None: + channel = Channel(hub, transport_stream, server_side) + try: + with trio.fail_after(5): + peer = await channel.do_handshake() + except trio.TooSlowError: + _logger.warning(f"Handshake timed out") + return + conn = Connection(hub, channel, peer) + await conn._run() + + +class ControllerConn: + def __init__(self, hub: Hub, hostname: str): + self._hub = hub + self.hostname = hostname + self.loopback = False + + async def _connect(self): + _logger.info(f"Connecting to controller at {self.hostname}") + with trio.fail_after(5): + try: + stream = await trio.open_tcp_stream(self.hostname, 5000) + except OSError as e: + _logger.warning(f"Failed to connect to controller at {self.hostname}: {e}") + return + try: + await Connection.run(self._hub, stream, server_side=False) + finally: + _logger.info(f"Connection to {self.hostname} closed") + + async def run(self): + while True: + try: + await self._connect() + except LoopbackConnectionError: + _logger.debug(f"Connection to {self.hostname} reached ourself") + self.loopback = True + return + except trio.TooSlowError: + pass + # try again later + retry_splay = random.random() * 5 + await trio.sleep(10 + retry_splay) + + +class HubApplication: + def is_controller(self) -> bool: + return False + + async def new_peer(self, *, peer_id: uuid.UUID) -> None: + _logger.info(f"New peer {peer_id}") + + def lost_peer(self, *, peer_id: uuid.UUID) -> None: + _logger.warning(f"Lost peer {peer_id}") + + async def received_unknown_message(self, *, from_peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None: + _logger.warning(f"Received from {from_peer_id}: {str(msg).strip()}") + + async def received_mac_state(self, *, from_peer_id: uuid.UUID, states: capport.comm.message.MacStates) -> None: + if _logger.isEnabledFor(logging.DEBUG): + _logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}") + + async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: + if _logger.isEnabledFor(logging.DEBUG): + _logger.debug(f"Received new states from {from_peer_id}: {pending_updates}") + + +class Hub: + def __init__(self, config: Config, app: HubApplication) -> None: + self._config = config + self._instance_id = uuid.uuid4() + self._hostname = socket.getfqdn() + self.database = capport.database.Database() + self._app = app + self._is_controller = bool(app.is_controller()) + self._anon_context = ssl.SSLContext() + # python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon + self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2 + self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2 + # -> AECDH-AES256-SHA + # sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way? + self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0') + self._controllers: typing.Dict[str, ControllerConn] = {} + self._established: typing.Dict[uuid.UUID, Connection] = {} + + async def _accept(self, stream): + remotename = stream.socket.getpeername() + if isinstance(remotename, tuple) and len(remotename) == 2: + remote = f'[{remotename[0]}]:{remotename[1]}' + else: + remote = str(remotename) + try: + await Connection.run(self, stream, server_side=True) + except LoopbackConnectionError: + pass + except trio.TooSlowError: + pass + finally: + _logger.debug(f"Connection from {remote} closed") + + async def _listen(self, task_status=trio.TASK_STATUS_IGNORED): + await trio.serve_tcp(self._accept, 5000, task_status=task_status) + + async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): + async with trio.open_nursery() as nursery: + if self._is_controller: + await nursery.start(self._listen) + + for name in self._config.controllers: + conn = ControllerConn(self, name) + self._controllers[name] = conn + + task_status.started() + + for conn in self._controllers.values(): + nursery.start_soon(conn.run) + + await trio.sleep_forever() + + def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes: + m = hmac.new(self._config.secret.encode('utf8'), digestmod=hashlib.sha256) + if server_side: + m.update(b'server$') + else: + m.update(b'client$') + m.update(ssl_binding) + return m.digest() + + def _make_hello(self, ssl_binding: bytes, server_side: bool) -> capport.comm.message.Hello: + return capport.comm.message.Hello( + instance_id=self._instance_id.bytes, + hostname=self._hostname, + is_controller=self._is_controller, + authentication=self._calc_authentication(ssl_binding, server_side), + ) + + async def _sync_new_connection(self, peer_id: uuid.UUID, conn: Connection) -> None: + # send database (and all changes) to peers + await self.send(*self.database.serialize(), to=peer_id) + + async def _new_peer(self, peer_id: uuid.UUID, conn: Connection, task_status=trio.TASK_STATUS_IGNORED) -> None: + have = self._established.get(peer_id, None) + if not have: + # peer unknown, "normal start" + # no "await" between get above and set here!!! + self._established[peer_id] = conn + # first wait for app to handle new peer + await self._app.new_peer(peer_id=peer_id) + task_status.started() + await self._sync_new_connection(peer_id, conn) + return + + # peer already known - immediately allow receiving messages, then sync connection + task_status.started() + await self._sync_new_connection(peer_id, conn) + # now try to register connection for outgoing messages + while True: + # recheck whether peer is currently known (due to awaits since last get) + have = self._established.get(peer_id, None) + if have: + # already got a connection, nothing to do as long as it lives + await have.closed.wait() + else: + # make `conn` new outgoing connection for peer + # no "await" between get above and set here!!! + self._established[peer_id] = conn + await self._app.new_peer(peer_id=peer_id) + return + + def _lost_peer(self, peer_id: uuid.UUID, conn: Connection): + have = self._established.get(peer_id, None) + lost = False + if have is conn: + lost = True + self._established.pop(peer_id) + conn.closed.set() + # only notify if this was the active connection + if lost: + # even when we failover to another connection we still need to resync + # as we don't know which messages might have got lost + # -> always trigger lost_peer + self._app.lost_peer(peer_id=peer_id) + + async def _received_msg(self, peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None: + variant = msg.to_variant() + if isinstance(variant, capport.comm.message.Hello): + pass + elif isinstance(variant, capport.comm.message.AuthenticationResult): + pass + elif isinstance(variant, capport.comm.message.Ping): + pass + elif isinstance(variant, capport.comm.message.MacStates): + await self._app.received_mac_state(from_peer_id=peer_id, states=variant) + pu = capport.database.PendingUpdates() + for state in variant.states: + self.database.received_mac_state(state, pending_updates=pu) + if pu.macs: + # re-broadcast all received updates to all peers + await self.broadcast(*pu.serialize(), exclude=peer_id) + await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu) + else: + await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg) + + def peer_is_controller(self, peer_id: uuid.UUID) -> bool: + conn = self._established.get(peer_id) + if conn: + return conn.peer.is_controller + return False + + async def send(self, *msgs: capport.comm.message.Message, to: uuid.UUID): + conn = self._established.get(to) + if conn: + await conn.send_msg(*msgs) + + async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID]=None): + async with trio.open_nursery() as nursery: + for peer_id, conn in self._established.items(): + if peer_id == exclude: + continue + nursery.start_soon(conn.send_msg, *msgs) diff --git a/src/capport/comm/message.py b/src/capport/comm/message.py new file mode 100644 index 0000000..7cd1750 --- /dev/null +++ b/src/capport/comm/message.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import typing + +from .protobuf import message_pb2 + + +def _message_to_variant(self: message_pb2.Message) -> typing.Any: + variant_name = self.WhichOneof('oneof') + if variant_name: + return getattr(self, variant_name) + return None + + +def _make_to_message(oneof_field): + def to_message(self) -> message_pb2.Message: + msg = message_pb2.Message(**{oneof_field: self}) + return msg + return to_message + + + +def _monkey_patch(): + g = globals() + g['Message'] = message_pb2.Message + message_pb2.Message.to_variant = _message_to_variant + for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields: + type_name = field.message_type.name + field_type = getattr(message_pb2, type_name) + field_type.to_message = _make_to_message(field.name) + g[type_name] = field_type + + +# also re-exports all message types +_monkey_patch() +# not a variant of Message, still re-export +MacState = message_pb2.MacState diff --git a/src/capport/comm/message.pyi b/src/capport/comm/message.pyi new file mode 100644 index 0000000..573ab12 --- /dev/null +++ b/src/capport/comm/message.pyi @@ -0,0 +1,93 @@ +import google.protobuf.message +import typing + + +# manually maintained typehints for protobuf created (and monkey-patched) types + + +class Message(google.protobuf.message.Message): + hello: Hello + authentication_result: AuthenticationResult + ping: Ping + mac_states: MacStates + + def __init__( + self, + *, + hello: typing.Optional[Hello]=None, + authentication_result: typing.Optional[AuthenticationResult]=None, + ping: typing.Optional[Ping]=None, + mac_states: typing.Optional[MacStates]=None, + ) -> None: ... + + def to_variant(self) -> typing.Union[Hello, AuthenticationResult, Ping, MacStates]: ... + + +class Hello(google.protobuf.message.Message): + instance_id: bytes + hostname: str + is_controller: bool + authentication: bytes + + def __init__( + self, + *, + instance_id: bytes=b'', + hostname: str='', + is_controller: bool=False, + authentication: bytes=b'', + ) -> None: ... + + def to_message(self) -> Message: ... + + +class AuthenticationResult(google.protobuf.message.Message): + success: bool + + def __init__( + self, + *, + success: bool=False, + ) -> None: ... + + def to_message(self) -> Message: ... + + +class Ping(google.protobuf.message.Message): + payload: bytes + + def __init__( + self, + *, + payload: bytes=b'', + ) -> None: ... + + def to_message(self) -> Message: ... + + +class MacStates(google.protobuf.message.Message): + states: typing.List[MacState] + + def __init__( + self, + *, + states: typing.List[MacState]=[], + ) -> None: ... + + def to_message(self) -> Message: ... + + +class MacState(google.protobuf.message.Message): + mac_address: bytes + last_change: int # Seconds of UTC time since epoch + allow_until: int # Seconds of UTC time since epoch + allowed: bool + + def __init__( + self, + *, + mac_address: bytes=b'', + last_change: int=0, + allow_until: int=0, + allowed: bool=False, + ) -> None: ... diff --git a/src/capport/comm/protobuf/message_pb2.py b/src/capport/comm/protobuf/message_pb2.py new file mode 100644 index 0000000..2af318c --- /dev/null +++ b/src/capport/comm/protobuf/message_pb2.py @@ -0,0 +1,355 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: message.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='message.proto', + package='capport', + syntax='proto3', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\rmessage.proto\x12\x07\x63\x61pport\"\xbc\x01\n\x07Message\x12\x1f\n\x05hello\x18\x01 \x01(\x0b\x32\x0e.capport.HelloH\x00\x12>\n\x15\x61uthentication_result\x18\x02 \x01(\x0b\x32\x1d.capport.AuthenticationResultH\x00\x12\x1d\n\x04ping\x18\x03 \x01(\x0b\x32\r.capport.PingH\x00\x12(\n\nmac_states\x18\n \x01(\x0b\x32\x12.capport.MacStatesH\x00\x42\x07\n\x05oneof\"]\n\x05Hello\x12\x13\n\x0binstance_id\x18\x01 \x01(\x0c\x12\x10\n\x08hostname\x18\x02 \x01(\t\x12\x15\n\ris_controller\x18\x03 \x01(\x08\x12\x16\n\x0e\x61uthentication\x18\x04 \x01(\x0c\"\'\n\x14\x41uthenticationResult\x12\x0f\n\x07success\x18\x01 \x01(\x08\"\x17\n\x04Ping\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\".\n\tMacStates\x12!\n\x06states\x18\x01 \x03(\x0b\x32\x11.capport.MacState\"Z\n\x08MacState\x12\x13\n\x0bmac_address\x18\x01 \x01(\x0c\x12\x13\n\x0blast_change\x18\x02 \x01(\x03\x12\x13\n\x0b\x61llow_until\x18\x03 \x01(\x03\x12\x0f\n\x07\x61llowed\x18\x04 \x01(\x08\x62\x06proto3' +) + + + + +_MESSAGE = _descriptor.Descriptor( + name='Message', + full_name='capport.Message', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='hello', full_name='capport.Message.hello', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='authentication_result', full_name='capport.Message.authentication_result', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='ping', full_name='capport.Message.ping', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='mac_states', full_name='capport.Message.mac_states', index=3, + number=10, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='oneof', full_name='capport.Message.oneof', + index=0, containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[]), + ], + serialized_start=27, + serialized_end=215, +) + + +_HELLO = _descriptor.Descriptor( + name='Hello', + full_name='capport.Hello', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='instance_id', full_name='capport.Hello.instance_id', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='hostname', full_name='capport.Hello.hostname', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='is_controller', full_name='capport.Hello.is_controller', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='authentication', full_name='capport.Hello.authentication', index=3, + number=4, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=217, + serialized_end=310, +) + + +_AUTHENTICATIONRESULT = _descriptor.Descriptor( + name='AuthenticationResult', + full_name='capport.AuthenticationResult', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='success', full_name='capport.AuthenticationResult.success', index=0, + number=1, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=312, + serialized_end=351, +) + + +_PING = _descriptor.Descriptor( + name='Ping', + full_name='capport.Ping', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='payload', full_name='capport.Ping.payload', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=353, + serialized_end=376, +) + + +_MACSTATES = _descriptor.Descriptor( + name='MacStates', + full_name='capport.MacStates', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='states', full_name='capport.MacStates.states', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=378, + serialized_end=424, +) + + +_MACSTATE = _descriptor.Descriptor( + name='MacState', + full_name='capport.MacState', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='mac_address', full_name='capport.MacState.mac_address', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='last_change', full_name='capport.MacState.last_change', index=1, + number=2, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='allow_until', full_name='capport.MacState.allow_until', index=2, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='allowed', full_name='capport.MacState.allowed', index=3, + number=4, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=426, + serialized_end=516, +) + +_MESSAGE.fields_by_name['hello'].message_type = _HELLO +_MESSAGE.fields_by_name['authentication_result'].message_type = _AUTHENTICATIONRESULT +_MESSAGE.fields_by_name['ping'].message_type = _PING +_MESSAGE.fields_by_name['mac_states'].message_type = _MACSTATES +_MESSAGE.oneofs_by_name['oneof'].fields.append( + _MESSAGE.fields_by_name['hello']) +_MESSAGE.fields_by_name['hello'].containing_oneof = _MESSAGE.oneofs_by_name['oneof'] +_MESSAGE.oneofs_by_name['oneof'].fields.append( + _MESSAGE.fields_by_name['authentication_result']) +_MESSAGE.fields_by_name['authentication_result'].containing_oneof = _MESSAGE.oneofs_by_name['oneof'] +_MESSAGE.oneofs_by_name['oneof'].fields.append( + _MESSAGE.fields_by_name['ping']) +_MESSAGE.fields_by_name['ping'].containing_oneof = _MESSAGE.oneofs_by_name['oneof'] +_MESSAGE.oneofs_by_name['oneof'].fields.append( + _MESSAGE.fields_by_name['mac_states']) +_MESSAGE.fields_by_name['mac_states'].containing_oneof = _MESSAGE.oneofs_by_name['oneof'] +_MACSTATES.fields_by_name['states'].message_type = _MACSTATE +DESCRIPTOR.message_types_by_name['Message'] = _MESSAGE +DESCRIPTOR.message_types_by_name['Hello'] = _HELLO +DESCRIPTOR.message_types_by_name['AuthenticationResult'] = _AUTHENTICATIONRESULT +DESCRIPTOR.message_types_by_name['Ping'] = _PING +DESCRIPTOR.message_types_by_name['MacStates'] = _MACSTATES +DESCRIPTOR.message_types_by_name['MacState'] = _MACSTATE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Message = _reflection.GeneratedProtocolMessageType('Message', (_message.Message,), { + 'DESCRIPTOR' : _MESSAGE, + '__module__' : 'message_pb2' + # @@protoc_insertion_point(class_scope:capport.Message) + }) +_sym_db.RegisterMessage(Message) + +Hello = _reflection.GeneratedProtocolMessageType('Hello', (_message.Message,), { + 'DESCRIPTOR' : _HELLO, + '__module__' : 'message_pb2' + # @@protoc_insertion_point(class_scope:capport.Hello) + }) +_sym_db.RegisterMessage(Hello) + +AuthenticationResult = _reflection.GeneratedProtocolMessageType('AuthenticationResult', (_message.Message,), { + 'DESCRIPTOR' : _AUTHENTICATIONRESULT, + '__module__' : 'message_pb2' + # @@protoc_insertion_point(class_scope:capport.AuthenticationResult) + }) +_sym_db.RegisterMessage(AuthenticationResult) + +Ping = _reflection.GeneratedProtocolMessageType('Ping', (_message.Message,), { + 'DESCRIPTOR' : _PING, + '__module__' : 'message_pb2' + # @@protoc_insertion_point(class_scope:capport.Ping) + }) +_sym_db.RegisterMessage(Ping) + +MacStates = _reflection.GeneratedProtocolMessageType('MacStates', (_message.Message,), { + 'DESCRIPTOR' : _MACSTATES, + '__module__' : 'message_pb2' + # @@protoc_insertion_point(class_scope:capport.MacStates) + }) +_sym_db.RegisterMessage(MacStates) + +MacState = _reflection.GeneratedProtocolMessageType('MacState', (_message.Message,), { + 'DESCRIPTOR' : _MACSTATE, + '__module__' : 'message_pb2' + # @@protoc_insertion_point(class_scope:capport.MacState) + }) +_sym_db.RegisterMessage(MacState) + + +# @@protoc_insertion_point(module_scope) diff --git a/src/capport/config.py b/src/capport/config.py new file mode 100644 index 0000000..6ebb52d --- /dev/null +++ b/src/capport/config.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import dataclasses +import os.path +import typing + +import yaml + + +@dataclasses.dataclass +class Config: + controllers: typing.List[str] + secret: str + venue_info_url: typing.Optional[str] + session_timeout: int # in seconds + debug: bool + + @staticmethod + def load(filename: typing.Optional[str]=None) -> 'Config': + if filename is None: + for name in ('capport.yaml', '/etc/capport.yaml'): + if os.path.exists(name): + return Config.load(name) + raise RuntimeError("Missing config file") + with open(filename) as f: + data = yaml.safe_load(f) + controllers = list(map(str, data['controllers'])) + return Config( + controllers=controllers, + secret=str(data['secret']), + venue_info_url=str(data.get('venue-info-url')), + session_timeout=data.get('session-timeout', 3600), + debug=data.get('debug', False) + ) diff --git a/src/capport/control/__init__.py b/src/capport/control/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/capport/control/run.py b/src/capport/control/run.py new file mode 100644 index 0000000..106c5c5 --- /dev/null +++ b/src/capport/control/run.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import logging +import uuid + +import capport.database +import capport.comm.hub +import capport.comm.message +import capport.config +import capport.utils.cli +import capport.utils.nft_set +import trio +from capport import cptypes + +_logger = logging.getLogger(__name__) + + +class ControlApp(capport.comm.hub.HubApplication): + hub: capport.comm.hub.Hub + + def __init__(self) -> None: + super().__init__() + self.nft_set = capport.utils.nft_set.NftSet() + + def is_controller(self) -> bool: + return True + + async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None: + # deploy changes to netfilter set + inserts = [] + removals = [] + now = cptypes.Timestamp.now() + for mac, state in pending_updates.macs.items(): + rem = state.allowed_remaining(now) + if rem > 0: + inserts.append((mac, rem)) + else: + removals.append(mac) + self.nft_set.bulk_insert(inserts) + self.nft_set.bulk_remove(removals) + + +async def amain(config: capport.config.Config) -> None: + app = ControlApp() + hub = capport.comm.hub.Hub(config=config, app=app) + app.hub = hub + await hub.run() + + +def main() -> None: + config = capport.config.Config.load() + capport.utils.cli.init_logger(config) + try: + trio.run(amain, config) + except (KeyboardInterrupt, InterruptedError): + print() diff --git a/src/capport/cptypes.py b/src/capport/cptypes.py new file mode 100644 index 0000000..39b86d8 --- /dev/null +++ b/src/capport/cptypes.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import dataclasses +import datetime +import ipaddress +import json +import time +import typing + +import quart + +if typing.TYPE_CHECKING: + from .config import Config + + +IPAddress = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + + +@dataclasses.dataclass(frozen=True) +class MacAddress: + raw: bytes + + def __str__(self) -> str: + return self.raw.hex(':') + + def __repr__(self) -> str: + return repr(str(self)) + + @staticmethod + def parse(s: str) -> MacAddress: + return MacAddress(bytes.fromhex(s.replace(':', ''))) + + +@dataclasses.dataclass(frozen=True, order=True) +class Timestamp: + epoch: int + + def __str__(self) -> str: + try: + ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc) + return ts.isoformat(sep=' ') + except OSError: + return f'epoch@{self.epoch}' + + def __repr__(self) -> str: + return repr(str(self)) + + @staticmethod + def now() -> Timestamp: + now = int(time.time()) + return Timestamp(epoch=now) + + @staticmethod + def from_protobuf(epoch: int) -> typing.Optional[Timestamp]: + if epoch: + return Timestamp(epoch=epoch) + return None + + +@dataclasses.dataclass +class MacPublicState: + address: IPAddress + mac: typing.Optional[MacAddress] + allowed_remaining: int + + @staticmethod + def from_missing_mac(address: IPAddress) -> MacPublicState: + return MacPublicState( + address=address, + mac=None, + allowed_remaining=0, + ) + + @property + def allowed(self) -> bool: + return self.allowed_remaining > 0 + + @property + def captive(self) -> bool: + return not self.allowed + + def to_rfc8908(self, config: Config) -> quart.Response: + response: typing.Dict[str, typing.Any] = { + 'user-portal-url': quart.url_for('index', _external=True), + } + if config.venue_info_url: + response['venue-info-url'] = config.venue_info_url + if self.captive: + response['captive'] = True + else: + response['captive'] = False + response['seconds-remaining'] = self.allowed_remaining + response['can-extend-session'] = True + return quart.Response(json.dumps(response), headers={'Cache-Control': 'private'}, content_type='application/captive+json') diff --git a/src/capport/database.py b/src/capport/database.py new file mode 100644 index 0000000..c634c49 --- /dev/null +++ b/src/capport/database.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import dataclasses +import typing + +import capport.comm.message +from capport import cptypes + + +@dataclasses.dataclass +class MacEntry: + # entry can be removed if last_change was some time ago and allow_until wasn't set + # or got reached. + WAIT_LAST_CHANGE_SECONDS = 60 + WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10 + + # last_change: timestamp of last change (sent by system initiating the change) + last_change: cptypes.Timestamp + # only if allowed is true and allow_until is set the device can communicate with the internet + # allow_until must not go backwards (and not get unset) + allow_until: typing.Optional[cptypes.Timestamp] + allowed: bool + + @staticmethod + def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]: + if len(msg.mac_address) < 6: + raise Exception("Invalid MacState: mac_address too short") + addr = cptypes.MacAddress(raw=msg.mac_address) + last_change = cptypes.Timestamp.from_protobuf(msg.last_change) + if not last_change: + raise Exception(f"Invalid MacState[{addr}]: missing last_change") + allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until) + return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed)) + + def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState: + allow_until = 0 + if self.allow_until: + allow_until = self.allow_until.epoch + return capport.comm.message.MacState( + mac_address=addr.raw, + last_change=self.last_change.epoch, + allow_until=allow_until, + allowed=self.allowed, + ) + + def as_json(self) -> dict: + allow_until = None + if self.allow_until: + allow_until = self.allow_until.epoch + return dict( + last_change=self.last_change.epoch, + allow_until=allow_until, + allowed=self.allowed, + ) + + def merge(self, new: MacEntry) -> bool: + changed = False + if new.last_change > self.last_change: + changed = True + self.last_change = new.last_change + self.allowed = new.allowed + elif new.last_change == self.last_change: + # same last_change: set allowed if one allowed + if new.allowed and not self.allowed: + changed = True + self.allowed = True + # set allow_until to max of both + if new.allow_until: # if not set nothing to change in local data + if not self.allow_until or self.allow_until < new.allow_until: + changed = True + self.allow_until = new.allow_until + return changed + + def timeout(self) -> cptypes.Timestamp: + elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS + if self.allow_until: + eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS + if eau > elc: + return cptypes.Timestamp(epoch=eau) + return cptypes.Timestamp(epoch=elc) + + # returns 0 if not allowed + def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp]=None) -> int: + if not self.allowed or not self.allow_until: + return 0 + if not now: + now = cptypes.Timestamp.now() + assert self.allow_until + return max(self.allow_until.epoch - now.epoch, 0) + + def outdated(self, now: typing.Optional[cptypes.Timestamp]=None) -> bool: + if not now: + now = cptypes.Timestamp.now() + return now.epoch > self.timeout().epoch + + +# might use this to serialize into file - don't need Message variant there +def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]: + result: typing.List[capport.comm.message.MacStates] = [] + current = capport.comm.message.MacStates() + for addr, entry in macs.items(): + state = entry.to_state(addr) + current.states.append(state) + if len(current.states) >= 1024: # split into messages with 1024 states + result.append(current) + current = capport.comm.message.MacStates() + if len(current.states): + result.append(current) + return result + + +def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]: + return [s.to_message() for s in _serialize_mac_states(macs)] + + +class NotReadyYet(Exception): + def __init__(self, msg: str, wait: int): + self.wait = wait # seconds to wait + super().__init__(msg) + + +@dataclasses.dataclass +class Database: + _macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict) + + def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates): + (addr, new_entry) = MacEntry.parse_state(state) + old_entry = self._macs.get(addr) + if not old_entry: + # only redistribute if not outdated + if not new_entry.outdated(): + self._macs[addr] = new_entry + pending_updates.macs[addr] = new_entry + elif old_entry.merge(new_entry): + if old_entry.outdated(): + # remove local entry, but still redistribute + self._macs.pop(addr) + pending_updates.macs[addr] = old_entry + + def serialize(self) -> typing.List[capport.comm.message.Message]: + return _serialize_mac_states_as_messages(self._macs) + + def as_json(self) -> dict: + return { + str(addr): entry.as_json() + for addr, entry in self._macs.items() + } + + def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState: + entry = self._macs.get(mac) + if entry: + allowed_remaining = entry.allowed_remaining() + else: + allowed_remaining = 0 + return cptypes.MacPublicState( + address=address, + mac=mac, + allowed_remaining=allowed_remaining, + ) + + def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8): + now = cptypes.Timestamp.now() + allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout) + new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True) + + entry = self._macs.get(mac) + if not entry: + self._macs[mac] = new_entry + pending_updates.macs[mac] = new_entry + elif entry.allowed_remaining(now) > renew_maximum * session_timeout: + # too much time left on clock, not renewing session + return + elif entry.merge(new_entry): + pending_updates.macs[mac] = entry + elif not entry.allowed_remaining() > 0: + # entry should have been updated - can only fail due to `now < entry.last_change` + # i.e. out of sync clocks + wait = entry.last_change.epoch - now.epoch + raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait) + + def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates): + now = cptypes.Timestamp.now() + new_entry = MacEntry(last_change=now, allow_until=None, allowed=False) + entry = self._macs.get(mac) + if entry: + if entry.merge(new_entry): + pending_updates.macs[mac] = entry + elif entry.allowed_remaining() > 0: + # still logged in. can only happen with `now <= entry.last_change` + # clocks not necessarily out of sync, but you can't logout in the same second you logged in + wait = entry.last_change.epoch - now.epoch + 1 + raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait) + + +@dataclasses.dataclass +class PendingUpdates: + macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict) + + def serialize(self) -> typing.List[capport.comm.message.Message]: + return _serialize_mac_states_as_messages(self.macs) diff --git a/src/capport/utils/__init__.py b/src/capport/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/capport/utils/cli.py b/src/capport/utils/cli.py new file mode 100644 index 0000000..4ed1fce --- /dev/null +++ b/src/capport/utils/cli.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import logging + +import capport.config + + +def init_logger(config: capport.config.Config): + loglevel = logging.INFO + if config.debug: + loglevel = logging.DEBUG + logging.basicConfig( + format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S %z]', + level=loglevel, + ) + logging.getLogger('hypercorn').propagate = False diff --git a/src/capport/utils/ipneigh.py b/src/capport/utils/ipneigh.py new file mode 100644 index 0000000..8cc6dcf --- /dev/null +++ b/src/capport/utils/ipneigh.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import contextlib +import errno +import typing + +import pr2modules.iproute.linux +import pr2modules.netlink.exceptions +from capport import cptypes + + +@contextlib.asynccontextmanager +async def connect(): + yield NeighborController() + + +# TODO: run blocking iproute calls in a different thread? +class NeighborController: + def __init__(self): + self.ip = pr2modules.iproute.linux.IPRoute() + + async def get_neighbor( + self, + address: cptypes.IPAddress, + *, + index: int=0, # interface index + flags: int=0, + ) -> typing.Optional[pr2modules.iproute.linux.ndmsg.ndmsg]: + if not index: + route = await self.get_route(address) + if route is None: + return None + index = route.get_attr(route.name2nla('oif')) + try: + return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0] + except pr2modules.netlink.exceptions.NetlinkError as e: + if e.code == errno.ENOENT: + return None + raise + + async def get_neighbor_mac( + self, + address: cptypes.IPAddress, + *, + index: int=0, # interface index + flags: int=0, + ) -> typing.Optional[cptypes.MacAddress]: + neigh = await self.get_neighbor(address, index=index, flags=flags) + if neigh is None: + return None + mac = neigh.get_attr(neigh.name2nla('lladdr')) + return cptypes.MacAddress.parse(mac) + + async def get_route( + self, + address: cptypes.IPAddress, + ) -> typing.Optional[pr2modules.iproute.linux.rtmsg]: + try: + return self.ip.route('get', dst=str(address))[0] + except pr2modules.netlink.exceptions.NetlinkError as e: + if e.code == errno.ENOENT: + return None + raise diff --git a/src/capport/utils/nft_set.py b/src/capport/utils/nft_set.py new file mode 100644 index 0000000..7b3fcfd --- /dev/null +++ b/src/capport/utils/nft_set.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import typing + +import pr2modules.netlink +from capport import cptypes +from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket + +from .nft_socket import NFTSocket + +NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX" + + +def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]: + # to seconds + if msecs is None: + return None + return msecs / 1000.0 + + +class NftSet: + def __init__(self): + self._socket = NFTSocket() + self._socket.bind() + + @staticmethod + def _set_elem(mac: cptypes.MacAddress, timeout: typing.Optional[typing.Union[int, float]]=None) -> _nftsocket.nft_set_elem_list_msg.set_elem: + attrs: typing.Dict[str, typing.Any] = { + 'NFTA_SET_ELEM_KEY': dict( + NFTA_DATA_VALUE=mac.raw, + ), + } + if timeout: + attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout) + return attrs + + def _bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None: + ser_entries = [ + self._set_elem(mac) + for mac, _timeout in entries + ] + ser_entries_with_timeout = [ + self._set_elem(mac, timeout) + for mac, timeout in entries + ] + with self._socket.begin() as tx: + # create doesn't affect existing elements, so: + # make sure entries exists + tx.put( + _nftsocket.NFT_MSG_NEWSETELEM, + pr2modules.netlink.NLM_F_CREATE, + nfgen_family=NFPROTO_INET, + attrs=dict( + NFTA_SET_TABLE='captive_mark', + NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, + ), + ) + # drop entries (would fail if it doesn't exist) + tx.put( + _nftsocket.NFT_MSG_DELSETELEM, + nfgen_family=NFPROTO_INET, + attrs=dict( + NFTA_SET_TABLE='captive_mark', + NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, + ), + ) + # now create entries with new timeout value + tx.put( + _nftsocket.NFT_MSG_NEWSETELEM, + pr2modules.netlink.NLM_F_CREATE|pr2modules.netlink.NLM_F_EXCL, + nfgen_family=NFPROTO_INET, + attrs=dict( + NFTA_SET_TABLE='captive_mark', + NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout, + ), + ) + + def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None: + # limit chunk size + while len(entries) > 0: + self._bulk_insert(entries[:1024]) + entries = entries[1024:] + + def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None: + self.bulk_insert([(mac, timeout)]) + + def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None: + ser_entries = [ + self._set_elem(mac) + for mac in entries + ] + with self._socket.begin() as tx: + # make sure entries exists + tx.put( + _nftsocket.NFT_MSG_NEWSETELEM, + pr2modules.netlink.NLM_F_CREATE, + nfgen_family=NFPROTO_INET, + attrs=dict( + NFTA_SET_TABLE='captive_mark', + NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, + ), + ) + # drop entries (would fail if it doesn't exist) + tx.put( + _nftsocket.NFT_MSG_DELSETELEM, + nfgen_family=NFPROTO_INET, + attrs=dict( + NFTA_SET_TABLE='captive_mark', + NFTA_SET_ELEM_LIST_SET='allowed', + NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, + ), + ) + + def bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None: + # limit chunk size + while len(entries) > 0: + self._bulk_remove(entries[:1024]) + entries = entries[1024:] + + def remove(self, mac: cptypes.MacAddress) -> None: + self.bulk_remove([mac]) + + def list(self) -> list: + responses: typing.Iterator[_nftsocket.nft_set_elem_list_msg] + responses = self._socket.nft_dump( + _nftsocket.NFT_MSG_GETSETELEM, + nfgen_family=NFPROTO_INET, + attrs=dict( + NFTA_SET_TABLE='captive_mark', + NFTA_SET_ELEM_LIST_SET='allowed', + ) + ) + return [ + { + 'mac': cptypes.MacAddress( + elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'), + ), + 'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)), + 'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)), + } + for response in responses + for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', []) + ] + + def flush(self) -> None: + self._socket.nft_put( + _nftsocket.NFT_MSG_DELSETELEM, + nfgen_family=NFPROTO_INET, + attrs=dict( + NFTA_SET_TABLE='captive_mark', + NFTA_SET_ELEM_LIST_SET='allowed', + ) + ) + + def create(self): + with self._socket.begin() as tx: + tx.put( + _nftsocket.NFT_MSG_NEWTABLE, + pr2modules.netlink.NLM_F_CREATE, + nfgen_family=NFPROTO_INET, + attrs=dict( + NFTA_TABLE_NAME='captive_mark', + ), + ) + tx.put( + _nftsocket.NFT_MSG_NEWSET, + pr2modules.netlink.NLM_F_CREATE, + nfgen_family=NFPROTO_INET, + attrs=dict( + NFTA_SET_TABLE='captive_mark', + NFTA_SET_NAME='allowed', + NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT + NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft + NFTA_SET_KEY_LEN=6, # length of key: mac address + NFTA_SET_ID=1, # kernel seems to need a set id unique per transaction + ), + ) diff --git a/src/capport/utils/nft_socket.py b/src/capport/utils/nft_socket.py new file mode 100644 index 0000000..968bd0c --- /dev/null +++ b/src/capport/utils/nft_socket.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import contextlib +import typing +import threading + +from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket +import pr2modules.netlink +import pr2modules.netlink.nlsocket +from pr2modules.netlink.nfnetlink import nfgen_msg +from pr2modules.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES + + +NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC" + + +_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pr2modules.netlink.nlmsg_base) + + +# nft uses NESTED for those.. lets do the same +_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pr2modules.netlink.NLA_F_NESTED +_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pr2modules.netlink.NLA_F_NESTED + + +def _monkey_patch_pyroute2(): + import pr2modules.netlink + # overwrite setdefault on nlmsg_base class hierarchy + _orig_setvalue = pr2modules.netlink.nlmsg_base.setvalue + + def _nlmsg_base__setvalue(self, value): + if not self.header or not self['header'] or not isinstance(value, dict): + return _orig_setvalue(self, value) + header = value.pop('header', {}) + res = _orig_setvalue(self, value) + self['header'].update(header) + return res + + def overwrite_methods(cls: typing.Type) -> None: + if cls.setvalue is _orig_setvalue: + cls.setvalue = _nlmsg_base__setvalue + for subcls in cls.__subclasses__(): + overwrite_methods(subcls) + + overwrite_methods(pr2modules.netlink.nlmsg_base) +_monkey_patch_pyroute2() + + +def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: typing.Dict={}, **fields) -> _NlMsgBase: + msg = msg_class() + for key, value in header.items(): + msg['header'][key] = value + for key, value in fields.items(): + msg[key] = value + if attrs: + attr_list = msg['attrs'] + r_nla_map = msg_class._nlmsg_base__r_nla_map + for key, value in attrs.items(): + if msg_class.prefix: + key = msg_class.name2nla(key) + prime = r_nla_map[key] + nla_class = prime['class'] + if issubclass(nla_class, pr2modules.netlink.nla): + # support passing nested attributes as dicts of subattributes (or lists of those) + if prime['nla_array']: + value = [ + _build(nla_class, attrs=elem) if not isinstance(elem, pr2modules.netlink.nlmsg_base) and isinstance(elem, dict) else elem + for elem in value + ] + elif not isinstance(value, pr2modules.netlink.nlmsg_base) and isinstance(value, dict): + value = _build(nla_class, attrs=value) + attr_list.append([key, value]) + return msg + + +class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket): + policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy + + def __init__(self) -> None: + super().__init__(family=pr2modules.netlink.NETLINK_NETFILTER) + policy = { + (x | (NFNL_SUBSYS_NFTABLES << 8)): y + for (x, y) in self.policy.items() + } + self.register_policy(policy) + + @contextlib.contextmanager + def begin(self) -> typing.Generator[NFTTransaction, None, None]: + try: + tx = NFTTransaction(socket=self) + yield tx + # autocommit when no exception was raised + # (only commits if it wasn't aborted) + tx.autocommit() + finally: + # abort does nothing if commit went through + tx.abort() + + def nft_put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: + with self.begin() as tx: + tx.put(msg_type, msg_flags, attrs=attrs, **fields) + + def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: + msg_flags |= pr2modules.netlink.NLM_F_DUMP + return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields) + + def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: + msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type] + msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type + msg_flags |= pr2modules.netlink.NLM_F_REQUEST + msg = _build(msg_class, attrs=attrs, **fields) + return self.nlm_request(msg, msg_type, msg_flags) + + +class NFTTransaction: + def __init__(self, socket: NFTSocket) -> None: + self._socket = socket + self._data = b'' + self._seqnum = self._socket.addr_pool.alloc() + self._closed = False + # neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK + # at the end of an transaction to make sure it worked. + # we could use a different sequence number for all changes, and wait for an ACK for each of them + # (but we'd also need to check for errors on the BEGIN sequence number). + # the other solution: use the same sequence number for all messages in the batch, and add ACK + # only to the final message (before END) - if we get the ACK we known all other messages before + # worked out. + self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None + begin_msg = _build( + nfgen_msg, + res_id=NFNL_SUBSYS_NFTABLES, + header=dict( + type=0x10, # NFNL_MSG_BATCH_BEGIN + flags=pr2modules.netlink.NLM_F_REQUEST, + sequence_number=self._seqnum, + ), + ) + begin_msg.encode() + self._data += begin_msg.data + + def abort(self) -> None: + """ + Aborts if transaction wasn't already committed or aborted + """ + if not self._closed: + self._closed = True + # unused seqnum + self._socket.addr_pool.free(self._seqnum) + + def autocommit(self) -> None: + """ + Commits if transaction wasn't already committed or aborted + """ + if self._closed: + return + self.commit() + + def commit(self) -> None: + if self._closed: + raise Exception("Transaction already closed") + if not self._final_msg: + # no inner messages were queued... just abort transaction + self.abort() + return + self._closed = True + # request ACK only on the last message (before END) + self._final_msg['header']['flags'] |= pr2modules.netlink.NLM_F_ACK + self._final_msg.encode() + self._data += self._final_msg.data + self._final_msg = None + # batch end + end_msg = _build( + nfgen_msg, + res_id=NFNL_SUBSYS_NFTABLES, + header=dict( + type=0x11, # NFNL_MSG_BATCH_END + flags=pr2modules.netlink.NLM_F_REQUEST, + sequence_number=self._seqnum, + ), + ) + end_msg.encode() + self._data += end_msg.data + # need to create backlog for our sequence number + with self._socket.lock[self._seqnum]: + self._socket.backlog[self._seqnum] = [] + # send + self._socket.sendto(self._data, (0, 0)) + try: + for _msg in self._socket.get(msg_seq=self._seqnum): + # we should see at most one ACK - real errors get raised anyway + pass + finally: + with self._socket.lock[0]: + # clear messages from "seq 0" queue - because if there + # was an error in our backlog, it got raised and the + # remaining messages moved to 0 + self._socket.backlog[0] = [] + + def _put(self, msg: nfgen_msg) -> None: + if self._closed: + raise Exception("Transaction already closed") + if self._final_msg: + # previous message wasn't the final one, encode it without ACK + self._final_msg.encode() + self._data += self._final_msg.data + self._final_msg = msg + + def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: + msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type] + msg_flags |= pr2modules.netlink.NLM_F_REQUEST # always set REQUEST + msg_flags &= ~pr2modules.netlink.NLM_F_ACK # make sure ACK is not set! + header = dict( + type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type, + flags=msg_flags, + sequence_number=self._seqnum, + ) + msg = _build(msg_class, attrs=attrs, header=header, **fields) + self._put(msg) diff --git a/start-api.sh b/start-api.sh new file mode 100755 index 0000000..50c2163 --- /dev/null +++ b/start-api.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -e + +base=$(dirname "$(readlink -f "$0")") +cd "${base}" + +exec ./venv/bin/hypercorn -k trio capport.api "$@" diff --git a/start-control.sh b/start-control.sh new file mode 100755 index 0000000..a5e7d09 --- /dev/null +++ b/start-control.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -e + +base=$(dirname "$(readlink -f "$0")") +cd "${base}" + +exec ./venv/bin/capport-control "$@"