3
0

initial commit

This commit is contained in:
2022-04-04 19:21:51 +02:00
commit d1050d2ee4
34 changed files with 2223 additions and 0 deletions

0
src/capport/__init__.py Normal file
View File

169
src/capport/api/__init__.py Normal file
View File

@ -0,0 +1,169 @@
from __future__ import annotations
import ipaddress
import logging
import typing
import uuid
import capport.database
import capport.comm.hub
import capport.comm.message
import capport.utils.cli
import capport.utils.ipneigh
import quart
import quart_trio
import trio
from capport import cptypes
from capport.config import Config
app = quart_trio.QuartTrio(__name__)
_logger = logging.getLogger(__name__)
config: typing.Optional[Config] = None
hub: typing.Optional[capport.comm.hub.Hub] = None
hub_app: typing.Optional[ApiHubApp] = None
nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None
def get_client_ip() -> cptypes.IPAddress:
try:
addr = ipaddress.ip_address(quart.request.remote_addr)
except ValueError as e:
_logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}')
quart.abort(500, 'Invalid client address')
if addr.is_loopback:
forw_addr_headers = quart.request.headers.getlist('X-Forwarded-For')
if len(forw_addr_headers) == 1:
try:
return ipaddress.ip_address(forw_addr_headers[0])
except ValueError as e:
_logger.warning(f'Invalid forwarded client address {forw_addr_headers!r} (from {addr}): {e}')
quart.abort(500, 'Invalid client address')
elif forw_addr_headers:
_logger.warning(f'Multiple forwarded client addresses {forw_addr_headers!r} (from {addr})')
quart.abort(500, 'Invalid client address')
return addr
async def get_client_mac_if_present(address: typing.Optional[cptypes.IPAddress]=None) -> typing.Optional[cptypes.MacAddress]:
assert nc # for mypy
if not address:
address = get_client_ip()
return await nc.get_neighbor_mac(address)
async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cptypes.MacAddress:
mac = await get_client_mac_if_present(address)
if mac is None:
_logger.warning(f"Couldn't find MAC addresss for {address}")
quart.abort(404, 'Unknown client')
return mac
class ApiHubApp(capport.comm.hub.HubApplication):
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
# TODO: support websocket notification updates to clients?
pass
async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None:
assert config # for mypy
assert hub # for mypy
pu = capport.database.PendingUpdates()
try:
hub.database.login(mac, config.session_timeout, pending_updates=pu)
except capport.database.NotReadyYet as e:
quart.abort(500, str(e))
if pu.macs:
_logger.info(f'User {mac} (with IP {address}) logged in')
for msg in pu.serialize():
await hub.broadcast(msg)
async def user_logout(mac: cptypes.MacAddress) -> None:
assert hub # for mypy
pu = capport.database.PendingUpdates()
try:
hub.database.logout(mac, pending_updates=pu)
except capport.database.NotReadyYet as e:
quart.abort(500, str(e))
if pu.macs:
_logger.info(f'User {mac} logged out')
for msg in pu.serialize():
await hub.broadcast(msg)
async def user_lookup() -> cptypes.MacPublicState:
assert hub # for mypy
address = get_client_ip()
mac = await get_client_mac_if_present(address)
if not mac:
return cptypes.MacPublicState.from_missing_mac(address)
else:
return hub.database.lookup(address, mac)
async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
global hub
global hub_app
global nc
assert config # for mypy
try:
async with capport.utils.ipneigh.connect() as mync:
nc = mync
_logger.info("Running hub for API")
myapp = ApiHubApp()
myhub = capport.comm.hub.Hub(config=config, app=myapp)
hub = myhub
hub_app = myapp
await myhub.run(task_status=task_status)
finally:
hub = None
hub_app = None
nc = None
_logger.info("Done running hub for API")
await app.shutdown()
@app.before_serving
async def init():
global config
config = Config.load()
capport.utils.cli.init_logger(config)
await app.nursery.start(_run_hub)
# @app.route('/all')
# async def route_all():
# return hub_app.database.as_json()
@app.route('/', methods=['GET'])
async def index():
state = await user_lookup()
return await quart.render_template('index.html', state=state)
@app.route('/login', methods=['POST'])
async def login():
address = get_client_ip()
mac = await get_client_mac(address)
await user_login(address, mac)
return quart.redirect('/', code=303)
@app.route('/logout', methods=['POST'])
async def logout():
mac = await get_client_mac()
await user_logout(mac)
return quart.redirect('/', code=303)
@app.route('/api/captive-portal', methods=['GET'])
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
async def captive_api():
state = await user_lookup()
return state.to_rfc8908(config)

View File

@ -0,0 +1,22 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>Captive Portal Universität Stuttgart</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
</head>
<body>
{% if not state.mac %}
It seems you're accessing this site from outside the network this captive portal is running for.
{% elif state.captive %}
To get access to the internet please accept our usage guidelines by clicking this button:
<form method="POST" action="/login"><button type="submit">Accept</button></form>
{% else %}
You already accepted out conditions and are currently granted access to the internet:
<form method="POST" action="/login"><button type="submit">Renew session</button></form>
<form method="POST" action="/logout"><button type="submit">Close session</button></form>
<br>
Your current session will last for {{ state.allowed_remaining }} seconds.
{% endif %}
</body>
</html>

View File

441
src/capport/comm/hub.py Normal file
View File

