2
0

initial commit

This commit is contained in:
Stefan Bühler 2022-04-04 19:21:51 +02:00
commit d1050d2ee4
34 changed files with 2223 additions and 0 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
.vscode
*.pyc
*.egg-info
__pycache__
venv
capport.yaml

11
.pycodestyle Normal file
View File

@ -0,0 +1,11 @@
[pycodestyle]
# E241 multiple spaces after ':' [ want to align stuff ]
# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ]
# E501 line too long [ temporary? can't disable it in certain places.. ]
# E701 multiple statements on one line (colon) [ perfectly readable ]
# E713 test for membership should be not in [ disagree: want `not a in x` ]
# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ]
# W503 Line break occurred before a binary operator [ pep8 flipped on this (also contradicts W504) ]
ignore = E241,E266,E501,E701,E713,E714,W503
max-line-length = 120
exclude = 00*.py,generated.py

20
.pylintrc Normal file
View File

@ -0,0 +1,20 @@
[MESSAGES CONTROL]
disable=logging-fstring-interpolation
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=120
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
[DESIGN]
# Maximum number of locals for function / method body.
max-locals=20
# Minimum number of public methods for a class (see R0903).
min-public-methods=0

19
LICENSE Normal file
View File

@ -0,0 +1,19 @@
Copyright (c) 2022 Universität Stuttgart
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

10
README.md Normal file
View File

