2
0

format with black (mostly quotes)

This commit is contained in:
Stefan Bühler 2023-11-15 10:02:28 +01:00
parent ced589f28a
commit 0f45f89211
22 changed files with 245 additions and 272 deletions

View File

@ -6,9 +6,14 @@ requires = [
build-backend = "setuptools.build_meta"
[tool.mypy]
python_version = "3.9"
python_version = "3.11"
# warn_return_any = true
warn_unused_configs = true
exclude = [
'_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary)
]
[tool.black]
line-length = 120
target-version = ['py311']
exclude = '_pb2.py'

View File

@ -4,8 +4,8 @@ from .app_cls import MyQuartApp
app = MyQuartApp(__name__)
__import__('capport.api.setup')
__import__('capport.api.proxy_fix')
__import__('capport.api.lang')
__import__('capport.api.template_filters')
__import__('capport.api.views')
__import__("capport.api.setup")
__import__("capport.api.proxy_fix")
__import__("capport.api.lang")
__import__("capport.api.template_filters")
__import__("capport.api.views")

View File

@ -49,16 +49,16 @@ class MyQuartApp(quart_trio.QuartTrio):
def __init__(self, import_name: str, **kwargs) -> None:
self.my_config = capport.config.Config.load_default_once()
kwargs.setdefault('template_folder', os.path.join(os.path.dirname(__file__), 'templates'))
cust_templ = os.path.join('custom', 'templates')
kwargs.setdefault("template_folder", os.path.join(os.path.dirname(__file__), "templates"))
cust_templ = os.path.join("custom", "templates")
if os.path.exists(cust_templ):
self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ))
cust_static = os.path.abspath(os.path.join('custom', 'static'))
cust_static = os.path.abspath(os.path.join("custom", "static"))
if os.path.exists(cust_static):
static_folder = cust_static
else:
static_folder = os.path.join(os.path.dirname(__file__), 'static')
kwargs.setdefault('static_folder', static_folder)
static_folder = os.path.join(os.path.dirname(__file__), "static")
kwargs.setdefault("static_folder", static_folder)
super().__init__(import_name, **kwargs)
self.debug = self.my_config.debug
self.secret_key = self.my_config.cookie_secret

View File

@ -12,7 +12,7 @@ import capport.config
def run(config: hypercorn.config.Config) -> None:
sockets = config.create_sockets()
assert config.worker_class == 'trio'
assert config.worker_class == "trio"
hypercorn.trio.run.trio_worker(config=config, sockets=sockets)
@ -28,7 +28,7 @@ class CliArguments:
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c')
parser.add_argument("--config", "-c")
args = parser.parse_args()
self.config = args.config
@ -39,8 +39,8 @@ def main() -> None:
_config = capport.config.Config.load_default_once(filename=args.config)
hypercorn_config = hypercorn.config.Config()
hypercorn_config.application_path = 'capport.api.app'
hypercorn_config.worker_class = 'trio'
hypercorn_config.application_path = "capport.api.app"
hypercorn_config.worker_class = "trio"
hypercorn_config.bind = [f"127.0.0.1:{_config.api_port}"]
if _config.server_names:

View File

@ -7,22 +7,23 @@ import quart
from .app import app
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$')
_VALID_LANGUAGE_NAMES = re.compile(r"^[-a-z0-9_]+$")
def parse_accept_language(value: str) -> list[str]:
value = value.strip()
if not value or value == '*':
if not value or value == "*":
return []
tuples = []
for entry in value.split(','):
attrs = entry.split(';')
for entry in value.split(","):
attrs = entry.split(";")
name = attrs.pop(0).strip().lower()
q = 1.0
for attr in attrs:
if not '=' in attr: continue
key, value = attr.split('=', maxsplit=1)
if key.strip().lower() == 'q':
if not "=" in attr:
continue
key, value = attr.split("=", maxsplit=1)
if key.strip().lower() == "q":
try:
q = float(value.strip())
except ValueError:
@ -32,14 +33,17 @@ def parse_accept_language(value: str) -> list[str]:
tuples.sort()
have = set()
result = []
for (_q, name) in tuples:
if name in have: continue
if name == '*': break
for _q, name in tuples:
if name in have:
continue
if name == "*":
break
have.add(name)
if _VALID_LANGUAGE_NAMES.match(name):
result.append(name)
short_name = name.split('-', maxsplit=1)[0].split('_', maxsplit=1)[0]
if not short_name or short_name in have: continue
short_name = name.split("-", maxsplit=1)[0].split("_", maxsplit=1)[0]
if not short_name or short_name in have:
continue
have.add(short_name)
result.append(short_name)
return result
@ -50,23 +54,23 @@ def detect_language():
g = quart.g
r = quart.request
s = quart.session
if 'setlang' in r.args:
lang = r.args.get('setlang').strip().lower()
if "setlang" in r.args:
lang = r.args.get("setlang").strip().lower()
if lang and _VALID_LANGUAGE_NAMES.match(lang):
if s.get('lang') != lang:
s['lang'] = lang
if s.get("lang") != lang:
s["lang"] = lang
g.langs = [lang]
return
else:
# reset language
s.pop('lang', None)
lang = s.get('lang')
s.pop("lang", None)
lang = s.get("lang")
if lang:
lang = lang.strip().lower()
if lang and _VALID_LANGUAGE_NAMES.match(lang):
g.langs = [lang]
return
acc_lang = ','.join(r.headers.get_all('Accept-Language'))
acc_lang = ",".join(r.headers.get_all("Accept-Language"))
g.langs = parse_accept_language(acc_lang)
@ -74,9 +78,6 @@ async def render_i18n_template(template, /, **kwargs) -> str:
langs: list[str] = quart.g.langs
if not langs:
return await quart.render_template(template, **kwargs)
names = [
os.path.join('i18n', lang, template)
for lang in langs
]
names = [os.path.join("i18n", lang, template) for lang in langs]
names.append(template)
return await quart.render_template(names, **kwargs)

