2
0

format with black (mostly quotes)

This commit is contained in:
Stefan Bühler 2023-11-15 10:02:28 +01:00
parent ced589f28a
commit 0f45f89211
22 changed files with 245 additions and 272 deletions

View File

@ -6,9 +6,14 @@ requires = [
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.mypy] [tool.mypy]
python_version = "3.9" python_version = "3.11"
# warn_return_any = true # warn_return_any = true
warn_unused_configs = true warn_unused_configs = true
exclude = [ exclude = [
'_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary) '_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary)
] ]
[tool.black]
line-length = 120
target-version = ['py311']
exclude = '_pb2.py'

View File

@ -4,8 +4,8 @@ from .app_cls import MyQuartApp
app = MyQuartApp(__name__) app = MyQuartApp(__name__)
__import__('capport.api.setup') __import__("capport.api.setup")
__import__('capport.api.proxy_fix') __import__("capport.api.proxy_fix")
__import__('capport.api.lang') __import__("capport.api.lang")
__import__('capport.api.template_filters') __import__("capport.api.template_filters")
__import__('capport.api.views') __import__("capport.api.views")

View File

@ -49,16 +49,16 @@ class MyQuartApp(quart_trio.QuartTrio):
def __init__(self, import_name: str, **kwargs) -> None: def __init__(self, import_name: str, **kwargs) -> None:
self.my_config = capport.config.Config.load_default_once() self.my_config = capport.config.Config.load_default_once()
kwargs.setdefault('template_folder', os.path.join(os.path.dirname(__file__), 'templates')) kwargs.setdefault("template_folder", os.path.join(os.path.dirname(__file__), "templates"))
cust_templ = os.path.join('custom', 'templates') cust_templ = os.path.join("custom", "templates")
if os.path.exists(cust_templ): if os.path.exists(cust_templ):
self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ)) self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ))
cust_static = os.path.abspath(os.path.join('custom', 'static')) cust_static = os.path.abspath(os.path.join("custom", "static"))
if os.path.exists(cust_static): if os.path.exists(cust_static):
static_folder = cust_static static_folder = cust_static
else: else:
static_folder = os.path.join(os.path.dirname(__file__), 'static') static_folder = os.path.join(os.path.dirname(__file__), "static")
kwargs.setdefault('static_folder', static_folder) kwargs.setdefault("static_folder", static_folder)
super().__init__(import_name, **kwargs) super().__init__(import_name, **kwargs)
self.debug = self.my_config.debug self.debug = self.my_config.debug
self.secret_key = self.my_config.cookie_secret self.secret_key = self.my_config.cookie_secret

View File

@ -12,7 +12,7 @@ import capport.config
def run(config: hypercorn.config.Config) -> None: def run(config: hypercorn.config.Config) -> None:
sockets = config.create_sockets() sockets = config.create_sockets()
assert config.worker_class == 'trio' assert config.worker_class == "trio"
hypercorn.trio.run.trio_worker(config=config, sockets=sockets) hypercorn.trio.run.trio_worker(config=config, sockets=sockets)
@ -28,7 +28,7 @@ class CliArguments:
def __init__(self): def __init__(self):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c') parser.add_argument("--config", "-c")
args = parser.parse_args() args = parser.parse_args()
self.config = args.config self.config = args.config
@ -39,8 +39,8 @@ def main() -> None:
_config = capport.config.Config.load_default_once(filename=args.config) _config = capport.config.Config.load_default_once(filename=args.config)
hypercorn_config = hypercorn.config.Config() hypercorn_config = hypercorn.config.Config()
hypercorn_config.application_path = 'capport.api.app' hypercorn_config.application_path = "capport.api.app"
hypercorn_config.worker_class = 'trio' hypercorn_config.worker_class = "trio"
hypercorn_config.bind = [f"127.0.0.1:{_config.api_port}"] hypercorn_config.bind = [f"127.0.0.1:{_config.api_port}"]
if _config.server_names: if _config.server_names:

View File

@ -7,22 +7,23 @@ import quart
from .app import app from .app import app
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$') _VALID_LANGUAGE_NAMES = re.compile(r"^[-a-z0-9_]+$")
def parse_accept_language(value: str) -> list[str]: def parse_accept_language(value: str) -> list[str]:
value = value.strip() value = value.strip()
if not value or value == '*': if not value or value == "*":
return [] return []
tuples = [] tuples = []
for entry in value.split(','): for entry in value.split(","):
attrs = entry.split(';') attrs = entry.split(";")
name = attrs.pop(0).strip().lower() name = attrs.pop(0).strip().lower()
q = 1.0 q = 1.0
for attr in attrs: for attr in attrs:
if not '=' in attr: continue if not "=" in attr:
key, value = attr.split('=', maxsplit=1) continue
if key.strip().lower() == 'q': key, value = attr.split("=", maxsplit=1)
if key.strip().lower() == "q":
try: try:
q = float(value.strip()) q = float(value.strip())
except ValueError: except ValueError:
@ -32,14 +33,17 @@ def parse_accept_language(value: str) -> list[str]:
tuples.sort() tuples.sort()
have = set() have = set()
result = [] result = []
for (_q, name) in tuples: for _q, name in tuples:
if name in have: continue if name in have:
if name == '*': break continue
if name == "*":
break
have.add(name) have.add(name)
if _VALID_LANGUAGE_NAMES.match(name): if _VALID_LANGUAGE_NAMES.match(name):
result.append(name) result.append(name)
short_name = name.split('-', maxsplit=1)[0].split('_', maxsplit=1)[0] short_name = name.split("-", maxsplit=1)[0].split("_", maxsplit=1)[0]
if not short_name or short_name in have: continue if not short_name or short_name in have:
continue
have.add(short_name) have.add(short_name)
result.append(short_name) result.append(short_name)
return result return result
@ -50,23 +54,23 @@ def detect_language():
g = quart.g g = quart.g
r = quart.request r = quart.request
s = quart.session s = quart.session
if 'setlang' in r.args: if "setlang" in r.args:
lang = r.args.get('setlang').strip().lower() lang = r.args.get("setlang").strip().lower()
if lang and _VALID_LANGUAGE_NAMES.match(lang): if lang and _VALID_LANGUAGE_NAMES.match(lang):
if s.get('lang') != lang: if s.get("lang") != lang:
s['lang'] = lang s["lang"] = lang
g.langs = [lang] g.langs = [lang]
return return
else: else:
# reset language # reset language
s.pop('lang', None) s.pop("lang", None)
lang = s.get('lang') lang = s.get("lang")
if lang: if lang:
lang = lang.strip().lower() lang = lang.strip().lower()
if lang and _VALID_LANGUAGE_NAMES.match(lang): if lang and _VALID_LANGUAGE_NAMES.match(lang):
g.langs = [lang] g.langs = [lang]
return return
acc_lang = ','.join(r.headers.get_all('Accept-Language')) acc_lang = ",".join(r.headers.get_all("Accept-Language"))
g.langs = parse_accept_language(acc_lang) g.langs = parse_accept_language(acc_lang)
@ -74,9 +78,6 @@ async def render_i18n_template(template, /, **kwargs) -> str:
langs: list[str] = quart.g.langs langs: list[str] = quart.g.langs
if not langs: if not langs:
return await quart.render_template(template, **kwargs) return await quart.render_template(template, **kwargs)
names = [ names = [os.path.join("i18n", lang, template) for lang in langs]
os.path.join('i18n', lang, template)
for lang in langs
]
names.append(template) names.append(template)
return await quart.render_template(names, **kwargs) return await quart.render_template(names, **kwargs)