@ -0,0 +1,441 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import random
import socket
import ssl
import struct
import typing
import uuid
import capport.database
import capport.comm.message
import trio
if typing.TYPE_CHECKING:
from ..config import Config
_logger = logging.getLogger(__name__)
class HubConnectionReadError(ConnectionError):
pass
class HubConnectionClosedError(ConnectionError):
pass
class LoopbackConnectionError(Exception):
pass
class Channel:
def __init__(self, hub: Hub, transport_stream, server_side: bool):
self._hub = hub
self._serverside = server_side
self._ssl = trio.SSLStream(transport_stream, self._hub._anon_context, server_side=server_side)
_logger.debug(f"{self}: created (server_side={server_side})")
def __repr__(self) -> str:
return f'Channel[0x{id(self):x}]'
async def do_handshake(self) -> capport.comm.message.Hello:
try:
await self._ssl.do_handshake()
ssl_binding = self._ssl.get_channel_binding()
if not ssl_binding:
# binding mustn't be None after successful handshake
raise ConnectionError("Missing SSL channel binding")
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
msg = self._hub._make_hello(ssl_binding, server_side=self._serverside).to_message()
await self.send_msg(msg)
peer_hello = (await self.recv_msg()).to_variant()
if not isinstance(peer_hello, capport.comm.message.Hello):
raise HubConnectionReadError("Expected Hello as first message")
auth_succ = (peer_hello.authentication == self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
peer_auth = (await self.recv_msg()).to_variant()
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
raise HubConnectionReadError("Expected AuthenticationResult as second message")
if not auth_succ or not peer_auth.success:
raise HubConnectionReadError("Authentication failed")
return peer_hello
async def _read(self, num: int) -> bytes:
assert num > 0
buf = b''
# _logger.debug(f"{self}:_read({num})")
while num > 0:
try:
part = await self._ssl.receive_some(num)
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
# _logger.debug(f"{self}:_read({num}) got part {part!r}")
if len(part) == 0:
if len(buf) == 0:
raise HubConnectionClosedError()
raise HubConnectionReadError("Unexpected end of TLS stream")
buf += part
num -= len(part)
if num < 0:
raise HubConnectionReadError("TLS receive_some returned too much")
return buf
async def _recv_raw_msg(self) -> bytes:
len_bytes = await self._read(4)
chunk_size, = struct.unpack('!I', len_bytes)
chunk = await self._read(chunk_size)
if chunk is None:
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
return chunk
async def recv_msg(self) -> capport.comm.message.Message:
try:
chunk = await self._recv_raw_msg()
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
msg = capport.comm.message.Message()
msg.ParseFromString(chunk)
return msg
async def _send_raw(self, chunk: bytes) -> None:
try:
await self._ssl.send_all(chunk)
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
async def send_msg(self, msg: capport.comm.message.Message):
chunk = msg.SerializeToString(deterministic=True)
chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size)
chunk = len_bytes + chunk
await self._send_raw(chunk)
async def aclose(self):
try:
await self._ssl.aclose()
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
class Connection:
PING_INTERVAL = 10
RECEIVE_TIMEOUT = 15
SEND_TIMEOUT = 5
def __init__(self, hub: Hub, channel: Channel, peer: capport.comm.message.Hello):
self._channel = channel
self._hub = hub
tx: trio.MemorySendChannel
rx: trio.MemoryReceiveChannel
(tx, rx) = trio.open_memory_channel(64)
self._pending_tx = tx
self._pending_rx = rx
self.peer: capport.comm.message.Hello = peer
self.peer_id: uuid.UUID = uuid.UUID(bytes=peer.instance_id)
self.closed = trio.Event() # set by Hub._lost_peer
_logger.debug(f"{self._channel}: authenticated -> {self.peer_id}")
async def _sender(self, cancel_scope: trio.CancelScope) -> None:
try:
msg: typing.Optional[capport.comm.message.Message]
while True:
msg = None
# make sure we send something every PING_INTERVAL
with trio.move_on_after(self.PING_INTERVAL):
msg = await self._pending_rx.receive()
# if send blocks too long we're in trouble
with trio.fail_after(self.SEND_TIMEOUT):
if msg:
await self._channel.send_msg(msg)
else:
await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message())
except trio.TooSlowError:
_logger.warning(f"{self._channel}: send timed out")
except ConnectionError as e:
_logger.warning(f"{self._channel}: failed sending: {e!r}")
except Exception as e:
_logger.exception(f"{self._channel}: failed sending")
finally:
cancel_scope.cancel()
async def _receive(self, cancel_scope: trio.CancelScope) -> None:
try:
while True:
try:
with trio.fail_after(self.RECEIVE_TIMEOUT):
msg = await self._channel.recv_msg()
except (HubConnectionClosedError, ConnectionResetError):
return
except trio.TooSlowError:
_logger.warning(f"{self._channel}: receive timed out")
return
await self._hub._received_msg(self.peer_id, msg)
except ConnectionError as e:
_logger.warning(f"{self._channel}: failed receiving: {e!r}")
except Exception:
_logger.exception(f"{self._channel}: failed receiving")
finally:
cancel_scope.cancel()
async def _inner_run(self) -> None:
if self.peer_id == self._hub._instance_id:
# connected to ourself, don't need that
raise LoopbackConnectionError()
async with trio.open_nursery() as nursery:
nursery.start_soon(self._sender, nursery.cancel_scope)
# be nice and wait for new_peer beforce receiving messages
# (won't work on failover to a second connection)
await nursery.start(self._hub._new_peer, self.peer_id, self)
nursery.start_soon(self._receive, nursery.cancel_scope)
async def send_msg(self, *msgs: capport.comm.message.Message):
try:
for msg in msgs:
await self._pending_tx.send(msg)
except trio.ClosedResourceError:
pass
async def _run(self) -> None:
try:
await self._inner_run()
finally:
_logger.debug(f"{self._channel}: finished message handling")
# basic (non-async) cleanup
self._hub._lost_peer(self.peer_id, self)
self._pending_tx.close()
self._pending_rx.close()
# allow 3 seconds for proper cleanup
with trio.CancelScope(shield=True, deadline=trio.current_time() + 3):
try:
await self._channel.aclose()
except OSError:
pass
@staticmethod
async def run(hub: Hub, transport_stream, server_side: bool) -> None:
channel = Channel(hub, transport_stream, server_side)
try:
with trio.fail_after(5):
peer = await channel.do_handshake()
except trio.TooSlowError:
_logger.warning(f"Handshake timed out")
return
conn = Connection(hub, channel, peer)
await conn._run()
class ControllerConn:
def __init__(self, hub: Hub, hostname: str):
self._hub = hub
self.hostname = hostname
self.loopback = False
async def _connect(self):
_logger.info(f"Connecting to controller at {self.hostname}")
with trio.fail_after(5):
try:
stream = await trio.open_tcp_stream(self.hostname, 5000)
except OSError as e:
_logger.warning(f"Failed to connect to controller at {self.hostname}: {e}")
return
try:
await Connection.run(self._hub, stream, server_side=False)
finally:
_logger.info(f"Connection to {self.hostname} closed")
async def run(self):
while True:
try:
await self._connect()
except LoopbackConnectionError:
_logger.debug(f"Connection to {self.hostname} reached ourself")
self.loopback = True
return
except trio.TooSlowError:
pass
# try again later
retry_splay = random.random() * 5
await trio.sleep(10 + retry_splay)
class HubApplication:
def is_controller(self) -> bool:
return False
async def new_peer(self, *, peer_id: uuid.UUID) -> None:
_logger.info(f"New peer {peer_id}")
def lost_peer(self, *, peer_id: uuid.UUID) -> None:
_logger.warning(f"Lost peer {peer_id}")
async def received_unknown_message(self, *, from_peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
_logger.warning(f"Received from {from_peer_id}: {str(msg).strip()}")
async def received_mac_state(self, *, from_peer_id: uuid.UUID, states: capport.comm.message.MacStates) -> None:
if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}")
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received new states from {from_peer_id}: {pending_updates}")
class Hub:
def __init__(self, config: Config, app: HubApplication) -> None:
self._config = config
self._instance_id = uuid.uuid4()
self._hostname = socket.getfqdn()
self.database = capport.database.Database()
self._app = app
self._is_controller = bool(app.is_controller())
self._anon_context = ssl.SSLContext()
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
# -> AECDH-AES256-SHA
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
self._controllers: typing.Dict[str, ControllerConn] = {}
self._established: typing.Dict[uuid.UUID, Connection] = {}
async def _accept(self, stream):
remotename = stream.socket.getpeername()
if isinstance(remotename, tuple) and len(remotename) == 2:
remote = f'[{remotename[0]}]:{remotename[1]}'
else:
remote = str(remotename)
try:
await Connection.run(self, stream, server_side=True)
except LoopbackConnectionError:
pass
except trio.TooSlowError:
pass
finally:
_logger.debug(f"Connection from {remote} closed")
async def _listen(self, task_status=trio.TASK_STATUS_IGNORED):
await trio.serve_tcp(self._accept, 5000, task_status=task_status)
async def run(self, *, task_status=trio.TASK_STATUS_IGNORED):
async with trio.open_nursery() as nursery:
if self._is_controller:
await nursery.start(self._listen)
for name in self._config.controllers:
conn = ControllerConn(self, name)
self._controllers[name] = conn
task_status.started()
for conn in self._controllers.values():
nursery.start_soon(conn.run)
await trio.sleep_forever()
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
m = hmac.new(self._config.secret.encode('utf8'), digestmod=hashlib.sha256)
if server_side:
m.update(b'server$')
else:
m.update(b'client$')
m.update(ssl_binding)
return m.digest()
def _make_hello(self, ssl_binding: bytes, server_side: bool) -> capport.comm.message.Hello:
return capport.comm.message.Hello(
instance_id=self._instance_id.bytes,
hostname=self._hostname,
is_controller=self._is_controller,
authentication=self._calc_authentication(ssl_binding, server_side),
)
async def _sync_new_connection(self, peer_id: uuid.UUID, conn: Connection) -> None:
# send database (and all changes) to peers
await self.send(*self.database.serialize(), to=peer_id)
async def _new_peer(self, peer_id: uuid.UUID, conn: Connection, task_status=trio.TASK_STATUS_IGNORED) -> None:
have = self._established.get(peer_id, None)
if not have:
# peer unknown, "normal start"
# no "await" between get above and set here!!!
self._established[peer_id] = conn
# first wait for app to handle new peer
await self._app.new_peer(peer_id=peer_id)
task_status.started()
await self._sync_new_connection(peer_id, conn)
return
# peer already known - immediately allow receiving messages, then sync connection
task_status.started()
await self._sync_new_connection(peer_id, conn)
# now try to register connection for outgoing messages
while True:
# recheck whether peer is currently known (due to awaits since last get)
have = self._established.get(peer_id, None)
if have:
# already got a connection, nothing to do as long as it lives
await have.closed.wait()
else:
# make `conn` new outgoing connection for peer
# no "await" between get above and set here!!!
self._established[peer_id] = conn
await self._app.new_peer(peer_id=peer_id)
return
def _lost_peer(self, peer_id: uuid.UUID, conn: Connection):
have = self._established.get(peer_id, None)
lost = False
if have is conn:
lost = True
self._established.pop(peer_id)
conn.closed.set()
# only notify if this was the active connection
if lost:
# even when we failover to another connection we still need to resync
# as we don't know which messages might have got lost
# -> always trigger lost_peer
self._app.lost_peer(peer_id=peer_id)
async def _received_msg(self, peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
variant = msg.to_variant()
if isinstance(variant, capport.comm.message.Hello):
pass
elif isinstance(variant, capport.comm.message.AuthenticationResult):
pass
elif isinstance(variant, capport.comm.message.Ping):
pass
elif isinstance(variant, capport.comm.message.MacStates):
await self._app.received_mac_state(from_peer_id=peer_id, states=variant)
pu = capport.database.PendingUpdates()
for state in variant.states:
self.database.received_mac_state(state, pending_updates=pu)
if pu.macs:
# re-broadcast all received updates to all peers
await self.broadcast(*pu.serialize(), exclude=peer_id)
await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu)
else:
await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg)
def peer_is_controller(self, peer_id: uuid.UUID) -> bool:
conn = self._established.get(peer_id)
if conn:
return conn.peer.is_controller
return False
async def send(self, *msgs: capport.comm.message.Message, to: uuid.UUID):
conn = self._established.get(to)
if conn:
await conn.send_msg(*msgs)
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID]=None):
async with trio.open_nursery() as nursery:
for peer_id, conn in self._established.items():
if peer_id == exclude:
continue
nursery.start_soon(conn.send_msg, *msgs)