View File

@ -31,28 +31,28 @@ def local_proxy_fix(request: quart.Request):
return
if not addr.is_loopback:
return
client = _get_first_in_list(request.headers.get('X-Forwarded-For'))
client = _get_first_in_list(request.headers.get("X-Forwarded-For"))
if not client:
# assume this is always set behind reverse proxies supporting any of the headers
return
request.remote_addr = client
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https'))
scheme = _get_first_in_list(request.headers.get("X-Forwarded-Proto"), ("http", "https"))
port: int | None = None
if scheme:
port = 443 if scheme == 'https' else 80
port = 443 if scheme == "https" else 80
request.scheme = scheme
host = _get_first_in_list(request.headers.get('X-Forwarded-Host'))
host = _get_first_in_list(request.headers.get("X-Forwarded-Host"))
port_s: str | None
if host:
request.host = host
if ':' in host and not host.endswith(']'):
if ":" in host and not host.endswith("]"):
try:
_, port_s = host.rsplit(':', maxsplit=1)
_, port_s = host.rsplit(":", maxsplit=1)
port = int(port_s)
except ValueError:
# ignore invalid port in host header
pass
port_s = _get_first_in_list(request.headers.get('X-Forwarded-Port'))
port_s = _get_first_in_list(request.headers.get("X-Forwarded-Port"))
if port_s:
try:
port = int(port_s)
@ -62,7 +62,7 @@ def local_proxy_fix(request: quart.Request):
if port:
if request.server and len(request.server) == 2:
request.server = (request.server[0], port)
root_path = _get_first_in_list(request.headers.get('X-Forwarded-Prefix'))
root_path = _get_first_in_list(request.headers.get("X-Forwarded-Prefix"))
if root_path:
request.root_path = root_path

View File

@ -47,10 +47,10 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
async def _setup(*, task_status=trio.TASK_STATUS_IGNORED):
async with open_sdnotify() as sn:
await sn.send('STATUS=Starting hub')
await sn.send("STATUS=Starting hub")
async with trio.open_nursery() as nursery:
await nursery.start(_run_hub)
await sn.send('READY=1', 'STATUS=Ready for client requests')
await sn.send("READY=1", "STATUS=Ready for client requests")
task_status.started()
# continue running hub and systemd watchdog handler

View File

