initial commit
This commit is contained in:
commit
d1050d2ee4
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
.vscode
|
||||||
|
*.pyc
|
||||||
|
*.egg-info
|
||||||
|
__pycache__
|
||||||
|
venv
|
||||||
|
capport.yaml
|
11
.pycodestyle
Normal file
11
.pycodestyle
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
[pycodestyle]
|
||||||
|
# E241 multiple spaces after ':' [ want to align stuff ]
|
||||||
|
# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ]
|
||||||
|
# E501 line too long [ temporary? can't disable it in certain places.. ]
|
||||||
|
# E701 multiple statements on one line (colon) [ perfectly readable ]
|
||||||
|
# E713 test for membership should be ‘not in’ [ disagree: want `not a in x` ]
|
||||||
|
# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ]
|
||||||
|
# W503 Line break occurred before a binary operator [ pep8 flipped on this (also contradicts W504) ]
|
||||||
|
ignore = E241,E266,E501,E701,E713,E714,W503
|
||||||
|
max-line-length = 120
|
||||||
|
exclude = 00*.py,generated.py
|
20
.pylintrc
Normal file
20
.pylintrc
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
[MESSAGES CONTROL]
|
||||||
|
|
||||||
|
disable=logging-fstring-interpolation
|
||||||
|
|
||||||
|
[FORMAT]
|
||||||
|
|
||||||
|
# Maximum number of characters on a single line.
|
||||||
|
max-line-length=120
|
||||||
|
|
||||||
|
# Allow the body of an if to be on the same line as the test if there is no
|
||||||
|
# else.
|
||||||
|
single-line-if-stmt=yes
|
||||||
|
|
||||||
|
[DESIGN]
|
||||||
|
|
||||||
|
# Maximum number of locals for function / method body.
|
||||||
|
max-locals=20
|
||||||
|
|
||||||
|
# Minimum number of public methods for a class (see R0903).
|
||||||
|
min-public-methods=0
|
19
LICENSE
Normal file
19
LICENSE
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
Copyright (c) 2022 Universität Stuttgart
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
10
README.md
Normal file
10
README.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# python Captive Portal
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
Either clone repository (and install dependencies either through distribution or as virtualenv with `./setup-venv.sh`) or install as package.
|
||||||
|
|
||||||
|
[`pipx`](https://pypa.github.io/pipx/) (available in debian as package) can be used to install in separate virtual environment:
|
||||||
|
|
||||||
|
pipx install https://github.tik.uni-stuttgart.de/NKS/python-capport
|
||||||
|
|
7
capport-example.yaml
Normal file
7
capport-example.yaml
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
---
|
||||||
|
secret: mysecret
|
||||||
|
controllers:
|
||||||
|
- capport-controller1.example.com
|
||||||
|
- capport-controller2.example.com
|
||||||
|
session-timeout: 3600 # in seconds
|
||||||
|
venue-info-url: 'https://example.com'
|
24
mypy
Executable file
24
mypy
Executable file
@ -0,0 +1,24 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
### check type annotations with mypy
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
base=$(dirname "$(readlink -f "$0")")
|
||||||
|
cd "${base}"
|
||||||
|
|
||||||
|
if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
|
||||||
|
echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -x ./venv/bin/mypy ]; then
|
||||||
|
./venv/bin/pip install mypy
|
||||||
|
fi
|
||||||
|
|
||||||
|
site_pkgs=$(./venv/bin/python -c 'import site; print(site.getsitepackages()[0])')
|
||||||
|
if [ ! -d "${site_pkgs}/trio_typing" ]; then
|
||||||
|
./venv/bin/pip install trio-typing[mypy]
|
||||||
|
fi
|
||||||
|
|
||||||
|
./venv/bin/mypy --install-types src
|
10
protobuf/compile.sh
Executable file
10
protobuf/compile.sh
Executable file
@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
cd "$(dirname "$(readlink -f "$0")")"
|
||||||
|
|
||||||
|
rm -rf ../src/capport/comm/protobuf/message_pb2.py
|
||||||
|
mkdir -p ../src/capport/comm/protobuf
|
||||||
|
|
||||||
|
protoc --python_out=../src/capport/comm/protobuf message.proto
|
42
protobuf/message.proto
Normal file
42
protobuf/message.proto
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package capport;
|
||||||
|
|
||||||
|
message Message {
|
||||||
|
oneof oneof {
|
||||||
|
Hello hello = 1;
|
||||||
|
AuthenticationResult authentication_result = 2;
|
||||||
|
Ping ping = 3;
|
||||||
|
MacStates mac_states = 10;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sent by clients and servers as first message
|
||||||
|
message Hello {
|
||||||
|
bytes instance_id = 1;
|
||||||
|
string hostname = 2;
|
||||||
|
bool is_controller = 3;
|
||||||
|
bytes authentication = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// tell peer whether hello authentication was good
|
||||||
|
message AuthenticationResult {
|
||||||
|
bool success = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Ping {
|
||||||
|
bytes payload = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MacStates {
|
||||||
|
repeated MacState states = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MacState {
|
||||||
|
bytes mac_address = 1;
|
||||||
|
// Seconds of UTC time since epoch
|
||||||
|
int64 last_change = 2;
|
||||||
|
// Seconds of UTC time since epoch
|
||||||
|
int64 allow_until = 3;
|
||||||
|
bool allowed = 4;
|
||||||
|
}
|
20
pylint
Executable file
20
pylint
Executable file
@ -0,0 +1,20 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
### check type annotations with mypy
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
base=$(dirname "$(readlink -f "$0")")
|
||||||
|
cd "${base}"
|
||||||
|
|
||||||
|
if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
|
||||||
|
echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -x ./venv/bin/pylint ]; then
|
||||||
|
# need current pylint to deal with more recent python features
|
||||||
|
./venv/bin/pip install pylint
|
||||||
|
fi
|
||||||
|
|
||||||
|
./venv/bin/pylint src
|
6
pyproject.toml
Normal file
6
pyproject.toml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = [
|
||||||
|
"setuptools>=42",
|
||||||
|
"wheel"
|
||||||
|
]
|
||||||
|
build-backend = "setuptools.build_meta"
|
11
setup-venv.sh
Executable file
11
setup-venv.sh
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
self=$(dirname "$(readlink -f "$0")")
|
||||||
|
cd "${self}"
|
||||||
|
|
||||||
|
python3 -m venv venv
|
||||||
|
|
||||||
|
# install cli extras
|
||||||
|
./venv/bin/pip install -e '.'
|
36
setup.cfg
Normal file
36
setup.cfg
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
[metadata]
|
||||||
|
name = capport-tik-nks
|
||||||
|
version = 0.0.1
|
||||||
|
author = Stefan Bühler
|
||||||
|
author_email = stefan.buehler@tik.uni-stuttgart.de
|
||||||
|
description = Captive Portal
|
||||||
|
long_description = file: README.md
|
||||||
|
long_description_content_type = text/markdown
|
||||||
|
url = https://github.tik.uni-stuttgart.de/NKS/python-capport
|
||||||
|
project_urls =
|
||||||
|
Bug Tracker = https://github.tik.uni-stuttgart.de/NKS/python-capport/issues
|
||||||
|
classifiers =
|
||||||
|
Programming Language :: Python :: 3
|
||||||
|
License :: OSI Approved :: MIT License
|
||||||
|
Operating System :: OS Independent
|
||||||
|
|
||||||
|
[options]
|
||||||
|
package_dir =
|
||||||
|
= src
|
||||||
|
packages = find:
|
||||||
|
python_requires = >=3.9
|
||||||
|
install_requires =
|
||||||
|
trio
|
||||||
|
quart-trio
|
||||||
|
quart
|
||||||
|
hypercorn[trio]
|
||||||
|
PyYAML
|
||||||
|
protobuf
|
||||||
|
pyroute2
|
||||||
|
|
||||||
|
[options.packages.find]
|
||||||
|
where = src
|
||||||
|
|
||||||
|
[options.entry_points]
|
||||||
|
console_scripts =
|
||||||
|
capport-control = capport.control.run:main
|
6
setup.py
Normal file
6
setup.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
# https://github.com/pypa/setuptools/issues/2816
|
||||||
|
# allow editable install on older pip versions
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup()
|
0
src/capport/__init__.py
Normal file
0
src/capport/__init__.py
Normal file
169
src/capport/api/__init__.py
Normal file
169
src/capport/api/__init__.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import capport.database
|
||||||
|
import capport.comm.hub
|
||||||
|
import capport.comm.message
|
||||||
|
import capport.utils.cli
|
||||||
|
import capport.utils.ipneigh
|
||||||
|
import quart
|
||||||
|
import quart_trio
|
||||||
|
import trio
|
||||||
|
from capport import cptypes
|
||||||
|
from capport.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
app = quart_trio.QuartTrio(__name__)
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
config: typing.Optional[Config] = None
|
||||||
|
hub: typing.Optional[capport.comm.hub.Hub] = None
|
||||||
|
hub_app: typing.Optional[ApiHubApp] = None
|
||||||
|
nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_ip() -> cptypes.IPAddress:
|
||||||
|
try:
|
||||||
|
addr = ipaddress.ip_address(quart.request.remote_addr)
|
||||||
|
except ValueError as e:
|
||||||
|
_logger.warning(f'Invalid client address {quart.request.remote_addr!r}: {e}')
|
||||||
|
quart.abort(500, 'Invalid client address')
|
||||||
|
if addr.is_loopback:
|
||||||
|
forw_addr_headers = quart.request.headers.getlist('X-Forwarded-For')
|
||||||
|
if len(forw_addr_headers) == 1:
|
||||||
|
try:
|
||||||
|
return ipaddress.ip_address(forw_addr_headers[0])
|
||||||
|
except ValueError as e:
|
||||||
|
_logger.warning(f'Invalid forwarded client address {forw_addr_headers!r} (from {addr}): {e}')
|
||||||
|
quart.abort(500, 'Invalid client address')
|
||||||
|
elif forw_addr_headers:
|
||||||
|
_logger.warning(f'Multiple forwarded client addresses {forw_addr_headers!r} (from {addr})')
|
||||||
|
quart.abort(500, 'Invalid client address')
|
||||||
|
return addr
|
||||||
|
|
||||||
|
|
||||||
|
async def get_client_mac_if_present(address: typing.Optional[cptypes.IPAddress]=None) -> typing.Optional[cptypes.MacAddress]:
|
||||||
|
assert nc # for mypy
|
||||||
|
if not address:
|
||||||
|
address = get_client_ip()
|
||||||
|
return await nc.get_neighbor_mac(address)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_client_mac(address: typing.Optional[cptypes.IPAddress]=None) -> cptypes.MacAddress:
|
||||||
|
mac = await get_client_mac_if_present(address)
|
||||||
|
if mac is None:
|
||||||
|
_logger.warning(f"Couldn't find MAC addresss for {address}")
|
||||||
|
quart.abort(404, 'Unknown client')
|
||||||
|
return mac
|
||||||
|
|
||||||
|
|
||||||
|
class ApiHubApp(capport.comm.hub.HubApplication):
|
||||||
|
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
|
||||||
|
# TODO: support websocket notification updates to clients?
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None:
|
||||||
|
assert config # for mypy
|
||||||
|
assert hub # for mypy
|
||||||
|
pu = capport.database.PendingUpdates()
|
||||||
|
try:
|
||||||
|
hub.database.login(mac, config.session_timeout, pending_updates=pu)
|
||||||
|
except capport.database.NotReadyYet as e:
|
||||||
|
quart.abort(500, str(e))
|
||||||
|
|
||||||
|
if pu.macs:
|
||||||
|
_logger.info(f'User {mac} (with IP {address}) logged in')
|
||||||
|
for msg in pu.serialize():
|
||||||
|
await hub.broadcast(msg)
|
||||||
|
|
||||||
|
|
||||||
|
async def user_logout(mac: cptypes.MacAddress) -> None:
|
||||||
|
assert hub # for mypy
|
||||||
|
pu = capport.database.PendingUpdates()
|
||||||
|
try:
|
||||||
|
hub.database.logout(mac, pending_updates=pu)
|
||||||
|
except capport.database.NotReadyYet as e:
|
||||||
|
quart.abort(500, str(e))
|
||||||
|
if pu.macs:
|
||||||
|
_logger.info(f'User {mac} logged out')
|
||||||
|
for msg in pu.serialize():
|
||||||
|
await hub.broadcast(msg)
|
||||||
|
|
||||||
|
|
||||||
|
async def user_lookup() -> cptypes.MacPublicState:
|
||||||
|
assert hub # for mypy
|
||||||
|
address = get_client_ip()
|
||||||
|
mac = await get_client_mac_if_present(address)
|
||||||
|
if not mac:
|
||||||
|
return cptypes.MacPublicState.from_missing_mac(address)
|
||||||
|
else:
|
||||||
|
return hub.database.lookup(address, mac)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
||||||
|
global hub
|
||||||
|
global hub_app
|
||||||
|
global nc
|
||||||
|
assert config # for mypy
|
||||||
|
try:
|
||||||
|
async with capport.utils.ipneigh.connect() as mync:
|
||||||
|
nc = mync
|
||||||
|
_logger.info("Running hub for API")
|
||||||
|
myapp = ApiHubApp()
|
||||||
|
myhub = capport.comm.hub.Hub(config=config, app=myapp)
|
||||||
|
hub = myhub
|
||||||
|
hub_app = myapp
|
||||||
|
await myhub.run(task_status=task_status)
|
||||||
|
finally:
|
||||||
|
hub = None
|
||||||
|
hub_app = None
|
||||||
|
nc = None
|
||||||
|
_logger.info("Done running hub for API")
|
||||||
|
await app.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_serving
|
||||||
|
async def init():
|
||||||
|
global config
|
||||||
|
config = Config.load()
|
||||||
|
capport.utils.cli.init_logger(config)
|
||||||
|
await app.nursery.start(_run_hub)
|
||||||
|
|
||||||
|
|
||||||
|
# @app.route('/all')
|
||||||
|
# async def route_all():
|
||||||
|
# return hub_app.database.as_json()
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/', methods=['GET'])
|
||||||
|
async def index():
|
||||||
|
state = await user_lookup()
|
||||||
|
return await quart.render_template('index.html', state=state)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/login', methods=['POST'])
|
||||||
|
async def login():
|
||||||
|
address = get_client_ip()
|
||||||
|
mac = await get_client_mac(address)
|
||||||
|
await user_login(address, mac)
|
||||||
|
return quart.redirect('/', code=303)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/logout', methods=['POST'])
|
||||||
|
async def logout():
|
||||||
|
mac = await get_client_mac()
|
||||||
|
await user_logout(mac)
|
||||||
|
return quart.redirect('/', code=303)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/captive-portal', methods=['GET'])
|
||||||
|
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
|
||||||
|
async def captive_api():
|
||||||
|
state = await user_lookup()
|
||||||
|
return state.to_rfc8908(config)
|
22
src/capport/api/templates/index.html
Normal file
22
src/capport/api/templates/index.html
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8" />
|
||||||
|
<title>Captive Portal Universität Stuttgart</title>
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
{% if not state.mac %}
|
||||||
|
It seems you're accessing this site from outside the network this captive portal is running for.
|
||||||
|
{% elif state.captive %}
|
||||||
|
To get access to the internet please accept our usage guidelines by clicking this button:
|
||||||
|
<form method="POST" action="/login"><button type="submit">Accept</button></form>
|
||||||
|
{% else %}
|
||||||
|
You already accepted out conditions and are currently granted access to the internet:
|
||||||
|
<form method="POST" action="/login"><button type="submit">Renew session</button></form>
|
||||||
|
<form method="POST" action="/logout"><button type="submit">Close session</button></form>
|
||||||
|
<br>
|
||||||
|
Your current session will last for {{ state.allowed_remaining }} seconds.
|
||||||
|
{% endif %}
|
||||||
|
</body>
|
||||||
|
</html>
|
0
src/capport/comm/__init__.py
Normal file
0
src/capport/comm/__init__.py
Normal file
441
src/capport/comm/hub.py
Normal file
441
src/capport/comm/hub.py
Normal file
@ -0,0 +1,441 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import socket
|
||||||
|
import ssl
|
||||||
|
import struct
|
||||||
|
import typing
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import capport.database
|
||||||
|
import capport.comm.message
|
||||||
|
import trio
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from ..config import Config
|
||||||
|
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HubConnectionReadError(ConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HubConnectionClosedError(ConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LoopbackConnectionError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Channel:
|
||||||
|
def __init__(self, hub: Hub, transport_stream, server_side: bool):
|
||||||
|
self._hub = hub
|
||||||
|
self._serverside = server_side
|
||||||
|
self._ssl = trio.SSLStream(transport_stream, self._hub._anon_context, server_side=server_side)
|
||||||
|
_logger.debug(f"{self}: created (server_side={server_side})")
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'Channel[0x{id(self):x}]'
|
||||||
|
|
||||||
|
async def do_handshake(self) -> capport.comm.message.Hello:
|
||||||
|
try:
|
||||||
|
await self._ssl.do_handshake()
|
||||||
|
ssl_binding = self._ssl.get_channel_binding()
|
||||||
|
if not ssl_binding:
|
||||||
|
# binding mustn't be None after successful handshake
|
||||||
|
raise ConnectionError("Missing SSL channel binding")
|
||||||
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||||
|
raise ConnectionError(e) from None
|
||||||
|
msg = self._hub._make_hello(ssl_binding, server_side=self._serverside).to_message()
|
||||||
|
await self.send_msg(msg)
|
||||||
|
peer_hello = (await self.recv_msg()).to_variant()
|
||||||
|
if not isinstance(peer_hello, capport.comm.message.Hello):
|
||||||
|
raise HubConnectionReadError("Expected Hello as first message")
|
||||||
|
auth_succ = (peer_hello.authentication == self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
|
||||||
|
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
|
||||||
|
peer_auth = (await self.recv_msg()).to_variant()
|
||||||
|
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
|
||||||
|
raise HubConnectionReadError("Expected AuthenticationResult as second message")
|
||||||
|
if not auth_succ or not peer_auth.success:
|
||||||
|
raise HubConnectionReadError("Authentication failed")
|
||||||
|
return peer_hello
|
||||||
|
|
||||||
|
async def _read(self, num: int) -> bytes:
|
||||||
|
assert num > 0
|
||||||
|
buf = b''
|
||||||
|
# _logger.debug(f"{self}:_read({num})")
|
||||||
|
while num > 0:
|
||||||
|
try:
|
||||||
|
part = await self._ssl.receive_some(num)
|
||||||
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||||
|
raise ConnectionError(e) from None
|
||||||
|
# _logger.debug(f"{self}:_read({num}) got part {part!r}")
|
||||||
|
if len(part) == 0:
|
||||||
|
if len(buf) == 0:
|
||||||
|
raise HubConnectionClosedError()
|
||||||
|
raise HubConnectionReadError("Unexpected end of TLS stream")
|
||||||
|
buf += part
|
||||||
|
num -= len(part)
|
||||||
|
if num < 0:
|
||||||
|
raise HubConnectionReadError("TLS receive_some returned too much")
|
||||||
|
return buf
|
||||||
|
|
||||||
|
async def _recv_raw_msg(self) -> bytes:
|
||||||
|
len_bytes = await self._read(4)
|
||||||
|
chunk_size, = struct.unpack('!I', len_bytes)
|
||||||
|
chunk = await self._read(chunk_size)
|
||||||
|
if chunk is None:
|
||||||
|
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
async def recv_msg(self) -> capport.comm.message.Message:
|
||||||
|
try:
|
||||||
|
chunk = await self._recv_raw_msg()
|
||||||
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||||
|
raise ConnectionError(e) from None
|
||||||
|
msg = capport.comm.message.Message()
|
||||||
|
msg.ParseFromString(chunk)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def _send_raw(self, chunk: bytes) -> None:
|
||||||
|
try:
|
||||||
|
await self._ssl.send_all(chunk)
|
||||||
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||||
|
raise ConnectionError(e) from None
|
||||||
|
|
||||||
|
async def send_msg(self, msg: capport.comm.message.Message):
|
||||||
|
chunk = msg.SerializeToString(deterministic=True)
|
||||||
|
chunk_size = len(chunk)
|
||||||
|
len_bytes = struct.pack('!I', chunk_size)
|
||||||
|
chunk = len_bytes + chunk
|
||||||
|
await self._send_raw(chunk)
|
||||||
|
|
||||||
|
async def aclose(self):
|
||||||
|
try:
|
||||||
|
await self._ssl.aclose()
|
||||||
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||||
|
raise ConnectionError(e) from None
|
||||||
|
|
||||||
|
|
||||||
|
class Connection:
|
||||||
|
PING_INTERVAL = 10
|
||||||
|
RECEIVE_TIMEOUT = 15
|
||||||
|
SEND_TIMEOUT = 5
|
||||||
|
|
||||||
|
def __init__(self, hub: Hub, channel: Channel, peer: capport.comm.message.Hello):
|
||||||
|
self._channel = channel
|
||||||
|
self._hub = hub
|
||||||
|
tx: trio.MemorySendChannel
|
||||||
|
rx: trio.MemoryReceiveChannel
|
||||||
|
(tx, rx) = trio.open_memory_channel(64)
|
||||||
|
self._pending_tx = tx
|
||||||
|
self._pending_rx = rx
|
||||||
|
self.peer: capport.comm.message.Hello = peer
|
||||||
|
self.peer_id: uuid.UUID = uuid.UUID(bytes=peer.instance_id)
|
||||||
|
self.closed = trio.Event() # set by Hub._lost_peer
|
||||||
|
_logger.debug(f"{self._channel}: authenticated -> {self.peer_id}")
|
||||||
|
|
||||||
|
async def _sender(self, cancel_scope: trio.CancelScope) -> None:
|
||||||
|
try:
|
||||||
|
msg: typing.Optional[capport.comm.message.Message]
|
||||||
|
while True:
|
||||||
|
msg = None
|
||||||
|
# make sure we send something every PING_INTERVAL
|
||||||
|
with trio.move_on_after(self.PING_INTERVAL):
|
||||||
|
msg = await self._pending_rx.receive()
|
||||||
|
# if send blocks too long we're in trouble
|
||||||
|
with trio.fail_after(self.SEND_TIMEOUT):
|
||||||
|
if msg:
|
||||||
|
await self._channel.send_msg(msg)
|
||||||
|
else:
|
||||||
|
await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message())
|
||||||
|
except trio.TooSlowError:
|
||||||
|
_logger.warning(f"{self._channel}: send timed out")
|
||||||
|
except ConnectionError as e:
|
||||||
|
_logger.warning(f"{self._channel}: failed sending: {e!r}")
|
||||||
|
except Exception as e:
|
||||||
|
_logger.exception(f"{self._channel}: failed sending")
|
||||||
|
finally:
|
||||||
|
cancel_scope.cancel()
|
||||||
|
|
||||||
|
async def _receive(self, cancel_scope: trio.CancelScope) -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with trio.fail_after(self.RECEIVE_TIMEOUT):
|
||||||
|
msg = await self._channel.recv_msg()
|
||||||
|
except (HubConnectionClosedError, ConnectionResetError):
|
||||||
|
return
|
||||||
|
except trio.TooSlowError:
|
||||||
|
_logger.warning(f"{self._channel}: receive timed out")
|
||||||
|
return
|
||||||
|
await self._hub._received_msg(self.peer_id, msg)
|
||||||
|
except ConnectionError as e:
|
||||||
|
_logger.warning(f"{self._channel}: failed receiving: {e!r}")
|
||||||
|
except Exception:
|
||||||
|
_logger.exception(f"{self._channel}: failed receiving")
|
||||||
|
finally:
|
||||||
|
cancel_scope.cancel()
|
||||||
|
|
||||||
|
async def _inner_run(self) -> None:
|
||||||
|
if self.peer_id == self._hub._instance_id:
|
||||||
|
# connected to ourself, don't need that
|
||||||
|
raise LoopbackConnectionError()
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
nursery.start_soon(self._sender, nursery.cancel_scope)
|
||||||
|
# be nice and wait for new_peer beforce receiving messages
|
||||||
|
# (won't work on failover to a second connection)
|
||||||
|
await nursery.start(self._hub._new_peer, self.peer_id, self)
|
||||||
|
nursery.start_soon(self._receive, nursery.cancel_scope)
|
||||||
|
|
||||||
|
async def send_msg(self, *msgs: capport.comm.message.Message):
|
||||||
|
try:
|
||||||
|
for msg in msgs:
|
||||||
|
await self._pending_tx.send(msg)
|
||||||
|
except trio.ClosedResourceError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _run(self) -> None:
|
||||||
|
try:
|
||||||
|
await self._inner_run()
|
||||||
|
finally:
|
||||||
|
_logger.debug(f"{self._channel}: finished message handling")
|
||||||
|
# basic (non-async) cleanup
|
||||||
|
self._hub._lost_peer(self.peer_id, self)
|
||||||
|
self._pending_tx.close()
|
||||||
|
self._pending_rx.close()
|
||||||
|
# allow 3 seconds for proper cleanup
|
||||||
|
with trio.CancelScope(shield=True, deadline=trio.current_time() + 3):
|
||||||
|
try:
|
||||||
|
await self._channel.aclose()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def run(hub: Hub, transport_stream, server_side: bool) -> None:
|
||||||
|
channel = Channel(hub, transport_stream, server_side)
|
||||||
|
try:
|
||||||
|
with trio.fail_after(5):
|
||||||
|
peer = await channel.do_handshake()
|
||||||
|
except trio.TooSlowError:
|
||||||
|
_logger.warning(f"Handshake timed out")
|
||||||
|
return
|
||||||
|
conn = Connection(hub, channel, peer)
|
||||||
|
await conn._run()
|
||||||
|
|
||||||
|
|
||||||
|
class ControllerConn:
|
||||||
|
def __init__(self, hub: Hub, hostname: str):
|
||||||
|
self._hub = hub
|
||||||
|
self.hostname = hostname
|
||||||
|
self.loopback = False
|
||||||
|
|
||||||
|
async def _connect(self):
|
||||||
|
_logger.info(f"Connecting to controller at {self.hostname}")
|
||||||
|
with trio.fail_after(5):
|
||||||
|
try:
|
||||||
|
stream = await trio.open_tcp_stream(self.hostname, 5000)
|
||||||
|
except OSError as e:
|
||||||
|
_logger.warning(f"Failed to connect to controller at {self.hostname}: {e}")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await Connection.run(self._hub, stream, server_side=False)
|
||||||
|
finally:
|
||||||
|
_logger.info(f"Connection to {self.hostname} closed")
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await self._connect()
|
||||||
|
except LoopbackConnectionError:
|
||||||
|
_logger.debug(f"Connection to {self.hostname} reached ourself")
|
||||||
|
self.loopback = True
|
||||||
|
return
|
||||||
|
except trio.TooSlowError:
|
||||||
|
pass
|
||||||
|
# try again later
|
||||||
|
retry_splay = random.random() * 5
|
||||||
|
await trio.sleep(10 + retry_splay)
|
||||||
|
|
||||||
|
|
||||||
|
class HubApplication:
|
||||||
|
def is_controller(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def new_peer(self, *, peer_id: uuid.UUID) -> None:
|
||||||
|
_logger.info(f"New peer {peer_id}")
|
||||||
|
|
||||||
|
def lost_peer(self, *, peer_id: uuid.UUID) -> None:
|
||||||
|
_logger.warning(f"Lost peer {peer_id}")
|
||||||
|
|
||||||
|
async def received_unknown_message(self, *, from_peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
|
||||||
|
_logger.warning(f"Received from {from_peer_id}: {str(msg).strip()}")
|
||||||
|
|
||||||
|
async def received_mac_state(self, *, from_peer_id: uuid.UUID, states: capport.comm.message.MacStates) -> None:
|
||||||
|
if _logger.isEnabledFor(logging.DEBUG):
|
||||||
|
_logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}")
|
||||||
|
|
||||||
|
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
|
||||||
|
if _logger.isEnabledFor(logging.DEBUG):
|
||||||
|
_logger.debug(f"Received new states from {from_peer_id}: {pending_updates}")
|
||||||
|
|
||||||
|
|
||||||
|
class Hub:
|
||||||
|
def __init__(self, config: Config, app: HubApplication) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._instance_id = uuid.uuid4()
|
||||||
|
self._hostname = socket.getfqdn()
|
||||||
|
self.database = capport.database.Database()
|
||||||
|
self._app = app
|
||||||
|
self._is_controller = bool(app.is_controller())
|
||||||
|
self._anon_context = ssl.SSLContext()
|
||||||
|
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
|
||||||
|
self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||||
|
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||||
|
# -> AECDH-AES256-SHA
|
||||||
|
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
|
||||||
|
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
|
||||||
|
self._controllers: typing.Dict[str, ControllerConn] = {}
|
||||||
|
self._established: typing.Dict[uuid.UUID, Connection] = {}
|
||||||
|
|
||||||
|
async def _accept(self, stream):
|
||||||
|
remotename = stream.socket.getpeername()
|
||||||
|
if isinstance(remotename, tuple) and len(remotename) == 2:
|
||||||
|
remote = f'[{remotename[0]}]:{remotename[1]}'
|
||||||
|
else:
|
||||||
|
remote = str(remotename)
|
||||||
|
try:
|
||||||
|
await Connection.run(self, stream, server_side=True)
|
||||||
|
except LoopbackConnectionError:
|
||||||
|
pass
|
||||||
|
except trio.TooSlowError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
_logger.debug(f"Connection from {remote} closed")
|
||||||
|
|
||||||
|
async def _listen(self, task_status=trio.TASK_STATUS_IGNORED):
|
||||||
|
await trio.serve_tcp(self._accept, 5000, task_status=task_status)
|
||||||
|
|
||||||
|
async def run(self, *, task_status=trio.TASK_STATUS_IGNORED):
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
if self._is_controller:
|
||||||
|
await nursery.start(self._listen)
|
||||||
|
|
||||||
|
for name in self._config.controllers:
|
||||||
|
conn = ControllerConn(self, name)
|
||||||
|
self._controllers[name] = conn
|
||||||
|
|
||||||
|
task_status.started()
|
||||||
|
|
||||||
|
for conn in self._controllers.values():
|
||||||
|
nursery.start_soon(conn.run)
|
||||||
|
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
|
||||||
|
m = hmac.new(self._config.secret.encode('utf8'), digestmod=hashlib.sha256)
|
||||||
|
if server_side:
|
||||||
|
m.update(b'server$')
|
||||||
|
else:
|
||||||
|
m.update(b'client$')
|
||||||
|
m.update(ssl_binding)
|
||||||
|
return m.digest()
|
||||||
|
|
||||||
|
def _make_hello(self, ssl_binding: bytes, server_side: bool) -> capport.comm.message.Hello:
|
||||||
|
return capport.comm.message.Hello(
|
||||||
|
instance_id=self._instance_id.bytes,
|
||||||
|
hostname=self._hostname,
|
||||||
|
is_controller=self._is_controller,
|
||||||
|
authentication=self._calc_authentication(ssl_binding, server_side),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _sync_new_connection(self, peer_id: uuid.UUID, conn: Connection) -> None:
|
||||||
|
# send database (and all changes) to peers
|
||||||
|
await self.send(*self.database.serialize(), to=peer_id)
|
||||||
|
|
||||||
|
async def _new_peer(self, peer_id: uuid.UUID, conn: Connection, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
||||||
|
have = self._established.get(peer_id, None)
|
||||||
|
if not have:
|
||||||
|
# peer unknown, "normal start"
|
||||||
|
# no "await" between get above and set here!!!
|
||||||
|
self._established[peer_id] = conn
|
||||||
|
# first wait for app to handle new peer
|
||||||
|
await self._app.new_peer(peer_id=peer_id)
|
||||||
|
task_status.started()
|
||||||
|
await self._sync_new_connection(peer_id, conn)
|
||||||
|
return
|
||||||
|
|
||||||
|
# peer already known - immediately allow receiving messages, then sync connection
|
||||||
|
task_status.started()
|
||||||
|
await self._sync_new_connection(peer_id, conn)
|
||||||
|
# now try to register connection for outgoing messages
|
||||||
|
while True:
|
||||||
|
# recheck whether peer is currently known (due to awaits since last get)
|
||||||
|
have = self._established.get(peer_id, None)
|
||||||
|
if have:
|
||||||
|
# already got a connection, nothing to do as long as it lives
|
||||||
|
await have.closed.wait()
|
||||||
|
else:
|
||||||
|
# make `conn` new outgoing connection for peer
|
||||||
|
# no "await" between get above and set here!!!
|
||||||
|
self._established[peer_id] = conn
|
||||||
|
await self._app.new_peer(peer_id=peer_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
def _lost_peer(self, peer_id: uuid.UUID, conn: Connection):
|
||||||
|
have = self._established.get(peer_id, None)
|
||||||
|
lost = False
|
||||||
|
if have is conn:
|
||||||
|
lost = True
|
||||||
|
self._established.pop(peer_id)
|
||||||
|
conn.closed.set()
|
||||||
|
# only notify if this was the active connection
|
||||||
|
if lost:
|
||||||
|
# even when we failover to another connection we still need to resync
|
||||||
|
# as we don't know which messages might have got lost
|
||||||
|
# -> always trigger lost_peer
|
||||||
|
self._app.lost_peer(peer_id=peer_id)
|
||||||
|
|
||||||
|
async def _received_msg(self, peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
|
||||||
|
variant = msg.to_variant()
|
||||||
|
if isinstance(variant, capport.comm.message.Hello):
|
||||||
|
pass
|
||||||
|
elif isinstance(variant, capport.comm.message.AuthenticationResult):
|
||||||
|
pass
|
||||||
|
elif isinstance(variant, capport.comm.message.Ping):
|
||||||
|
pass
|
||||||
|
elif isinstance(variant, capport.comm.message.MacStates):
|
||||||
|
await self._app.received_mac_state(from_peer_id=peer_id, states=variant)
|
||||||
|
pu = capport.database.PendingUpdates()
|
||||||
|
for state in variant.states:
|
||||||
|
self.database.received_mac_state(state, pending_updates=pu)
|
||||||
|
if pu.macs:
|
||||||
|
# re-broadcast all received updates to all peers
|
||||||
|
await self.broadcast(*pu.serialize(), exclude=peer_id)
|
||||||
|
await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu)
|
||||||
|
else:
|
||||||
|
await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg)
|
||||||
|
|
||||||
|
def peer_is_controller(self, peer_id: uuid.UUID) -> bool:
|
||||||
|
conn = self._established.get(peer_id)
|
||||||
|
if conn:
|
||||||
|
return conn.peer.is_controller
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send(self, *msgs: capport.comm.message.Message, to: uuid.UUID):
|
||||||
|
conn = self._established.get(to)
|
||||||
|
if conn:
|
||||||
|
await conn.send_msg(*msgs)
|
||||||
|
|
||||||
|
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID]=None):
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for peer_id, conn in self._established.items():
|
||||||
|
if peer_id == exclude:
|
||||||
|
continue
|
||||||
|
nursery.start_soon(conn.send_msg, *msgs)
|
37
src/capport/comm/message.py
Normal file
37
src/capport/comm/message.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from .protobuf import message_pb2
|
||||||
|
|
||||||
|
|
||||||
|
def _message_to_variant(self: message_pb2.Message) -> typing.Any:
|
||||||
|
variant_name = self.WhichOneof('oneof')
|
||||||
|
if variant_name:
|
||||||
|
return getattr(self, variant_name)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _make_to_message(oneof_field):
|
||||||
|
def to_message(self) -> message_pb2.Message:
|
||||||
|
msg = message_pb2.Message(**{oneof_field: self})
|
||||||
|
return msg
|
||||||
|
return to_message
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _monkey_patch():
|
||||||
|
g = globals()
|
||||||
|
g['Message'] = message_pb2.Message
|
||||||
|
message_pb2.Message.to_variant = _message_to_variant
|
||||||
|
for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields:
|
||||||
|
type_name = field.message_type.name
|
||||||
|
field_type = getattr(message_pb2, type_name)
|
||||||
|
field_type.to_message = _make_to_message(field.name)
|
||||||
|
g[type_name] = field_type
|
||||||
|
|
||||||
|
|
||||||
|
# also re-exports all message types
|
||||||
|
_monkey_patch()
|
||||||
|
# not a variant of Message, still re-export
|
||||||
|
MacState = message_pb2.MacState
|
93
src/capport/comm/message.pyi
Normal file
93
src/capport/comm/message.pyi
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import google.protobuf.message
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
|
# manually maintained typehints for protobuf created (and monkey-patched) types
|
||||||
|
|
||||||
|
|
||||||
|
class Message(google.protobuf.message.Message):
|
||||||
|
hello: Hello
|
||||||
|
authentication_result: AuthenticationResult
|
||||||
|
ping: Ping
|
||||||
|
mac_states: MacStates
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
hello: typing.Optional[Hello]=None,
|
||||||
|
authentication_result: typing.Optional[AuthenticationResult]=None,
|
||||||
|
ping: typing.Optional[Ping]=None,
|
||||||
|
mac_states: typing.Optional[MacStates]=None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
def to_variant(self) -> typing.Union[Hello, AuthenticationResult, Ping, MacStates]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class Hello(google.protobuf.message.Message):
|
||||||
|
instance_id: bytes
|
||||||
|
hostname: str
|
||||||
|
is_controller: bool
|
||||||
|
authentication: bytes
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
instance_id: bytes=b'',
|
||||||
|
hostname: str='',
|
||||||
|
is_controller: bool=False,
|
||||||
|
authentication: bytes=b'',
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
def to_message(self) -> Message: ...
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationResult(google.protobuf.message.Message):
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
success: bool=False,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
def to_message(self) -> Message: ...
|
||||||
|
|
||||||
|
|
||||||
|
class Ping(google.protobuf.message.Message):
|
||||||
|
payload: bytes
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
payload: bytes=b'',
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
def to_message(self) -> Message: ...
|
||||||
|
|
||||||
|
|
||||||
|
class MacStates(google.protobuf.message.Message):
|
||||||
|
states: typing.List[MacState]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
states: typing.List[MacState]=[],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
def to_message(self) -> Message: ...
|
||||||
|
|
||||||
|
|
||||||
|
class MacState(google.protobuf.message.Message):
|
||||||
|
mac_address: bytes
|
||||||
|
last_change: int # Seconds of UTC time since epoch
|
||||||
|
allow_until: int # Seconds of UTC time since epoch
|
||||||
|
allowed: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
mac_address: bytes=b'',
|
||||||
|
last_change: int=0,
|
||||||
|
allow_until: int=0,
|
||||||
|
allowed: bool=False,
|
||||||
|
) -> None: ...
|
355
src/capport/comm/protobuf/message_pb2.py
Normal file
355
src/capport/comm/protobuf/message_pb2.py
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# source: message.proto
|
||||||
|
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import message as _message
|
||||||
|
from google.protobuf import reflection as _reflection
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor.FileDescriptor(
|
||||||
|
name='message.proto',
|
||||||
|
package='capport',
|
||||||
|
syntax='proto3',
|
||||||
|
serialized_options=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
serialized_pb=b'\n\rmessage.proto\x12\x07\x63\x61pport\"\xbc\x01\n\x07Message\x12\x1f\n\x05hello\x18\x01 \x01(\x0b\x32\x0e.capport.HelloH\x00\x12>\n\x15\x61uthentication_result\x18\x02 \x01(\x0b\x32\x1d.capport.AuthenticationResultH\x00\x12\x1d\n\x04ping\x18\x03 \x01(\x0b\x32\r.capport.PingH\x00\x12(\n\nmac_states\x18\n \x01(\x0b\x32\x12.capport.MacStatesH\x00\x42\x07\n\x05oneof\"]\n\x05Hello\x12\x13\n\x0binstance_id\x18\x01 \x01(\x0c\x12\x10\n\x08hostname\x18\x02 \x01(\t\x12\x15\n\ris_controller\x18\x03 \x01(\x08\x12\x16\n\x0e\x61uthentication\x18\x04 \x01(\x0c\"\'\n\x14\x41uthenticationResult\x12\x0f\n\x07success\x18\x01 \x01(\x08\"\x17\n\x04Ping\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\".\n\tMacStates\x12!\n\x06states\x18\x01 \x03(\x0b\x32\x11.capport.MacState\"Z\n\x08MacState\x12\x13\n\x0bmac_address\x18\x01 \x01(\x0c\x12\x13\n\x0blast_change\x18\x02 \x01(\x03\x12\x13\n\x0b\x61llow_until\x18\x03 \x01(\x03\x12\x0f\n\x07\x61llowed\x18\x04 \x01(\x08\x62\x06proto3'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
_MESSAGE = _descriptor.Descriptor(
|
||||||
|
name='Message',
|
||||||
|
full_name='capport.Message',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='hello', full_name='capport.Message.hello', index=0,
|
||||||
|
number=1, type=11, cpp_type=10, label=1,
|
||||||
|
has_default_value=False, default_value=None,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='authentication_result', full_name='capport.Message.authentication_result', index=1,
|
||||||
|
number=2, type=11, cpp_type=10, label=1,
|
||||||
|
has_default_value=False, default_value=None,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='ping', full_name='capport.Message.ping', index=2,
|
||||||
|
number=3, type=11, cpp_type=10, label=1,
|
||||||
|
has_default_value=False, default_value=None,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='mac_states', full_name='capport.Message.mac_states', index=3,
|
||||||
|
number=10, type=11, cpp_type=10, label=1,
|
||||||
|
has_default_value=False, default_value=None,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
_descriptor.OneofDescriptor(
|
||||||
|
name='oneof', full_name='capport.Message.oneof',
|
||||||
|
index=0, containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[]),
|
||||||
|
],
|
||||||
|
serialized_start=27,
|
||||||
|
serialized_end=215,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_HELLO = _descriptor.Descriptor(
|
||||||
|
name='Hello',
|
||||||
|
full_name='capport.Hello',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='instance_id', full_name='capport.Hello.instance_id', index=0,
|
||||||
|
number=1, type=12, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=b"",
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='hostname', full_name='capport.Hello.hostname', index=1,
|
||||||
|
number=2, type=9, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=b"".decode('utf-8'),
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='is_controller', full_name='capport.Hello.is_controller', index=2,
|
||||||
|
number=3, type=8, cpp_type=7, label=1,
|
||||||
|
has_default_value=False, default_value=False,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='authentication', full_name='capport.Hello.authentication', index=3,
|
||||||
|
number=4, type=12, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=b"",
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=217,
|
||||||
|
serialized_end=310,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_AUTHENTICATIONRESULT = _descriptor.Descriptor(
|
||||||
|
name='AuthenticationResult',
|
||||||
|
full_name='capport.AuthenticationResult',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='success', full_name='capport.AuthenticationResult.success', index=0,
|
||||||
|
number=1, type=8, cpp_type=7, label=1,
|
||||||
|
has_default_value=False, default_value=False,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=312,
|
||||||
|
serialized_end=351,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PING = _descriptor.Descriptor(
|
||||||
|
name='Ping',
|
||||||
|
full_name='capport.Ping',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='payload', full_name='capport.Ping.payload', index=0,
|
||||||
|
number=1, type=12, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=b"",
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=353,
|
||||||
|
serialized_end=376,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_MACSTATES = _descriptor.Descriptor(
|
||||||
|
name='MacStates',
|
||||||
|
full_name='capport.MacStates',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='states', full_name='capport.MacStates.states', index=0,
|
||||||
|
number=1, type=11, cpp_type=10, label=3,
|
||||||
|
has_default_value=False, default_value=[],
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=378,
|
||||||
|
serialized_end=424,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_MACSTATE = _descriptor.Descriptor(
|
||||||
|
name='MacState',
|
||||||
|
full_name='capport.MacState',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='mac_address', full_name='capport.MacState.mac_address', index=0,
|
||||||
|
number=1, type=12, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=b"",
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='last_change', full_name='capport.MacState.last_change', index=1,
|
||||||
|
number=2, type=3, cpp_type=2, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='allow_until', full_name='capport.MacState.allow_until', index=2,
|
||||||
|
number=3, type=3, cpp_type=2, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='allowed', full_name='capport.MacState.allowed', index=3,
|
||||||
|
number=4, type=8, cpp_type=7, label=1,
|
||||||
|
has_default_value=False, default_value=False,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=426,
|
||||||
|
serialized_end=516,
|
||||||
|
)
|
||||||
|
|
||||||
|
_MESSAGE.fields_by_name['hello'].message_type = _HELLO
|
||||||
|
_MESSAGE.fields_by_name['authentication_result'].message_type = _AUTHENTICATIONRESULT
|
||||||
|
_MESSAGE.fields_by_name['ping'].message_type = _PING
|
||||||
|
_MESSAGE.fields_by_name['mac_states'].message_type = _MACSTATES
|
||||||
|
_MESSAGE.oneofs_by_name['oneof'].fields.append(
|
||||||
|
_MESSAGE.fields_by_name['hello'])
|
||||||
|
_MESSAGE.fields_by_name['hello'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
|
||||||
|
_MESSAGE.oneofs_by_name['oneof'].fields.append(
|
||||||
|
_MESSAGE.fields_by_name['authentication_result'])
|
||||||
|
_MESSAGE.fields_by_name['authentication_result'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
|
||||||
|
_MESSAGE.oneofs_by_name['oneof'].fields.append(
|
||||||
|
_MESSAGE.fields_by_name['ping'])
|
||||||
|
_MESSAGE.fields_by_name['ping'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
|
||||||
|
_MESSAGE.oneofs_by_name['oneof'].fields.append(
|
||||||
|
_MESSAGE.fields_by_name['mac_states'])
|
||||||
|
_MESSAGE.fields_by_name['mac_states'].containing_oneof = _MESSAGE.oneofs_by_name['oneof']
|
||||||
|
_MACSTATES.fields_by_name['states'].message_type = _MACSTATE
|
||||||
|
DESCRIPTOR.message_types_by_name['Message'] = _MESSAGE
|
||||||
|
DESCRIPTOR.message_types_by_name['Hello'] = _HELLO
|
||||||
|
DESCRIPTOR.message_types_by_name['AuthenticationResult'] = _AUTHENTICATIONRESULT
|
||||||
|
DESCRIPTOR.message_types_by_name['Ping'] = _PING
|
||||||
|
DESCRIPTOR.message_types_by_name['MacStates'] = _MACSTATES
|
||||||
|
DESCRIPTOR.message_types_by_name['MacState'] = _MACSTATE
|
||||||
|
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
||||||
|
|
||||||
|
Message = _reflection.GeneratedProtocolMessageType('Message', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _MESSAGE,
|
||||||
|
'__module__' : 'message_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:capport.Message)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(Message)
|
||||||
|
|
||||||
|
Hello = _reflection.GeneratedProtocolMessageType('Hello', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _HELLO,
|
||||||
|
'__module__' : 'message_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:capport.Hello)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(Hello)
|
||||||
|
|
||||||
|
AuthenticationResult = _reflection.GeneratedProtocolMessageType('AuthenticationResult', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _AUTHENTICATIONRESULT,
|
||||||
|
'__module__' : 'message_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:capport.AuthenticationResult)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(AuthenticationResult)
|
||||||
|
|
||||||
|
Ping = _reflection.GeneratedProtocolMessageType('Ping', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _PING,
|
||||||
|
'__module__' : 'message_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:capport.Ping)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(Ping)
|
||||||
|
|
||||||
|
MacStates = _reflection.GeneratedProtocolMessageType('MacStates', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _MACSTATES,
|
||||||
|
'__module__' : 'message_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:capport.MacStates)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(MacStates)
|
||||||
|
|
||||||
|
MacState = _reflection.GeneratedProtocolMessageType('MacState', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _MACSTATE,
|
||||||
|
'__module__' : 'message_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:capport.MacState)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(MacState)
|
||||||
|
|
||||||
|
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
34
src/capport/config.py
Normal file
34
src/capport/config.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import os.path
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Config:
|
||||||
|
controllers: typing.List[str]
|
||||||
|
secret: str
|
||||||
|
venue_info_url: typing.Optional[str]
|
||||||
|
session_timeout: int # in seconds
|
||||||
|
debug: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(filename: typing.Optional[str]=None) -> 'Config':
|
||||||
|
if filename is None:
|
||||||
|
for name in ('capport.yaml', '/etc/capport.yaml'):
|
||||||
|
if os.path.exists(name):
|
||||||
|
return Config.load(name)
|
||||||
|
raise RuntimeError("Missing config file")
|
||||||
|
with open(filename) as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
controllers = list(map(str, data['controllers']))
|
||||||
|
return Config(
|
||||||
|
controllers=controllers,
|
||||||
|
secret=str(data['secret']),
|
||||||
|
venue_info_url=str(data.get('venue-info-url')),
|
||||||
|
session_timeout=data.get('session-timeout', 3600),
|
||||||
|
debug=data.get('debug', False)
|
||||||
|
)
|
0
src/capport/control/__init__.py
Normal file
0
src/capport/control/__init__.py
Normal file
56
src/capport/control/run.py
Normal file
56
src/capport/control/run.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import capport.database
|
||||||
|
import capport.comm.hub
|
||||||
|
import capport.comm.message
|
||||||
|
import capport.config
|
||||||
|
import capport.utils.cli
|
||||||
|
import capport.utils.nft_set
|
||||||
|
import trio
|
||||||
|
from capport import cptypes
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ControlApp(capport.comm.hub.HubApplication):
|
||||||
|
hub: capport.comm.hub.Hub
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.nft_set = capport.utils.nft_set.NftSet()
|
||||||
|
|
||||||
|
def is_controller(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def mac_states_changed(self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates) -> None:
|
||||||
|
# deploy changes to netfilter set
|
||||||
|
inserts = []
|
||||||
|
removals = []
|
||||||
|
now = cptypes.Timestamp.now()
|
||||||
|
for mac, state in pending_updates.macs.items():
|
||||||
|
rem = state.allowed_remaining(now)
|
||||||
|
if rem > 0:
|
||||||
|
inserts.append((mac, rem))
|
||||||
|
else:
|
||||||
|
removals.append(mac)
|
||||||
|
self.nft_set.bulk_insert(inserts)
|
||||||
|
self.nft_set.bulk_remove(removals)
|
||||||
|
|
||||||
|
|
||||||
|
async def amain(config: capport.config.Config) -> None:
|
||||||
|
app = ControlApp()
|
||||||
|
hub = capport.comm.hub.Hub(config=config, app=app)
|
||||||
|
app.hub = hub
|
||||||
|
await hub.run()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
config = capport.config.Config.load()
|
||||||
|
capport.utils.cli.init_logger(config)
|
||||||
|
try:
|
||||||
|
trio.run(amain, config)
|
||||||
|
except (KeyboardInterrupt, InterruptedError):
|
||||||
|
print()
|
94
src/capport/cptypes.py
Normal file
94
src/capport/cptypes.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import datetime
|
||||||
|
import ipaddress
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import quart
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
|
|
||||||
|
IPAddress = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class MacAddress:
|
||||||
|
raw: bytes
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.raw.hex(':')
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return repr(str(self))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse(s: str) -> MacAddress:
|
||||||
|
return MacAddress(bytes.fromhex(s.replace(':', '')))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True, order=True)
|
||||||
|
class Timestamp:
|
||||||
|
epoch: int
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
try:
|
||||||
|
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
|
||||||
|
return ts.isoformat(sep=' ')
|
||||||
|
except OSError:
|
||||||
|
return f'epoch@{self.epoch}'
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return repr(str(self))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def now() -> Timestamp:
|
||||||
|
now = int(time.time())
|
||||||
|
return Timestamp(epoch=now)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_protobuf(epoch: int) -> typing.Optional[Timestamp]:
|
||||||
|
if epoch:
|
||||||
|
return Timestamp(epoch=epoch)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MacPublicState:
|
||||||
|
address: IPAddress
|
||||||
|
mac: typing.Optional[MacAddress]
|
||||||
|
allowed_remaining: int
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_missing_mac(address: IPAddress) -> MacPublicState:
|
||||||
|
return MacPublicState(
|
||||||
|
address=address,
|
||||||
|
mac=None,
|
||||||
|
allowed_remaining=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allowed(self) -> bool:
|
||||||
|
return self.allowed_remaining > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def captive(self) -> bool:
|
||||||
|
return not self.allowed
|
||||||
|
|
||||||
|
def to_rfc8908(self, config: Config) -> quart.Response:
|
||||||
|
response: typing.Dict[str, typing.Any] = {
|
||||||
|
'user-portal-url': quart.url_for('index', _external=True),
|
||||||
|
}
|
||||||
|
if config.venue_info_url:
|
||||||
|
response['venue-info-url'] = config.venue_info_url
|
||||||
|
if self.captive:
|
||||||
|
response['captive'] = True
|
||||||
|
else:
|
||||||
|
response['captive'] = False
|
||||||
|
response['seconds-remaining'] = self.allowed_remaining
|
||||||
|
response['can-extend-session'] = True
|
||||||
|
return quart.Response(json.dumps(response), headers={'Cache-Control': 'private'}, content_type='application/captive+json')
|
200
src/capport/database.py
Normal file
200
src/capport/database.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import capport.comm.message
|
||||||
|
from capport import cptypes
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MacEntry:
|
||||||
|
# entry can be removed if last_change was some time ago and allow_until wasn't set
|
||||||
|
# or got reached.
|
||||||
|
WAIT_LAST_CHANGE_SECONDS = 60
|
||||||
|
WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10
|
||||||
|
|
||||||
|
# last_change: timestamp of last change (sent by system initiating the change)
|
||||||
|
last_change: cptypes.Timestamp
|
||||||
|
# only if allowed is true and allow_until is set the device can communicate with the internet
|
||||||
|
# allow_until must not go backwards (and not get unset)
|
||||||
|
allow_until: typing.Optional[cptypes.Timestamp]
|
||||||
|
allowed: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
|
||||||
|
if len(msg.mac_address) < 6:
|
||||||
|
raise Exception("Invalid MacState: mac_address too short")
|
||||||
|
addr = cptypes.MacAddress(raw=msg.mac_address)
|
||||||
|
last_change = cptypes.Timestamp.from_protobuf(msg.last_change)
|
||||||
|
if not last_change:
|
||||||
|
raise Exception(f"Invalid MacState[{addr}]: missing last_change")
|
||||||
|
allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until)
|
||||||
|
return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed))
|
||||||
|
|
||||||
|
def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState:
|
||||||
|
allow_until = 0
|
||||||
|
if self.allow_until:
|
||||||
|
allow_until = self.allow_until.epoch
|
||||||
|
return capport.comm.message.MacState(
|
||||||
|
mac_address=addr.raw,
|
||||||
|
last_change=self.last_change.epoch,
|
||||||
|
allow_until=allow_until,
|
||||||
|
allowed=self.allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_json(self) -> dict:
|
||||||
|
allow_until = None
|
||||||
|
if self.allow_until:
|
||||||
|
allow_until = self.allow_until.epoch
|
||||||
|
return dict(
|
||||||
|
last_change=self.last_change.epoch,
|
||||||
|
allow_until=allow_until,
|
||||||
|
allowed=self.allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def merge(self, new: MacEntry) -> bool:
|
||||||
|
changed = False
|
||||||
|
if new.last_change > self.last_change:
|
||||||
|
changed = True
|
||||||
|
self.last_change = new.last_change
|
||||||
|
self.allowed = new.allowed
|
||||||
|
elif new.last_change == self.last_change:
|
||||||
|
# same last_change: set allowed if one allowed
|
||||||
|
if new.allowed and not self.allowed:
|
||||||
|
changed = True
|
||||||
|
self.allowed = True
|
||||||
|
# set allow_until to max of both
|
||||||
|
if new.allow_until: # if not set nothing to change in local data
|
||||||
|
if not self.allow_until or self.allow_until < new.allow_until:
|
||||||
|
changed = True
|
||||||
|
self.allow_until = new.allow_until
|
||||||
|
return changed
|
||||||
|
|
||||||
|
def timeout(self) -> cptypes.Timestamp:
|
||||||
|
elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS
|
||||||
|
if self.allow_until:
|
||||||
|
eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS
|
||||||
|
if eau > elc:
|
||||||
|
return cptypes.Timestamp(epoch=eau)
|
||||||
|
return cptypes.Timestamp(epoch=elc)
|
||||||
|
|
||||||
|
# returns 0 if not allowed
|
||||||
|
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp]=None) -> int:
|
||||||
|
if not self.allowed or not self.allow_until:
|
||||||
|
return 0
|
||||||
|
if not now:
|
||||||
|
now = cptypes.Timestamp.now()
|
||||||
|
assert self.allow_until
|
||||||
|
return max(self.allow_until.epoch - now.epoch, 0)
|
||||||
|
|
||||||
|
def outdated(self, now: typing.Optional[cptypes.Timestamp]=None) -> bool:
|
||||||
|
if not now:
|
||||||
|
now = cptypes.Timestamp.now()
|
||||||
|
return now.epoch > self.timeout().epoch
|
||||||
|
|
||||||
|
|
||||||
|
# might use this to serialize into file - don't need Message variant there
|
||||||
|
def _serialize_mac_states(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
|
||||||
|
result: typing.List[capport.comm.message.MacStates] = []
|
||||||
|
current = capport.comm.message.MacStates()
|
||||||
|
for addr, entry in macs.items():
|
||||||
|
state = entry.to_state(addr)
|
||||||
|
current.states.append(state)
|
||||||
|
if len(current.states) >= 1024: # split into messages with 1024 states
|
||||||
|
result.append(current)
|
||||||
|
current = capport.comm.message.MacStates()
|
||||||
|
if len(current.states):
|
||||||
|
result.append(current)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_mac_states_as_messages(macs: typing.Dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.Message]:
|
||||||
|
return [s.to_message() for s in _serialize_mac_states(macs)]
|
||||||
|
|
||||||
|
|
||||||
|
class NotReadyYet(Exception):
|
||||||
|
def __init__(self, msg: str, wait: int):
|
||||||
|
self.wait = wait # seconds to wait
|
||||||
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Database:
|
||||||
|
_macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
def received_mac_state(self, state: capport.comm.message.MacState, *, pending_updates: PendingUpdates):
|
||||||
|
(addr, new_entry) = MacEntry.parse_state(state)
|
||||||
|
old_entry = self._macs.get(addr)
|
||||||
|
if not old_entry:
|
||||||
|
# only redistribute if not outdated
|
||||||
|
if not new_entry.outdated():
|
||||||
|
self._macs[addr] = new_entry
|
||||||
|
pending_updates.macs[addr] = new_entry
|
||||||
|
elif old_entry.merge(new_entry):
|
||||||
|
if old_entry.outdated():
|
||||||
|
# remove local entry, but still redistribute
|
||||||
|
self._macs.pop(addr)
|
||||||
|
pending_updates.macs[addr] = old_entry
|
||||||
|
|
||||||
|
def serialize(self) -> typing.List[capport.comm.message.Message]:
|
||||||
|
return _serialize_mac_states_as_messages(self._macs)
|
||||||
|
|
||||||
|
def as_json(self) -> dict:
|
||||||
|
return {
|
||||||
|
str(addr): entry.as_json()
|
||||||
|
for addr, entry in self._macs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
|
||||||
|
entry = self._macs.get(mac)
|
||||||
|
if entry:
|
||||||
|
allowed_remaining = entry.allowed_remaining()
|
||||||
|
else:
|
||||||
|
allowed_remaining = 0
|
||||||
|
return cptypes.MacPublicState(
|
||||||
|
address=address,
|
||||||
|
mac=mac,
|
||||||
|
allowed_remaining=allowed_remaining,
|
||||||
|
)
|
||||||
|
|
||||||
|
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, pending_updates: PendingUpdates, renew_maximum: float=0.8):
|
||||||
|
now = cptypes.Timestamp.now()
|
||||||
|
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
|
||||||
|
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
|
||||||
|
|
||||||
|
entry = self._macs.get(mac)
|
||||||
|
if not entry:
|
||||||
|
self._macs[mac] = new_entry
|
||||||
|
pending_updates.macs[mac] = new_entry
|
||||||
|
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
|
||||||
|
# too much time left on clock, not renewing session
|
||||||
|
return
|
||||||
|
elif entry.merge(new_entry):
|
||||||
|
pending_updates.macs[mac] = entry
|
||||||
|
elif not entry.allowed_remaining() > 0:
|
||||||
|
# entry should have been updated - can only fail due to `now < entry.last_change`
|
||||||
|
# i.e. out of sync clocks
|
||||||
|
wait = entry.last_change.epoch - now.epoch
|
||||||
|
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
|
||||||
|
|
||||||
|
def logout(self, mac: cptypes.MacAddress, *, pending_updates: PendingUpdates):
|
||||||
|
now = cptypes.Timestamp.now()
|
||||||
|
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
|
||||||
|
entry = self._macs.get(mac)
|
||||||
|
if entry:
|
||||||
|
if entry.merge(new_entry):
|
||||||
|
pending_updates.macs[mac] = entry
|
||||||
|
elif entry.allowed_remaining() > 0:
|
||||||
|
# still logged in. can only happen with `now <= entry.last_change`
|
||||||
|
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
|
||||||
|
wait = entry.last_change.epoch - now.epoch + 1
|
||||||
|
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PendingUpdates:
|
||||||
|
macs: typing.Dict[cptypes.MacAddress, MacEntry] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
def serialize(self) -> typing.List[capport.comm.message.Message]:
|
||||||
|
return _serialize_mac_states_as_messages(self.macs)
|
0
src/capport/utils/__init__.py
Normal file
0
src/capport/utils/__init__.py
Normal file
17
src/capport/utils/cli.py
Normal file
17
src/capport/utils/cli.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import capport.config
|
||||||
|
|
||||||
|
|
||||||
|
def init_logger(config: capport.config.Config):
|
||||||
|
loglevel = logging.INFO
|
||||||
|
if config.debug:
|
||||||
|
loglevel = logging.DEBUG
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
|
||||||
|
datefmt='[%Y-%m-%d %H:%M:%S %z]',
|
||||||
|
level=loglevel,
|
||||||
|
)
|
||||||
|
logging.getLogger('hypercorn').propagate = False
|
63
src/capport/utils/ipneigh.py
Normal file
63
src/capport/utils/ipneigh.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import errno
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import pr2modules.iproute.linux
|
||||||
|
import pr2modules.netlink.exceptions
|
||||||
|
from capport import cptypes
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def connect():
|
||||||
|
yield NeighborController()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: run blocking iproute calls in a different thread?
|
||||||
|
class NeighborController:
|
||||||
|
def __init__(self):
|
||||||
|
self.ip = pr2modules.iproute.linux.IPRoute()
|
||||||
|
|
||||||
|
async def get_neighbor(
|
||||||
|
self,
|
||||||
|
address: cptypes.IPAddress,
|
||||||
|
*,
|
||||||
|
index: int=0, # interface index
|
||||||
|
flags: int=0,
|
||||||
|
) -> typing.Optional[pr2modules.iproute.linux.ndmsg.ndmsg]:
|
||||||
|
if not index:
|
||||||
|
route = await self.get_route(address)
|
||||||
|
if route is None:
|
||||||
|
return None
|
||||||
|
index = route.get_attr(route.name2nla('oif'))
|
||||||
|
try:
|
||||||
|
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
|
||||||
|
except pr2modules.netlink.exceptions.NetlinkError as e:
|
||||||
|
if e.code == errno.ENOENT:
|
||||||
|
return None
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_neighbor_mac(
|
||||||
|
self,
|
||||||
|
address: cptypes.IPAddress,
|
||||||
|
*,
|
||||||
|
index: int=0, # interface index
|
||||||
|
flags: int=0,
|
||||||
|
) -> typing.Optional[cptypes.MacAddress]:
|
||||||
|
neigh = await self.get_neighbor(address, index=index, flags=flags)
|
||||||
|
if neigh is None:
|
||||||
|
return None
|
||||||
|
mac = neigh.get_attr(neigh.name2nla('lladdr'))
|
||||||
|
return cptypes.MacAddress.parse(mac)
|
||||||
|
|
||||||
|
async def get_route(
|
||||||
|
self,
|
||||||
|
address: cptypes.IPAddress,
|
||||||
|
) -> typing.Optional[pr2modules.iproute.linux.rtmsg]:
|
||||||
|
try:
|
||||||
|
return self.ip.route('get', dst=str(address))[0]
|
||||||
|
except pr2modules.netlink.exceptions.NetlinkError as e:
|
||||||
|
if e.code == errno.ENOENT:
|
||||||
|
return None
|
||||||
|
raise
|
181
src/capport/utils/nft_set.py
Normal file
181
src/capport/utils/nft_set.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import pr2modules.netlink
|
||||||
|
from capport import cptypes
|
||||||
|
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
|
||||||
|
|
||||||
|
from .nft_socket import NFTSocket
|
||||||
|
|
||||||
|
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
|
||||||
|
|
||||||
|
|
||||||
|
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]:
|
||||||
|
# to seconds
|
||||||
|
if msecs is None:
|
||||||
|
return None
|
||||||
|
return msecs / 1000.0
|
||||||
|
|
||||||
|
|
||||||
|
class NftSet:
|
||||||
|
def __init__(self):
|
||||||
|
self._socket = NFTSocket()
|
||||||
|
self._socket.bind()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _set_elem(mac: cptypes.MacAddress, timeout: typing.Optional[typing.Union[int, float]]=None) -> _nftsocket.nft_set_elem_list_msg.set_elem:
|
||||||
|
attrs: typing.Dict[str, typing.Any] = {
|
||||||
|
'NFTA_SET_ELEM_KEY': dict(
|
||||||
|
NFTA_DATA_VALUE=mac.raw,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if timeout:
|
||||||
|
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
def _bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
|
||||||
|
ser_entries = [
|
||||||
|
self._set_elem(mac)
|
||||||
|
for mac, _timeout in entries
|
||||||
|
]
|
||||||
|
ser_entries_with_timeout = [
|
||||||
|
self._set_elem(mac, timeout)
|
||||||
|
for mac, timeout in entries
|
||||||
|
]
|
||||||
|
with self._socket.begin() as tx:
|
||||||
|
# create doesn't affect existing elements, so:
|
||||||
|
# make sure entries exists
|
||||||
|
tx.put(
|
||||||
|
_nftsocket.NFT_MSG_NEWSETELEM,
|
||||||
|
pr2modules.netlink.NLM_F_CREATE,
|
||||||
|
nfgen_family=NFPROTO_INET,
|
||||||
|
attrs=dict(
|
||||||
|
NFTA_SET_TABLE='captive_mark',
|
||||||
|
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||||
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# drop entries (would fail if it doesn't exist)
|
||||||
|
tx.put(
|
||||||
|
_nftsocket.NFT_MSG_DELSETELEM,
|
||||||
|
nfgen_family=NFPROTO_INET,
|
||||||
|
attrs=dict(
|
||||||
|
NFTA_SET_TABLE='captive_mark',
|
||||||
|
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||||
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# now create entries with new timeout value
|
||||||
|
tx.put(
|
||||||
|
_nftsocket.NFT_MSG_NEWSETELEM,
|
||||||
|
pr2modules.netlink.NLM_F_CREATE|pr2modules.netlink.NLM_F_EXCL,
|
||||||
|
nfgen_family=NFPROTO_INET,
|
||||||
|
attrs=dict(
|
||||||
|
NFTA_SET_TABLE='captive_mark',
|
||||||
|
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||||
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
|
||||||
|
# limit chunk size
|
||||||
|
while len(entries) > 0:
|
||||||
|
self._bulk_insert(entries[:1024])
|
||||||
|
entries = entries[1024:]
|
||||||
|
|
||||||
|
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None:
|
||||||
|
self.bulk_insert([(mac, timeout)])
|
||||||
|
|
||||||
|
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
||||||
|
ser_entries = [
|
||||||
|
self._set_elem(mac)
|
||||||
|
for mac in entries
|
||||||
|
]
|
||||||
|
with self._socket.begin() as tx:
|
||||||
|
# make sure entries exists
|
||||||
|
tx.put(
|
||||||
|
_nftsocket.NFT_MSG_NEWSETELEM,
|
||||||
|
pr2modules.netlink.NLM_F_CREATE,
|
||||||
|
nfgen_family=NFPROTO_INET,
|
||||||
|
attrs=dict(
|
||||||
|
NFTA_SET_TABLE='captive_mark',
|
||||||
|
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||||
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# drop entries (would fail if it doesn't exist)
|
||||||
|
tx.put(
|
||||||
|
_nftsocket.NFT_MSG_DELSETELEM,
|
||||||
|
nfgen_family=NFPROTO_INET,
|
||||||
|
attrs=dict(
|
||||||
|
NFTA_SET_TABLE='captive_mark',
|
||||||
|
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||||
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
||||||
|
# limit chunk size
|
||||||
|
while len(entries) > 0:
|
||||||
|
self._bulk_remove(entries[:1024])
|
||||||
|
entries = entries[1024:]
|
||||||
|
|
||||||
|
def remove(self, mac: cptypes.MacAddress) -> None:
|
||||||
|
self.bulk_remove([mac])
|
||||||
|
|
||||||
|
def list(self) -> list:
|
||||||
|
responses: typing.Iterator[_nftsocket.nft_set_elem_list_msg]
|
||||||
|
responses = self._socket.nft_dump(
|
||||||
|
_nftsocket.NFT_MSG_GETSETELEM,
|
||||||
|
nfgen_family=NFPROTO_INET,
|
||||||
|
attrs=dict(
|
||||||
|
NFTA_SET_TABLE='captive_mark',
|
||||||
|
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'mac': cptypes.MacAddress(
|
||||||
|
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
|
||||||
|
),
|
||||||
|
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
|
||||||
|
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
|
||||||
|
}
|
||||||
|
for response in responses
|
||||||
|
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
|
||||||
|
]
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
self._socket.nft_put(
|
||||||
|
_nftsocket.NFT_MSG_DELSETELEM,
|
||||||
|
nfgen_family=NFPROTO_INET,
|
||||||
|
attrs=dict(
|
||||||
|
NFTA_SET_TABLE='captive_mark',
|
||||||
|
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def create(self):
|
||||||
|
with self._socket.begin() as tx:
|
||||||
|
tx.put(
|
||||||
|
_nftsocket.NFT_MSG_NEWTABLE,
|
||||||
|
pr2modules.netlink.NLM_F_CREATE,
|
||||||
|
nfgen_family=NFPROTO_INET,
|
||||||
|
attrs=dict(
|
||||||
|
NFTA_TABLE_NAME='captive_mark',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tx.put(
|
||||||
|
_nftsocket.NFT_MSG_NEWSET,
|
||||||
|
pr2modules.netlink.NLM_F_CREATE,
|
||||||
|
nfgen_family=NFPROTO_INET,
|
||||||
|
attrs=dict(
|
||||||
|
NFTA_SET_TABLE='captive_mark',
|
||||||
|
NFTA_SET_NAME='allowed',
|
||||||
|
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
|
||||||
|
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
|
||||||
|
NFTA_SET_KEY_LEN=6, # length of key: mac address
|
||||||
|
NFTA_SET_ID=1, # kernel seems to need a set id unique per transaction
|
||||||
|
),
|
||||||
|
)
|
217
src/capport/utils/nft_socket.py
Normal file
217
src/capport/utils/nft_socket.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import typing
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
|
||||||
|
import pr2modules.netlink
|
||||||
|
import pr2modules.netlink.nlsocket
|
||||||
|
from pr2modules.netlink.nfnetlink import nfgen_msg
|
||||||
|
from pr2modules.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES
|
||||||
|
|
||||||
|
|
||||||
|
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
|
||||||
|
|
||||||
|
|
||||||
|
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pr2modules.netlink.nlmsg_base)
|
||||||
|
|
||||||
|
|
||||||
|
# nft uses NESTED for those.. lets do the same
|
||||||
|
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pr2modules.netlink.NLA_F_NESTED
|
||||||
|
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pr2modules.netlink.NLA_F_NESTED
|
||||||
|
|
||||||
|
|
||||||
|
def _monkey_patch_pyroute2():
|
||||||
|
import pr2modules.netlink
|
||||||
|
# overwrite setdefault on nlmsg_base class hierarchy
|
||||||
|
_orig_setvalue = pr2modules.netlink.nlmsg_base.setvalue
|
||||||
|
|
||||||
|
def _nlmsg_base__setvalue(self, value):
|
||||||
|
if not self.header or not self['header'] or not isinstance(value, dict):
|
||||||
|
return _orig_setvalue(self, value)
|
||||||
|
header = value.pop('header', {})
|
||||||
|
res = _orig_setvalue(self, value)
|
||||||
|
self['header'].update(header)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def overwrite_methods(cls: typing.Type) -> None:
|
||||||
|
if cls.setvalue is _orig_setvalue:
|
||||||
|
cls.setvalue = _nlmsg_base__setvalue
|
||||||
|
for subcls in cls.__subclasses__():
|
||||||
|
overwrite_methods(subcls)
|
||||||
|
|
||||||
|
overwrite_methods(pr2modules.netlink.nlmsg_base)
|
||||||
|
_monkey_patch_pyroute2()
|
||||||
|
|
||||||
|
|
||||||
|
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: typing.Dict={}, **fields) -> _NlMsgBase:
|
||||||
|
msg = msg_class()
|
||||||
|
for key, value in header.items():
|
||||||
|
msg['header'][key] = value
|
||||||
|
for key, value in fields.items():
|
||||||
|
msg[key] = value
|
||||||
|
if attrs:
|
||||||
|
attr_list = msg['attrs']
|
||||||
|
r_nla_map = msg_class._nlmsg_base__r_nla_map
|
||||||
|
for key, value in attrs.items():
|
||||||
|
if msg_class.prefix:
|
||||||
|
key = msg_class.name2nla(key)
|
||||||
|
prime = r_nla_map[key]
|
||||||
|
nla_class = prime['class']
|
||||||
|
if issubclass(nla_class, pr2modules.netlink.nla):
|
||||||
|
# support passing nested attributes as dicts of subattributes (or lists of those)
|
||||||
|
if prime['nla_array']:
|
||||||
|
value = [
|
||||||
|
_build(nla_class, attrs=elem) if not isinstance(elem, pr2modules.netlink.nlmsg_base) and isinstance(elem, dict) else elem
|
||||||
|
for elem in value
|
||||||
|
]
|
||||||
|
elif not isinstance(value, pr2modules.netlink.nlmsg_base) and isinstance(value, dict):
|
||||||
|
value = _build(nla_class, attrs=value)
|
||||||
|
attr_list.append([key, value])
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket):
|
||||||
|
policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(family=pr2modules.netlink.NETLINK_NETFILTER)
|
||||||
|
policy = {
|
||||||
|
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
|
||||||
|
for (x, y) in self.policy.items()
|
||||||
|
}
|
||||||
|
self.register_policy(policy)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def begin(self) -> typing.Generator[NFTTransaction, None, None]:
|
||||||
|
try:
|
||||||
|
tx = NFTTransaction(socket=self)
|
||||||
|
yield tx
|
||||||
|
# autocommit when no exception was raised
|
||||||
|
# (only commits if it wasn't aborted)
|
||||||
|
tx.autocommit()
|
||||||
|
finally:
|
||||||
|
# abort does nothing if commit went through
|
||||||
|
tx.abort()
|
||||||
|
|
||||||
|
def nft_put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
|
||||||
|
with self.begin() as tx:
|
||||||
|
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
|
||||||
|
|
||||||
|
def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
|
||||||
|
msg_flags |= pr2modules.netlink.NLM_F_DUMP
|
||||||
|
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
|
||||||
|
|
||||||
|
def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
|
||||||
|
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
|
||||||
|
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
|
||||||
|
msg_flags |= pr2modules.netlink.NLM_F_REQUEST
|
||||||
|
msg = _build(msg_class, attrs=attrs, **fields)
|
||||||
|
return self.nlm_request(msg, msg_type, msg_flags)
|
||||||
|
|
||||||
|
|
||||||
|
class NFTTransaction:
|
||||||
|
def __init__(self, socket: NFTSocket) -> None:
|
||||||
|
self._socket = socket
|
||||||
|
self._data = b''
|
||||||
|
self._seqnum = self._socket.addr_pool.alloc()
|
||||||
|
self._closed = False
|
||||||
|
# neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK
|
||||||
|
# at the end of an transaction to make sure it worked.
|
||||||
|
# we could use a different sequence number for all changes, and wait for an ACK for each of them
|
||||||
|
# (but we'd also need to check for errors on the BEGIN sequence number).
|
||||||
|
# the other solution: use the same sequence number for all messages in the batch, and add ACK
|
||||||
|
# only to the final message (before END) - if we get the ACK we known all other messages before
|
||||||
|
# worked out.
|
||||||
|
self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None
|
||||||
|
begin_msg = _build(
|
||||||
|
nfgen_msg,
|
||||||
|
res_id=NFNL_SUBSYS_NFTABLES,
|
||||||
|
header=dict(
|
||||||
|
type=0x10, # NFNL_MSG_BATCH_BEGIN
|
||||||
|
flags=pr2modules.netlink.NLM_F_REQUEST,
|
||||||
|
sequence_number=self._seqnum,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
begin_msg.encode()
|
||||||
|
self._data += begin_msg.data
|
||||||
|
|
||||||
|
def abort(self) -> None:
|
||||||
|
"""
|
||||||
|
Aborts if transaction wasn't already committed or aborted
|
||||||
|
"""
|
||||||
|
if not self._closed:
|
||||||
|
self._closed = True
|
||||||
|
# unused seqnum
|
||||||
|
self._socket.addr_pool.free(self._seqnum)
|
||||||
|
|
||||||
|
def autocommit(self) -> None:
|
||||||
|
"""
|
||||||
|
Commits if transaction wasn't already committed or aborted
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
self.commit()
|
||||||
|
|
||||||
|
def commit(self) -> None:
|
||||||
|
if self._closed:
|
||||||
|
raise Exception("Transaction already closed")
|
||||||
|
if not self._final_msg:
|
||||||
|
# no inner messages were queued... just abort transaction
|
||||||
|
self.abort()
|
||||||
|
return
|
||||||
|
self._closed = True
|
||||||
|
# request ACK only on the last message (before END)
|
||||||
|
self._final_msg['header']['flags'] |= pr2modules.netlink.NLM_F_ACK
|
||||||
|
self._final_msg.encode()
|
||||||
|
self._data += self._final_msg.data
|
||||||
|
self._final_msg = None
|
||||||
|
# batch end
|
||||||
|
end_msg = _build(
|
||||||
|
nfgen_msg,
|
||||||
|
res_id=NFNL_SUBSYS_NFTABLES,
|
||||||
|
header=dict(
|
||||||
|
type=0x11, # NFNL_MSG_BATCH_END
|
||||||
|
flags=pr2modules.netlink.NLM_F_REQUEST,
|
||||||
|
sequence_number=self._seqnum,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
end_msg.encode()
|
||||||
|
self._data += end_msg.data
|
||||||
|
# need to create backlog for our sequence number
|
||||||
|
with self._socket.lock[self._seqnum]:
|
||||||
|
self._socket.backlog[self._seqnum] = []
|
||||||
|
# send
|
||||||
|
self._socket.sendto(self._data, (0, 0))
|
||||||
|
try:
|
||||||
|
for _msg in self._socket.get(msg_seq=self._seqnum):
|
||||||
|
# we should see at most one ACK - real errors get raised anyway
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
with self._socket.lock[0]:
|
||||||
|
# clear messages from "seq 0" queue - because if there
|
||||||
|
# was an error in our backlog, it got raised and the
|
||||||
|
# remaining messages moved to 0
|
||||||
|
self._socket.backlog[0] = []
|
||||||
|
|
||||||
|
def _put(self, msg: nfgen_msg) -> None:
|
||||||
|
if self._closed:
|
||||||
|
raise Exception("Transaction already closed")
|
||||||
|
if self._final_msg:
|
||||||
|
# previous message wasn't the final one, encode it without ACK
|
||||||
|
self._final_msg.encode()
|
||||||
|
self._data += self._final_msg.data
|
||||||
|
self._final_msg = msg
|
||||||
|
|
||||||
|
def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
|
||||||
|
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
|
||||||
|
msg_flags |= pr2modules.netlink.NLM_F_REQUEST # always set REQUEST
|
||||||
|
msg_flags &= ~pr2modules.netlink.NLM_F_ACK # make sure ACK is not set!
|
||||||
|
header = dict(
|
||||||
|
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
|
||||||
|
flags=msg_flags,
|
||||||
|
sequence_number=self._seqnum,
|
||||||
|
)
|
||||||
|
msg = _build(msg_class, attrs=attrs, header=header, **fields)
|
||||||
|
self._put(msg)
|
8
start-api.sh
Executable file
8
start-api.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
base=$(dirname "$(readlink -f "$0")")
|
||||||
|
cd "${base}"
|
||||||
|
|
||||||
|
exec ./venv/bin/hypercorn -k trio capport.api "$@"
|
8
start-control.sh
Executable file
8
start-control.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
base=$(dirname "$(readlink -f "$0")")
|
||||||
|
cd "${base}"
|
||||||
|
|
||||||
|
exec ./venv/bin/capport-control "$@"
|
Loading…
Reference in New Issue
Block a user