View File

@ -0,0 +1,37 @@
from __future__ import annotations
import typing
from .protobuf import message_pb2
def _message_to_variant(self: message_pb2.Message) -> typing.Any:
variant_name = self.WhichOneof('oneof')
if variant_name:
return getattr(self, variant_name)
return None
def _make_to_message(oneof_field):
def to_message(self) -> message_pb2.Message:
msg = message_pb2.Message(**{oneof_field: self})
return msg
return to_message
def _monkey_patch():
g = globals()
g['Message'] = message_pb2.Message
message_pb2.Message.to_variant = _message_to_variant
for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields:
type_name = field.message_type.name
field_type = getattr(message_pb2, type_name)
field_type.to_message = _make_to_message(field.name)
g[type_name] = field_type
# also re-exports all message types
_monkey_patch()
# not a variant of Message, still re-export
MacState = message_pb2.MacState

View File

@ -0,0 +1,93 @@
import google.protobuf.message
import typing
# manually maintained typehints for protobuf created (and monkey-patched) types
class Message(google.protobuf.message.Message):
hello: Hello
authentication_result: AuthenticationResult
ping: Ping
mac_states: MacStates
def __init__(
self,
*,
hello: typing.Optional[Hello]=None,
authentication_result: typing.Optional[AuthenticationResult]=None,
ping: typing.Optional[Ping]=None,
mac_states: typing.Optional[MacStates]=None,
) -> None: ...
def to_variant(self) -> typing.Union[Hello, AuthenticationResult, Ping, MacStates]: ...
class Hello(google.protobuf.message.Message):
instance_id: bytes
hostname: str
is_controller: bool
authentication: bytes
def __init__(
self,
*,
instance_id: bytes=b'',
hostname: str='',
is_controller: bool=False,
authentication: bytes=b'',
) -> None: ...
def to_message(self) -> Message: ...
class AuthenticationResult(google.protobuf.message.Message):
success: bool
def __init__(
self,
*,
success: bool=False,
) -> None: ...
def to_message(self) -> Message: ...
class Ping(google.protobuf.message.Message):
payload: bytes
def __init__(
self,
*,
payload: bytes=b'',
) -> None: ...
def to_message(self) -> Message: ...
class MacStates(google.protobuf.message.Message):
states: typing.List[MacState]
def __init__(
self,
*,
states: typing.List[MacState]=[],
) -> None: ...
def to_message(self) -> Message: ...
class MacState(google.protobuf.message.Message):
mac_address: bytes
last_change: int # Seconds of UTC time since epoch
allow_until: int # Seconds of UTC time since epoch
allowed: bool
def __init__(
self,
*,
mac_address: bytes=b'',
last_change: int=0,
allow_until: int=0,
allowed: bool=False,
) -> None: ...