@ -23,12 +23,12 @@ _logger = logging.getLogger(__name__)
def get_client_ip() -> cptypes.IPAddress:
remote_addr = quart.request.remote_addr
if not remote_addr:
quart.abort(500, 'Missing client address')
quart.abort(500, "Missing client address")
try:
addr = ipaddress.ip_address(remote_addr)
except ValueError as e:
_logger.warning(f'Invalid client address {remote_addr!r}: {e}')
quart.abort(500, 'Invalid client address')
_logger.warning(f"Invalid client address {remote_addr!r}: {e}")
quart.abort(500, "Invalid client address")
return addr
@ -45,7 +45,7 @@ async def get_client_mac(address: cptypes.IPAddress | None = None) -> cptypes.Ma
mac = await get_client_mac_if_present(address)
if mac is None:
_logger.warning(f"Couldn't find MAC addresss for {address}")
quart.abort(404, 'Unknown client')
quart.abort(404, "Unknown client")
return mac
@ -58,7 +58,7 @@ async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> Non
quart.abort(500, str(e))
if pu:
_logger.debug(f'User {mac} (with IP {address}) logged in')
_logger.debug(f"User {mac} (with IP {address}) logged in")
for msg in pu.serialized:
await app.my_hub.broadcast(msg)
@ -71,7 +71,7 @@ async def user_logout(mac: cptypes.MacAddress) -> None:
except capport.database.NotReadyYet as e:
quart.abort(500, str(e))
if pu:
_logger.debug(f'User {mac} logged out')
_logger.debug(f"User {mac} logged out")
for msg in pu.serialized:
await app.my_hub.broadcast(msg)
@ -92,73 +92,73 @@ async def user_lookup() -> cptypes.MacPublicState:
def check_self_origin():
origin = quart.request.headers.get('Origin', None)
origin = quart.request.headers.get("Origin", None)
if origin is None:
# not a request by a modern browser - probably curl or something similar. don't care.
return
origin = origin.lower().strip()
if origin == 'none':
quart.abort(403, 'Origin is none')
origin_parts = origin.split('/')
if origin == "none":
quart.abort(403, "Origin is none")
origin_parts = origin.split("/")
# Origin should look like: <scheme>://<hostname> (optionally followed by :<port>)
if len(origin_parts) < 3:
quart.abort(400, 'Broken Origin header')
if origin_parts[0] != 'https:' and not app.my_config.debug:
quart.abort(400, "Broken Origin header")
if origin_parts[0] != "https:" and not app.my_config.debug:
# -> require https in production
quart.abort(403, 'Non-https Origin not allowed')
quart.abort(403, "Non-https Origin not allowed")
origin_host = origin_parts[2]
host = quart.request.headers.get('Host', None)
host = quart.request.headers.get("Host", None)
if host is None:
quart.abort(403, 'Missing Host header')
quart.abort(403, "Missing Host header")
if host.lower() != origin_host:
quart.abort(403, 'Origin mismatch')
quart.abort(403, "Origin mismatch")
@app.route('/', methods=['GET'])
@app.route("/", methods=["GET"])
async def index(missing_accept: bool = False):
state = await user_lookup()
if not state.mac:
return await render_i18n_template('index_unknown.html', state=state, missing_accept=missing_accept)
return await render_i18n_template("index_unknown.html", state=state, missing_accept=missing_accept)
elif state.allowed:
return await render_i18n_template('index_active.html', state=state, missing_accept=missing_accept)
return await render_i18n_template("index_active.html", state=state, missing_accept=missing_accept)
else:
return await render_i18n_template('index_inactive.html', state=state, missing_accept=missing_accept)
return await render_i18n_template("index_inactive.html", state=state, missing_accept=missing_accept)
@app.route('/login', methods=['POST'])
@app.route("/login", methods=["POST"])
async def login():
check_self_origin()
with trio.fail_after(5.0):
form = await quart.request.form
if form.get('accept') != '1':
if form.get("accept") != "1":
return await index(missing_accept=True)
req_mac = form.get('mac')
req_mac = form.get("mac")
if not req_mac:
quart.abort(400, description='Missing MAC in request form data')
quart.abort(400, description="Missing MAC in request form data")
address = get_client_ip()
mac = await get_client_mac(address)
if str(mac) != req_mac:
quart.abort(403, description="Passed MAC in request form doesn't match client address")
await user_login(address, mac)
return quart.redirect('/', code=303)
return quart.redirect("/", code=303)
@app.route('/logout', methods=['POST'])
@app.route("/logout", methods=["POST"])
async def logout():
check_self_origin()
with trio.fail_after(5.0):
form = await quart.request.form
req_mac = form.get('mac')
req_mac = form.get("mac")
if not req_mac:
quart.abort(400, description='Missing MAC in request form data')
quart.abort(400, description="Missing MAC in request form data")
mac = await get_client_mac()
if str(mac) != req_mac:
quart.abort(403, description="Passed MAC in request form doesn't match client address")
await user_logout(mac)
return quart.redirect('/', code=303)
return quart.redirect("/", code=303)
@app.route('/api/captive-portal', methods=['GET'])
@app.route("/api/captive-portal", methods=["GET"])
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
async def captive_api():
state = await user_lookup()

View File