@ -0,0 +1,10 @@
# python Captive Portal
### Installation
Either clone repository (and install dependencies either through distribution or as virtualenv with `./setup-venv.sh`) or install as package.
[`pipx`](https://pypa.github.io/pipx/) (available in debian as package) can be used to install in separate virtual environment:
pipx install https://github.tik.uni-stuttgart.de/NKS/python-capport

7
capport-example.yaml Normal file
View File

@ -0,0 +1,7 @@
---
secret: mysecret
controllers:
- capport-controller1.example.com
- capport-controller2.example.com
session-timeout: 3600 # in seconds
venue-info-url: 'https://example.com'

24
mypy Executable file
View File

@ -0,0 +1,24 @@
#!/bin/sh
### check type annotations with mypy
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!"
exit 1
fi
if [ ! -x ./venv/bin/mypy ]; then
./venv/bin/pip install mypy
fi
site_pkgs=$(./venv/bin/python -c 'import site; print(site.getsitepackages()[0])')
if [ ! -d "${site_pkgs}/trio_typing" ]; then
./venv/bin/pip install trio-typing[mypy]
fi
./venv/bin/mypy --install-types src

10
protobuf/compile.sh Executable file
View File

@ -0,0 +1,10 @@
#!/bin/bash
set -e
cd "$(dirname "$(readlink -f "$0")")"
rm -rf ../src/capport/comm/protobuf/message_pb2.py
mkdir -p ../src/capport/comm/protobuf
protoc --python_out=../src/capport/comm/protobuf message.proto

42
protobuf/message.proto Normal file
View File

@ -0,0 +1,42 @@
syntax = "proto3";
package capport;
message Message {
oneof oneof {
Hello hello = 1;
AuthenticationResult authentication_result = 2;
Ping ping = 3;
MacStates mac_states = 10;
}
}
// sent by clients and servers as first message
message Hello {
bytes instance_id = 1;
string hostname = 2;
bool is_controller = 3;
bytes authentication = 4;
}
// tell peer whether hello authentication was good
message AuthenticationResult {
bool success = 1;
}
message Ping {
bytes payload = 1;
}
message MacStates {
repeated MacState states = 1;
}
message MacState {
bytes mac_address = 1;
// Seconds of UTC time since epoch
int64 last_change = 2;
// Seconds of UTC time since epoch
int64 allow_until = 3;
bool allowed = 4;
}

20
pylint Executable file
View File

@ -0,0 +1,20 @@
#!/bin/sh
### check type annotations with mypy
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!"
exit 1
fi
if [ ! -x ./venv/bin/pylint ]; then
# need current pylint to deal with more recent python features
./venv/bin/pip install pylint
fi
./venv/bin/pylint src

6
pyproject.toml Normal file
View File

@ -0,0 +1,6 @@
[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"

11
setup-venv.sh Executable file
View File

@ -0,0 +1,11 @@
#!/bin/bash
set -e
self=$(dirname "$(readlink -f "$0")")
cd "${self}"
python3 -m venv venv
# install cli extras
./venv/bin/pip install -e '.'

36
setup.cfg Normal file
View File

@ -0,0 +1,36 @@
[metadata]
name = capport-tik-nks
version = 0.0.1
author = Stefan Bühler
author_email = stefan.buehler@tik.uni-stuttgart.de
description = Captive Portal
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.tik.uni-stuttgart.de/NKS/python-capport
project_urls =
Bug Tracker = https://github.tik.uni-stuttgart.de/NKS/python-capport/issues
classifiers =
Programming Language :: Python :: 3
License :: OSI Approved :: MIT License
Operating System :: OS Independent
[options]
package_dir =
= src
packages = find:
python_requires = >=3.9
install_requires =
trio
quart-trio
quart
hypercorn[trio]
PyYAML
protobuf
pyroute2
[options.packages.find]
where = src
[options.entry_points]
console_scripts =
capport-control = capport.control.run:main

6
setup.py Normal file
View File

@ -0,0 +1,6 @@
# https://github.com/pypa/setuptools/issues/2816
# allow editable install on older pip versions
from setuptools import setup
if __name__ == "__main__":
setup()

0
src/capport/__init__.py Normal file
View File

169
src/capport/api/__init__.py Normal file
View File

@ -0,0 +1,169 @@
from __future__ import annotations
import ipaddress
import logging
import typing
import uuid
import capport.database
import capport.comm.hub
import capport.comm.message
import capport.utils.cli
import capport.utils.ipneigh
import quart
import quart_trio
import trio
from capport import cptypes
from capport.config import Config
app = quart_trio.QuartTrio(__name__)
_logger = logging.getLogger(__name__)
config: typing.Optional[Config] = None
hub: typing.Optional[capport.comm.hub.Hub] = None
hub_app: typing.Optional[ApiHubApp] = None
nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None
def get_client_ip() -> cptypes.IPAddress:
try:
addr = ipaddress.ip_address(quart.request.remote_addr)
except ValueError as e:
_logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}')
quart.abort(500, 'Invalid client address')
if addr.is_loopback:
forw_addr_headers = quart.request.headers.getlist('X-Forwarded-For')
if len(forw_addr_headers) == 1:
try:
return ipaddress.ip_address(forw_addr_headers[0])
except ValueError as e:
_logger.warning(f'Invalid forwarded client address {forw_addr_headers!r} (from {addr}): {e}')
quart.abort(500, 'Invalid client address')
elif forw_addr_headers:
_logger.warning(f'Multiple forwarded client addresses {forw_addr_headers!r} (from {addr})')
quart.abort(500, 'Invalid client address')
return addr
async def get_client_mac_if_present(address: typing.Optional[cptypes.IPAddress]=None) -> typing.Optional[cptypes.MacAddress]:
assert nc # for mypy
if not address:
address = get_client_ip()
return await nc.get_neighbor_mac(address)
async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cptypes.MacAddress:
mac = await get_client_mac_if_present(address)
if mac is None:
_logger.warning(f"Couldn't find MAC addresss for {address}")
quart.abort(404, 'Unknown client')
return mac
class ApiHubApp(capport.comm.hub.HubApplication):
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
# TODO: support websocket notification updates to clients?
pass
async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None:
assert config # for mypy
assert hub # for mypy
pu = capport.database.PendingUpdates()
try:
hub.database.login(mac, config.session_timeout, pending_updates=pu)
except capport.database.NotReadyYet as e:
quart.abort(500, str(e))
if pu.macs:
_logger.info(f'User {mac} (with IP {address}) logged in')
for msg in pu.serialize():
await hub.broadcast(msg)
async def user_logout(mac: cptypes.MacAddress) -> None:
assert hub # for mypy
pu = capport.database.PendingUpdates()
try:
hub.database.logout(mac, pending_updates=pu)
except capport.database.NotReadyYet as e:
quart.abort(500, str(e))
if pu.macs:
_logger.info(f'User {mac} logged out')
for msg in pu.serialize():
await hub.broadcast(msg)
async def user_lookup() -> cptypes.MacPublicState:
assert hub # for mypy
address = get_client_ip()
mac = await get_client_mac_if_present(address)
if not mac:
return cptypes.MacPublicState.from_missing_mac(address)
else:
return hub.database.lookup(address, mac)
async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
global hub
global hub_app
global nc
assert config # for mypy
try:
async with capport.utils.ipneigh.connect() as mync:
nc = mync
_logger.info("Running hub for API")
myapp = ApiHubApp()
myhub = capport.comm.hub.Hub(config=config, app=myapp)
hub = myhub
hub_app = myapp
await myhub.run(task_status=task_status)
finally:
hub = None
hub_app = None
nc = None
_logger.info("Done running hub for API")
await app.shutdown()
@app.before_serving
async def init():
global config
config = Config.load()
capport.utils.cli.init_logger(config)
await app.nursery.start(_run_hub)
# @app.route('/all')
# async def route_all():
# return hub_app.database.as_json()
@app.route('/', methods=['GET'])
async def index():
state = await user_lookup()
return await quart.render_template('index.html', state=state)
@app.route('/login', methods=['POST'])
async def login():
address = get_client_ip()
mac = await get_client_mac(address)
await user_login(address, mac)
return quart.redirect('/', code=303)
@app.route('/logout', methods=['POST'])
async def logout():
mac = await get_client_mac()
await user_logout(mac)
return quart.redirect('/', code=303)
@app.route('/api/captive-portal', methods=['GET'])
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
async def captive_api():
state = await user_lookup()
return state.to_rfc8908(config)

View File

@ -0,0 +1,22 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>Captive Portal Universität Stuttgart</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
</head>
<body>
{% if not state.mac %}
It seems you're accessing this site from outside the network this captive portal is running for.
{% elif state.captive %}
To get access to the internet please accept our usage guidelines by clicking this button:
<form method="POST" action="/login"><button type="submit">Accept</button></form>
{% else %}
You already accepted out conditions and are currently granted access to the internet:
<form method="POST" action="/login"><button type="submit">Renew session</button></form>
<form method="POST" action="/logout"><button type="submit">Close session</button></form>
<br>
Your current session will last for {{ state.allowed_remaining }} seconds.
{% endif %}
</body>
</html>

View File

441
src/capport/comm/hub.py Normal file
View File

@ -0,0 +1,441 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import random
import socket
import ssl
import struct
import typing
import uuid
import capport.database
import capport.comm.message
import trio
if typing.TYPE_CHECKING:
from ..config import Config
_logger = logging.getLogger(__name__)
class HubConnectionReadError(ConnectionError):
pass
class HubConnectionClosedError(ConnectionError):
pass
class LoopbackConnectionError(Exception):
pass
class Channel:
def __init__(self, hub: Hub, transport_stream, server_side: bool):
self._hub = hub
self._serverside = server_side
self._ssl = trio.SSLStream(transport_stream, self._hub._anon_context, server_side=server_side)
_logger.debug(f"{self}: created (server_side={server_side})")
def __repr__(self) -> str:
return f'Channel[0x{id(self):x}]'
async def do_handshake(self) -> capport.comm.message.Hello:
try:
await self._ssl.do_handshake()
ssl_binding = self._ssl.get_channel_binding()
if not ssl_binding:
# binding mustn't be None after successful handshake
raise ConnectionError("Missing SSL channel binding")
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
msg = self._hub._make_hello(ssl_binding, server_side=self._serverside).to_message()
await self.send_msg(msg)
peer_hello = (await self.recv_msg()).to_variant()
if not isinstance(peer_hello, capport.comm.message.Hello):
raise HubConnectionReadError("Expected Hello as first message")
auth_succ = (peer_hello.authentication == self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
peer_auth = (await self.recv_msg()).to_variant()
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
raise HubConnectionReadError("Expected AuthenticationResult as second message")
if not auth_succ or not peer_auth.success:
raise HubConnectionReadError("Authentication failed")
return peer_hello
async def _read(self, num: int) -> bytes:
assert num > 0
buf = b''
# _logger.debug(f"{self}:_read({num})")
while num > 0:
try:
part = await self._ssl.receive_some(num)
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
# _logger.debug(f"{self}:_read({num}) got part {part!r}")
if len(part) == 0:
if len(buf) == 0:
raise HubConnectionClosedError()
raise HubConnectionReadError("Unexpected end of TLS stream")
buf += part
num -= len(part)
if num < 0:
raise HubConnectionReadError("TLS receive_some returned too much")
return buf
async def _recv_raw_msg(self) -> bytes:
len_bytes = await self._read(4)
chunk_size, = struct.unpack('!I', len_bytes)
chunk = await self._read(chunk_size)
if chunk is None:
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
return chunk
async def recv_msg(self) -> capport.comm.message.Message:
try:
chunk = await self._recv_raw_msg()
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
msg = capport.comm.message.Message()
msg.ParseFromString(chunk)
return msg
async def _send_raw(self, chunk: bytes) -> None:
try:
await self._ssl.send_all(chunk)
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
async def send_msg(self, msg: capport.comm.message.Message):
chunk = msg.SerializeToString(deterministic=True)
chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size)
chunk = len_bytes + chunk
await self._send_raw(chunk)
async def aclose(self):
try:
await self._ssl.aclose()
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
class Connection:
PING_INTERVAL = 10
RECEIVE_TIMEOUT = 15
SEND_TIMEOUT = 5
def __init__(self, hub: Hub, channel: Channel, peer: capport.comm.message.Hello):
self._channel = channel
self._hub = hub
tx: trio.MemorySendChannel
rx: trio.MemoryReceiveChannel
(tx, rx) = trio.open_memory_channel(64)
self._pending_tx = tx
self._pending_rx = rx
self.peer: capport.comm.message.Hello = peer
self.peer_id: uuid.UUID = uuid.UUID(bytes=peer.instance_id)
self.closed = trio.Event() # set by Hub._lost_peer
_logger.debug(f"{self._channel}: authenticated -> {self.peer_id}")
async def _sender(self, cancel_scope: trio.CancelScope) -> None:
try:
msg: typing.Optional[capport.comm.message.Message]
while True:
msg = None
# make sure we send something every PING_INTERVAL
with trio.move_on_after(self.PING_INTERVAL):
msg = await self._pending_rx.receive()
# if send blocks too long we're in trouble
with trio.fail_after(self.SEND_TIMEOUT):
if msg:
await self._channel.send_msg(msg)
else:
await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message())
except trio.TooSlowError:
_logger.warning(f"{self._channel}: send timed out")
except ConnectionError as e:
_logger.warning(f"{self._channel}: failed sending: {e!r}")
except Exception as e:
_logger.exception(f"{self._channel}: failed sending")
finally:
cancel_scope.cancel()
async def _receive(self, cancel_scope: trio.CancelScope) -> None:
try:
while True:
try:
with trio.fail_after(self.RECEIVE_TIMEOUT):
msg = await self._channel.recv_msg()
except (HubConnectionClosedError, ConnectionResetError):
return
except trio.TooSlowError:
_logger.warning(f"{self._channel}: receive timed out")
return
await self._hub._received_msg(self.peer_id, msg)
except ConnectionError as e:
_logger.warning(f"{self._channel}: failed receiving: {e!r}")
except Exception:
_logger.exception(f"{self._channel}: failed receiving")
finally:
cancel_scope.cancel()
async def _inner_run(self) -> None:
if self.peer_id == self._hub._instance_id:
# connected to ourself, don't need that
raise LoopbackConnectionError()
async with trio.open_nursery() as nursery:
nursery.start_soon(self._sender, nursery.cancel_scope)
# be nice and wait for new_peer beforce receiving messages
# (won't work on failover to a second connection)
await nursery.start(self._hub._new_peer, self.peer_id, self)
nursery.start_soon(self._receive, nursery.cancel_scope)
async def send_msg(self, *msgs: capport.comm.message.Message):
try:
for msg in msgs:
await self._pending_tx.send(msg)
except trio.ClosedResourceError:
pass
async def _run(self) -> None:
try:
await self._inner_run()
finally:
_logger.debug(f"{self._channel}: finished message handling")
# basic (non-async) cleanup
self._hub._lost_peer(self.peer_id, self)
self._pending_tx.close()
self._pending_rx.close()
# allow 3 seconds for proper cleanup
with trio.CancelScope(shield=True, deadline=trio.current_time() + 3):
try:
await self._channel.aclose()
except OSError:
pass
@staticmethod
async def run(hub: Hub, transport_stream, server_side: bool) -> None:
channel = Channel(hub, transport_stream, server_side)
try:
with trio.fail_after(5):
peer = await channel.do_handshake()
except trio.TooSlowError:
_logger.warning(f"Handshake timed out")
return
conn = Connection(hub, channel, peer)
await conn._run()
class ControllerConn:
def __init__(self, hub: Hub, hostname: str):
self._hub = hub
self.hostname = hostname
self.loopback = False
async def _connect(self):
_logger.info(f"Connecting to controller at {self.hostname}")
with trio.fail_after(5):
try:
stream = await trio.open_tcp_stream(self.hostname, 5000)
except OSError as e:
_logger.warning(f"Failed to connect to controller at {self.hostname}: {e}")
return
try:
await Connection.run(self._hub, stream, server_side=False)
finally:
_logger.info(f"Connection to {self.hostname} closed")
async def run(self):
while True:
try:
await self._connect()
except LoopbackConnectionError:
_logger.debug(f"Connection to {self.hostname} reached ourself")
self.loopback = True
return
except trio.TooSlowError:
pass
# try again later
retry_splay = random.random() * 5
await trio.sleep(10 + retry_splay)
class HubApplication:
def is_controller(self) -> bool:
return False
async def new_peer(self, *, peer_id: uuid.UUID) -> None:
_logger.info(f"New peer {peer_id}")
def lost_peer(self, *, peer_id: uuid.UUID) -> None:
_logger.warning(f"Lost peer {peer_id}")
async def received_unknown_message(self, *, from_peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
_logger.warning(f"Received from {from_peer_id}: {str(msg).strip()}")
async def received_mac_state(self, *, from_peer_id: uuid.UUID, states: capport.comm.message.MacStates) -> None:
if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}")
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received new states from {from_peer_id}: {pending_updates}")
class Hub:
def __init__(self, config: Config, app: HubApplication) -> None:
self._config = config
self._instance_id = uuid.uuid4()
self._hostname = socket.getfqdn()
self.database = capport.database.Database()
self._app = app
self._is_controller = bool(app.is_controller())
self._anon_context = ssl.SSLContext()
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
# -> AECDH-AES256-SHA
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
self._controllers: typing.Dict[str, ControllerConn] = {}
self._established: typing.Dict[uuid.UUID, Connection] = {}
async def _accept(self, stream):
remotename = stream.socket.getpeername()
if isinstance(remotename, tuple) and len(remotename) == 2:
remote = f'[{remotename[0]}]:{remotename[1]}'
else:
remote = str(remotename)
try:
await Connection.run(self, stream, server_side=True)
except LoopbackConnectionError:
pass
except trio.TooSlowError:
pass
finally:
_logger.debug(f"Connection from {remote} closed")
async def _listen(self, task_status=trio.TASK_STATUS_IGNORED):
await trio.serve_tcp(self._accept, 5000, task_status=task_status)
async def run(self, *, task_status=trio.TASK_STATUS_IGNORED):
async with trio.open_nursery() as nursery:
if self._is_controller:
await nursery.start(self._listen)
for name in self._config.controllers:
conn = ControllerConn(self, name)
self._controllers[name] = conn
task_status.started()
for conn in self._controllers.values():
nursery.start_soon(conn.run)
await trio.sleep_forever()
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
m = hmac.new(self._config.secret.encode('utf8'), digestmod=hashlib.sha256)
if server_side:
m.update(b'server$')
else:
m.update(b'client$')
m.update(ssl_binding)
return m.digest()
def _make_hello(self, ssl_binding: bytes, server_side: bool) -> capport.comm.message.Hello:
return capport.comm.message.Hello(
instance_id=self._instance_id.bytes,
hostname=self._hostname,
is_controller=self._is_controller,
authentication=self._calc_authentication(ssl_binding, server_side),
)
async def _sync_new_connection(self, peer_id: uuid.UUID, conn: Connection) -> None:
# send database (and all changes) to peers
await self.send(*self.database.serialize(), to=peer_id)
async def _new_peer(self, peer_id: uuid.UUID, conn: Connection, task_status=trio.TASK_STATUS_IGNORED) -> None:
have = self._established.get(peer_id, None)
if not have:
# peer unknown, "normal start"
# no "await" between get above and set here!!!
self._established[peer_id] = conn
# first wait for app to handle new peer
await self._app.new_peer(peer_id=peer_id)
task_status.started()
await self._sync_new_connection(peer_id, conn)
return
# peer already known - immediately allow receiving messages, then sync connection
task_status.started()
await self._sync_new_connection(peer_id, conn)
# now try to register connection for outgoing messages
while True:
# recheck whether peer is currently known (due to awaits since last get)
have = self._established.get(peer_id, None)
if have:
# already got a connection, nothing to do as long as it lives
await have.closed.wait()
else:
# make `conn` new outgoing connection for peer
# no "await" between get above and set here!!!
self._established[peer_id] = conn
await self._app.new_peer(peer_id=peer_id)
return
def _lost_peer(self, peer_id: uuid.UUID, conn: Connection):
have = self._established.get(peer_id, None)
lost = False
if have is conn:
lost = True
self._established.pop(peer_id)
conn.closed.set()
# only notify if this was the active connection
if lost:
# even when we failover to another connection we still need to resync
# as we don't know which messages might have got lost
# -> always trigger lost_peer
self._app.lost_peer(peer_id=peer_id)
async def _received_msg(self, peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
variant = msg.to_variant()
if isinstance(variant, capport.comm.message.Hello):
pass
elif isinstance(variant, capport.comm.message.AuthenticationResult):
pass
elif isinstance(variant, capport.comm.message.Ping):
pass
elif isinstance(variant, capport.comm.message.MacStates):
await self._app.received_mac_state(from_peer_id=peer_id, states=variant)
pu = capport.database.PendingUpdates()
for state in variant.states:
self.database.received_mac_state(state, pending_updates=pu)
if pu.macs:
# re-broadcast all received updates to all peers
await self.broadcast(*pu.serialize(), exclude=peer_id)
await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu)
else:
await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg)
def peer_is_controller(self, peer_id: uuid.UUID) -> bool:
conn = self._established.get(peer_id)
if conn:
return conn.peer.is_controller
return False
async def send(self, *msgs: capport.comm.message.Message, to: uuid.UUID):
conn = self._established.get(to)
if conn:
await conn.send_msg(*msgs)
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID]=None):
async with trio.open_nursery() as nursery:
for peer_id, conn in self._established.items():
if peer_id == exclude:
continue
nursery.start_soon(conn.send_msg, *msgs)