View File

@ -31,28 +31,28 @@ def local_proxy_fix(request: quart.Request):
return return
if not addr.is_loopback: if not addr.is_loopback:
return return
client = _get_first_in_list(request.headers.get('X-Forwarded-For')) client = _get_first_in_list(request.headers.get("X-Forwarded-For"))
if not client: if not client:
# assume this is always set behind reverse proxies supporting any of the headers # assume this is always set behind reverse proxies supporting any of the headers
return return
request.remote_addr = client request.remote_addr = client
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https')) scheme = _get_first_in_list(request.headers.get("X-Forwarded-Proto"), ("http", "https"))
port: int | None = None port: int | None = None
if scheme: if scheme:
port = 443 if scheme == 'https' else 80 port = 443 if scheme == "https" else 80
request.scheme = scheme request.scheme = scheme
host = _get_first_in_list(request.headers.get('X-Forwarded-Host')) host = _get_first_in_list(request.headers.get("X-Forwarded-Host"))
port_s: str | None port_s: str | None
if host: if host:
request.host = host request.host = host
if ':' in host and not host.endswith(']'): if ":" in host and not host.endswith("]"):
try: try:
_, port_s = host.rsplit(':', maxsplit=1) _, port_s = host.rsplit(":", maxsplit=1)
port = int(port_s) port = int(port_s)
except ValueError: except ValueError:
# ignore invalid port in host header # ignore invalid port in host header
pass pass
port_s = _get_first_in_list(request.headers.get('X-Forwarded-Port')) port_s = _get_first_in_list(request.headers.get("X-Forwarded-Port"))
if port_s: if port_s:
try: try:
port = int(port_s) port = int(port_s)
@ -62,7 +62,7 @@ def local_proxy_fix(request: quart.Request):
if port: if port:
if request.server and len(request.server) == 2: if request.server and len(request.server) == 2:
request.server = (request.server[0], port) request.server = (request.server[0], port)
root_path = _get_first_in_list(request.headers.get('X-Forwarded-Prefix')) root_path = _get_first_in_list(request.headers.get("X-Forwarded-Prefix"))
if root_path: if root_path:
request.root_path = root_path request.root_path = root_path

View File

@ -47,10 +47,10 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
async def _setup(*, task_status=trio.TASK_STATUS_IGNORED): async def _setup(*, task_status=trio.TASK_STATUS_IGNORED):
async with open_sdnotify() as sn: async with open_sdnotify() as sn:
await sn.send('STATUS=Starting hub') await sn.send("STATUS=Starting hub")
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
await nursery.start(_run_hub) await nursery.start(_run_hub)
await sn.send('READY=1', 'STATUS=Ready for client requests') await sn.send("READY=1", "STATUS=Ready for client requests")
task_status.started() task_status.started()
# continue running hub and systemd watchdog handler # continue running hub and systemd watchdog handler

View File