@ -43,7 +43,7 @@ class Channel:
_logger.debug(f"{self}: created (server_side={server_side})")
def __repr__(self) -> str:
return f'Channel[0x{id(self):x}]'
return f"Channel[0x{id(self):x}]"
async def do_handshake(self) -> capport.comm.message.Hello:
try:
@ -59,9 +59,8 @@ class Channel:
peer_hello = (await self.recv_msg()).to_variant()
if not isinstance(peer_hello, capport.comm.message.Hello):
raise HubConnectionReadError("Expected Hello as first message")
auth_succ = \
(peer_hello.authentication ==
self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
expected_auth = self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)
auth_succ = peer_hello.authentication == expected_auth
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
peer_auth = (await self.recv_msg()).to_variant()
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
@ -72,7 +71,7 @@ class Channel:
async def _read(self, num: int) -> bytes:
assert num > 0
buf = b''
buf = b""
# _logger.debug(f"{self}:_read({num})")
while num > 0:
try:
@ -92,7 +91,7 @@ class Channel:
async def _recv_raw_msg(self) -> bytes:
len_bytes = await self._read(4)
chunk_size, = struct.unpack('!I', len_bytes)
(chunk_size,) = struct.unpack("!I", len_bytes)
chunk = await self._read(chunk_size)
if chunk is None:
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
@ -116,7 +115,7 @@ class Channel:
async def send_msg(self, msg: capport.comm.message.Message):
chunk = msg.SerializeToString(deterministic=True)
chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size)
len_bytes = struct.pack("!I", chunk_size)
chunk = len_bytes + chunk
await self._send_raw(chunk)
@ -158,7 +157,7 @@ class Connection:
if msg:
await self._channel.send_msg(msg)
else:
await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message())
await self._channel.send_msg(capport.comm.message.Ping(payload=b"ping").to_message())
except trio.TooSlowError:
_logger.warning(f"{self._channel}: send timed out")
except ConnectionError as e:
@ -303,7 +302,7 @@ class Hub:
if is_controller:
state_filename = config.database_file
else:
state_filename = ''
state_filename = ""
self.database = capport.database.Database(state_filename=state_filename)
self._anon_context = ssl.SSLContext()
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
@ -311,14 +310,14 @@ class Hub:
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
# -> AECDH-AES256-SHA
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
self._anon_context.set_ciphers("HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0")
self._controllers: dict[str, ControllerConn] = {}
self._established: dict[uuid.UUID, Connection] = {}
async def _accept(self, stream):
remotename = stream.socket.getpeername()
if isinstance(remotename, tuple) and len(remotename) == 2:
remote = f'[{remotename[0]}]:{remotename[1]}'
remote = f"[{remotename[0]}]:{remotename[1]}"
else:
remote = str(remotename)
try:
@ -351,11 +350,11 @@ class Hub:
await trio.sleep_forever()
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
m = hmac.new(self._config.comm_secret.encode('utf8'), digestmod=hashlib.sha256)
m = hmac.new(self._config.comm_secret.encode("utf8"), digestmod=hashlib.sha256)
if server_side:
m.update(b'server$')
m.update(b"server$")
else:
m.update(b'client$')
m.update(b"client$")
m.update(ssl_binding)
return m.digest()

View File

@ -6,7 +6,7 @@ from .protobuf import message_pb2
def _message_to_variant(self: message_pb2.Message) -> typing.Any:
variant_name = self.WhichOneof('oneof')
variant_name = self.WhichOneof("oneof")
if variant_name:
return getattr(self, variant_name)
return None
@ -16,14 +16,15 @@ def _make_to_message(oneof_field):
def to_message(self) -> message_pb2.Message:
msg = message_pb2.Message(**{oneof_field: self})
return msg
return to_message
def _monkey_patch():
g = globals()
g['Message'] = message_pb2.Message
g["Message"] = message_pb2.Message
message_pb2.Message.to_variant = _message_to_variant
for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields:
for field in message_pb2._MESSAGE.oneofs_by_name["oneof"].fields:
type_name = field.message_type.name
field_type = getattr(message_pb2, type_name)
field_type.to_message = _make_to_message(field.name)

View File

@ -1,10 +1,8 @@
import google.protobuf.message
import typing
# manually maintained typehints for protobuf created (and monkey-patched) types
class Message(google.protobuf.message.Message):
hello: Hello
authentication_result: AuthenticationResult
@ -14,15 +12,13 @@ class Message(google.protobuf.message.Message):
def __init__(
self,
*,
hello: Hello | None=None,
authentication_result: AuthenticationResult | None=None,
ping: Ping | None=None,
mac_states: MacStates | None=None,
hello: Hello | None = None,
authentication_result: AuthenticationResult | None = None,
ping: Ping | None = None,
mac_states: MacStates | None = None,
) -> None: ...
def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ...
class Hello(google.protobuf.message.Message):
instance_id: bytes
hostname: str
@ -32,51 +28,43 @@ class Hello(google.protobuf.message.Message):
def __init__(
self,
*,
instance_id: bytes=b'',
hostname: str='',
is_controller: bool=False,
authentication: bytes=b'',
instance_id: bytes = b"",
hostname: str = "",
is_controller: bool = False,
authentication: bytes = b"",
) -> None: ...
def to_message(self) -> Message: ...
class AuthenticationResult(google.protobuf.message.Message):
success: bool
def __init__(
self,
*,
success: bool=False,
success: bool = False,
) -> None: ...
def to_message(self) -> Message: ...
class Ping(google.protobuf.message.Message):
payload: bytes
def __init__(
self,
*,
payload: bytes=b'',
payload: bytes = b"",
) -> None: ...
def to_message(self) -> Message: ...
class MacStates(google.protobuf.message.Message):
states: list[MacState]
def __init__(
self,
*,
states: list[MacState]=[],
states: list[MacState] = [],
) -> None: ...
def to_message(self) -> Message: ...
class MacState(google.protobuf.message.Message):
mac_address: bytes
last_change: int # Seconds of UTC time since epoch
@ -86,8 +74,8 @@ class MacState(google.protobuf.message.Message):
def __init__(
self,
*,
mac_address: bytes=b'',
last_change: int=0,
allow_until: int=0,
allowed: bool=False,
mac_address: bytes = b"",
last_change: int = 0,
allow_until: int = 0,
allowed: bool = False,
) -> None: ...

View File