View File

@ -0,0 +1,37 @@
from __future__ import annotations
import typing
from .protobuf import message_pb2
def _message_to_variant(self: message_pb2.Message) -> typing.Any:
variant_name = self.WhichOneof('oneof')
if variant_name:
return getattr(self, variant_name)
return None
def _make_to_message(oneof_field):
def to_message(self) -> message_pb2.Message:
msg = message_pb2.Message(**{oneof_field: self})
return msg
return to_message
def _monkey_patch():
g = globals()
g['Message'] = message_pb2.Message
message_pb2.Message.to_variant = _message_to_variant
for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields:
type_name = field.message_type.name
field_type = getattr(message_pb2, type_name)
field_type.to_message = _make_to_message(field.name)
g[type_name] = field_type
# also re-exports all message types
_monkey_patch()
# not a variant of Message, still re-export
MacState = message_pb2.MacState

View File

@ -0,0 +1,93 @@
import google.protobuf.message
import typing
# manually maintained typehints for protobuf created (and monkey-patched) types
class Message(google.protobuf.message.Message):
hello: Hello
authentication_result: AuthenticationResult
ping: Ping
mac_states: MacStates
def __init__(
self,
*,
hello: typing.Optional[Hello]=None,
authentication_result: typing.Optional[AuthenticationResult]=None,
ping: typing.Optional[Ping]=None,
mac_states: typing.Optional[MacStates]=None,
) -> None: ...
def to_variant(self) -> typing.Union[Hello, AuthenticationResult, Ping, MacStates]: ...
class Hello(google.protobuf.message.Message):
instance_id: bytes
hostname: str
is_controller: bool
authentication: bytes
def __init__(
self,
*,
instance_id: bytes=b'',
hostname: str='',
is_controller: bool=False,
authentication: bytes=b'',
) -> None: ...
def to_message(self) -> Message: ...
class AuthenticationResult(google.protobuf.message.Message):
success: bool
def __init__(
self,
*,
success: bool=False,
) -> None: ...
def to_message(self) -> Message: ...
class Ping(google.protobuf.message.Message):
payload: bytes
def __init__(
self,
*,
payload: bytes=b'',
) -> None: ...
def to_message(self) -> Message: ...
class MacStates(google.protobuf.message.Message):
states: typing.List[MacState]
def __init__(
self,
*,
states: typing.List[MacState]=[],
) -> None: ...
def to_message(self) -> Message: ...
class MacState(google.protobuf.message.Message):
mac_address: bytes
last_change: int # Seconds of UTC time since epoch
allow_until: int # Seconds of UTC time since epoch
allowed: bool
def __init__(
self,
*,
mac_address: bytes=b'',
last_change: int=0,
allow_until: int=0,
allowed: bool=False,
) -> None: ...