View File

@ -0,0 +1,355 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: message.proto
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='message.proto',
package='capport',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\rmessage.proto\x12\x07\x63\x61pport\"\xbc\x01\n\x07Message\x12\x1f\n\x05hello\x18\x01 \x01(\x0b\x32\x0e.capport.HelloH\x00\x12>\n\x15\x61uthentication_result\x18\x02 \x01(\x0b\x32\x1d.capport.AuthenticationResultH\x00\x12\x1d\n\x04ping\x18\x03 \x01(\x0b\x32\r.capport.PingH\x00\x12(\n\nmac_states\x18\n \x01(\x0b\x32\x12.capport.MacStatesH\x00\x42\x07\n\x05oneof\"]\n\x05Hello\x12\x13\n\x0binstance_id\x18\x01 \x01(\x0c\x12\x10\n\x08hostname\x18\x02 \x01(\t\x12\x15\n\ris_controller\x18\x03 \x01(\x08\x12\x16\n\x0e\x61uthentication\x18\x04 \x01(\x0c\"\'\n\x14\x41uthenticationResult\x12\x0f\n\x07success\x18\x01 \x01(\x08\"\x17\n\x04Ping\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\".\n\tMacStates\x12!\n\x06states\x18\x01 \x03(\x0b\x32\x11.capport.MacState\"Z\n\x08MacState\x12\x13\n\x0bmac_address\x18\x01 \x01(\x0c\x12\x13\n\x0blast_change\x18\x02 \x01(\x03\x12\x13\n\x0b\x61llow_until\x18\x03 \x01(\x03\x12\x0f\n\x07\x61llowed\x18\x04 \x01(\x08\x62\x06proto3'
)
_MESSAGE = _descriptor.Descriptor(
name='Message',
full_name='capport.Message',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='hello', full_name='capport.Message.hello', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='authentication_result', full_name='capport.Message.authentication_result', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='ping', full_name='capport.Message.ping', index=2,
number=3, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='mac_states', full_name='capport.Message.mac_states', index=3,
number=10, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='oneof', full_name='capport.Message.oneof',
index=0, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[]),
],
serialized_start=27,
serialized_end=215,
)
_HELLO = _descriptor.Descriptor(
name='Hello',
full_name='capport.Hello',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='instance_id', full_name='capport.Hello.instance_id', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='hostname', full_name='capport.Hello.hostname', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='is_controller', full_name='capport.Hello.is_controller', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='authentication', full_name='capport.Hello.authentication', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=217,
serialized_end=310,
)
_AUTHENTICATIONRESULT = _descriptor.Descriptor(
name='AuthenticationResult',
full_name='capport.AuthenticationResult',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='success', full_name='capport.AuthenticationResult.success', index=0,
number=1, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=312,
serialized_end=351,
)
_PING = _descriptor.Descriptor(
name='Ping',
full_name='capport.Ping',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='payload', full_name='capport.Ping.payload', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=353,
serialized_end=376,
)
_MACSTATES = _descriptor.Descriptor(
name='MacStates',
full_name='capport.MacStates',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='states', full_name='capport.MacStates.states', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=378,
serialized_end=424,
)
_MACSTATE = _descriptor.Descriptor(
name='MacState',
full_name='capport.MacState',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='mac_address', full_name='capport.MacState.mac_address', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='last_change', full_name='capport.MacState.last_change', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='allow_until', full_name='capport.MacState.allow_until', index=2,
number=3, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='allowed', full_name='capport.MacState.allowed', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=426,
serialized_end=516,
)
_MESSAGE.fields_by_name['hello'].message_type = _HELLO
_MESSAGE.fields_by_name['authentication_result'].message_type = _AUTHENTICATIONRESULT
_MESSAGE.fields_by_name['ping'].message_type = _PING
_MESSAGE.fields_by_name['mac_states'].message_type = _MACSTATES
_MESSAGE.oneofs_by_name['oneof'].fields.append(
_MESSAGE.fields_by_name['hello'])
_MESSAGE.fields_by_name['hello'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
_MESSAGE.oneofs_by_name['oneof'].fields.append(
_MESSAGE.fields_by_name['authentication_result'])
_MESSAGE.fields_by_name['authentication_result'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
_MESSAGE.oneofs_by_name['oneof'].fields.append(
_MESSAGE.fields_by_name['ping'])
_MESSAGE.fields_by_name['ping'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
_MESSAGE.oneofs_by_name['oneof'].fields.append(
_MESSAGE.fields_by_name['mac_states'])
_MESSAGE.fields_by_name['mac_states'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
_MACSTATES.fields_by_name['states'].message_type = _MACSTATE
DESCRIPTOR.message_types_by_name['Message'] = _MESSAGE
DESCRIPTOR.message_types_by_name['Hello'] = _HELLO
DESCRIPTOR.message_types_by_name['AuthenticationResult'] = _AUTHENTICATIONRESULT
DESCRIPTOR.message_types_by_name['Ping'] = _PING
DESCRIPTOR.message_types_by_name['MacStates'] = _MACSTATES
DESCRIPTOR.message_types_by_name['MacState'] = _MACSTATE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Message = _reflection.GeneratedProtocolMessageType('Message', (_message.Message,), {
'DESCRIPTOR' : _MESSAGE,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.Message)
})
_sym_db.RegisterMessage(Message)
Hello = _reflection.GeneratedProtocolMessageType('Hello', (_message.Message,), {
'DESCRIPTOR' : _HELLO,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.Hello)
})
_sym_db.RegisterMessage(Hello)
AuthenticationResult = _reflection.GeneratedProtocolMessageType('AuthenticationResult', (_message.Message,), {
'DESCRIPTOR' : _AUTHENTICATIONRESULT,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.AuthenticationResult)
})
_sym_db.RegisterMessage(AuthenticationResult)
Ping = _reflection.GeneratedProtocolMessageType('Ping', (_message.Message,), {
'DESCRIPTOR' : _PING,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.Ping)
})
_sym_db.RegisterMessage(Ping)
MacStates = _reflection.GeneratedProtocolMessageType('MacStates', (_message.Message,), {
'DESCRIPTOR' : _MACSTATES,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.MacStates)
})
_sym_db.RegisterMessage(MacStates)
MacState = _reflection.GeneratedProtocolMessageType('MacState', (_message.Message,), {
'DESCRIPTOR' : _MACSTATE,
'__module__' : 'message_pb2'
# @@protoc_insertion_point(class_scope:capport.MacState)
})
_sym_db.RegisterMessage(MacState)
# @@protoc_insertion_point(module_scope)