@ -23,12 +23,12 @@ _logger = logging.getLogger(__name__)
def get_client_ip() -> cptypes.IPAddress: def get_client_ip() -> cptypes.IPAddress:
remote_addr = quart.request.remote_addr remote_addr = quart.request.remote_addr
if not remote_addr: if not remote_addr:
quart.abort(500, 'Missing client address') quart.abort(500, "Missing client address")
try: try:
addr = ipaddress.ip_address(remote_addr) addr = ipaddress.ip_address(remote_addr)
except ValueError as e: except ValueError as e:
_logger.warning(f'Invalid client address {remote_addr!r}: {e}') _logger.warning(f"Invalid client address {remote_addr!r}: {e}")
quart.abort(500, 'Invalid client address') quart.abort(500, "Invalid client address")
return addr return addr
@ -45,7 +45,7 @@ async def get_client_mac(address: cptypes.IPAddress | None = None) -> cptypes.Ma
mac = await get_client_mac_if_present(address) mac = await get_client_mac_if_present(address)
if mac is None: if mac is None:
_logger.warning(f"Couldn't find MAC addresss for {address}") _logger.warning(f"Couldn't find MAC addresss for {address}")
quart.abort(404, 'Unknown client') quart.abort(404, "Unknown client")
return mac return mac
@ -58,7 +58,7 @@ async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> Non
quart.abort(500, str(e)) quart.abort(500, str(e))
if pu: if pu:
_logger.debug(f'User {mac} (with IP {address}) logged in') _logger.debug(f"User {mac} (with IP {address}) logged in")
for msg in pu.serialized: for msg in pu.serialized:
await app.my_hub.broadcast(msg) await app.my_hub.broadcast(msg)
@ -71,7 +71,7 @@ async def user_logout(mac: cptypes.MacAddress) -> None:
except capport.database.NotReadyYet as e: except capport.database.NotReadyYet as e:
quart.abort(500, str(e)) quart.abort(500, str(e))
if pu: if pu:
_logger.debug(f'User {mac} logged out') _logger.debug(f"User {mac} logged out")
for msg in pu.serialized: for msg in pu.serialized:
await app.my_hub.broadcast(msg) await app.my_hub.broadcast(msg)
@ -92,73 +92,73 @@ async def user_lookup() -> cptypes.MacPublicState:
def check_self_origin(): def check_self_origin():
origin = quart.request.headers.get('Origin', None) origin = quart.request.headers.get("Origin", None)
if origin is None: if origin is None:
# not a request by a modern browser - probably curl or something similar. don't care. # not a request by a modern browser - probably curl or something similar. don't care.
return return
origin = origin.lower().strip() origin = origin.lower().strip()
if origin == 'none': if origin == "none":
quart.abort(403, 'Origin is none') quart.abort(403, "Origin is none")
origin_parts = origin.split('/') origin_parts = origin.split("/")
# Origin should look like: <scheme>://<hostname> (optionally followed by :<port>) # Origin should look like: <scheme>://<hostname> (optionally followed by :<port>)
if len(origin_parts) < 3: if len(origin_parts) < 3:
quart.abort(400, 'Broken Origin header') quart.abort(400, "Broken Origin header")
if origin_parts[0] != 'https:' and not app.my_config.debug: if origin_parts[0] != "https:" and not app.my_config.debug:
# -> require https in production # -> require https in production
quart.abort(403, 'Non-https Origin not allowed') quart.abort(403, "Non-https Origin not allowed")
origin_host = origin_parts[2] origin_host = origin_parts[2]
host = quart.request.headers.get('Host', None) host = quart.request.headers.get("Host", None)
if host is None: if host is None:
quart.abort(403, 'Missing Host header') quart.abort(403, "Missing Host header")
if host.lower() != origin_host: if host.lower() != origin_host:
quart.abort(403, 'Origin mismatch') quart.abort(403, "Origin mismatch")
@app.route('/', methods=['GET']) @app.route("/", methods=["GET"])
async def index(missing_accept: bool = False): async def index(missing_accept: bool = False):
state = await user_lookup() state = await user_lookup()
if not state.mac: if not state.mac:
return await render_i18n_template('index_unknown.html', state=state, missing_accept=missing_accept) return await render_i18n_template("index_unknown.html", state=state, missing_accept=missing_accept)
elif state.allowed: elif state.allowed:
return await render_i18n_template('index_active.html', state=state, missing_accept=missing_accept) return await render_i18n_template("index_active.html", state=state, missing_accept=missing_accept)
else: else:
return await render_i18n_template('index_inactive.html', state=state, missing_accept=missing_accept) return await render_i18n_template("index_inactive.html", state=state, missing_accept=missing_accept)
@app.route('/login', methods=['POST']) @app.route("/login", methods=["POST"])
async def login(): async def login():
check_self_origin() check_self_origin()
with trio.fail_after(5.0): with trio.fail_after(5.0):
form = await quart.request.form form = await quart.request.form
if form.get('accept') != '1': if form.get("accept") != "1":
return await index(missing_accept=True) return await index(missing_accept=True)
req_mac = form.get('mac') req_mac = form.get("mac")
if not req_mac: if not req_mac:
quart.abort(400, description='Missing MAC in request form data') quart.abort(400, description="Missing MAC in request form data")
address = get_client_ip() address = get_client_ip()
mac = await get_client_mac(address) mac = await get_client_mac(address)
if str(mac) != req_mac: if str(mac) != req_mac:
quart.abort(403, description="Passed MAC in request form doesn't match client address") quart.abort(403, description="Passed MAC in request form doesn't match client address")
await user_login(address, mac) await user_login(address, mac)
return quart.redirect('/', code=303) return quart.redirect("/", code=303)
@app.route('/logout', methods=['POST']) @app.route("/logout", methods=["POST"])
async def logout(): async def logout():
check_self_origin() check_self_origin()
with trio.fail_after(5.0): with trio.fail_after(5.0):
form = await quart.request.form form = await quart.request.form
req_mac = form.get('mac') req_mac = form.get("mac")
if not req_mac: if not req_mac:
quart.abort(400, description='Missing MAC in request form data') quart.abort(400, description="Missing MAC in request form data")
mac = await get_client_mac() mac = await get_client_mac()
if str(mac) != req_mac: if str(mac) != req_mac:
quart.abort(403, description="Passed MAC in request form doesn't match client address") quart.abort(403, description="Passed MAC in request form doesn't match client address")
await user_logout(mac) await user_logout(mac)
return quart.redirect('/', code=303) return quart.redirect("/", code=303)
@app.route('/api/captive-portal', methods=['GET']) @app.route("/api/captive-portal", methods=["GET"])
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908 # RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
async def captive_api(): async def captive_api():
state = await user_lookup() state = await user_lookup()

View File

@ -43,7 +43,7 @@ class Channel:
_logger.debug(f"{self}: created (server_side={server_side})") _logger.debug(f"{self}: created (server_side={server_side})")
def __repr__(self) -> str: def __repr__(self) -> str:
return f'Channel[0x{id(self):x}]' return f"Channel[0x{id(self):x}]"
async def do_handshake(self) -> capport.comm.message.Hello: async def do_handshake(self) -> capport.comm.message.Hello:
try: try:
@ -59,9 +59,8 @@ class Channel:
peer_hello = (await self.recv_msg()).to_variant() peer_hello = (await self.recv_msg()).to_variant()
if not isinstance(peer_hello, capport.comm.message.Hello): if not isinstance(peer_hello, capport.comm.message.Hello):
raise HubConnectionReadError("Expected Hello as first message") raise HubConnectionReadError("Expected Hello as first message")
auth_succ = \ expected_auth = self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)
(peer_hello.authentication == auth_succ = peer_hello.authentication == expected_auth
self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message()) await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
peer_auth = (await self.recv_msg()).to_variant() peer_auth = (await self.recv_msg()).to_variant()
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult): if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
@ -72,7 +71,7 @@ class Channel:
async def _read(self, num: int) -> bytes: async def _read(self, num: int) -> bytes:
assert num > 0 assert num > 0
buf = b'' buf = b""
# _logger.debug(f"{self}:_read({num})") # _logger.debug(f"{self}:_read({num})")
while num > 0: while num > 0:
try: try:
@ -92,7 +91,7 @@ class Channel:
async def _recv_raw_msg(self) -> bytes: async def _recv_raw_msg(self) -> bytes:
len_bytes = await self._read(4) len_bytes = await self._read(4)
chunk_size, = struct.unpack('!I', len_bytes) (chunk_size,) = struct.unpack("!I", len_bytes)
chunk = await self._read(chunk_size) chunk = await self._read(chunk_size)
if chunk is None: if chunk is None:
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length") raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
@ -116,7 +115,7 @@ class Channel:
async def send_msg(self, msg: capport.comm.message.Message): async def send_msg(self, msg: capport.comm.message.Message):
chunk = msg.SerializeToString(deterministic=True) chunk = msg.SerializeToString(deterministic=True)
chunk_size = len(chunk) chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size) len_bytes = struct.pack("!I", chunk_size)
chunk = len_bytes + chunk chunk = len_bytes + chunk
await self._send_raw(chunk) await self._send_raw(chunk)
@ -158,7 +157,7 @@ class Connection:
if msg: if msg:
await self._channel.send_msg(msg) await self._channel.send_msg(msg)
else: else:
await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message()) await self._channel.send_msg(capport.comm.message.Ping(payload=b"ping").to_message())
except trio.TooSlowError: except trio.TooSlowError:
_logger.warning(f"{self._channel}: send timed out") _logger.warning(f"{self._channel}: send timed out")
except ConnectionError as e: except ConnectionError as e:
@ -303,7 +302,7 @@ class Hub:
if is_controller: if is_controller:
state_filename = config.database_file state_filename = config.database_file
else: else:
state_filename = '' state_filename = ""
self.database = capport.database.Database(state_filename=state_filename) self.database = capport.database.Database(state_filename=state_filename)
self._anon_context = ssl.SSLContext() self._anon_context = ssl.SSLContext()
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon # python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
@ -311,14 +310,14 @@ class Hub:
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2 self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
# -> AECDH-AES256-SHA # -> AECDH-AES256-SHA
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way? # sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0') self._anon_context.set_ciphers("HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0")
self._controllers: dict[str, ControllerConn] = {} self._controllers: dict[str, ControllerConn] = {}
self._established: dict[uuid.UUID, Connection] = {} self._established: dict[uuid.UUID, Connection] = {}
async def _accept(self, stream): async def _accept(self, stream):
remotename = stream.socket.getpeername() remotename = stream.socket.getpeername()
if isinstance(remotename, tuple) and len(remotename) == 2: if isinstance(remotename, tuple) and len(remotename) == 2:
remote = f'[{remotename[0]}]:{remotename[1]}' remote = f"[{remotename[0]}]:{remotename[1]}"
else: else:
remote = str(remotename) remote = str(remotename)
try: try:
@ -351,11 +350,11 @@ class Hub:
await trio.sleep_forever() await trio.sleep_forever()
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes: def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
m = hmac.new(self._config.comm_secret.encode('utf8'), digestmod=hashlib.sha256) m = hmac.new(self._config.comm_secret.encode("utf8"), digestmod=hashlib.sha256)
if server_side: if server_side:
m.update(b'server$') m.update(b"server$")
else: else:
m.update(b'client$') m.update(b"client$")
m.update(ssl_binding) m.update(ssl_binding)
return m.digest() return m.digest()