View File

@ -0,0 +1,355 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: message.proto
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='message.proto',
package='capport',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\rmessage.proto\x12\x07\x63\x61pport\"\xbc\x01\n\x07Message\x12\x1f\n\x05hello\x18\x01 \x01(\x0b\x32\x0e.capport.HelloH\x00\x12>\n\x15\x61uthentication_result\x18\x02 \x01(\x0b\x32\x1d.capport.AuthenticationResultH\x00\x12\x1d\n\x04ping\x18\x03 \x01(\x0b\x32\r.capport.PingH\x00\x12(\n\nmac_states\x18\n \x01(\x0b\x32\x12.capport.MacStatesH\x00\x42\x07\n\x05oneof\"]\n\x05Hello\x12\x13\n\x0binstance_id\x18\x01 \x01(\x0c\x12\x10\n\x08hostname\x18\x02 \x01(\t\x12\x15\n\ris_controller\x18\x03 \x01(\x08\x12\x16\n\x0e\x61uthentication\x18\x04 \x01(\x0c\"\'\n\x14\x41uthenticationResult\x12\x0f\n\x07success\x18\x01 \x01(\x08\"\x17\n\x04Ping\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\".\n\tMacStates\x12!\n\x06states\x18\x01 \x03(\x0b\x32\x11.capport.MacState\"Z\n\x08MacState\x12\x13\n\x0bmac_address\x18\x01 \x01(\x0c\x12\x13\n\x0blast_change\x18\x02 \x01(\x03\x12\x13\n\x0b\x61llow_until\x18\x03 \x01(\x03\x12\x0f\n\x07\x61llowed\x18\x04 \x01(\x08\x62\x06proto3'
)
_MESSAGE = _descriptor.Descriptor(
name='Message',
full_name='capport.Message',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='hello', full_name='capport.Message.hello', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='authentication_result', full_name='capport.Message.authentication_result', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='ping', full_name='capport.Message.ping', index=2,
number=3, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='mac_states', full_name='capport.Message.mac_states', index=3,
number=10, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='oneof', full_name='capport.Message.oneof',
index=0, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[]),
],
serialized_start=27,
serialized_end=215,
)
_HELLO = _descriptor.Descriptor(
name='Hello',
full_name='capport.Hello',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='instance_id', full_name='capport.Hello.instance_id', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='hostname', full_name='capport.Hello.hostname', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='is_controller', full_name='capport.Hello.is_controller', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='authentication', full_name='capport.Hello.authentication', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=217,
serialized_end=310,
)
_AUTHENTICATIONRESULT = _descriptor.Descriptor(
name='AuthenticationResult',
full_name='capport.AuthenticationResult',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='success', full_name='capport.AuthenticationResult.success', index=0,
number=1, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=312,
serialized_end=351,
)
_PING = _descriptor.Descriptor(
name='Ping',
full_name='capport.Ping',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='payload', full_name='capport.Ping.payload', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=353,
serialized_end=376,
)
_MACSTATES = _descriptor.Descriptor(
name='MacStates',
full_name='capport.MacStates',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='states', full_name='capport.MacStates.states', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=378,
serialized_end=424,
)
_MACSTATE = _descriptor.Descriptor(
name='MacState',
full_name='capport.MacState',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='mac_address', full_name='capport.MacState.mac_address', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='last_change', full_name='capport.MacState.last_change', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='allow_until', full_name='capport.MacState.allow_until', index=2,
number=3, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='allowed', full_name='capport.MacState.allowed', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=426,
serialized_end=516,
)
_MESSAGE.fields_by_name['hello'].message_type = _HELLO
_MESSAGE.fields_by_name['authentication_result'].message_type = _AUTHENTICATIONRESULT
_MESSAGE.fields_by_name['ping'].message_type = _PING
_MESSAGE.fields_by_name['mac_states'].message_type = _MACSTATES
_MESSAGE.oneofs_by_name['oneof'].fields.append(
_MESSAGE.fields_by_name['hello'])
_MESSAGE.fields_by_name['hello'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
_MESSAGE.oneofs_by_name['oneof'].fields.append(
_MESSAGE.fields_by_name['authentication_result'])
_MESSAGE.fields_by_name['authentication_result'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
_MESSAGE.oneofs_by_name['oneof'].fields.append(
_MESSAGE.fields_by_name['ping'])
_MESSAGE.fields_by_name['ping'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
_MESSAGE.oneofs_by_name['oneof'].fields.append(
_MESSAGE.fields_by_name['mac_states'])
_MESSAGE.fields_by_name['mac_states'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
_MACSTATES.fields_by_name['states'].message_type = _MACSTATE
DESCRIPTOR.message_types_by_name['Message'] = _MESSAGE
DESCRIPTOR.message_types_by_name['Hello'] = _HELLO
DESCRIPTOR.message_types_by_name['AuthenticationResult'] = _AUTHENTICATIONRESULT
DESCRIPTOR.message_types_by_name['Ping'] = _PING
DESCRIPTOR.message_types_by_name['MacStates'] = _MACSTATES
DESCRIPTOR.message_types_by_name['MacState'] = _MACSTATE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Message = _reflection.GeneratedProtocolMessageType('Message', (_message.Message,), {
'DESCRIPTOR' : _MESSAGE,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.Message)
})
_sym_db.RegisterMessage(Message)
Hello = _reflection.GeneratedProtocolMessageType('Hello', (_message.Message,), {
'DESCRIPTOR' : _HELLO,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.Hello)
})
_sym_db.RegisterMessage(Hello)
AuthenticationResult = _reflection.GeneratedProtocolMessageType('AuthenticationResult', (_message.Message,), {
'DESCRIPTOR' : _AUTHENTICATIONRESULT,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.AuthenticationResult)
})
_sym_db.RegisterMessage(AuthenticationResult)
Ping = _reflection.GeneratedProtocolMessageType('Ping', (_message.Message,), {
'DESCRIPTOR' : _PING,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.Ping)
})
_sym_db.RegisterMessage(Ping)
MacStates = _reflection.GeneratedProtocolMessageType('MacStates', (_message.Message,), {
'DESCRIPTOR' : _MACSTATES,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.MacStates)
})
_sym_db.RegisterMessage(MacStates)
MacState = _reflection.GeneratedProtocolMessageType('MacState', (_message.Message,), {
'DESCRIPTOR' : _MACSTATE,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.MacState)
})
_sym_db.RegisterMessage(MacState)
# @@protoc_insertion_point(module_scope)