@ -39,7 +39,7 @@ class Config:
@staticmethod
def load(filename: str | None = None) -> Config:
if filename is None:
for name in ('capport.yaml', '/etc/capport.yaml'):
for name in ("capport.yaml", "/etc/capport.yaml"):
if os.path.exists(name):
return Config.load(name)
raise RuntimeError("Missing config file")
@ -47,19 +47,19 @@ class Config:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise RuntimeError(f"Invalid yaml config data, expected dict: {data!r}")
controllers = list(map(str, data.pop('controllers')))
controllers = list(map(str, data.pop("controllers")))
config = Config(
_source_filename=filename,
controllers=controllers,
server_names=data.pop('server-names', []),
comm_secret=str(data.pop('comm-secret')),
cookie_secret=str(data.pop('cookie-secret')),
venue_info_url=str(data.pop('venue-info-url')),
session_timeout=data.pop('session-timeout', 3600),
api_port=data.pop('api-port', 8000),
controller_port=data.pop('controller-port', 5000),
database_file=str(data.pop('database-file', 'capport.state')),
debug=data.pop('debug', False)
server_names=data.pop("server-names", []),
comm_secret=str(data.pop("comm-secret")),
cookie_secret=str(data.pop("cookie-secret")),
venue_info_url=str(data.pop("venue-info-url")),
session_timeout=data.pop("session-timeout", 3600),
api_port=data.pop("api-port", 8000),
controller_port=data.pop("controller-port", 5000),
database_file=str(data.pop("database-file", "capport.state")),
debug=data.pop("debug", False),
)
if data:
_logger.error(f"Unknown config elements: {list(data.keys())}")

View File

@ -58,9 +58,9 @@ async def amain(config: capport.config.Config) -> None:
async with trio.open_nursery() as nursery:
# hub.run loads the statefile from disk before signalling it was "started"
await nursery.start(hub.run)
await sn.send('READY=1', 'STATUS=Deploying initial entries to nftables set')
await sn.send("READY=1", "STATUS=Deploying initial entries to nftables set")
app.apply_db_entries(hub.database.entries())
await sn.send('STATUS=Kernel fully synchronized')
await sn.send("STATUS=Kernel fully synchronized")
@dataclasses.dataclass
@ -69,7 +69,7 @@ class CliArguments:
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c')
parser.add_argument("--config", "-c")
args = parser.parse_args()
self.config = args.config

View File

@ -23,14 +23,14 @@ class MacAddress:
raw: bytes
def __str__(self) -> str:
return self.raw.hex(':')
return self.raw.hex(":")
def __repr__(self) -> str:
return repr(str(self))
@staticmethod
def parse(s: str) -> MacAddress:
return MacAddress(bytes.fromhex(s.replace(':', '')))
return MacAddress(bytes.fromhex(s.replace(":", "")))
@dataclasses.dataclass(frozen=True, order=True)
@ -40,9 +40,9 @@ class Timestamp:
def __str__(self) -> str:
try:
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
return ts.isoformat(sep=' ')
return ts.isoformat(sep=" ")
except OSError:
return f'epoch@{self.epoch}'
return f"epoch@{self.epoch}"
def __repr__(self) -> str:
return repr(str(self))
@ -85,7 +85,7 @@ class MacPublicState:
def allowed_remaining_duration(self) -> str:
mm, ss = divmod(self.allowed_remaining, 60)
hh, mm = divmod(mm, 60)
return f'{hh}:{mm:02}:{ss:02}'
return f"{hh}:{mm:02}:{ss:02}"
@property
def allowed_until(self) -> datetime.datetime | None:
@ -95,18 +95,18 @@ class MacPublicState:
def to_rfc8908(self, config: Config) -> quart.Response:
response: dict[str, typing.Any] = {
'user-portal-url': quart.url_for('index', _external=True),
"user-portal-url": quart.url_for("index", _external=True),
}
if config.venue_info_url:
response['venue-info-url'] = config.venue_info_url
response["venue-info-url"] = config.venue_info_url
if self.captive:
response['captive'] = True
response["captive"] = True
else:
response['captive'] = False
response['seconds-remaining'] = self.allowed_remaining
response['can-extend-session'] = True
response["captive"] = False
response["seconds-remaining"] = self.allowed_remaining
response["can-extend-session"] = True
return quart.Response(
json.dumps(response),
headers={'Cache-Control': 'private'},
content_type='application/captive+json',
headers={"Cache-Control": "private"},
content_type="application/captive+json",
)

View File