34
src/capport/config.py Normal file
View File

@ -0,0 +1,34 @@
from __future__ import annotations
import dataclasses
import os.path
import typing
import yaml
@dataclasses.dataclass
class Config:
controllers: typing.List[str]
secret: str
venue_info_url: typing.Optional[str]
session_timeout: int # in seconds
debug: bool
@staticmethod
def load(filename: typing.Optional[str]=None) -> 'Config':
if filename is None:
for name in ('capport.yaml', '/etc/capport.yaml'):
if os.path.exists(name):
return Config.load(name)
raise RuntimeError("Missing config file")
with open(filename) as f:
data = yaml.safe_load(f)
controllers = list(map(str, data['controllers']))
return Config(
controllers=controllers,
secret=str(data['secret']),
venue_info_url=str(data.get('venue-info-url')),
session_timeout=data.get('session-timeout', 3600),
debug=data.get('debug', False)
)

View File

View File

@ -0,0 +1,56 @@
from __future__ import annotations
import logging
import uuid
import capport.database
import capport.comm.hub
import capport.comm.message
import capport.config
import capport.utils.cli
import capport.utils.nft_set
import trio
from capport import cptypes
_logger = logging.getLogger(__name__)
class ControlApp(capport.comm.hub.HubApplication):
hub: capport.comm.hub.Hub
def __init__(self) -> None:
super().__init__()
self.nft_set = capport.utils.nft_set.NftSet()
def is_controller(self) -> bool:
return True
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
# deploy changes to netfilter set
inserts = []
removals = []
now = cptypes.Timestamp.now()
for mac, state in pending_updates.macs.items():
rem = state.allowed_remaining(now)
if rem > 0:
inserts.append((mac, rem))
else:
removals.append(mac)
self.nft_set.bulk_insert(inserts)
self.nft_set.bulk_remove(removals)
async def amain(config: capport.config.Config) -> None:
app = ControlApp()
hub = capport.comm.hub.Hub(config=config, app=app)
app.hub = hub
await hub.run()
def main() -> None:
config = capport.config.Config.load()
capport.utils.cli.init_logger(config)
try:
trio.run(amain, config)
except (KeyboardInterrupt, InterruptedError):
print()