34
src/capport/config.py Normal file
View File

@ -0,0 +1,34 @@
from __future__ import annotations
import dataclasses
import os.path
import typing
import yaml
@dataclasses.dataclass
class Config:
controllers: typing.List[str]
secret: str
venue_info_url: typing.Optional[str]
session_timeout: int # in seconds
debug: bool
@staticmethod
def load(filename: typing.Optional[str]=None) -> 'Config':
if filename is None:
for name in ('capport.yaml', '/etc/capport.yaml'):
if os.path.exists(name):
return Config.load(name)
raise RuntimeError("Missing config file")
with open(filename) as f:
data = yaml.safe_load(f)
controllers = list(map(str, data['controllers']))
return Config(
controllers=controllers,
secret=str(data['secret']),
venue_info_url=str(data.get('venue-info-url')),
session_timeout=data.get('session-timeout', 3600),
debug=data.get('debug', False)
)

View File

View File

@ -0,0 +1,56 @@
from __future__ import annotations
import logging
import uuid
import capport.database
import capport.comm.hub
import capport.comm.message
import capport.config
import capport.utils.cli
import capport.utils.nft_set
import trio
from capport import cptypes
_logger = logging.getLogger(__name__)
class ControlApp(capport.comm.hub.HubApplication):
hub: capport.comm.hub.Hub
def __init__(self) -> None:
super().__init__()
self.nft_set = capport.utils.nft_set.NftSet()
def is_controller(self) -> bool:
return True
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
# deploy changes to netfilter set
inserts = []
removals = []
now = cptypes.Timestamp.now()
for mac, state in pending_updates.macs.items():
rem = state.allowed_remaining(now)
if rem > 0:
inserts.append((mac, rem))
else:
removals.append(mac)
self.nft_set.bulk_insert(inserts)
self.nft_set.bulk_remove(removals)
async def amain(config: capport.config.Config) -> None:
app = ControlApp()
hub = capport.comm.hub.Hub(config=config, app=app)
app.hub = hub
await hub.run()
def main() -> None:
config = capport.config.Config.load()
capport.utils.cli.init_logger(config)
try:
trio.run(amain, config)
except (KeyboardInterrupt, InterruptedError):
print()