View File

@ -6,7 +6,7 @@ from .protobuf import message_pb2
def _message_to_variant(self: message_pb2.Message) -> typing.Any: def _message_to_variant(self: message_pb2.Message) -> typing.Any:
variant_name = self.WhichOneof('oneof') variant_name = self.WhichOneof("oneof")
if variant_name: if variant_name:
return getattr(self, variant_name) return getattr(self, variant_name)
return None return None
@ -16,14 +16,15 @@ def _make_to_message(oneof_field):
def to_message(self) -> message_pb2.Message: def to_message(self) -> message_pb2.Message:
msg = message_pb2.Message(**{oneof_field: self}) msg = message_pb2.Message(**{oneof_field: self})
return msg return msg
return to_message return to_message
def _monkey_patch(): def _monkey_patch():
g = globals() g = globals()
g['Message'] = message_pb2.Message g["Message"] = message_pb2.Message
message_pb2.Message.to_variant = _message_to_variant message_pb2.Message.to_variant = _message_to_variant
for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields: for field in message_pb2._MESSAGE.oneofs_by_name["oneof"].fields:
type_name = field.message_type.name type_name = field.message_type.name
field_type = getattr(message_pb2, type_name) field_type = getattr(message_pb2, type_name)
field_type.to_message = _make_to_message(field.name) field_type.to_message = _make_to_message(field.name)

View File