94
src/capport/cptypes.py Normal file
View File

@ -0,0 +1,94 @@
from __future__ import annotations
import dataclasses
import datetime
import ipaddress
import json
import time
import typing
import quart
if typing.TYPE_CHECKING:
from .config import Config
IPAddress = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@dataclasses.dataclass(frozen=True)
class MacAddress:
raw: bytes
def __str__(self) -> str:
return self.raw.hex(':')
def __repr__(self) -> str:
return repr(str(self))
@staticmethod
def parse(s: str) -> MacAddress:
return MacAddress(bytes.fromhex(s.replace(':', '')))
@dataclasses.dataclass(frozen=True, order=True)
class Timestamp:
epoch: int
def __str__(self) -> str:
try:
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
return ts.isoformat(sep=' ')
except OSError:
return f'epoch@{self.epoch}'
def __repr__(self) -> str:
return repr(str(self))
@staticmethod
def now() -> Timestamp:
now = int(time.time())
return Timestamp(epoch=now)
@staticmethod
def from_protobuf(epoch: int) -> typing.Optional[Timestamp]:
if epoch:
return Timestamp(epoch=epoch)
return None
@dataclasses.dataclass
class MacPublicState:
address: IPAddress
mac: typing.Optional[MacAddress]
allowed_remaining: int
@staticmethod
def from_missing_mac(address: IPAddress) -> MacPublicState:
return MacPublicState(
address=address,
mac=None,
allowed_remaining=0,
)
@property
def allowed(self) -> bool:
return self.allowed_remaining > 0
@property
def captive(self) -> bool:
return not self.allowed
def to_rfc8908(self, config: Config) -> quart.Response:
response: typing.Dict[str, typing.Any] = {
'user-portal-url': quart.url_for('index', _external=True),
}
if config.venue_info_url:
response['venue-info-url'] = config.venue_info_url
if self.captive:
response['captive'] = True
else:
response['captive'] = False
response['seconds-remaining'] = self.allowed_remaining
response['can-extend-session'] = True
return quart.Response(json.dumps(response), headers={'Cache-Control': 'private'}, content_type='application/captive+json')

200
src/capport/database.py Normal file
View File

@ -0,0 +1,200 @@
from __future__ import annotations
import dataclasses
import typing
import capport.comm.message
from capport import cptypes
@dataclasses.dataclass
class MacEntry:
# entry can be removed if last_change was some time ago and allow_until wasn't set
# or got reached.
WAIT_LAST_CHANGE_SECONDS = 60
WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10
# last_change: timestamp of last change (sent by system initiating the change)
last_change: cptypes.Timestamp
# only if allowed is true and allow_until is set the device can communicate with the internet
# allow_until must not go backwards (and not get unset)
allow_until: typing.Optional[cptypes.Timestamp]
allowed: bool
@staticmethod
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
if len(msg.mac_address) < 6:
raise Exception("Invalid MacState: mac_address too short")
addr = cptypes.MacAddress(raw=msg.mac_address)
last_change = cptypes.Timestamp.from_protobuf(msg.last_change)
if not last_change:
raise Exception(f"Invalid MacState[{addr}]: missing last_change")
allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until)
return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed))
def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState:
allow_until = 0
if self.allow_until:
allow_until = self.allow_until.epoch
return capport.comm.message.MacState(
mac_address=addr.raw,
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def as_json(self) -> dict:
allow_until = None
if self.allow_until:
allow_until = self.allow_until.epoch
return dict(
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def merge(self, new: MacEntry) -> bool:
changed = False
if new.last_change > self.last_change:
changed = True
self.last_change = new.last_change
self.allowed = new.allowed
elif new.last_change == self.last_change:
# same last_change: set allowed if one allowed
if new.allowed and not self.allowed:
changed = True
self.allowed = True
# set allow_until to max of both
if new.allow_until: # if not set nothing to change in local data
if not self.allow_until or self.allow_until < new.allow_until:
changed = True
self.allow_until = new.allow_until
return changed
def timeout(self) -> cptypes.Timestamp:
elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS
if self.allow_until:
eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS
if eau > elc:
return cptypes.Timestamp(epoch=eau)
return cptypes.Timestamp(epoch=elc)
# returns 0 if not allowed
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp]=None) -> int:
if not self.allowed or not self.allow_until:
return 0
if not now:
now = cptypes.Timestamp.now()
assert self.allow_until
return max(self.allow_until.epoch - now.epoch, 0)
def outdated(self, now: typing.Optional[cptypes.Timestamp]=None) -> bool:
if not now:
now = cptypes.Timestamp.now()
return now.epoch > self.timeout().epoch
# might use this to serialize into file - don't need Message variant there
def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
result: typing.List[capport.comm.message.MacStates] = []
current = capport.comm.message.MacStates()
for addr, entry in macs.items():
state = entry.to_state(addr)
current.states.append(state)
if len(current.states) >= 1024: # split into messages with 1024 states
result.append(current)
current = capport.comm.message.MacStates()
if len(current.states):
result.append(current)
return result
def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]:
return [s.to_message() for s in _serialize_mac_states(macs)]
class NotReadyYet(Exception):
def __init__(self, msg: str, wait: int):
self.wait = wait # seconds to wait
super().__init__(msg)
@dataclasses.dataclass
class Database:
_macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates):
(addr, new_entry) = MacEntry.parse_state(state)
old_entry = self._macs.get(addr)
if not old_entry:
# only redistribute if not outdated
if not new_entry.outdated():
self._macs[addr] = new_entry
pending_updates.macs[addr] = new_entry
elif old_entry.merge(new_entry):
if old_entry.outdated():
# remove local entry, but still redistribute
self._macs.pop(addr)
pending_updates.macs[addr] = old_entry
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self._macs)
def as_json(self) -> dict:
return {
str(addr): entry.as_json()
for addr, entry in self._macs.items()
}
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
entry = self._macs.get(mac)
if entry:
allowed_remaining = entry.allowed_remaining()
else:
allowed_remaining = 0
return cptypes.MacPublicState(
address=address,
mac=mac,
allowed_remaining=allowed_remaining,
)
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8):
now = cptypes.Timestamp.now()
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
entry = self._macs.get(mac)
if not entry:
self._macs[mac] = new_entry
pending_updates.macs[mac] = new_entry
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
# too much time left on clock, not renewing session
return
elif entry.merge(new_entry):
pending_updates.macs[mac] = entry
elif not entry.allowed_remaining() > 0:
# entry should have been updated - can only fail due to `now < entry.last_change`
# i.e. out of sync clocks
wait = entry.last_change.epoch - now.epoch
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates):
now = cptypes.Timestamp.now()
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
entry = self._macs.get(mac)
if entry:
if entry.merge(new_entry):
pending_updates.macs[mac] = entry
elif entry.allowed_remaining() > 0:
# still logged in. can only happen with `now <= entry.last_change`
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
wait = entry.last_change.epoch - now.epoch + 1
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
@dataclasses.dataclass
class PendingUpdates:
macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self.macs)