@ -128,7 +128,7 @@ def _serialize_mac_states_as_messages(
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
chunk = states.SerializeToString(deterministic=True)
chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size)
len_bytes = struct.pack("!I", chunk_size)
return len_bytes + chunk
@ -204,39 +204,32 @@ class Database:
return _serialize_mac_states_as_messages(self._macs)
def as_json(self) -> dict:
return {
str(addr): entry.as_json()
for addr, entry in self._macs.items()
}
return {str(addr): entry.as_json() for addr, entry in self._macs.items()}
async def _run_statefile(self) -> None:
rx: trio.MemoryReceiveChannel[
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
]
tx: trio.MemorySendChannel[
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
]
rx: trio.MemoryReceiveChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]]
tx: trio.MemorySendChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]]
tx, rx = trio.open_memory_channel(64)
self._send_changes = tx
assert self._state_filename
filename: str = self._state_filename
tmp_filename = f'{filename}.new-{os.getpid()}'
tmp_filename = f"{filename}.new-{os.getpid()}"
async def resync(all_states: list[capport.comm.message.MacStates]):
try:
async with await trio.open_file(tmp_filename, 'xb') as tf:
async with await trio.open_file(tmp_filename, "xb") as tf:
for states in all_states:
await tf.write(_states_to_chunk(states))
os.rename(tmp_filename, filename)
finally:
if os.path.exists(tmp_filename):
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
_logger.warning(f"Removing (failed) state export file {tmp_filename}")
os.unlink(tmp_filename)
try:
while True:
async with await trio.open_file(filename, 'ab', buffering=0) as sf:
async with await trio.open_file(filename, "ab", buffering=0) as sf:
while True:
update = await rx.receive()
if isinstance(update, list):
@ -247,10 +240,10 @@ class Database:
await resync(update)
# now reopen normal statefile and continue appending updates
except trio.Cancelled:
_logger.info('Final sync to disk')
_logger.info("Final sync to disk")
with trio.CancelScope(shield=True):
await resync(_serialize_mac_states(self._macs))
_logger.info('Final sync to disk done')
_logger.info("Final sync to disk done")
async def _load_statefile(self):
if not os.path.exists(self._state_filename):
@ -259,7 +252,7 @@ class Database:
# we're going to ignore changes from loading the file
pu = PendingUpdates(self)
pu._closed = False
async with await trio.open_file(self._state_filename, 'rb') as sf:
async with await trio.open_file(self._state_filename, "rb") as sf:
while True:
try:
len_bytes = await sf.read(4)
@ -268,7 +261,7 @@ class Database:
if len(len_bytes) < 4:
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
return
chunk_size, = struct.unpack('!I', len_bytes)
(chunk_size,) = struct.unpack("!I", len_bytes)
chunk = await sf.read(chunk_size)
except IOError as e:
_logger.error(f"Failed to read next chunk from statefile: {e}")
@ -286,7 +279,7 @@ class Database:
except Exception as e:
errors += 1
if errors < 5:
_logger.error(f'Failed to handle state: {e}')
_logger.error(f"Failed to handle state: {e}")
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
entry = self._macs.get(mac)

View File

@ -14,20 +14,17 @@ from . import cptypes
def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None):
# no labels in our names for now, always print help and type
if help:
print(f'# HELP {name} {help}')
print(f'# TYPE {name} {mtype}')
print(f"# HELP {name} {help}")
print(f"# TYPE {name} {mtype}")
if now:
print(f'{name} {value} {now}')
print(f"{name} {value} {now}")
else:
print(f'{name} {value}')
print(f"{name} {value}")
async def amain(client_ifname: str):
ns = capport.utils.nft_set.NftSet()
captive_allowed_entries: set[cptypes.MacAddress] = {
entry['mac']
for entry in ns.list()
}
captive_allowed_entries: set[cptypes.MacAddress] = {entry["mac"] for entry in ns.list()}
seen_allowed_entries: set[cptypes.MacAddress] = set()
total_ipv4 = 0
total_ipv6 = 0
@ -47,46 +44,46 @@ async def amain(client_ifname: str):
total_ipv6 += 1
unique_ipv6.add(mac)
print_metric(
'capport_allowed_macs',
'gauge',
"capport_allowed_macs",
"gauge",
len(captive_allowed_entries),
help='Number of allowed client mac addresses',
help="Number of allowed client mac addresses",
)
print_metric(
'capport_allowed_neigh_macs',
'gauge',
"capport_allowed_neigh_macs",
"gauge",
len(seen_allowed_entries),
help='Number of allowed client mac addresses seen in neighbor cache',
help="Number of allowed client mac addresses seen in neighbor cache",
)
print_metric(
'capport_unique',
'gauge',
"capport_unique",
"gauge",
len(unique_clients),
help='Number of clients (mac addresses) in client network seen in neighbor cache',
help="Number of clients (mac addresses) in client network seen in neighbor cache",
)
print_metric(
'capport_unique_ipv4',
'gauge',
"capport_unique_ipv4",
"gauge",
len(unique_ipv4),
help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache',
help="Number of IPv4 clients (unique per mac) in client network seen in neighbor cache",
)
print_metric(
'capport_unique_ipv6',
'gauge',
"capport_unique_ipv6",
"gauge",
len(unique_ipv6),
help='Number of IPv6 clients (unique per mac) in client network seen in neighbor cache',
help="Number of IPv6 clients (unique per mac) in client network seen in neighbor cache",
)
print_metric(
'capport_total_ipv4',
'gauge',
"capport_total_ipv4",
"gauge",
total_ipv4,
help='Number of IPv4 addresses seen in neighbor cache',
help="Number of IPv4 addresses seen in neighbor cache",
)
print_metric(
'capport_total_ipv6',
'gauge',
"capport_total_ipv6",
"gauge",
total_ipv6,
help='Number of IPv6 addresses seen in neighbor cache',
help="Number of IPv6 addresses seen in neighbor cache",
)