@ -1,10 +1,8 @@
import google.protobuf.message import google.protobuf.message
import typing import typing
# manually maintained typehints for protobuf created (and monkey-patched) types # manually maintained typehints for protobuf created (and monkey-patched) types
class Message(google.protobuf.message.Message): class Message(google.protobuf.message.Message):
hello: Hello hello: Hello
authentication_result: AuthenticationResult authentication_result: AuthenticationResult
@ -19,10 +17,8 @@ class Message(google.protobuf.message.Message):
ping: Ping | None = None, ping: Ping | None = None,
mac_states: MacStates | None = None, mac_states: MacStates | None = None,
) -> None: ... ) -> None: ...
def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ... def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ...
class Hello(google.protobuf.message.Message): class Hello(google.protobuf.message.Message):
instance_id: bytes instance_id: bytes
hostname: str hostname: str
@ -32,15 +28,13 @@ class Hello(google.protobuf.message.Message):
def __init__( def __init__(
self, self,
*, *,
instance_id: bytes=b'', instance_id: bytes = b"",
hostname: str='', hostname: str = "",
is_controller: bool = False, is_controller: bool = False,
authentication: bytes=b'', authentication: bytes = b"",
) -> None: ... ) -> None: ...
def to_message(self) -> Message: ... def to_message(self) -> Message: ...
class AuthenticationResult(google.protobuf.message.Message): class AuthenticationResult(google.protobuf.message.Message):
success: bool success: bool
@ -49,22 +43,18 @@ class AuthenticationResult(google.protobuf.message.Message):
*, *,
success: bool = False, success: bool = False,
) -> None: ... ) -> None: ...
def to_message(self) -> Message: ... def to_message(self) -> Message: ...
class Ping(google.protobuf.message.Message): class Ping(google.protobuf.message.Message):
payload: bytes payload: bytes
def __init__( def __init__(
self, self,
*, *,
payload: bytes=b'', payload: bytes = b"",
) -> None: ... ) -> None: ...
def to_message(self) -> Message: ... def to_message(self) -> Message: ...
class MacStates(google.protobuf.message.Message): class MacStates(google.protobuf.message.Message):
states: list[MacState] states: list[MacState]
@ -73,10 +63,8 @@ class MacStates(google.protobuf.message.Message):
*, *,
states: list[MacState] = [], states: list[MacState] = [],
) -> None: ... ) -> None: ...
def to_message(self) -> Message: ... def to_message(self) -> Message: ...
class MacState(google.protobuf.message.Message): class MacState(google.protobuf.message.Message):
mac_address: bytes mac_address: bytes
last_change: int # Seconds of UTC time since epoch last_change: int # Seconds of UTC time since epoch
@ -86,7 +74,7 @@ class MacState(google.protobuf.message.Message):
def __init__( def __init__(
self, self,
*, *,
mac_address: bytes=b'', mac_address: bytes = b"",
last_change: int = 0, last_change: int = 0,
allow_until: int = 0, allow_until: int = 0,
allowed: bool = False, allowed: bool = False,

View File

@ -39,7 +39,7 @@ class Config:
@staticmethod @staticmethod
def load(filename: str | None = None) -> Config: def load(filename: str | None = None) -> Config:
if filename is None: if filename is None:
for name in ('capport.yaml', '/etc/capport.yaml'): for name in ("capport.yaml", "/etc/capport.yaml"):
if os.path.exists(name): if os.path.exists(name):
return Config.load(name) return Config.load(name)
raise RuntimeError("Missing config file") raise RuntimeError("Missing config file")
@ -47,19 +47,19 @@ class Config:
data = yaml.safe_load(f) data = yaml.safe_load(f)
if not isinstance(data, dict): if not isinstance(data, dict):
raise RuntimeError(f"Invalid yaml config data, expected dict: {data!r}") raise RuntimeError(f"Invalid yaml config data, expected dict: {data!r}")
controllers = list(map(str, data.pop('controllers'))) controllers = list(map(str, data.pop("controllers")))
config = Config( config = Config(
_source_filename=filename, _source_filename=filename,
controllers=controllers, controllers=controllers,
server_names=data.pop('server-names', []), server_names=data.pop("server-names", []),
comm_secret=str(data.pop('comm-secret')), comm_secret=str(data.pop("comm-secret")),
cookie_secret=str(data.pop('cookie-secret')), cookie_secret=str(data.pop("cookie-secret")),
venue_info_url=str(data.pop('venue-info-url')), venue_info_url=str(data.pop("venue-info-url")),
session_timeout=data.pop('session-timeout', 3600), session_timeout=data.pop("session-timeout", 3600),
api_port=data.pop('api-port', 8000), api_port=data.pop("api-port", 8000),
controller_port=data.pop('controller-port', 5000), controller_port=data.pop("controller-port", 5000),
database_file=str(data.pop('database-file', 'capport.state')), database_file=str(data.pop("database-file", "capport.state")),
debug=data.pop('debug', False) debug=data.pop("debug", False),
) )
if data: if data:
_logger.error(f"Unknown config elements: {list(data.keys())}") _logger.error(f"Unknown config elements: {list(data.keys())}")

View File

@ -58,9 +58,9 @@ async def amain(config: capport.config.Config) -> None:
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
# hub.run loads the statefile from disk before signalling it was "started" # hub.run loads the statefile from disk before signalling it was "started"
await nursery.start(hub.run) await nursery.start(hub.run)
await sn.send('READY=1', 'STATUS=Deploying initial entries to nftables set') await sn.send("READY=1", "STATUS=Deploying initial entries to nftables set")
app.apply_db_entries(hub.database.entries()) app.apply_db_entries(hub.database.entries())
await sn.send('STATUS=Kernel fully synchronized') await sn.send("STATUS=Kernel fully synchronized")
@dataclasses.dataclass @dataclasses.dataclass
@ -69,7 +69,7 @@ class CliArguments:
def __init__(self): def __init__(self):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c') parser.add_argument("--config", "-c")
args = parser.parse_args() args = parser.parse_args()
self.config = args.config self.config = args.config

View File

@ -23,14 +23,14 @@ class MacAddress:
raw: bytes raw: bytes
def __str__(self) -> str: def __str__(self) -> str:
return self.raw.hex(':') return self.raw.hex(":")
def __repr__(self) -> str: def __repr__(self) -> str:
return repr(str(self)) return repr(str(self))
@staticmethod @staticmethod
def parse(s: str) -> MacAddress: def parse(s: str) -> MacAddress:
return MacAddress(bytes.fromhex(s.replace(':', ''))) return MacAddress(bytes.fromhex(s.replace(":", "")))
@dataclasses.dataclass(frozen=True, order=True) @dataclasses.dataclass(frozen=True, order=True)
@ -40,9 +40,9 @@ class Timestamp:
def __str__(self) -> str: def __str__(self) -> str:
try: try:
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc) ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
return ts.isoformat(sep=' ') return ts.isoformat(sep=" ")
except OSError: except OSError:
return f'epoch@{self.epoch}' return f"epoch@{self.epoch}"
def __repr__(self) -> str: def __repr__(self) -> str:
return repr(str(self)) return repr(str(self))
@ -85,7 +85,7 @@ class MacPublicState:
def allowed_remaining_duration(self) -> str: def allowed_remaining_duration(self) -> str:
mm, ss = divmod(self.allowed_remaining, 60) mm, ss = divmod(self.allowed_remaining, 60)
hh, mm = divmod(mm, 60) hh, mm = divmod(mm, 60)
return f'{hh}:{mm:02}:{ss:02}' return f"{hh}:{mm:02}:{ss:02}"
@property @property
def allowed_until(self) -> datetime.datetime | None: def allowed_until(self) -> datetime.datetime | None:
@ -95,18 +95,18 @@ class MacPublicState:
def to_rfc8908(self, config: Config) -> quart.Response: def to_rfc8908(self, config: Config) -> quart.Response:
response: dict[str, typing.Any] = { response: dict[str, typing.Any] = {
'user-portal-url': quart.url_for('index', _external=True), "user-portal-url": quart.url_for("index", _external=True),
} }
if config.venue_info_url: if config.venue_info_url:
response['venue-info-url'] = config.venue_info_url response["venue-info-url"] = config.venue_info_url
if self.captive: if self.captive:
response['captive'] = True response["captive"] = True
else: else:
response['captive'] = False response["captive"] = False
response['seconds-remaining'] = self.allowed_remaining response["seconds-remaining"] = self.allowed_remaining
response['can-extend-session'] = True response["can-extend-session"] = True
return quart.Response( return quart.Response(
json.dumps(response), json.dumps(response),
headers={'Cache-Control': 'private'}, headers={"Cache-Control": "private"},
content_type='application/captive+json', content_type="application/captive+json",
) )

View File

@ -128,7 +128,7 @@ def _serialize_mac_states_as_messages(
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes: def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
chunk = states.SerializeToString(deterministic=True) chunk = states.SerializeToString(deterministic=True)
chunk_size = len(chunk) chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size) len_bytes = struct.pack("!I", chunk_size)
return len_bytes + chunk return len_bytes + chunk
@ -204,39 +204,32 @@ class Database:
return _serialize_mac_states_as_messages(self._macs) return _serialize_mac_states_as_messages(self._macs)
def as_json(self) -> dict: def as_json(self) -> dict:
return { return {str(addr): entry.as_json() for addr, entry in self._macs.items()}
str(addr): entry.as_json()
for addr, entry in self._macs.items()
}
async def _run_statefile(self) -> None: async def _run_statefile(self) -> None:
rx: trio.MemoryReceiveChannel[ rx: trio.MemoryReceiveChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]]
capport.comm.message.MacStates | list[capport.comm.message.MacStates], tx: trio.MemorySendChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]]
]
tx: trio.MemorySendChannel[
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
]
tx, rx = trio.open_memory_channel(64) tx, rx = trio.open_memory_channel(64)
self._send_changes = tx self._send_changes = tx
assert self._state_filename assert self._state_filename
filename: str = self._state_filename filename: str = self._state_filename
tmp_filename = f'{filename}.new-{os.getpid()}' tmp_filename = f"{filename}.new-{os.getpid()}"
async def resync(all_states: list[capport.comm.message.MacStates]): async def resync(all_states: list[capport.comm.message.MacStates]):
try: try:
async with await trio.open_file(tmp_filename, 'xb') as tf: async with await trio.open_file(tmp_filename, "xb") as tf:
for states in all_states: for states in all_states:
await tf.write(_states_to_chunk(states)) await tf.write(_states_to_chunk(states))
os.rename(tmp_filename, filename) os.rename(tmp_filename, filename)
finally: finally:
if os.path.exists(tmp_filename): if os.path.exists(tmp_filename):
_logger.warning(f'Removing (failed) state export file {tmp_filename}') _logger.warning(f"Removing (failed) state export file {tmp_filename}")
os.unlink(tmp_filename) os.unlink(tmp_filename)
try: try:
while True: while True:
async with await trio.open_file(filename, 'ab', buffering=0) as sf: async with await trio.open_file(filename, "ab", buffering=0) as sf:
while True: while True:
update = await rx.receive() update = await rx.receive()
if isinstance(update, list): if isinstance(update, list):
@ -247,10 +240,10 @@ class Database:
await resync(update) await resync(update)
# now reopen normal statefile and continue appending updates # now reopen normal statefile and continue appending updates
except trio.Cancelled: except trio.Cancelled:
_logger.info('Final sync to disk') _logger.info("Final sync to disk")
with trio.CancelScope(shield=True): with trio.CancelScope(shield=True):
await resync(_serialize_mac_states(self._macs)) await resync(_serialize_mac_states(self._macs))
_logger.info('Final sync to disk done') _logger.info("Final sync to disk done")
async def _load_statefile(self): async def _load_statefile(self):
if not os.path.exists(self._state_filename): if not os.path.exists(self._state_filename):
@ -259,7 +252,7 @@ class Database:
# we're going to ignore changes from loading the file # we're going to ignore changes from loading the file
pu = PendingUpdates(self) pu = PendingUpdates(self)
pu._closed = False pu._closed = False
async with await trio.open_file(self._state_filename, 'rb') as sf: async with await trio.open_file(self._state_filename, "rb") as sf:
while True: while True:
try: try:
len_bytes = await sf.read(4) len_bytes = await sf.read(4)
@ -268,7 +261,7 @@ class Database:
if len(len_bytes) < 4: if len(len_bytes) < 4:
_logger.error("Failed to read next chunk from statefile (unexpected EOF)") _logger.error("Failed to read next chunk from statefile (unexpected EOF)")
return return
chunk_size, = struct.unpack('!I', len_bytes) (chunk_size,) = struct.unpack("!I", len_bytes)
chunk = await sf.read(chunk_size) chunk = await sf.read(chunk_size)
except IOError as e: except IOError as e:
_logger.error(f"Failed to read next chunk from statefile: {e}") _logger.error(f"Failed to read next chunk from statefile: {e}")
@ -286,7 +279,7 @@ class Database:
except Exception as e: except Exception as e:
errors += 1 errors += 1
if errors < 5: if errors < 5:
_logger.error(f'Failed to handle state: {e}') _logger.error(f"Failed to handle state: {e}")
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState: def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
entry = self._macs.get(mac) entry = self._macs.get(mac)