94
src/capport/cptypes.py Normal file
View File

@ -0,0 +1,94 @@
from __future__ import annotations
import dataclasses
import datetime
import ipaddress
import json
import time
import typing
import quart
if typing.TYPE_CHECKING:
from .config import Config
IPAddress = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@dataclasses.dataclass(frozen=True)
class MacAddress:
raw: bytes
def __str__(self) -> str:
return self.raw.hex(':')
def __repr__(self) -> str:
return repr(str(self))
@staticmethod
def parse(s: str) -> MacAddress:
return MacAddress(bytes.fromhex(s.replace(':', '')))
@dataclasses.dataclass(frozen=True, order=True)
class Timestamp:
epoch: int
def __str__(self) -> str:
try:
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
return ts.isoformat(sep=' ')
except OSError:
return f'epoch@{self.epoch}'
def __repr__(self) -> str:
return repr(str(self))
@staticmethod
def now() -> Timestamp:
now = int(time.time())
return Timestamp(epoch=now)
@staticmethod
def from_protobuf(epoch: int) -> typing.Optional[Timestamp]:
if epoch:
return Timestamp(epoch=epoch)
return None
@dataclasses.dataclass
class MacPublicState:
address: IPAddress
mac: typing.Optional[MacAddress]
allowed_remaining: int
@staticmethod
def from_missing_mac(address: IPAddress) -> MacPublicState:
return MacPublicState(
address=address,
mac=None,
allowed_remaining=0,
)
@property
def allowed(self) -> bool:
return self.allowed_remaining > 0
@property
def captive(self) -> bool:
return not self.allowed
def to_rfc8908(self, config: Config) -> quart.Response:
response: typing.Dict[str, typing.Any] = {
'user-portal-url': quart.url_for('index', _external=True),
}
if config.venue_info_url:
response['venue-info-url'] = config.venue_info_url
if self.captive:
response['captive'] = True
else:
response['captive'] = False
response['seconds-remaining'] = self.allowed_remaining
response['can-extend-session'] = True
return quart.Response(json.dumps(response), headers={'Cache-Control': 'private'}, content_type='application/captive+json')

200
src/capport/database.py Normal file
View File