View File

17
src/capport/utils/cli.py Normal file
View File

@ -0,0 +1,17 @@
from __future__ import annotations
import logging
import capport.config
def init_logger(config: capport.config.Config):
loglevel = logging.INFO
if config.debug:
loglevel = logging.DEBUG
logging.basicConfig(
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S %z]',
level=loglevel,
)
logging.getLogger('hypercorn').propagate = False

View File

@ -0,0 +1,63 @@
from __future__ import annotations
import contextlib
import errno
import typing
import pr2modules.iproute.linux
import pr2modules.netlink.exceptions
from capport import cptypes
@contextlib.asynccontextmanager
async def connect():
yield NeighborController()
# TODO: run blocking iproute calls in a different thread?
class NeighborController:
def __init__(self):
self.ip = pr2modules.iproute.linux.IPRoute()
async def get_neighbor(
self,
address: cptypes.IPAddress,
*,
index: int=0, # interface index
flags: int=0,
) -> typing.Optional[pr2modules.iproute.linux.ndmsg.ndmsg]:
if not index:
route = await self.get_route(address)
if route is None:
return None
index = route.get_attr(route.name2nla('oif'))
try:
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
except pr2modules.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
raise
async def get_neighbor_mac(
self,
address: cptypes.IPAddress,
*,
index: int=0, # interface index
flags: int=0,
) -> typing.Optional[cptypes.MacAddress]:
neigh = await self.get_neighbor(address, index=index, flags=flags)
if neigh is None:
return None
mac = neigh.get_attr(neigh.name2nla('lladdr'))
return cptypes.MacAddress.parse(mac)
async def get_route(
self,
address: cptypes.IPAddress,
) -> typing.Optional[pr2modules.iproute.linux.rtmsg]:
try:
return self.ip.route('get', dst=str(address))[0]
except pr2modules.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
raise

View File

@ -0,0 +1,181 @@
from __future__ import annotations
import typing
import pr2modules.netlink
from capport import cptypes
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
from .nft_socket import NFTSocket
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]:
# to seconds
if msecs is None:
return None
return msecs / 1000.0
class NftSet:
def __init__(self):
self._socket = NFTSocket()
self._socket.bind()
@staticmethod
def _set_elem(mac: cptypes.MacAddress, timeout: typing.Optional[typing.Union[int, float]]=None) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: typing.Dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict(
NFTA_DATA_VALUE=mac.raw,
),
}
if timeout:
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
return attrs
def _bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
ser_entries = [
self._set_elem(mac)
for mac, _timeout in entries
]
ser_entries_with_timeout = [
self._set_elem(mac, timeout)
for mac, timeout in entries
]
with self._socket.begin() as tx:
# create doesn't affect existing elements, so:
# make sure entries exists
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# drop entries (would fail if it doesn't exist)
tx.put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# now create entries with new timeout value
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE|pr2modules.netlink.NLM_F_EXCL,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
),
)
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
# limit chunk size
while len(entries) > 0:
self._bulk_insert(entries[:1024])
entries = entries[1024:]
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None:
self.bulk_insert([(mac, timeout)])
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
ser_entries = [
self._set_elem(mac)
for mac in entries
]
with self._socket.begin() as tx:
# make sure entries exists
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# drop entries (would fail if it doesn't exist)
tx.put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
def bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
# limit chunk size
while len(entries) > 0:
self._bulk_remove(entries[:1024])
entries = entries[1024:]
def remove(self, mac: cptypes.MacAddress) -> None:
self.bulk_remove([mac])
def list(self) -> list:
responses: typing.Iterator[_nftsocket.nft_set_elem_list_msg]
responses = self._socket.nft_dump(
_nftsocket.NFT_MSG_GETSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
)
return [
{
'mac': cptypes.MacAddress(
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
),
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
}
for response in responses
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
]
def flush(self) -> None:
self._socket.nft_put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
)
def create(self):
with self._socket.begin() as tx:
tx.put(
_nftsocket.NFT_MSG_NEWTABLE,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_TABLE_NAME='captive_mark',
),
)
tx.put(
_nftsocket.NFT_MSG_NEWSET,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_NAME='allowed',
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
NFTA_SET_KEY_LEN=6, # length of key: mac address
NFTA_SET_ID=1, # kernel seems to need a set id unique per transaction
),
)