View File

@ -14,20 +14,17 @@ from . import cptypes
def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None): def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None):
# no labels in our names for now, always print help and type # no labels in our names for now, always print help and type
if help: if help:
print(f'# HELP {name} {help}') print(f"# HELP {name} {help}")
print(f'# TYPE {name} {mtype}') print(f"# TYPE {name} {mtype}")
if now: if now:
print(f'{name} {value} {now}') print(f"{name} {value} {now}")
else: else:
print(f'{name} {value}') print(f"{name} {value}")
async def amain(client_ifname: str): async def amain(client_ifname: str):
ns = capport.utils.nft_set.NftSet() ns = capport.utils.nft_set.NftSet()
captive_allowed_entries: set[cptypes.MacAddress] = { captive_allowed_entries: set[cptypes.MacAddress] = {entry["mac"] for entry in ns.list()}
entry['mac']
for entry in ns.list()
}
seen_allowed_entries: set[cptypes.MacAddress] = set() seen_allowed_entries: set[cptypes.MacAddress] = set()
total_ipv4 = 0 total_ipv4 = 0
total_ipv6 = 0 total_ipv6 = 0
@ -47,46 +44,46 @@ async def amain(client_ifname: str):
total_ipv6 += 1 total_ipv6 += 1
unique_ipv6.add(mac) unique_ipv6.add(mac)
print_metric( print_metric(
'capport_allowed_macs', "capport_allowed_macs",
'gauge', "gauge",
len(captive_allowed_entries), len(captive_allowed_entries),
help='Number of allowed client mac addresses', help="Number of allowed client mac addresses",
) )
print_metric( print_metric(
'capport_allowed_neigh_macs', "capport_allowed_neigh_macs",
'gauge', "gauge",
len(seen_allowed_entries), len(seen_allowed_entries),
help='Number of allowed client mac addresses seen in neighbor cache', help="Number of allowed client mac addresses seen in neighbor cache",
) )
print_metric( print_metric(
'capport_unique', "capport_unique",
'gauge', "gauge",
len(unique_clients), len(unique_clients),
help='Number of clients (mac addresses) in client network seen in neighbor cache', help="Number of clients (mac addresses) in client network seen in neighbor cache",
) )
print_metric( print_metric(
'capport_unique_ipv4', "capport_unique_ipv4",
'gauge', "gauge",
len(unique_ipv4), len(unique_ipv4),
help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache', help="Number of IPv4 clients (unique per mac) in client network seen in neighbor cache",
) )
print_metric( print_metric(
'capport_unique_ipv6', "capport_unique_ipv6",
'gauge', "gauge",
len(unique_ipv6), len(unique_ipv6),
help='Number of IPv6 clients (unique per mac) in client network seen in neighbor cache', help="Number of IPv6 clients (unique per mac) in client network seen in neighbor cache",
) )
print_metric( print_metric(
'capport_total_ipv4', "capport_total_ipv4",
'gauge', "gauge",
total_ipv4, total_ipv4,
help='Number of IPv4 addresses seen in neighbor cache', help="Number of IPv4 addresses seen in neighbor cache",
) )
print_metric( print_metric(
'capport_total_ipv6', "capport_total_ipv6",
'gauge', "gauge",
total_ipv6, total_ipv6,
help='Number of IPv6 addresses seen in neighbor cache', help="Number of IPv6 addresses seen in neighbor cache",
) )

View File