@ -0,0 +1,200 @@
from __future__ import annotations
import dataclasses
import typing
import capport.comm.message
from capport import cptypes
@dataclasses.dataclass
class MacEntry:
# entry can be removed if last_change was some time ago and allow_until wasn't set
# or got reached.
WAIT_LAST_CHANGE_SECONDS = 60
WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10
# last_change: timestamp of last change (sent by system initiating the change)
last_change: cptypes.Timestamp
# only if allowed is true and allow_until is set the device can communicate with the internet
# allow_until must not go backwards (and not get unset)
allow_until: typing.Optional[cptypes.Timestamp]
allowed: bool
@staticmethod
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
if len(msg.mac_address) < 6:
raise Exception("Invalid MacState: mac_address too short")
addr = cptypes.MacAddress(raw=msg.mac_address)
last_change = cptypes.Timestamp.from_protobuf(msg.last_change)
if not last_change:
raise Exception(f"Invalid MacState[{addr}]: missing last_change")
allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until)
return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed))
def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState:
allow_until = 0
if self.allow_until:
allow_until = self.allow_until.epoch
return capport.comm.message.MacState(
mac_address=addr.raw,
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def as_json(self) -> dict:
allow_until = None
if self.allow_until:
allow_until = self.allow_until.epoch
return dict(
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def merge(self, new: MacEntry) -> bool:
changed = False
if new.last_change > self.last_change:
changed = True
self.last_change = new.last_change
self.allowed = new.allowed
elif new.last_change == self.last_change:
# same last_change: set allowed if one allowed
if new.allowed and not self.allowed:
changed = True
self.allowed = True
# set allow_until to max of both
if new.allow_until: # if not set nothing to change in local data
if not self.allow_until or self.allow_until < new.allow_until:
changed = True
self.allow_until = new.allow_until
return changed
def timeout(self) -> cptypes.Timestamp:
elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS
if self.allow_until:
eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS
if eau > elc:
return cptypes.Timestamp(epoch=eau)
return cptypes.Timestamp(epoch=elc)
# returns 0 if not allowed
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp]=None) -> int:
if not self.allowed or not self.allow_until:
return 0
if not now:
now = cptypes.Timestamp.now()
assert self.allow_until
return max(self.allow_until.epoch - now.epoch, 0)
def outdated(self, now: typing.Optional[cptypes.Timestamp]=None) -> bool:
if not now:
now = cptypes.Timestamp.now()
return now.epoch > self.timeout().epoch
# might use this to serialize into file - don't need Message variant there
def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
result: typing.List[capport.comm.message.MacStates] = []
current = capport.comm.message.MacStates()
for addr, entry in macs.items():
state = entry.to_state(addr)
current.states.append(state)
if len(current.states) >= 1024: # split into messages with 1024 states
result.append(current)
current = capport.comm.message.MacStates()
if len(current.states):
result.append(current)
return result
def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]:
return [s.to_message() for s in _serialize_mac_states(macs)]
class NotReadyYet(Exception):
def __init__(self, msg: str, wait: int):
self.wait = wait # seconds to wait
super().__init__(msg)
@dataclasses.dataclass
class Database:
_macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates):
(addr, new_entry) = MacEntry.parse_state(state)
old_entry = self._macs.get(addr)
if not old_entry:
# only redistribute if not outdated
if not new_entry.outdated():
self._macs[addr] = new_entry
pending_updates.macs[addr] = new_entry
elif old_entry.merge(new_entry):
if old_entry.outdated():
# remove local entry, but still redistribute
self._macs.pop(addr)
pending_updates.macs[addr] = old_entry
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self._macs)
def as_json(self) -> dict:
return {
str(addr): entry.as_json()
for addr, entry in self._macs.items()
}
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
entry = self._macs.get(mac)
if entry:
allowed_remaining = entry.allowed_remaining()
else:
allowed_remaining = 0
return cptypes.MacPublicState(
address=address,
mac=mac,
allowed_remaining=allowed_remaining,
)
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8):
now = cptypes.Timestamp.now()
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
entry = self._macs.get(mac)
if not entry:
self._macs[mac] = new_entry
pending_updates.macs[mac] = new_entry
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
# too much time left on clock, not renewing session
return
elif entry.merge(new_entry):
pending_updates.macs[mac] = entry
elif not entry.allowed_remaining() > 0:
# entry should have been updated - can only fail due to `now < entry.last_change`
# i.e. out of sync clocks
wait = entry.last_change.epoch - now.epoch
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates):
now = cptypes.Timestamp.now()
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
entry = self._macs.get(mac)
if entry:
if entry.merge(new_entry):
pending_updates.macs[mac] = entry
elif entry.allowed_remaining() > 0:
# still logged in. can only happen with `now <= entry.last_change`
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
wait = entry.last_change.epoch - now.epoch + 1
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
@dataclasses.dataclass
class PendingUpdates:
macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self.macs)

View File

17
src/capport/utils/cli.py Normal file
View File

@ -0,0 +1,17 @@
from __future__ import annotations
import logging
import capport.config
def init_logger(config: capport.config.Config):
loglevel = logging.INFO
if config.debug:
loglevel = logging.DEBUG
logging.basicConfig(
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S %z]',
level=loglevel,
)
logging.getLogger('hypercorn').propagate = False

View File

@ -0,0 +1,63 @@
from __future__ import annotations
import contextlib
import errno
import typing
import pr2modules.iproute.linux
import pr2modules.netlink.exceptions
from capport import cptypes
@contextlib.asynccontextmanager
async def connect():
yield NeighborController()
# TODO: run blocking iproute calls in a different thread?
class NeighborController:
def __init__(self):
self.ip = pr2modules.iproute.linux.IPRoute()
async def get_neighbor(
self,
address: cptypes.IPAddress,
*,
index: int=0, # interface index
flags: int=0,
) -> typing.Optional[pr2modules.iproute.linux.ndmsg.ndmsg]:
if not index:
route = await self.get_route(address)
if route is None:
return None
index = route.get_attr(route.name2nla('oif'))
try:
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
except pr2modules.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
raise
async def get_neighbor_mac(
self,
address: cptypes.IPAddress,
*,
index: int=0, # interface index
flags: int=0,
) -> typing.Optional[cptypes.MacAddress]:
neigh = await self.get_neighbor(address, index=index, flags=flags)
if neigh is None:
return None
mac = neigh.get_attr(neigh.name2nla('lladdr'))
return cptypes.MacAddress.parse(mac)
async def get_route(
self,
address: cptypes.IPAddress,
) -> typing.Optional[pr2modules.iproute.linux.rtmsg]:
try:
return self.ip.route('get', dst=str(address))[0]
except pr2modules.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
raise

View File

@ -0,0 +1,181 @@
from __future__ import annotations
import typing
import pr2modules.netlink
from capport import cptypes
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
from .nft_socket import NFTSocket
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]:
# to seconds
if msecs is None:
return None
return msecs / 1000.0
class NftSet:
def __init__(self):
self._socket = NFTSocket()
self._socket.bind()
@staticmethod
def _set_elem(mac: cptypes.MacAddress, timeout: typing.Optional[typing.Union[int, float]]=None) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: typing.Dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict(
NFTA_DATA_VALUE=mac.raw,
),
}
if timeout:
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
return attrs
def _bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
ser_entries = [
self._set_elem(mac)
for mac, _timeout in entries
]
ser_entries_with_timeout = [
self._set_elem(mac, timeout)
for mac, timeout in entries
]
with self._socket.begin() as tx:
# create doesn't affect existing elements, so:
# make sure entries exists
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# drop entries (would fail if it doesn't exist)
tx.put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# now create entries with new timeout value
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE|pr2modules.netlink.NLM_F_EXCL,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
),
)
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
# limit chunk size
while len(entries) > 0:
self._bulk_insert(entries[:1024])
entries = entries[1024:]
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None:
self.bulk_insert([(mac, timeout)])
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
ser_entries = [
self._set_elem(mac)
for mac in entries
]
with self._socket.begin() as tx:
# make sure entries exists
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# drop entries (would fail if it doesn't exist)
tx.put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
def bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
# limit chunk size
while len(entries) > 0:
self._bulk_remove(entries[:1024])
entries = entries[1024:]
def remove(self, mac: cptypes.MacAddress) -> None:
self.bulk_remove([mac])
def list(self) -> list:
responses: typing.Iterator[_nftsocket.nft_set_elem_list_msg]
responses = self._socket.nft_dump(
_nftsocket.NFT_MSG_GETSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
)
return [
{
'mac': cptypes.MacAddress(
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
),
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
}
for response in responses
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
]
def flush(self) -> None:
self._socket.nft_put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
)
def create(self):
with self._socket.begin() as tx:
tx.put(
_nftsocket.NFT_MSG_NEWTABLE,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_TABLE_NAME='captive_mark',
),
)
tx.put(
_nftsocket.NFT_MSG_NEWSET,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_NAME='allowed',
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
NFTA_SET_KEY_LEN=6, # length of key: mac address
NFTA_SET_ID=1, # kernel seems to need a set id unique per transaction
),
)