View File

@ -0,0 +1,217 @@
from __future__ import annotations
import contextlib
import typing
import threading
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
import pr2modules.netlink
import pr2modules.netlink.nlsocket
from pr2modules.netlink.nfnetlink import nfgen_msg
from pr2modules.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pr2modules.netlink.nlmsg_base)
# nft uses NESTED for those.. lets do the same
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pr2modules.netlink.NLA_F_NESTED
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pr2modules.netlink.NLA_F_NESTED
def _monkey_patch_pyroute2():
import pr2modules.netlink
# overwrite setdefault on nlmsg_base class hierarchy
_orig_setvalue = pr2modules.netlink.nlmsg_base.setvalue
def _nlmsg_base__setvalue(self, value):
if not self.header or not self['header'] or not isinstance(value, dict):
return _orig_setvalue(self, value)
header = value.pop('header', {})
res = _orig_setvalue(self, value)
self['header'].update(header)
return res
def overwrite_methods(cls: typing.Type) -> None:
if cls.setvalue is _orig_setvalue:
cls.setvalue = _nlmsg_base__setvalue
for subcls in cls.__subclasses__():
overwrite_methods(subcls)
overwrite_methods(pr2modules.netlink.nlmsg_base)
_monkey_patch_pyroute2()
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: typing.Dict={}, **fields) -> _NlMsgBase:
msg = msg_class()
for key, value in header.items():
msg['header'][key] = value
for key, value in fields.items():
msg[key] = value
if attrs:
attr_list = msg['attrs']
r_nla_map = msg_class._nlmsg_base__r_nla_map
for key, value in attrs.items():
if msg_class.prefix:
key = msg_class.name2nla(key)
prime = r_nla_map[key]
nla_class = prime['class']
if issubclass(nla_class, pr2modules.netlink.nla):
# support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']:
value = [
_build(nla_class, attrs=elem) if not isinstance(elem, pr2modules.netlink.nlmsg_base) and isinstance(elem, dict) else elem
for elem in value
]
elif not isinstance(value, pr2modules.netlink.nlmsg_base) and isinstance(value, dict):
value = _build(nla_class, attrs=value)
attr_list.append([key, value])
return msg
class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket):
policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None:
super().__init__(family=pr2modules.netlink.NETLINK_NETFILTER)
policy = {
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
for (x, y) in self.policy.items()
}
self.register_policy(policy)
@contextlib.contextmanager
def begin(self) -> typing.Generator[NFTTransaction, None, None]:
try:
tx = NFTTransaction(socket=self)
yield tx
# autocommit when no exception was raised
# (only commits if it wasn't aborted)
tx.autocommit()
finally:
# abort does nothing if commit went through
tx.abort()
def nft_put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
with self.begin() as tx:
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_flags |= pr2modules.netlink.NLM_F_DUMP
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pr2modules.netlink.NLM_F_REQUEST
msg = _build(msg_class, attrs=attrs, **fields)
return self.nlm_request(msg, msg_type, msg_flags)
class NFTTransaction:
def __init__(self, socket: NFTSocket) -> None:
self._socket = socket
self._data = b''
self._seqnum = self._socket.addr_pool.alloc()
self._closed = False
# neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK
# at the end of an transaction to make sure it worked.
# we could use a different sequence number for all changes, and wait for an ACK for each of them
# (but we'd also need to check for errors on the BEGIN sequence number).
# the other solution: use the same sequence number for all messages in the batch, and add ACK
# only to the final message (before END) - if we get the ACK we known all other messages before
# worked out.
self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None
begin_msg = _build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x10, # NFNL_MSG_BATCH_BEGIN
flags=pr2modules.netlink.NLM_F_REQUEST,
sequence_number=self._seqnum,
),
)
begin_msg.encode()
self._data += begin_msg.data
def abort(self) -> None:
"""
Aborts if transaction wasn't already committed or aborted
"""
if not self._closed:
self._closed = True
# unused seqnum
self._socket.addr_pool.free(self._seqnum)
def autocommit(self) -> None:
"""
Commits if transaction wasn't already committed or aborted
"""
if self._closed:
return
self.commit()
def commit(self) -> None:
if self._closed:
raise Exception("Transaction already closed")
if not self._final_msg:
# no inner messages were queued... just abort transaction
self.abort()
return
self._closed = True
# request ACK only on the last message (before END)
self._final_msg['header']['flags'] |= pr2modules.netlink.NLM_F_ACK
self._final_msg.encode()
self._data += self._final_msg.data
self._final_msg = None
# batch end
end_msg = _build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x11, # NFNL_MSG_BATCH_END
flags=pr2modules.netlink.NLM_F_REQUEST,
sequence_number=self._seqnum,
),
)
end_msg.encode()
self._data += end_msg.data
# need to create backlog for our sequence number
with self._socket.lock[self._seqnum]:
self._socket.backlog[self._seqnum] = []
# send
self._socket.sendto(self._data, (0, 0))
try:
for _msg in self._socket.get(msg_seq=self._seqnum):
# we should see at most one ACK - real errors get raised anyway
pass
finally:
with self._socket.lock[0]:
# clear messages from "seq 0" queue - because if there
# was an error in our backlog, it got raised and the
# remaining messages moved to 0
self._socket.backlog[0] = []
def _put(self, msg: nfgen_msg) -> None:
if self._closed:
raise Exception("Transaction already closed")
if self._final_msg:
# previous message wasn't the final one, encode it without ACK
self._final_msg.encode()
self._data += self._final_msg.data
self._final_msg = msg
def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pr2modules.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pr2modules.netlink.NLM_F_ACK # make sure ACK is not set!
header = dict(
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
flags=msg_flags,
sequence_number=self._seqnum,
)
msg = _build(msg_class, attrs=attrs, header=header, **fields)
self._put(msg)