@ -10,8 +10,8 @@ def init_logger(config: capport.config.Config):
if config.debug: if config.debug:
loglevel = logging.DEBUG loglevel = logging.DEBUG
logging.basicConfig( logging.basicConfig(
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s', format="%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s",
datefmt='[%Y-%m-%d %H:%M:%S %z]', datefmt="[%Y-%m-%d %H:%M:%S %z]",
level=loglevel, level=loglevel,
) )
logging.getLogger('hypercorn').propagate = False logging.getLogger("hypercorn").propagate = False

View File

@ -35,9 +35,9 @@ class NeighborController:
route = await self.get_route(address) route = await self.get_route(address)
if route is None: if route is None:
return None return None
index = route.get_attr(route.name2nla('oif')) index = route.get_attr(route.name2nla("oif"))
try: try:
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0] return self.ip.neigh("get", dst=str(address), ifindex=index, state="none")[0]
except pyroute2.netlink.exceptions.NetlinkError as e: except pyroute2.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT: if e.code == errno.ENOENT:
return None return None
@ -53,7 +53,7 @@ class NeighborController:
neigh = await self.get_neighbor(address, index=index, flags=flags) neigh = await self.get_neighbor(address, index=index, flags=flags)
if neigh is None: if neigh is None:
return None return None
mac = neigh.get_attr(neigh.name2nla('lladdr')) mac = neigh.get_attr(neigh.name2nla("lladdr"))
if mac is None: if mac is None:
return None return None
return cptypes.MacAddress.parse(mac) return cptypes.MacAddress.parse(mac)
@ -63,7 +63,7 @@ class NeighborController:
address: cptypes.IPAddress, address: cptypes.IPAddress,
) -> pyroute2.iproute.linux.rtmsg | None: ) -> pyroute2.iproute.linux.rtmsg | None:
try: try:
return self.ip.route('get', dst=str(address))[0] return self.ip.route("get", dst=str(address))[0]
except pyroute2.netlink.exceptions.NetlinkError as e: except pyroute2.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT: if e.code == errno.ENOENT:
return None return None
@ -74,16 +74,16 @@ class NeighborController:
interface: str, interface: str,
) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]: ) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
ifindex = socket.if_nametoindex(interface) ifindex = socket.if_nametoindex(interface)
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast'] unicast_num = pyroute2.netlink.rtnl.rt_type["unicast"]
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET) # ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
for family in (socket.AF_INET, socket.AF_INET6): for family in (socket.AF_INET, socket.AF_INET6):
for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family): for neigh in self.ip.neigh("dump", ifindex=ifindex, family=family):
if neigh['ndm_type'] != unicast_num: if neigh["ndm_type"] != unicast_num:
continue continue
mac = neigh.get_attr(neigh.name2nla('lladdr')) mac = neigh.get_attr(neigh.name2nla("lladdr"))
if not mac: if not mac:
continue continue
dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla('dst'))) dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla("dst")))
if dst.is_link_local: if dst.is_link_local:
continue continue
yield (cptypes.MacAddress.parse(mac), dst) yield (cptypes.MacAddress.parse(mac), dst)

View File