View File

@ -10,8 +10,8 @@ def init_logger(config: capport.config.Config):
if config.debug:
loglevel = logging.DEBUG
logging.basicConfig(
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S %z]',
format="%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s",
datefmt="[%Y-%m-%d %H:%M:%S %z]",
level=loglevel,
)
logging.getLogger('hypercorn').propagate = False
logging.getLogger("hypercorn").propagate = False

View File

@ -35,9 +35,9 @@ class NeighborController:
route = await self.get_route(address)
if route is None:
return None
index = route.get_attr(route.name2nla('oif'))
index = route.get_attr(route.name2nla("oif"))
try:
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
return self.ip.neigh("get", dst=str(address), ifindex=index, state="none")[0]
except pyroute2.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
@ -53,7 +53,7 @@ class NeighborController:
neigh = await self.get_neighbor(address, index=index, flags=flags)
if neigh is None:
return None
mac = neigh.get_attr(neigh.name2nla('lladdr'))
mac = neigh.get_attr(neigh.name2nla("lladdr"))
if mac is None:
return None
return cptypes.MacAddress.parse(mac)
@ -63,7 +63,7 @@ class NeighborController:
address: cptypes.IPAddress,
) -> pyroute2.iproute.linux.rtmsg | None:
try:
return self.ip.route('get', dst=str(address))[0]
return self.ip.route("get", dst=str(address))[0]
except pyroute2.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
@ -74,16 +74,16 @@ class NeighborController:
interface: str,
) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
ifindex = socket.if_nametoindex(interface)
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
unicast_num = pyroute2.netlink.rtnl.rt_type["unicast"]
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
for family in (socket.AF_INET, socket.AF_INET6):
for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family):
if neigh['ndm_type'] != unicast_num:
for neigh in self.ip.neigh("dump", ifindex=ifindex, family=family):
if neigh["ndm_type"] != unicast_num:
continue
mac = neigh.get_attr(neigh.name2nla('lladdr'))
mac = neigh.get_attr(neigh.name2nla("lladdr"))
if not mac:
continue
dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla('dst')))
dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla("dst")))
if dst.is_link_local:
continue
yield (cptypes.MacAddress.parse(mac), dst)

View File

