format with black (mostly quotes)
This commit is contained in:
parent
ced589f28a
commit
0f45f89211
@ -6,9 +6,14 @@ requires = [
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.9"
|
||||
python_version = "3.11"
|
||||
# warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
exclude = [
|
||||
'_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary)
|
||||
]
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py311']
|
||||
exclude = '_pb2.py'
|
||||
|
@ -4,8 +4,8 @@ from .app_cls import MyQuartApp
|
||||
|
||||
app = MyQuartApp(__name__)
|
||||
|
||||
__import__('capport.api.setup')
|
||||
__import__('capport.api.proxy_fix')
|
||||
__import__('capport.api.lang')
|
||||
__import__('capport.api.template_filters')
|
||||
__import__('capport.api.views')
|
||||
__import__("capport.api.setup")
|
||||
__import__("capport.api.proxy_fix")
|
||||
__import__("capport.api.lang")
|
||||
__import__("capport.api.template_filters")
|
||||
__import__("capport.api.views")
|
||||
|
@ -49,16 +49,16 @@ class MyQuartApp(quart_trio.QuartTrio):
|
||||
|
||||
def __init__(self, import_name: str, **kwargs) -> None:
|
||||
self.my_config = capport.config.Config.load_default_once()
|
||||
kwargs.setdefault('template_folder', os.path.join(os.path.dirname(__file__), 'templates'))
|
||||
cust_templ = os.path.join('custom', 'templates')
|
||||
kwargs.setdefault("template_folder", os.path.join(os.path.dirname(__file__), "templates"))
|
||||
cust_templ = os.path.join("custom", "templates")
|
||||
if os.path.exists(cust_templ):
|
||||
self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ))
|
||||
cust_static = os.path.abspath(os.path.join('custom', 'static'))
|
||||
cust_static = os.path.abspath(os.path.join("custom", "static"))
|
||||
if os.path.exists(cust_static):
|
||||
static_folder = cust_static
|
||||
else:
|
||||
static_folder = os.path.join(os.path.dirname(__file__), 'static')
|
||||
kwargs.setdefault('static_folder', static_folder)
|
||||
static_folder = os.path.join(os.path.dirname(__file__), "static")
|
||||
kwargs.setdefault("static_folder", static_folder)
|
||||
super().__init__(import_name, **kwargs)
|
||||
self.debug = self.my_config.debug
|
||||
self.secret_key = self.my_config.cookie_secret
|
||||
|
@ -12,7 +12,7 @@ import capport.config
|
||||
|
||||
def run(config: hypercorn.config.Config) -> None:
|
||||
sockets = config.create_sockets()
|
||||
assert config.worker_class == 'trio'
|
||||
assert config.worker_class == "trio"
|
||||
|
||||
hypercorn.trio.run.trio_worker(config=config, sockets=sockets)
|
||||
|
||||
@ -28,7 +28,7 @@ class CliArguments:
|
||||
|
||||
def __init__(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', '-c')
|
||||
parser.add_argument("--config", "-c")
|
||||
args = parser.parse_args()
|
||||
|
||||
self.config = args.config
|
||||
@ -39,8 +39,8 @@ def main() -> None:
|
||||
_config = capport.config.Config.load_default_once(filename=args.config)
|
||||
|
||||
hypercorn_config = hypercorn.config.Config()
|
||||
hypercorn_config.application_path = 'capport.api.app'
|
||||
hypercorn_config.worker_class = 'trio'
|
||||
hypercorn_config.application_path = "capport.api.app"
|
||||
hypercorn_config.worker_class = "trio"
|
||||
hypercorn_config.bind = [f"127.0.0.1:{_config.api_port}"]
|
||||
|
||||
if _config.server_names:
|
||||
|
@ -7,22 +7,23 @@ import quart
|
||||
|
||||
from .app import app
|
||||
|
||||
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$')
|
||||
_VALID_LANGUAGE_NAMES = re.compile(r"^[-a-z0-9_]+$")
|
||||
|
||||
|
||||
def parse_accept_language(value: str) -> list[str]:
|
||||
value = value.strip()
|
||||
if not value or value == '*':
|
||||
if not value or value == "*":
|
||||
return []
|
||||
tuples = []
|
||||
for entry in value.split(','):
|
||||
attrs = entry.split(';')
|
||||
for entry in value.split(","):
|
||||
attrs = entry.split(";")
|
||||
name = attrs.pop(0).strip().lower()
|
||||
q = 1.0
|
||||
for attr in attrs:
|
||||
if not '=' in attr: continue
|
||||
key, value = attr.split('=', maxsplit=1)
|
||||
if key.strip().lower() == 'q':
|
||||
if not "=" in attr:
|
||||
continue
|
||||
key, value = attr.split("=", maxsplit=1)
|
||||
if key.strip().lower() == "q":
|
||||
try:
|
||||
q = float(value.strip())
|
||||
except ValueError:
|
||||
@ -32,14 +33,17 @@ def parse_accept_language(value: str) -> list[str]:
|
||||
tuples.sort()
|
||||
have = set()
|
||||
result = []
|
||||
for (_q, name) in tuples:
|
||||
if name in have: continue
|
||||
if name == '*': break
|
||||
for _q, name in tuples:
|
||||
if name in have:
|
||||
continue
|
||||
if name == "*":
|
||||
break
|
||||
have.add(name)
|
||||
if _VALID_LANGUAGE_NAMES.match(name):
|
||||
result.append(name)
|
||||
short_name = name.split('-', maxsplit=1)[0].split('_', maxsplit=1)[0]
|
||||
if not short_name or short_name in have: continue
|
||||
short_name = name.split("-", maxsplit=1)[0].split("_", maxsplit=1)[0]
|
||||
if not short_name or short_name in have:
|
||||
continue
|
||||
have.add(short_name)
|
||||
result.append(short_name)
|
||||
return result
|
||||
@ -50,23 +54,23 @@ def detect_language():
|
||||
g = quart.g
|
||||
r = quart.request
|
||||
s = quart.session
|
||||
if 'setlang' in r.args:
|
||||
lang = r.args.get('setlang').strip().lower()
|
||||
if "setlang" in r.args:
|
||||
lang = r.args.get("setlang").strip().lower()
|
||||
if lang and _VALID_LANGUAGE_NAMES.match(lang):
|
||||
if s.get('lang') != lang:
|
||||
s['lang'] = lang
|
||||
if s.get("lang") != lang:
|
||||
s["lang"] = lang
|
||||
g.langs = [lang]
|
||||
return
|
||||
else:
|
||||
# reset language
|
||||
s.pop('lang', None)
|
||||
lang = s.get('lang')
|
||||
s.pop("lang", None)
|
||||
lang = s.get("lang")
|
||||
if lang:
|
||||
lang = lang.strip().lower()
|
||||
if lang and _VALID_LANGUAGE_NAMES.match(lang):
|
||||
g.langs = [lang]
|
||||
return
|
||||
acc_lang = ','.join(r.headers.get_all('Accept-Language'))
|
||||
acc_lang = ",".join(r.headers.get_all("Accept-Language"))
|
||||
g.langs = parse_accept_language(acc_lang)
|
||||
|
||||
|
||||
@ -74,9 +78,6 @@ async def render_i18n_template(template, /, **kwargs) -> str:
|
||||
langs: list[str] = quart.g.langs
|
||||
if not langs:
|
||||
return await quart.render_template(template, **kwargs)
|
||||
names = [
|
||||
os.path.join('i18n', lang, template)
|
||||
for lang in langs
|
||||
]
|
||||
names = [os.path.join("i18n", lang, template) for lang in langs]
|
||||
names.append(template)
|
||||
return await quart.render_template(names, **kwargs)
|
||||
|
@ -31,28 +31,28 @@ def local_proxy_fix(request: quart.Request):
|
||||
return
|
||||
if not addr.is_loopback:
|
||||
return
|
||||
client = _get_first_in_list(request.headers.get('X-Forwarded-For'))
|
||||
client = _get_first_in_list(request.headers.get("X-Forwarded-For"))
|
||||
if not client:
|
||||
# assume this is always set behind reverse proxies supporting any of the headers
|
||||
return
|
||||
request.remote_addr = client
|
||||
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https'))
|
||||
scheme = _get_first_in_list(request.headers.get("X-Forwarded-Proto"), ("http", "https"))
|
||||
port: int | None = None
|
||||
if scheme:
|
||||
port = 443 if scheme == 'https' else 80
|
||||
port = 443 if scheme == "https" else 80
|
||||
request.scheme = scheme
|
||||
host = _get_first_in_list(request.headers.get('X-Forwarded-Host'))
|
||||
host = _get_first_in_list(request.headers.get("X-Forwarded-Host"))
|
||||
port_s: str | None
|
||||
if host:
|
||||
request.host = host
|
||||
if ':' in host and not host.endswith(']'):
|
||||
if ":" in host and not host.endswith("]"):
|
||||
try:
|
||||
_, port_s = host.rsplit(':', maxsplit=1)
|
||||
_, port_s = host.rsplit(":", maxsplit=1)
|
||||
port = int(port_s)
|
||||
except ValueError:
|
||||
# ignore invalid port in host header
|
||||
pass
|
||||
port_s = _get_first_in_list(request.headers.get('X-Forwarded-Port'))
|
||||
port_s = _get_first_in_list(request.headers.get("X-Forwarded-Port"))
|
||||
if port_s:
|
||||
try:
|
||||
port = int(port_s)
|
||||
@ -62,7 +62,7 @@ def local_proxy_fix(request: quart.Request):
|
||||
if port:
|
||||
if request.server and len(request.server) == 2:
|
||||
request.server = (request.server[0], port)
|
||||
root_path = _get_first_in_list(request.headers.get('X-Forwarded-Prefix'))
|
||||
root_path = _get_first_in_list(request.headers.get("X-Forwarded-Prefix"))
|
||||
if root_path:
|
||||
request.root_path = root_path
|
||||
|
||||
|
@ -47,10 +47,10 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
||||
|
||||
async def _setup(*, task_status=trio.TASK_STATUS_IGNORED):
|
||||
async with open_sdnotify() as sn:
|
||||
await sn.send('STATUS=Starting hub')
|
||||
await sn.send("STATUS=Starting hub")
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(_run_hub)
|
||||
await sn.send('READY=1', 'STATUS=Ready for client requests')
|
||||
await sn.send("READY=1", "STATUS=Ready for client requests")
|
||||
task_status.started()
|
||||
# continue running hub and systemd watchdog handler
|
||||
|
||||
|
@ -23,12 +23,12 @@ _logger = logging.getLogger(__name__)
|
||||
def get_client_ip() -> cptypes.IPAddress:
|
||||
remote_addr = quart.request.remote_addr
|
||||
if not remote_addr:
|
||||
quart.abort(500, 'Missing client address')
|
||||
quart.abort(500, "Missing client address")
|
||||
try:
|
||||
addr = ipaddress.ip_address(remote_addr)
|
||||
except ValueError as e:
|
||||
_logger.warning(f'Invalid client address {remote_addr!r}: {e}')
|
||||
quart.abort(500, 'Invalid client address')
|
||||
_logger.warning(f"Invalid client address {remote_addr!r}: {e}")
|
||||
quart.abort(500, "Invalid client address")
|
||||
return addr
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ async def get_client_mac(address: cptypes.IPAddress | None = None) -> cptypes.Ma
|
||||
mac = await get_client_mac_if_present(address)
|
||||
if mac is None:
|
||||
_logger.warning(f"Couldn't find MAC addresss for {address}")
|
||||
quart.abort(404, 'Unknown client')
|
||||
quart.abort(404, "Unknown client")
|
||||
return mac
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> Non
|
||||
quart.abort(500, str(e))
|
||||
|
||||
if pu:
|
||||
_logger.debug(f'User {mac} (with IP {address}) logged in')
|
||||
_logger.debug(f"User {mac} (with IP {address}) logged in")
|
||||
for msg in pu.serialized:
|
||||
await app.my_hub.broadcast(msg)
|
||||
|
||||
@ -71,7 +71,7 @@ async def user_logout(mac: cptypes.MacAddress) -> None:
|
||||
except capport.database.NotReadyYet as e:
|
||||
quart.abort(500, str(e))
|
||||
if pu:
|
||||
_logger.debug(f'User {mac} logged out')
|
||||
_logger.debug(f"User {mac} logged out")
|
||||
for msg in pu.serialized:
|
||||
await app.my_hub.broadcast(msg)
|
||||
|
||||
@ -92,73 +92,73 @@ async def user_lookup() -> cptypes.MacPublicState:
|
||||
|
||||
|
||||
def check_self_origin():
|
||||
origin = quart.request.headers.get('Origin', None)
|
||||
origin = quart.request.headers.get("Origin", None)
|
||||
if origin is None:
|
||||
# not a request by a modern browser - probably curl or something similar. don't care.
|
||||
return
|
||||
origin = origin.lower().strip()
|
||||
if origin == 'none':
|
||||
quart.abort(403, 'Origin is none')
|
||||
origin_parts = origin.split('/')
|
||||
if origin == "none":
|
||||
quart.abort(403, "Origin is none")
|
||||
origin_parts = origin.split("/")
|
||||
# Origin should look like: <scheme>://<hostname> (optionally followed by :<port>)
|
||||
if len(origin_parts) < 3:
|
||||
quart.abort(400, 'Broken Origin header')
|
||||
if origin_parts[0] != 'https:' and not app.my_config.debug:
|
||||
quart.abort(400, "Broken Origin header")
|
||||
if origin_parts[0] != "https:" and not app.my_config.debug:
|
||||
# -> require https in production
|
||||
quart.abort(403, 'Non-https Origin not allowed')
|
||||
quart.abort(403, "Non-https Origin not allowed")
|
||||
origin_host = origin_parts[2]
|
||||
host = quart.request.headers.get('Host', None)
|
||||
host = quart.request.headers.get("Host", None)
|
||||
if host is None:
|
||||
quart.abort(403, 'Missing Host header')
|
||||
quart.abort(403, "Missing Host header")
|
||||
if host.lower() != origin_host:
|
||||
quart.abort(403, 'Origin mismatch')
|
||||
quart.abort(403, "Origin mismatch")
|
||||
|
||||
|
||||
@app.route('/', methods=['GET'])
|
||||
@app.route("/", methods=["GET"])
|
||||
async def index(missing_accept: bool = False):
|
||||
state = await user_lookup()
|
||||
if not state.mac:
|
||||
return await render_i18n_template('index_unknown.html', state=state, missing_accept=missing_accept)
|
||||
return await render_i18n_template("index_unknown.html", state=state, missing_accept=missing_accept)
|
||||
elif state.allowed:
|
||||
return await render_i18n_template('index_active.html', state=state, missing_accept=missing_accept)
|
||||
return await render_i18n_template("index_active.html", state=state, missing_accept=missing_accept)
|
||||
else:
|
||||
return await render_i18n_template('index_inactive.html', state=state, missing_accept=missing_accept)
|
||||
return await render_i18n_template("index_inactive.html", state=state, missing_accept=missing_accept)
|
||||
|
||||
|
||||
@app.route('/login', methods=['POST'])
|
||||
@app.route("/login", methods=["POST"])
|
||||
async def login():
|
||||
check_self_origin()
|
||||
with trio.fail_after(5.0):
|
||||
form = await quart.request.form
|
||||
if form.get('accept') != '1':
|
||||
if form.get("accept") != "1":
|
||||
return await index(missing_accept=True)
|
||||
req_mac = form.get('mac')
|
||||
req_mac = form.get("mac")
|
||||
if not req_mac:
|
||||
quart.abort(400, description='Missing MAC in request form data')
|
||||
quart.abort(400, description="Missing MAC in request form data")
|
||||
address = get_client_ip()
|
||||
mac = await get_client_mac(address)
|
||||
if str(mac) != req_mac:
|
||||
quart.abort(403, description="Passed MAC in request form doesn't match client address")
|
||||
await user_login(address, mac)
|
||||
return quart.redirect('/', code=303)
|
||||
return quart.redirect("/", code=303)
|
||||
|
||||
|
||||
@app.route('/logout', methods=['POST'])
|
||||
@app.route("/logout", methods=["POST"])
|
||||
async def logout():
|
||||
check_self_origin()
|
||||
with trio.fail_after(5.0):
|
||||
form = await quart.request.form
|
||||
req_mac = form.get('mac')
|
||||
req_mac = form.get("mac")
|
||||
if not req_mac:
|
||||
quart.abort(400, description='Missing MAC in request form data')
|
||||
quart.abort(400, description="Missing MAC in request form data")
|
||||
mac = await get_client_mac()
|
||||
if str(mac) != req_mac:
|
||||
quart.abort(403, description="Passed MAC in request form doesn't match client address")
|
||||
await user_logout(mac)
|
||||
return quart.redirect('/', code=303)
|
||||
return quart.redirect("/", code=303)
|
||||
|
||||
|
||||
@app.route('/api/captive-portal', methods=['GET'])
|
||||
@app.route("/api/captive-portal", methods=["GET"])
|
||||
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
|
||||
async def captive_api():
|
||||
state = await user_lookup()
|
||||
|
@ -43,7 +43,7 @@ class Channel:
|
||||
_logger.debug(f"{self}: created (server_side={server_side})")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Channel[0x{id(self):x}]'
|
||||
return f"Channel[0x{id(self):x}]"
|
||||
|
||||
async def do_handshake(self) -> capport.comm.message.Hello:
|
||||
try:
|
||||
@ -59,9 +59,8 @@ class Channel:
|
||||
peer_hello = (await self.recv_msg()).to_variant()
|
||||
if not isinstance(peer_hello, capport.comm.message.Hello):
|
||||
raise HubConnectionReadError("Expected Hello as first message")
|
||||
auth_succ = \
|
||||
(peer_hello.authentication ==
|
||||
self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
|
||||
expected_auth = self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)
|
||||
auth_succ = peer_hello.authentication == expected_auth
|
||||
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
|
||||
peer_auth = (await self.recv_msg()).to_variant()
|
||||
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
|
||||
@ -72,7 +71,7 @@ class Channel:
|
||||
|
||||
async def _read(self, num: int) -> bytes:
|
||||
assert num > 0
|
||||
buf = b''
|
||||
buf = b""
|
||||
# _logger.debug(f"{self}:_read({num})")
|
||||
while num > 0:
|
||||
try:
|
||||
@ -92,7 +91,7 @@ class Channel:
|
||||
|
||||
async def _recv_raw_msg(self) -> bytes:
|
||||
len_bytes = await self._read(4)
|
||||
chunk_size, = struct.unpack('!I', len_bytes)
|
||||
(chunk_size,) = struct.unpack("!I", len_bytes)
|
||||
chunk = await self._read(chunk_size)
|
||||
if chunk is None:
|
||||
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
|
||||
@ -116,7 +115,7 @@ class Channel:
|
||||
async def send_msg(self, msg: capport.comm.message.Message):
|
||||
chunk = msg.SerializeToString(deterministic=True)
|
||||
chunk_size = len(chunk)
|
||||
len_bytes = struct.pack('!I', chunk_size)
|
||||
len_bytes = struct.pack("!I", chunk_size)
|
||||
chunk = len_bytes + chunk
|
||||
await self._send_raw(chunk)
|
||||
|
||||
@ -158,7 +157,7 @@ class Connection:
|
||||
if msg:
|
||||
await self._channel.send_msg(msg)
|
||||
else:
|
||||
await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message())
|
||||
await self._channel.send_msg(capport.comm.message.Ping(payload=b"ping").to_message())
|
||||
except trio.TooSlowError:
|
||||
_logger.warning(f"{self._channel}: send timed out")
|
||||
except ConnectionError as e:
|
||||
@ -303,7 +302,7 @@ class Hub:
|
||||
if is_controller:
|
||||
state_filename = config.database_file
|
||||
else:
|
||||
state_filename = ''
|
||||
state_filename = ""
|
||||
self.database = capport.database.Database(state_filename=state_filename)
|
||||
self._anon_context = ssl.SSLContext()
|
||||
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
|
||||
@ -311,14 +310,14 @@ class Hub:
|
||||
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||
# -> AECDH-AES256-SHA
|
||||
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
|
||||
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
|
||||
self._anon_context.set_ciphers("HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0")
|
||||
self._controllers: dict[str, ControllerConn] = {}
|
||||
self._established: dict[uuid.UUID, Connection] = {}
|
||||
|
||||
async def _accept(self, stream):
|
||||
remotename = stream.socket.getpeername()
|
||||
if isinstance(remotename, tuple) and len(remotename) == 2:
|
||||
remote = f'[{remotename[0]}]:{remotename[1]}'
|
||||
remote = f"[{remotename[0]}]:{remotename[1]}"
|
||||
else:
|
||||
remote = str(remotename)
|
||||
try:
|
||||
@ -351,11 +350,11 @@ class Hub:
|
||||
await trio.sleep_forever()
|
||||
|
||||
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
|
||||
m = hmac.new(self._config.comm_secret.encode('utf8'), digestmod=hashlib.sha256)
|
||||
m = hmac.new(self._config.comm_secret.encode("utf8"), digestmod=hashlib.sha256)
|
||||
if server_side:
|
||||
m.update(b'server$')
|
||||
m.update(b"server$")
|
||||
else:
|
||||
m.update(b'client$')
|
||||
m.update(b"client$")
|
||||
m.update(ssl_binding)
|
||||
return m.digest()
|
||||
|
||||
|
@ -6,7 +6,7 @@ from .protobuf import message_pb2
|
||||
|
||||
|
||||
def _message_to_variant(self: message_pb2.Message) -> typing.Any:
|
||||
variant_name = self.WhichOneof('oneof')
|
||||
variant_name = self.WhichOneof("oneof")
|
||||
if variant_name:
|
||||
return getattr(self, variant_name)
|
||||
return None
|
||||
@ -16,14 +16,15 @@ def _make_to_message(oneof_field):
|
||||
def to_message(self) -> message_pb2.Message:
|
||||
msg = message_pb2.Message(**{oneof_field: self})
|
||||
return msg
|
||||
|
||||
return to_message
|
||||
|
||||
|
||||
def _monkey_patch():
|
||||
g = globals()
|
||||
g['Message'] = message_pb2.Message
|
||||
g["Message"] = message_pb2.Message
|
||||
message_pb2.Message.to_variant = _message_to_variant
|
||||
for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields:
|
||||
for field in message_pb2._MESSAGE.oneofs_by_name["oneof"].fields:
|
||||
type_name = field.message_type.name
|
||||
field_type = getattr(message_pb2, type_name)
|
||||
field_type.to_message = _make_to_message(field.name)
|
||||
|
@ -1,10 +1,8 @@
|
||||
import google.protobuf.message
|
||||
import typing
|
||||
|
||||
|
||||
# manually maintained typehints for protobuf created (and monkey-patched) types
|
||||
|
||||
|
||||
class Message(google.protobuf.message.Message):
|
||||
hello: Hello
|
||||
authentication_result: AuthenticationResult
|
||||
@ -19,10 +17,8 @@ class Message(google.protobuf.message.Message):
|
||||
ping: Ping | None = None,
|
||||
mac_states: MacStates | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ...
|
||||
|
||||
|
||||
class Hello(google.protobuf.message.Message):
|
||||
instance_id: bytes
|
||||
hostname: str
|
||||
@ -32,15 +28,13 @@ class Hello(google.protobuf.message.Message):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
instance_id: bytes=b'',
|
||||
hostname: str='',
|
||||
instance_id: bytes = b"",
|
||||
hostname: str = "",
|
||||
is_controller: bool = False,
|
||||
authentication: bytes=b'',
|
||||
authentication: bytes = b"",
|
||||
) -> None: ...
|
||||
|
||||
def to_message(self) -> Message: ...
|
||||
|
||||
|
||||
class AuthenticationResult(google.protobuf.message.Message):
|
||||
success: bool
|
||||
|
||||
@ -49,22 +43,18 @@ class AuthenticationResult(google.protobuf.message.Message):
|
||||
*,
|
||||
success: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
def to_message(self) -> Message: ...
|
||||
|
||||
|
||||
class Ping(google.protobuf.message.Message):
|
||||
payload: bytes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: bytes=b'',
|
||||
payload: bytes = b"",
|
||||
) -> None: ...
|
||||
|
||||
def to_message(self) -> Message: ...
|
||||
|
||||
|
||||
class MacStates(google.protobuf.message.Message):
|
||||
states: list[MacState]
|
||||
|
||||
@ -73,10 +63,8 @@ class MacStates(google.protobuf.message.Message):
|
||||
*,
|
||||
states: list[MacState] = [],
|
||||
) -> None: ...
|
||||
|
||||
def to_message(self) -> Message: ...
|
||||
|
||||
|
||||
class MacState(google.protobuf.message.Message):
|
||||
mac_address: bytes
|
||||
last_change: int # Seconds of UTC time since epoch
|
||||
@ -86,7 +74,7 @@ class MacState(google.protobuf.message.Message):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mac_address: bytes=b'',
|
||||
mac_address: bytes = b"",
|
||||
last_change: int = 0,
|
||||
allow_until: int = 0,
|
||||
allowed: bool = False,
|
||||
|
@ -39,7 +39,7 @@ class Config:
|
||||
@staticmethod
|
||||
def load(filename: str | None = None) -> Config:
|
||||
if filename is None:
|
||||
for name in ('capport.yaml', '/etc/capport.yaml'):
|
||||
for name in ("capport.yaml", "/etc/capport.yaml"):
|
||||
if os.path.exists(name):
|
||||
return Config.load(name)
|
||||
raise RuntimeError("Missing config file")
|
||||
@ -47,19 +47,19 @@ class Config:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError(f"Invalid yaml config data, expected dict: {data!r}")
|
||||
controllers = list(map(str, data.pop('controllers')))
|
||||
controllers = list(map(str, data.pop("controllers")))
|
||||
config = Config(
|
||||
_source_filename=filename,
|
||||
controllers=controllers,
|
||||
server_names=data.pop('server-names', []),
|
||||
comm_secret=str(data.pop('comm-secret')),
|
||||
cookie_secret=str(data.pop('cookie-secret')),
|
||||
venue_info_url=str(data.pop('venue-info-url')),
|
||||
session_timeout=data.pop('session-timeout', 3600),
|
||||
api_port=data.pop('api-port', 8000),
|
||||
controller_port=data.pop('controller-port', 5000),
|
||||
database_file=str(data.pop('database-file', 'capport.state')),
|
||||
debug=data.pop('debug', False)
|
||||
server_names=data.pop("server-names", []),
|
||||
comm_secret=str(data.pop("comm-secret")),
|
||||
cookie_secret=str(data.pop("cookie-secret")),
|
||||
venue_info_url=str(data.pop("venue-info-url")),
|
||||
session_timeout=data.pop("session-timeout", 3600),
|
||||
api_port=data.pop("api-port", 8000),
|
||||
controller_port=data.pop("controller-port", 5000),
|
||||
database_file=str(data.pop("database-file", "capport.state")),
|
||||
debug=data.pop("debug", False),
|
||||
)
|
||||
if data:
|
||||
_logger.error(f"Unknown config elements: {list(data.keys())}")
|
||||
|
@ -58,9 +58,9 @@ async def amain(config: capport.config.Config) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# hub.run loads the statefile from disk before signalling it was "started"
|
||||
await nursery.start(hub.run)
|
||||
await sn.send('READY=1', 'STATUS=Deploying initial entries to nftables set')
|
||||
await sn.send("READY=1", "STATUS=Deploying initial entries to nftables set")
|
||||
app.apply_db_entries(hub.database.entries())
|
||||
await sn.send('STATUS=Kernel fully synchronized')
|
||||
await sn.send("STATUS=Kernel fully synchronized")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -69,7 +69,7 @@ class CliArguments:
|
||||
|
||||
def __init__(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', '-c')
|
||||
parser.add_argument("--config", "-c")
|
||||
args = parser.parse_args()
|
||||
|
||||
self.config = args.config
|
||||
|
@ -23,14 +23,14 @@ class MacAddress:
|
||||
raw: bytes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.raw.hex(':')
|
||||
return self.raw.hex(":")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(str(self))
|
||||
|
||||
@staticmethod
|
||||
def parse(s: str) -> MacAddress:
|
||||
return MacAddress(bytes.fromhex(s.replace(':', '')))
|
||||
return MacAddress(bytes.fromhex(s.replace(":", "")))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, order=True)
|
||||
@ -40,9 +40,9 @@ class Timestamp:
|
||||
def __str__(self) -> str:
|
||||
try:
|
||||
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
|
||||
return ts.isoformat(sep=' ')
|
||||
return ts.isoformat(sep=" ")
|
||||
except OSError:
|
||||
return f'epoch@{self.epoch}'
|
||||
return f"epoch@{self.epoch}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(str(self))
|
||||
@ -85,7 +85,7 @@ class MacPublicState:
|
||||
def allowed_remaining_duration(self) -> str:
|
||||
mm, ss = divmod(self.allowed_remaining, 60)
|
||||
hh, mm = divmod(mm, 60)
|
||||
return f'{hh}:{mm:02}:{ss:02}'
|
||||
return f"{hh}:{mm:02}:{ss:02}"
|
||||
|
||||
@property
|
||||
def allowed_until(self) -> datetime.datetime | None:
|
||||
@ -95,18 +95,18 @@ class MacPublicState:
|
||||
|
||||
def to_rfc8908(self, config: Config) -> quart.Response:
|
||||
response: dict[str, typing.Any] = {
|
||||
'user-portal-url': quart.url_for('index', _external=True),
|
||||
"user-portal-url": quart.url_for("index", _external=True),
|
||||
}
|
||||
if config.venue_info_url:
|
||||
response['venue-info-url'] = config.venue_info_url
|
||||
response["venue-info-url"] = config.venue_info_url
|
||||
if self.captive:
|
||||
response['captive'] = True
|
||||
response["captive"] = True
|
||||
else:
|
||||
response['captive'] = False
|
||||
response['seconds-remaining'] = self.allowed_remaining
|
||||
response['can-extend-session'] = True
|
||||
response["captive"] = False
|
||||
response["seconds-remaining"] = self.allowed_remaining
|
||||
response["can-extend-session"] = True
|
||||
return quart.Response(
|
||||
json.dumps(response),
|
||||
headers={'Cache-Control': 'private'},
|
||||
content_type='application/captive+json',
|
||||
headers={"Cache-Control": "private"},
|
||||
content_type="application/captive+json",
|
||||
)
|
||||
|
@ -128,7 +128,7 @@ def _serialize_mac_states_as_messages(
|
||||
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
|
||||
chunk = states.SerializeToString(deterministic=True)
|
||||
chunk_size = len(chunk)
|
||||
len_bytes = struct.pack('!I', chunk_size)
|
||||
len_bytes = struct.pack("!I", chunk_size)
|
||||
return len_bytes + chunk
|
||||
|
||||
|
||||
@ -204,39 +204,32 @@ class Database:
|
||||
return _serialize_mac_states_as_messages(self._macs)
|
||||
|
||||
def as_json(self) -> dict:
|
||||
return {
|
||||
str(addr): entry.as_json()
|
||||
for addr, entry in self._macs.items()
|
||||
}
|
||||
return {str(addr): entry.as_json() for addr, entry in self._macs.items()}
|
||||
|
||||
async def _run_statefile(self) -> None:
|
||||
rx: trio.MemoryReceiveChannel[
|
||||
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
|
||||
]
|
||||
tx: trio.MemorySendChannel[
|
||||
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
|
||||
]
|
||||
rx: trio.MemoryReceiveChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]]
|
||||
tx: trio.MemorySendChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]]
|
||||
tx, rx = trio.open_memory_channel(64)
|
||||
self._send_changes = tx
|
||||
|
||||
assert self._state_filename
|
||||
filename: str = self._state_filename
|
||||
tmp_filename = f'{filename}.new-{os.getpid()}'
|
||||
tmp_filename = f"{filename}.new-{os.getpid()}"
|
||||
|
||||
async def resync(all_states: list[capport.comm.message.MacStates]):
|
||||
try:
|
||||
async with await trio.open_file(tmp_filename, 'xb') as tf:
|
||||
async with await trio.open_file(tmp_filename, "xb") as tf:
|
||||
for states in all_states:
|
||||
await tf.write(_states_to_chunk(states))
|
||||
os.rename(tmp_filename, filename)
|
||||
finally:
|
||||
if os.path.exists(tmp_filename):
|
||||
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
|
||||
_logger.warning(f"Removing (failed) state export file {tmp_filename}")
|
||||
os.unlink(tmp_filename)
|
||||
|
||||
try:
|
||||
while True:
|
||||
async with await trio.open_file(filename, 'ab', buffering=0) as sf:
|
||||
async with await trio.open_file(filename, "ab", buffering=0) as sf:
|
||||
while True:
|
||||
update = await rx.receive()
|
||||
if isinstance(update, list):
|
||||
@ -247,10 +240,10 @@ class Database:
|
||||
await resync(update)
|
||||
# now reopen normal statefile and continue appending updates
|
||||
except trio.Cancelled:
|
||||
_logger.info('Final sync to disk')
|
||||
_logger.info("Final sync to disk")
|
||||
with trio.CancelScope(shield=True):
|
||||
await resync(_serialize_mac_states(self._macs))
|
||||
_logger.info('Final sync to disk done')
|
||||
_logger.info("Final sync to disk done")
|
||||
|
||||
async def _load_statefile(self):
|
||||
if not os.path.exists(self._state_filename):
|
||||
@ -259,7 +252,7 @@ class Database:
|
||||
# we're going to ignore changes from loading the file
|
||||
pu = PendingUpdates(self)
|
||||
pu._closed = False
|
||||
async with await trio.open_file(self._state_filename, 'rb') as sf:
|
||||
async with await trio.open_file(self._state_filename, "rb") as sf:
|
||||
while True:
|
||||
try:
|
||||
len_bytes = await sf.read(4)
|
||||
@ -268,7 +261,7 @@ class Database:
|
||||
if len(len_bytes) < 4:
|
||||
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
|
||||
return
|
||||
chunk_size, = struct.unpack('!I', len_bytes)
|
||||
(chunk_size,) = struct.unpack("!I", len_bytes)
|
||||
chunk = await sf.read(chunk_size)
|
||||
except IOError as e:
|
||||
_logger.error(f"Failed to read next chunk from statefile: {e}")
|
||||
@ -286,7 +279,7 @@ class Database:
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if errors < 5:
|
||||
_logger.error(f'Failed to handle state: {e}')
|
||||
_logger.error(f"Failed to handle state: {e}")
|
||||
|
||||
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
|
||||
entry = self._macs.get(mac)
|
||||
|
@ -14,20 +14,17 @@ from . import cptypes
|
||||
def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None):
|
||||
# no labels in our names for now, always print help and type
|
||||
if help:
|
||||
print(f'# HELP {name} {help}')
|
||||
print(f'# TYPE {name} {mtype}')
|
||||
print(f"# HELP {name} {help}")
|
||||
print(f"# TYPE {name} {mtype}")
|
||||
if now:
|
||||
print(f'{name} {value} {now}')
|
||||
print(f"{name} {value} {now}")
|
||||
else:
|
||||
print(f'{name} {value}')
|
||||
print(f"{name} {value}")
|
||||
|
||||
|
||||
async def amain(client_ifname: str):
|
||||
ns = capport.utils.nft_set.NftSet()
|
||||
captive_allowed_entries: set[cptypes.MacAddress] = {
|
||||
entry['mac']
|
||||
for entry in ns.list()
|
||||
}
|
||||
captive_allowed_entries: set[cptypes.MacAddress] = {entry["mac"] for entry in ns.list()}
|
||||
seen_allowed_entries: set[cptypes.MacAddress] = set()
|
||||
total_ipv4 = 0
|
||||
total_ipv6 = 0
|
||||
@ -47,46 +44,46 @@ async def amain(client_ifname: str):
|
||||
total_ipv6 += 1
|
||||
unique_ipv6.add(mac)
|
||||
print_metric(
|
||||
'capport_allowed_macs',
|
||||
'gauge',
|
||||
"capport_allowed_macs",
|
||||
"gauge",
|
||||
len(captive_allowed_entries),
|
||||
help='Number of allowed client mac addresses',
|
||||
help="Number of allowed client mac addresses",
|
||||
)
|
||||
print_metric(
|
||||
'capport_allowed_neigh_macs',
|
||||
'gauge',
|
||||
"capport_allowed_neigh_macs",
|
||||
"gauge",
|
||||
len(seen_allowed_entries),
|
||||
help='Number of allowed client mac addresses seen in neighbor cache',
|
||||
help="Number of allowed client mac addresses seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
'capport_unique',
|
||||
'gauge',
|
||||
"capport_unique",
|
||||
"gauge",
|
||||
len(unique_clients),
|
||||
help='Number of clients (mac addresses) in client network seen in neighbor cache',
|
||||
help="Number of clients (mac addresses) in client network seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
'capport_unique_ipv4',
|
||||
'gauge',
|
||||
"capport_unique_ipv4",
|
||||
"gauge",
|
||||
len(unique_ipv4),
|
||||
help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache',
|
||||
help="Number of IPv4 clients (unique per mac) in client network seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
'capport_unique_ipv6',
|
||||
'gauge',
|
||||
"capport_unique_ipv6",
|
||||
"gauge",
|
||||
len(unique_ipv6),
|
||||
help='Number of IPv6 clients (unique per mac) in client network seen in neighbor cache',
|
||||
help="Number of IPv6 clients (unique per mac) in client network seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
'capport_total_ipv4',
|
||||
'gauge',
|
||||
"capport_total_ipv4",
|
||||
"gauge",
|
||||
total_ipv4,
|
||||
help='Number of IPv4 addresses seen in neighbor cache',
|
||||
help="Number of IPv4 addresses seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
'capport_total_ipv6',
|
||||
'gauge',
|
||||
"capport_total_ipv6",
|
||||
"gauge",
|
||||
total_ipv6,
|
||||
help='Number of IPv6 addresses seen in neighbor cache',
|
||||
help="Number of IPv6 addresses seen in neighbor cache",
|
||||
)
|
||||
|
||||
|
||||
|
@ -10,8 +10,8 @@ def init_logger(config: capport.config.Config):
|
||||
if config.debug:
|
||||
loglevel = logging.DEBUG
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S %z]',
|
||||
format="%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s",
|
||||
datefmt="[%Y-%m-%d %H:%M:%S %z]",
|
||||
level=loglevel,
|
||||
)
|
||||
logging.getLogger('hypercorn').propagate = False
|
||||
logging.getLogger("hypercorn").propagate = False
|
||||
|
@ -35,9 +35,9 @@ class NeighborController:
|
||||
route = await self.get_route(address)
|
||||
if route is None:
|
||||
return None
|
||||
index = route.get_attr(route.name2nla('oif'))
|
||||
index = route.get_attr(route.name2nla("oif"))
|
||||
try:
|
||||
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
|
||||
return self.ip.neigh("get", dst=str(address), ifindex=index, state="none")[0]
|
||||
except pyroute2.netlink.exceptions.NetlinkError as e:
|
||||
if e.code == errno.ENOENT:
|
||||
return None
|
||||
@ -53,7 +53,7 @@ class NeighborController:
|
||||
neigh = await self.get_neighbor(address, index=index, flags=flags)
|
||||
if neigh is None:
|
||||
return None
|
||||
mac = neigh.get_attr(neigh.name2nla('lladdr'))
|
||||
mac = neigh.get_attr(neigh.name2nla("lladdr"))
|
||||
if mac is None:
|
||||
return None
|
||||
return cptypes.MacAddress.parse(mac)
|
||||
@ -63,7 +63,7 @@ class NeighborController:
|
||||
address: cptypes.IPAddress,
|
||||
) -> pyroute2.iproute.linux.rtmsg | None:
|
||||
try:
|
||||
return self.ip.route('get', dst=str(address))[0]
|
||||
return self.ip.route("get", dst=str(address))[0]
|
||||
except pyroute2.netlink.exceptions.NetlinkError as e:
|
||||
if e.code == errno.ENOENT:
|
||||
return None
|
||||
@ -74,16 +74,16 @@ class NeighborController:
|
||||
interface: str,
|
||||
) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
|
||||
ifindex = socket.if_nametoindex(interface)
|
||||
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
|
||||
unicast_num = pyroute2.netlink.rtnl.rt_type["unicast"]
|
||||
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
|
||||
for family in (socket.AF_INET, socket.AF_INET6):
|
||||
for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family):
|
||||
if neigh['ndm_type'] != unicast_num:
|
||||
for neigh in self.ip.neigh("dump", ifindex=ifindex, family=family):
|
||||
if neigh["ndm_type"] != unicast_num:
|
||||
continue
|
||||
mac = neigh.get_attr(neigh.name2nla('lladdr'))
|
||||
mac = neigh.get_attr(neigh.name2nla("lladdr"))
|
||||
if not mac:
|
||||
continue
|
||||
dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla('dst')))
|
||||
dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla("dst")))
|
||||
if dst.is_link_local:
|
||||
continue
|
||||
yield (cptypes.MacAddress.parse(mac), dst)
|
||||
|
@ -30,26 +30,20 @@ class NftSet:
|
||||
timeout: int | float | None = None,
|
||||
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
|
||||
attrs: dict[str, typing.Any] = {
|
||||
'NFTA_SET_ELEM_KEY': dict(
|
||||
"NFTA_SET_ELEM_KEY": dict(
|
||||
NFTA_DATA_VALUE=mac.raw,
|
||||
),
|
||||
}
|
||||
if timeout:
|
||||
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
|
||||
attrs["NFTA_SET_ELEM_TIMEOUT"] = int(1000 * timeout)
|
||||
return attrs
|
||||
|
||||
def _bulk_insert(
|
||||
self,
|
||||
entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]],
|
||||
) -> None:
|
||||
ser_entries = [
|
||||
self._set_elem(mac)
|
||||
for mac, _timeout in entries
|
||||
]
|
||||
ser_entries_with_timeout = [
|
||||
self._set_elem(mac, timeout)
|
||||
for mac, timeout in entries
|
||||
]
|
||||
ser_entries = [self._set_elem(mac) for mac, _timeout in entries]
|
||||
ser_entries_with_timeout = [self._set_elem(mac, timeout) for mac, timeout in entries]
|
||||
with self._socket.begin() as tx:
|
||||
# create doesn't affect existing elements, so:
|
||||
# make sure entries exists
|
||||
@ -58,8 +52,8 @@ class NftSet:
|
||||
pyroute2.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
@ -68,8 +62,8 @@ class NftSet:
|
||||
_nftsocket.NFT_MSG_DELSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
@ -79,8 +73,8 @@ class NftSet:
|
||||
pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
|
||||
),
|
||||
)
|
||||
@ -95,10 +89,7 @@ class NftSet:
|
||||
self.bulk_insert([(mac, timeout)])
|
||||
|
||||
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
||||
ser_entries = [
|
||||
self._set_elem(mac)
|
||||
for mac in entries
|
||||
]
|
||||
ser_entries = [self._set_elem(mac) for mac in entries]
|
||||
with self._socket.begin() as tx:
|
||||
# make sure entries exists
|
||||
tx.put(
|
||||
@ -106,8 +97,8 @@ class NftSet:
|
||||
pyroute2.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
@ -116,8 +107,8 @@ class NftSet:
|
||||
_nftsocket.NFT_MSG_DELSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
@ -137,20 +128,20 @@ class NftSet:
|
||||
_nftsocket.NFT_MSG_GETSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
)
|
||||
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||
),
|
||||
)
|
||||
return [
|
||||
{
|
||||
'mac': cptypes.MacAddress(
|
||||
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
|
||||
"mac": cptypes.MacAddress(
|
||||
elem.get_attr("NFTA_SET_ELEM_KEY").get_attr("NFTA_DATA_VALUE"),
|
||||
),
|
||||
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
|
||||
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
|
||||
"timeout": _from_msec(elem.get_attr("NFTA_SET_ELEM_TIMEOUT", None)),
|
||||
"expiration": _from_msec(elem.get_attr("NFTA_SET_ELEM_EXPIRATION", None)),
|
||||
}
|
||||
for response in responses
|
||||
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
|
||||
for elem in response.get_attr("NFTA_SET_ELEM_LIST_ELEMENTS", [])
|
||||
]
|
||||
|
||||
def flush(self) -> None:
|
||||
@ -158,9 +149,9 @@ class NftSet:
|
||||
_nftsocket.NFT_MSG_DELSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
)
|
||||
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||
),
|
||||
)
|
||||
|
||||
def create(self):
|
||||
@ -170,7 +161,7 @@ class NftSet:
|
||||
pyroute2.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_TABLE_NAME='captive_mark',
|
||||
NFTA_TABLE_NAME="captive_mark",
|
||||
),
|
||||
)
|
||||
tx.put(
|
||||
@ -178,8 +169,8 @@ class NftSet:
|
||||
pyroute2.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_NAME='allowed',
|
||||
NFTA_SET_TABLE="captive_mark",
|
||||
NFTA_SET_NAME="allowed",
|
||||
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
|
||||
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
|
||||
NFTA_SET_KEY_LEN=6, # length of key: mac address
|
||||
|
@ -13,7 +13,7 @@ from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
|
||||
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
|
||||
|
||||
|
||||
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base)
|
||||
_NlMsgBase = typing.TypeVar("_NlMsgBase", bound=pyroute2.netlink.nlmsg_base)
|
||||
|
||||
|
||||
# nft uses NESTED for those.. lets do the same
|
||||
@ -25,16 +25,17 @@ _nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM
|
||||
|
||||
def _monkey_patch_pyroute2():
|
||||
import pyroute2.netlink
|
||||
|
||||
# overwrite setdefault on nlmsg_base class hierarchy
|
||||
_orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
|
||||
|
||||
def _nlmsg_base__setvalue(self, value):
|
||||
if not self.header or not self['header'] or not isinstance(value, dict):
|
||||
if not self.header or not self["header"] or not isinstance(value, dict):
|
||||
return _orig_setvalue(self, value)
|
||||
# merge headers instead of overwriting them
|
||||
header = value.pop('header', {})
|
||||
header = value.pop("header", {})
|
||||
res = _orig_setvalue(self, value)
|
||||
self['header'].update(header)
|
||||
self["header"].update(header)
|
||||
return res
|
||||
|
||||
def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None:
|
||||
@ -52,20 +53,20 @@ _monkey_patch_pyroute2()
|
||||
def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
|
||||
msg = msg_class()
|
||||
for key, value in header.items():
|
||||
msg['header'][key] = value
|
||||
msg["header"][key] = value
|
||||
for key, value in fields.items():
|
||||
msg[key] = value
|
||||
if attrs:
|
||||
attr_list = msg['attrs']
|
||||
attr_list = msg["attrs"]
|
||||
r_nla_map = msg_class._nlmsg_base__r_nla_map
|
||||
for key, value in attrs.items():
|
||||
if msg_class.prefix:
|
||||
key = msg_class.name2nla(key)
|
||||
prime = r_nla_map[key]
|
||||
nla_class = prime['class']
|
||||
nla_class = prime["class"]
|
||||
if issubclass(nla_class, pyroute2.netlink.nla):
|
||||
# support passing nested attributes as dicts of subattributes (or lists of those)
|
||||
if prime['nla_array']:
|
||||
if prime["nla_array"]:
|
||||
value = [
|
||||
_build(nla_class, attrs=elem)
|
||||
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
|
||||
@ -83,10 +84,7 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
|
||||
policy = {
|
||||
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
|
||||
for (x, y) in self.policy.items()
|
||||
}
|
||||
policy = {(x | (NFNL_SUBSYS_NFTABLES << 8)): y for (x, y) in self.policy.items()}
|
||||
self.register_policy(policy)
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -156,7 +154,7 @@ class NFTTransaction:
|
||||
# no inner messages were queued... not sending anything
|
||||
return
|
||||
# request ACK on the last message (before END)
|
||||
self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
|
||||
self._msgs[-1]["header"]["flags"] |= pyroute2.netlink.NLM_F_ACK
|
||||
self._msgs.append(
|
||||
# batch end message
|
||||
_build(
|
||||
|
@ -10,7 +10,7 @@ import trio.socket
|
||||
|
||||
|
||||
def _check_watchdog_pid() -> bool:
|
||||
wpid = os.environ.pop('WATCHDOG_PID', None)
|
||||
wpid = os.environ.pop("WATCHDOG_PID", None)
|
||||
if not wpid:
|
||||
return True
|
||||
return wpid == str(os.getpid())
|
||||
@ -18,16 +18,16 @@ def _check_watchdog_pid() -> bool:
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
|
||||
target = os.environ.pop('NOTIFY_SOCKET', None)
|
||||
target = os.environ.pop("NOTIFY_SOCKET", None)
|
||||
ns: trio.socket.SocketType | None = None
|
||||
watchdog_usec: int = 0
|
||||
if target:
|
||||
if target.startswith('@'):
|
||||
if target.startswith("@"):
|
||||
# Linux abstract namespace socket
|
||||
target = '\0' + target[1:]
|
||||
target = "\0" + target[1:]
|
||||
ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
||||
await ns.connect(target)
|
||||
watchdog_usec_s = os.environ.pop('WATCHDOG_USEC', None)
|
||||
watchdog_usec_s = os.environ.pop("WATCHDOG_USEC", None)
|
||||
if _check_watchdog_pid() and watchdog_usec_s:
|
||||
watchdog_usec = int(watchdog_usec_s)
|
||||
try:
|
||||
@ -52,18 +52,18 @@ class SdNotify:
|
||||
|
||||
async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
||||
assert self.is_connected(), "Watchdog can't run without socket"
|
||||
await self.send('WATCHDOG=1')
|
||||
await self.send("WATCHDOG=1")
|
||||
task_status.started()
|
||||
# send every half of the watchdog timeout
|
||||
interval = (watchdog_usec / 1e6) / 2.0
|
||||
while True:
|
||||
await trio.sleep(interval)
|
||||
await self.send('WATCHDOG=1')
|
||||
await self.send("WATCHDOG=1")
|
||||
|
||||
async def send(self, *msg: str) -> None:
|
||||
if not self.is_connected():
|
||||
return
|
||||
dgram = '\n'.join(msg).encode('utf-8')
|
||||
dgram = "\n".join(msg).encode("utf-8")
|
||||
assert self._ns, "not connected" # checked above
|
||||
sent = await self._ns.send(dgram)
|
||||
if sent != len(dgram):
|
||||
|
@ -10,9 +10,9 @@ def get_local_timezone():
|
||||
global _zoneinfo
|
||||
if not _zoneinfo:
|
||||
try:
|
||||
with open('/etc/timezone') as f:
|
||||
with open("/etc/timezone") as f:
|
||||
key = f.readline().strip()
|
||||
_zoneinfo = zoneinfo.ZoneInfo(key)
|
||||
except (OSError, zoneinfo.ZoneInfoNotFoundError):
|
||||
_zoneinfo = zoneinfo.ZoneInfo('UTC')
|
||||
_zoneinfo = zoneinfo.ZoneInfo("UTC")
|
||||
return _zoneinfo
|
||||
|
Loading…
Reference in New Issue
Block a user