@ -30,26 +30,20 @@ class NftSet:
timeout: int | float | None = None, timeout: int | float | None = None,
) -> _nftsocket.nft_set_elem_list_msg.set_elem: ) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: dict[str, typing.Any] = { attrs: dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict( "NFTA_SET_ELEM_KEY": dict(
NFTA_DATA_VALUE=mac.raw, NFTA_DATA_VALUE=mac.raw,
), ),
} }
if timeout: if timeout:
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout) attrs["NFTA_SET_ELEM_TIMEOUT"] = int(1000 * timeout)
return attrs return attrs
def _bulk_insert( def _bulk_insert(
self, self,
entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]], entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]],
) -> None: ) -> None:
ser_entries = [ ser_entries = [self._set_elem(mac) for mac, _timeout in entries]
self._set_elem(mac) ser_entries_with_timeout = [self._set_elem(mac, timeout) for mac, timeout in entries]
for mac, _timeout in entries
]
ser_entries_with_timeout = [
self._set_elem(mac, timeout)
for mac, timeout in entries
]
with self._socket.begin() as tx: with self._socket.begin() as tx:
# create doesn't affect existing elements, so: # create doesn't affect existing elements, so:
# make sure entries exists # make sure entries exists
@ -58,8 +52,8 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE, pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark', NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET='allowed', NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
), ),
) )
@ -68,8 +62,8 @@ class NftSet:
_nftsocket.NFT_MSG_DELSETELEM, _nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark', NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET='allowed', NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
), ),
) )
@ -79,8 +73,8 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL, pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark', NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET='allowed', NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout, NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
), ),
) )
@ -95,10 +89,7 @@ class NftSet:
self.bulk_insert([(mac, timeout)]) self.bulk_insert([(mac, timeout)])
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None: def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
ser_entries = [ ser_entries = [self._set_elem(mac) for mac in entries]
self._set_elem(mac)
for mac in entries
]
with self._socket.begin() as tx: with self._socket.begin() as tx:
# make sure entries exists # make sure entries exists
tx.put( tx.put(
@ -106,8 +97,8 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE, pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark', NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET='allowed', NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
), ),
) )
@ -116,8 +107,8 @@ class NftSet:
_nftsocket.NFT_MSG_DELSETELEM, _nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark', NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET='allowed', NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries, NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
), ),
) )
@ -137,20 +128,20 @@ class NftSet:
_nftsocket.NFT_MSG_GETSETELEM, _nftsocket.NFT_MSG_GETSETELEM,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark', NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET='allowed', NFTA_SET_ELEM_LIST_SET="allowed",
) ),
) )
return [ return [
{ {
'mac': cptypes.MacAddress( "mac": cptypes.MacAddress(
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'), elem.get_attr("NFTA_SET_ELEM_KEY").get_attr("NFTA_DATA_VALUE"),
), ),
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)), "timeout": _from_msec(elem.get_attr("NFTA_SET_ELEM_TIMEOUT", None)),
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)), "expiration": _from_msec(elem.get_attr("NFTA_SET_ELEM_EXPIRATION", None)),
} }
for response in responses for response in responses
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', []) for elem in response.get_attr("NFTA_SET_ELEM_LIST_ELEMENTS", [])
] ]
def flush(self) -> None: def flush(self) -> None:
@ -158,9 +149,9 @@ class NftSet:
_nftsocket.NFT_MSG_DELSETELEM, _nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark', NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET='allowed', NFTA_SET_ELEM_LIST_SET="allowed",
) ),
) )
def create(self): def create(self):
@ -170,7 +161,7 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE, pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_TABLE_NAME='captive_mark', NFTA_TABLE_NAME="captive_mark",
), ),
) )
tx.put( tx.put(
@ -178,8 +169,8 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE, pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_TABLE='captive_mark', NFTA_SET_TABLE="captive_mark",
NFTA_SET_NAME='allowed', NFTA_SET_NAME="allowed",
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
NFTA_SET_KEY_LEN=6, # length of key: mac address NFTA_SET_KEY_LEN=6, # length of key: mac address

View File

@ -13,7 +13,7 @@ from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC" NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base) _NlMsgBase = typing.TypeVar("_NlMsgBase", bound=pyroute2.netlink.nlmsg_base)
# nft uses NESTED for those.. lets do the same # nft uses NESTED for those.. lets do the same
@ -25,16 +25,17 @@ _nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM
def _monkey_patch_pyroute2(): def _monkey_patch_pyroute2():
import pyroute2.netlink import pyroute2.netlink
# overwrite setdefault on nlmsg_base class hierarchy # overwrite setdefault on nlmsg_base class hierarchy
_orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue _orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
def _nlmsg_base__setvalue(self, value): def _nlmsg_base__setvalue(self, value):
if not self.header or not self['header'] or not isinstance(value, dict): if not self.header or not self["header"] or not isinstance(value, dict):
return _orig_setvalue(self, value) return _orig_setvalue(self, value)
# merge headers instead of overwriting them # merge headers instead of overwriting them
header = value.pop('header', {}) header = value.pop("header", {})
res = _orig_setvalue(self, value) res = _orig_setvalue(self, value)
self['header'].update(header) self["header"].update(header)
return res return res
def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None: def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None:
@ -52,20 +53,20 @@ _monkey_patch_pyroute2()
def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase: def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
msg = msg_class() msg = msg_class()
for key, value in header.items(): for key, value in header.items():
msg['header'][key] = value msg["header"][key] = value
for key, value in fields.items(): for key, value in fields.items():
msg[key] = value msg[key] = value
if attrs: if attrs:
attr_list = msg['attrs'] attr_list = msg["attrs"]
r_nla_map = msg_class._nlmsg_base__r_nla_map r_nla_map = msg_class._nlmsg_base__r_nla_map
for key, value in attrs.items(): for key, value in attrs.items():
if msg_class.prefix: if msg_class.prefix:
key = msg_class.name2nla(key) key = msg_class.name2nla(key)
prime = r_nla_map[key] prime = r_nla_map[key]
nla_class = prime['class'] nla_class = prime["class"]
if issubclass(nla_class, pyroute2.netlink.nla): if issubclass(nla_class, pyroute2.netlink.nla):
# support passing nested attributes as dicts of subattributes (or lists of those) # support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']: if prime["nla_array"]:
value = [ value = [
_build(nla_class, attrs=elem) _build(nla_class, attrs=elem)
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict) if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
@ -83,10 +84,7 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER) super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
policy = { policy = {(x | (NFNL_SUBSYS_NFTABLES << 8)): y for (x, y) in self.policy.items()}
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
for (x, y) in self.policy.items()
}
self.register_policy(policy) self.register_policy(policy)
@contextlib.contextmanager @contextlib.contextmanager
@ -156,7 +154,7 @@ class NFTTransaction:
# no inner messages were queued... not sending anything # no inner messages were queued... not sending anything
return return
# request ACK on the last message (before END) # request ACK on the last message (before END)
self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK self._msgs[-1]["header"]["flags"] |= pyroute2.netlink.NLM_F_ACK
self._msgs.append( self._msgs.append(
# batch end message # batch end message
_build( _build(

View File

@ -10,7 +10,7 @@ import trio.socket
def _check_watchdog_pid() -> bool: def _check_watchdog_pid() -> bool:
wpid = os.environ.pop('WATCHDOG_PID', None) wpid = os.environ.pop("WATCHDOG_PID", None)
if not wpid: if not wpid:
return True return True
return wpid == str(os.getpid()) return wpid == str(os.getpid())
@ -18,16 +18,16 @@ def _check_watchdog_pid() -> bool:
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]: async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
target = os.environ.pop('NOTIFY_SOCKET', None) target = os.environ.pop("NOTIFY_SOCKET", None)
ns: trio.socket.SocketType | None = None ns: trio.socket.SocketType | None = None
watchdog_usec: int = 0 watchdog_usec: int = 0
if target: if target:
if target.startswith('@'): if target.startswith("@"):
# Linux abstract namespace socket # Linux abstract namespace socket
target = '\0' + target[1:] target = "\0" + target[1:]
ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
await ns.connect(target) await ns.connect(target)
watchdog_usec_s = os.environ.pop('WATCHDOG_USEC', None) watchdog_usec_s = os.environ.pop("WATCHDOG_USEC", None)
if _check_watchdog_pid() and watchdog_usec_s: if _check_watchdog_pid() and watchdog_usec_s:
watchdog_usec = int(watchdog_usec_s) watchdog_usec = int(watchdog_usec_s)
try: try:
@ -52,18 +52,18 @@ class SdNotify:
async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None: async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None:
assert self.is_connected(), "Watchdog can't run without socket" assert self.is_connected(), "Watchdog can't run without socket"
await self.send('WATCHDOG=1') await self.send("WATCHDOG=1")
task_status.started() task_status.started()
# send every half of the watchdog timeout # send every half of the watchdog timeout
interval = (watchdog_usec / 1e6) / 2.0 interval = (watchdog_usec / 1e6) / 2.0
while True: while True:
await trio.sleep(interval) await trio.sleep(interval)
await self.send('WATCHDOG=1') await self.send("WATCHDOG=1")
async def send(self, *msg: str) -> None: async def send(self, *msg: str) -> None:
if not self.is_connected(): if not self.is_connected():
return return
dgram = '\n'.join(msg).encode('utf-8') dgram = "\n".join(msg).encode("utf-8")
assert self._ns, "not connected" # checked above assert self._ns, "not connected" # checked above
sent = await self._ns.send(dgram) sent = await self._ns.send(dgram)
if sent != len(dgram): if sent != len(dgram):

View File

@ -10,9 +10,9 @@ def get_local_timezone():
global _zoneinfo global _zoneinfo
if not _zoneinfo: if not _zoneinfo:
try: try:
with open('/etc/timezone') as f: with open("/etc/timezone") as f:
key = f.readline().strip() key = f.readline().strip()
_zoneinfo = zoneinfo.ZoneInfo(key) _zoneinfo = zoneinfo.ZoneInfo(key)
except (OSError, zoneinfo.ZoneInfoNotFoundError): except (OSError, zoneinfo.ZoneInfoNotFoundError):
_zoneinfo = zoneinfo.ZoneInfo('UTC') _zoneinfo = zoneinfo.ZoneInfo("UTC")
return _zoneinfo return _zoneinfo