View File

@ -0,0 +1,217 @@
from __future__ import annotations
import contextlib
import typing
import threading
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
import pr2modules.netlink
import pr2modules.netlink.nlsocket
from pr2modules.netlink.nfnetlink import nfgen_msg
from pr2modules.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pr2modules.netlink.nlmsg_base)
# nft uses NESTED for those.. lets do the same
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pr2modules.netlink.NLA_F_NESTED
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pr2modules.netlink.NLA_F_NESTED
def _monkey_patch_pyroute2():
import pr2modules.netlink
# overwrite setdefault on nlmsg_base class hierarchy
_orig_setvalue = pr2modules.netlink.nlmsg_base.setvalue
def _nlmsg_base__setvalue(self, value):
if not self.header or not self['header'] or not isinstance(value, dict):
return _orig_setvalue(self, value)
header = value.pop('header', {})
res = _orig_setvalue(self, value)
self['header'].update(header)
return res
def overwrite_methods(cls: typing.Type) -> None:
if cls.setvalue is _orig_setvalue:
cls.setvalue = _nlmsg_base__setvalue
for subcls in cls.__subclasses__():
overwrite_methods(subcls)
overwrite_methods(pr2modules.netlink.nlmsg_base)
_monkey_patch_pyroute2()
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: typing.Dict={}, **fields) -> _NlMsgBase:
msg = msg_class()
for key, value in header.items():
msg['header'][key] = value
for key, value in fields.items():
msg[key] = value
if attrs:
attr_list = msg['attrs']
r_nla_map = msg_class._nlmsg_base__r_nla_map
for key, value in attrs.items():
if msg_class.prefix:
key = msg_class.name2nla(key)
prime = r_nla_map[key]
nla_class = prime['class']
if issubclass(nla_class, pr2modules.netlink.nla):
# support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']:
value = [
_build(nla_class, attrs=elem) if not isinstance(elem, pr2modules.netlink.nlmsg_base) and isinstance(elem, dict) else elem
for elem in value
]
elif not isinstance(value, pr2modules.netlink.nlmsg_base) and isinstance(value, dict):
value = _build(nla_class, attrs=value)
attr_list.append([key, value])
return msg
class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket):
policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None:
super().__init__(family=pr2modules.netlink.NETLINK_NETFILTER)
policy = {
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
for (x, y) in self.policy.items()
}
self.register_policy(policy)
@contextlib.contextmanager
def begin(self) -> typing.Generator[NFTTransaction, None, None]:
try:
tx = NFTTransaction(socket=self)
yield tx
# autocommit when no exception was raised
# (only commits if it wasn't aborted)
tx.autocommit()
finally:
# abort does nothing if commit went through
tx.abort()
def nft_put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
with self.begin() as tx:
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_flags |= pr2modules.netlink.NLM_F_DUMP
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pr2modules.netlink.NLM_F_REQUEST
msg = _build(msg_class, attrs=attrs, **fields)
return self.nlm_request(msg, msg_type, msg_flags)
class NFTTransaction:
def __init__(self, socket: NFTSocket) -> None:
self._socket = socket
self._data = b''
self._seqnum = self._socket.addr_pool.alloc()
self._closed = False
# neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK
# at the end of an transaction to make sure it worked.
# we could use a different sequence number for all changes, and wait for an ACK for each of them
# (but we'd also need to check for errors on the BEGIN sequence number).
# the other solution: use the same sequence number for all messages in the batch, and add ACK
# only to the final message (before END) - if we get the ACK we known all other messages before
# worked out.
self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None
begin_msg = _build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x10, # NFNL_MSG_BATCH_BEGIN
flags=pr2modules.netlink.NLM_F_REQUEST,
sequence_number=self._seqnum,
),
)
begin_msg.encode()
self._data += begin_msg.data
def abort(self) -> None:
"""
Aborts if transaction wasn't already committed or aborted
"""
if not self._closed:
self._closed = True
# unused seqnum
self._socket.addr_pool.free(self._seqnum)
def autocommit(self) -> None:
"""
Commits if transaction wasn't already committed or aborted
"""
if self._closed:
return
self.commit()
def commit(self) -> None:
if self._closed:
raise Exception("Transaction already closed")
if not self._final_msg:
# no inner messages were queued... just abort transaction
self.abort()
return
self._closed = True
# request ACK only on the last message (before END)
self._final_msg['header']['flags'] |= pr2modules.netlink.NLM_F_ACK
self._final_msg.encode()
self._data += self._final_msg.data
self._final_msg = None
# batch end
end_msg = _build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x11, # NFNL_MSG_BATCH_END
flags=pr2modules.netlink.NLM_F_REQUEST,
sequence_number=self._seqnum,
),
)
end_msg.encode()
self._data += end_msg.data
# need to create backlog for our sequence number
with self._socket.lock[self._seqnum]:
self._socket.backlog[self._seqnum] = []
# send
self._socket.sendto(self._data, (0, 0))
try:
for _msg in self._socket.get(msg_seq=self._seqnum):
# we should see at most one ACK - real errors get raised anyway
pass
finally:
with self._socket.lock[0]:
# clear messages from "seq 0" queue - because if there
# was an error in our backlog, it got raised and the
# remaining messages moved to 0
self._socket.backlog[0] = []
def _put(self, msg: nfgen_msg) -> None:
if self._closed:
raise Exception("Transaction already closed")
if self._final_msg:
# previous message wasn't the final one, encode it without ACK
self._final_msg.encode()
self._data += self._final_msg.data
self._final_msg = msg
def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pr2modules.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pr2modules.netlink.NLM_F_ACK # make sure ACK is not set!
header = dict(
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
flags=msg_flags,
sequence_number=self._seqnum,
)
msg = _build(msg_class, attrs=attrs, header=header, **fields)
self._put(msg)

8
start-api.sh Executable file
View File

@ -0,0 +1,8 @@
#!/bin/bash
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
exec ./venv/bin/hypercorn -k trio capport.api "$@"

8
start-control.sh Executable file
View File

@ -0,0 +1,8 @@
#!/bin/bash
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
exec ./venv/bin/capport-control "$@"