@ -30,26 +30,20 @@ class NftSet:
timeout: int | float | None = None,
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict(
"NFTA_SET_ELEM_KEY": dict(
NFTA_DATA_VALUE=mac.raw,
),
}
if timeout:
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
attrs["NFTA_SET_ELEM_TIMEOUT"] = int(1000 * timeout)
return attrs
def _bulk_insert(
self,
entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]],
) -> None:
ser_entries = [
self._set_elem(mac)
for mac, _timeout in entries
]
ser_entries_with_timeout = [
self._set_elem(mac, timeout)
for mac, timeout in entries
]
ser_entries = [self._set_elem(mac) for mac, _timeout in entries]
ser_entries_with_timeout = [self._set_elem(mac, timeout) for mac, timeout in entries]
with self._socket.begin() as tx:
# create doesn't affect existing elements, so:
# make sure entries exists
@ -58,8 +52,8 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
@ -68,8 +62,8 @@ class NftSet:
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
@ -79,8 +73,8 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
),
)
@ -95,10 +89,7 @@ class NftSet:
self.bulk_insert([(mac, timeout)])
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
ser_entries = [
self._set_elem(mac)
for mac in entries
]
ser_entries = [self._set_elem(mac) for mac in entries]
with self._socket.begin() as tx:
# make sure entries exists
tx.put(
@ -106,8 +97,8 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
@ -116,8 +107,8 @@ class NftSet:
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET="allowed",
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
@ -137,20 +128,20 @@ class NftSet:
_nftsocket.NFT_MSG_GETSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET="allowed",
),
)
return [
{
'mac': cptypes.MacAddress(
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
"mac": cptypes.MacAddress(
elem.get_attr("NFTA_SET_ELEM_KEY").get_attr("NFTA_DATA_VALUE"),
),
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
"timeout": _from_msec(elem.get_attr("NFTA_SET_ELEM_TIMEOUT", None)),
"expiration": _from_msec(elem.get_attr("NFTA_SET_ELEM_EXPIRATION", None)),
}
for response in responses
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
for elem in response.get_attr("NFTA_SET_ELEM_LIST_ELEMENTS", [])
]
def flush(self) -> None:
@ -158,9 +149,9 @@ class NftSet:
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
NFTA_SET_ELEM_LIST_SET="allowed",
),
)
def create(self):
@ -170,7 +161,7 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_TABLE_NAME='captive_mark',
NFTA_TABLE_NAME="captive_mark",
),
)
tx.put(
@ -178,8 +169,8 @@ class NftSet:
pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_NAME='allowed',
NFTA_SET_TABLE="captive_mark",
NFTA_SET_NAME="allowed",
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
NFTA_SET_KEY_LEN=6, # length of key: mac address

View File

@ -13,7 +13,7 @@ from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base)
_NlMsgBase = typing.TypeVar("_NlMsgBase", bound=pyroute2.netlink.nlmsg_base)
# nft uses NESTED for those.. lets do the same
@ -25,16 +25,17 @@ _nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM
def _monkey_patch_pyroute2():
import pyroute2.netlink
# overwrite setdefault on nlmsg_base class hierarchy
_orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
def _nlmsg_base__setvalue(self, value):
if not self.header or not self['header'] or not isinstance(value, dict):
if not self.header or not self["header"] or not isinstance(value, dict):
return _orig_setvalue(self, value)
# merge headers instead of overwriting them
header = value.pop('header', {})
header = value.pop("header", {})
res = _orig_setvalue(self, value)
self['header'].update(header)
self["header"].update(header)
return res
def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None:
@ -52,20 +53,20 @@ _monkey_patch_pyroute2()
def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
msg = msg_class()
for key, value in header.items():
msg['header'][key] = value
msg["header"][key] = value
for key, value in fields.items():
msg[key] = value
if attrs:
attr_list = msg['attrs']
attr_list = msg["attrs"]
r_nla_map = msg_class._nlmsg_base__r_nla_map
for key, value in attrs.items():
if msg_class.prefix:
key = msg_class.name2nla(key)
prime = r_nla_map[key]
nla_class = prime['class']
nla_class = prime["class"]
if issubclass(nla_class, pyroute2.netlink.nla):
# support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']:
if prime["nla_array"]:
value = [
_build(nla_class, attrs=elem)
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
@ -83,10 +84,7 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
def __init__(self) -> None:
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
policy = {
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
for (x, y) in self.policy.items()
}
policy = {(x | (NFNL_SUBSYS_NFTABLES << 8)): y for (x, y) in self.policy.items()}
self.register_policy(policy)
@contextlib.contextmanager
@ -156,7 +154,7 @@ class NFTTransaction:
# no inner messages were queued... not sending anything
return
# request ACK on the last message (before END)
self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
self._msgs[-1]["header"]["flags"] |= pyroute2.netlink.NLM_F_ACK
self._msgs.append(
# batch end message
_build(

View File

@ -10,7 +10,7 @@ import trio.socket
def _check_watchdog_pid() -> bool:
wpid = os.environ.pop('WATCHDOG_PID', None)
wpid = os.environ.pop("WATCHDOG_PID", None)
if not wpid:
return True
return wpid == str(os.getpid())
@ -18,16 +18,16 @@ def _check_watchdog_pid() -> bool:
@contextlib.asynccontextmanager
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
target = os.environ.pop('NOTIFY_SOCKET', None)
target = os.environ.pop("NOTIFY_SOCKET", None)
ns: trio.socket.SocketType | None = None
watchdog_usec: int = 0
if target:
if target.startswith('@'):
if target.startswith("@"):
# Linux abstract namespace socket
target = '\0' + target[1:]
target = "\0" + target[1:]
ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
await ns.connect(target)
watchdog_usec_s = os.environ.pop('WATCHDOG_USEC', None)
watchdog_usec_s = os.environ.pop("WATCHDOG_USEC", None)
if _check_watchdog_pid() and watchdog_usec_s:
watchdog_usec = int(watchdog_usec_s)
try:
@ -52,18 +52,18 @@ class SdNotify:
async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None:
assert self.is_connected(), "Watchdog can't run without socket"
await self.send('WATCHDOG=1')
await self.send("WATCHDOG=1")
task_status.started()
# send every half of the watchdog timeout
interval = (watchdog_usec/1e6) / 2.0
interval = (watchdog_usec / 1e6) / 2.0
while True:
await trio.sleep(interval)
await self.send('WATCHDOG=1')
await self.send("WATCHDOG=1")
async def send(self, *msg: str) -> None:
if not self.is_connected():
return
dgram = '\n'.join(msg).encode('utf-8')
dgram = "\n".join(msg).encode("utf-8")
assert self._ns, "not connected" # checked above
sent = await self._ns.send(dgram)
if sent != len(dgram):

View File

@ -10,9 +10,9 @@ def get_local_timezone():
global _zoneinfo
if not _zoneinfo:
try:
with open('/etc/timezone') as f:
with open("/etc/timezone") as f:
key = f.readline().strip()
_zoneinfo = zoneinfo.ZoneInfo(key)
except (OSError, zoneinfo.ZoneInfoNotFoundError):
_zoneinfo = zoneinfo.ZoneInfo('UTC')
_zoneinfo = zoneinfo.ZoneInfo("UTC")
return _zoneinfo