format with black (mostly quotes)
This commit is contained in:
parent
ced589f28a
commit
0f45f89211
@ -6,9 +6,14 @@ requires = [
|
|||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
python_version = "3.9"
|
python_version = "3.11"
|
||||||
# warn_return_any = true
|
# warn_return_any = true
|
||||||
warn_unused_configs = true
|
warn_unused_configs = true
|
||||||
exclude = [
|
exclude = [
|
||||||
'_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary)
|
'_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 120
|
||||||
|
target-version = ['py311']
|
||||||
|
exclude = '_pb2.py'
|
||||||
|
@ -4,8 +4,8 @@ from .app_cls import MyQuartApp
|
|||||||
|
|
||||||
app = MyQuartApp(__name__)
|
app = MyQuartApp(__name__)
|
||||||
|
|
||||||
__import__('capport.api.setup')
|
__import__("capport.api.setup")
|
||||||
__import__('capport.api.proxy_fix')
|
__import__("capport.api.proxy_fix")
|
||||||
__import__('capport.api.lang')
|
__import__("capport.api.lang")
|
||||||
__import__('capport.api.template_filters')
|
__import__("capport.api.template_filters")
|
||||||
__import__('capport.api.views')
|
__import__("capport.api.views")
|
||||||
|
@ -49,16 +49,16 @@ class MyQuartApp(quart_trio.QuartTrio):
|
|||||||
|
|
||||||
def __init__(self, import_name: str, **kwargs) -> None:
|
def __init__(self, import_name: str, **kwargs) -> None:
|
||||||
self.my_config = capport.config.Config.load_default_once()
|
self.my_config = capport.config.Config.load_default_once()
|
||||||
kwargs.setdefault('template_folder', os.path.join(os.path.dirname(__file__), 'templates'))
|
kwargs.setdefault("template_folder", os.path.join(os.path.dirname(__file__), "templates"))
|
||||||
cust_templ = os.path.join('custom', 'templates')
|
cust_templ = os.path.join("custom", "templates")
|
||||||
if os.path.exists(cust_templ):
|
if os.path.exists(cust_templ):
|
||||||
self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ))
|
self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ))
|
||||||
cust_static = os.path.abspath(os.path.join('custom', 'static'))
|
cust_static = os.path.abspath(os.path.join("custom", "static"))
|
||||||
if os.path.exists(cust_static):
|
if os.path.exists(cust_static):
|
||||||
static_folder = cust_static
|
static_folder = cust_static
|
||||||
else:
|
else:
|
||||||
static_folder = os.path.join(os.path.dirname(__file__), 'static')
|
static_folder = os.path.join(os.path.dirname(__file__), "static")
|
||||||
kwargs.setdefault('static_folder', static_folder)
|
kwargs.setdefault("static_folder", static_folder)
|
||||||
super().__init__(import_name, **kwargs)
|
super().__init__(import_name, **kwargs)
|
||||||
self.debug = self.my_config.debug
|
self.debug = self.my_config.debug
|
||||||
self.secret_key = self.my_config.cookie_secret
|
self.secret_key = self.my_config.cookie_secret
|
||||||
|
@ -12,7 +12,7 @@ import capport.config
|
|||||||
|
|
||||||
def run(config: hypercorn.config.Config) -> None:
|
def run(config: hypercorn.config.Config) -> None:
|
||||||
sockets = config.create_sockets()
|
sockets = config.create_sockets()
|
||||||
assert config.worker_class == 'trio'
|
assert config.worker_class == "trio"
|
||||||
|
|
||||||
hypercorn.trio.run.trio_worker(config=config, sockets=sockets)
|
hypercorn.trio.run.trio_worker(config=config, sockets=sockets)
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ class CliArguments:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--config', '-c')
|
parser.add_argument("--config", "-c")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
self.config = args.config
|
self.config = args.config
|
||||||
@ -39,8 +39,8 @@ def main() -> None:
|
|||||||
_config = capport.config.Config.load_default_once(filename=args.config)
|
_config = capport.config.Config.load_default_once(filename=args.config)
|
||||||
|
|
||||||
hypercorn_config = hypercorn.config.Config()
|
hypercorn_config = hypercorn.config.Config()
|
||||||
hypercorn_config.application_path = 'capport.api.app'
|
hypercorn_config.application_path = "capport.api.app"
|
||||||
hypercorn_config.worker_class = 'trio'
|
hypercorn_config.worker_class = "trio"
|
||||||
hypercorn_config.bind = [f"127.0.0.1:{_config.api_port}"]
|
hypercorn_config.bind = [f"127.0.0.1:{_config.api_port}"]
|
||||||
|
|
||||||
if _config.server_names:
|
if _config.server_names:
|
||||||
|
@ -7,22 +7,23 @@ import quart
|
|||||||
|
|
||||||
from .app import app
|
from .app import app
|
||||||
|
|
||||||
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$')
|
_VALID_LANGUAGE_NAMES = re.compile(r"^[-a-z0-9_]+$")
|
||||||
|
|
||||||
|
|
||||||
def parse_accept_language(value: str) -> list[str]:
|
def parse_accept_language(value: str) -> list[str]:
|
||||||
value = value.strip()
|
value = value.strip()
|
||||||
if not value or value == '*':
|
if not value or value == "*":
|
||||||
return []
|
return []
|
||||||
tuples = []
|
tuples = []
|
||||||
for entry in value.split(','):
|
for entry in value.split(","):
|
||||||
attrs = entry.split(';')
|
attrs = entry.split(";")
|
||||||
name = attrs.pop(0).strip().lower()
|
name = attrs.pop(0).strip().lower()
|
||||||
q = 1.0
|
q = 1.0
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
if not '=' in attr: continue
|
if not "=" in attr:
|
||||||
key, value = attr.split('=', maxsplit=1)
|
continue
|
||||||
if key.strip().lower() == 'q':
|
key, value = attr.split("=", maxsplit=1)
|
||||||
|
if key.strip().lower() == "q":
|
||||||
try:
|
try:
|
||||||
q = float(value.strip())
|
q = float(value.strip())
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -32,14 +33,17 @@ def parse_accept_language(value: str) -> list[str]:
|
|||||||
tuples.sort()
|
tuples.sort()
|
||||||
have = set()
|
have = set()
|
||||||
result = []
|
result = []
|
||||||
for (_q, name) in tuples:
|
for _q, name in tuples:
|
||||||
if name in have: continue
|
if name in have:
|
||||||
if name == '*': break
|
continue
|
||||||
|
if name == "*":
|
||||||
|
break
|
||||||
have.add(name)
|
have.add(name)
|
||||||
if _VALID_LANGUAGE_NAMES.match(name):
|
if _VALID_LANGUAGE_NAMES.match(name):
|
||||||
result.append(name)
|
result.append(name)
|
||||||
short_name = name.split('-', maxsplit=1)[0].split('_', maxsplit=1)[0]
|
short_name = name.split("-", maxsplit=1)[0].split("_", maxsplit=1)[0]
|
||||||
if not short_name or short_name in have: continue
|
if not short_name or short_name in have:
|
||||||
|
continue
|
||||||
have.add(short_name)
|
have.add(short_name)
|
||||||
result.append(short_name)
|
result.append(short_name)
|
||||||
return result
|
return result
|
||||||
@ -50,23 +54,23 @@ def detect_language():
|
|||||||
g = quart.g
|
g = quart.g
|
||||||
r = quart.request
|
r = quart.request
|
||||||
s = quart.session
|
s = quart.session
|
||||||
if 'setlang' in r.args:
|
if "setlang" in r.args:
|
||||||
lang = r.args.get('setlang').strip().lower()
|
lang = r.args.get("setlang").strip().lower()
|
||||||
if lang and _VALID_LANGUAGE_NAMES.match(lang):
|
if lang and _VALID_LANGUAGE_NAMES.match(lang):
|
||||||
if s.get('lang') != lang:
|
if s.get("lang") != lang:
|
||||||
s['lang'] = lang
|
s["lang"] = lang
|
||||||
g.langs = [lang]
|
g.langs = [lang]
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# reset language
|
# reset language
|
||||||
s.pop('lang', None)
|
s.pop("lang", None)
|
||||||
lang = s.get('lang')
|
lang = s.get("lang")
|
||||||
if lang:
|
if lang:
|
||||||
lang = lang.strip().lower()
|
lang = lang.strip().lower()
|
||||||
if lang and _VALID_LANGUAGE_NAMES.match(lang):
|
if lang and _VALID_LANGUAGE_NAMES.match(lang):
|
||||||
g.langs = [lang]
|
g.langs = [lang]
|
||||||
return
|
return
|
||||||
acc_lang = ','.join(r.headers.get_all('Accept-Language'))
|
acc_lang = ",".join(r.headers.get_all("Accept-Language"))
|
||||||
g.langs = parse_accept_language(acc_lang)
|
g.langs = parse_accept_language(acc_lang)
|
||||||
|
|
||||||
|
|
||||||
@ -74,9 +78,6 @@ async def render_i18n_template(template, /, **kwargs) -> str:
|
|||||||
langs: list[str] = quart.g.langs
|
langs: list[str] = quart.g.langs
|
||||||
if not langs:
|
if not langs:
|
||||||
return await quart.render_template(template, **kwargs)
|
return await quart.render_template(template, **kwargs)
|
||||||
names = [
|
names = [os.path.join("i18n", lang, template) for lang in langs]
|
||||||
os.path.join('i18n', lang, template)
|
|
||||||
for lang in langs
|
|
||||||
]
|
|
||||||
names.append(template)
|
names.append(template)
|
||||||
return await quart.render_template(names, **kwargs)
|
return await quart.render_template(names, **kwargs)
|
||||||
|
@ -31,28 +31,28 @@ def local_proxy_fix(request: quart.Request):
|
|||||||
return
|
return
|
||||||
if not addr.is_loopback:
|
if not addr.is_loopback:
|
||||||
return
|
return
|
||||||
client = _get_first_in_list(request.headers.get('X-Forwarded-For'))
|
client = _get_first_in_list(request.headers.get("X-Forwarded-For"))
|
||||||
if not client:
|
if not client:
|
||||||
# assume this is always set behind reverse proxies supporting any of the headers
|
# assume this is always set behind reverse proxies supporting any of the headers
|
||||||
return
|
return
|
||||||
request.remote_addr = client
|
request.remote_addr = client
|
||||||
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https'))
|
scheme = _get_first_in_list(request.headers.get("X-Forwarded-Proto"), ("http", "https"))
|
||||||
port: int | None = None
|
port: int | None = None
|
||||||
if scheme:
|
if scheme:
|
||||||
port = 443 if scheme == 'https' else 80
|
port = 443 if scheme == "https" else 80
|
||||||
request.scheme = scheme
|
request.scheme = scheme
|
||||||
host = _get_first_in_list(request.headers.get('X-Forwarded-Host'))
|
host = _get_first_in_list(request.headers.get("X-Forwarded-Host"))
|
||||||
port_s: str | None
|
port_s: str | None
|
||||||
if host:
|
if host:
|
||||||
request.host = host
|
request.host = host
|
||||||
if ':' in host and not host.endswith(']'):
|
if ":" in host and not host.endswith("]"):
|
||||||
try:
|
try:
|
||||||
_, port_s = host.rsplit(':', maxsplit=1)
|
_, port_s = host.rsplit(":", maxsplit=1)
|
||||||
port = int(port_s)
|
port = int(port_s)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# ignore invalid port in host header
|
# ignore invalid port in host header
|
||||||
pass
|
pass
|
||||||
port_s = _get_first_in_list(request.headers.get('X-Forwarded-Port'))
|
port_s = _get_first_in_list(request.headers.get("X-Forwarded-Port"))
|
||||||
if port_s:
|
if port_s:
|
||||||
try:
|
try:
|
||||||
port = int(port_s)
|
port = int(port_s)
|
||||||
@ -62,7 +62,7 @@ def local_proxy_fix(request: quart.Request):
|
|||||||
if port:
|
if port:
|
||||||
if request.server and len(request.server) == 2:
|
if request.server and len(request.server) == 2:
|
||||||
request.server = (request.server[0], port)
|
request.server = (request.server[0], port)
|
||||||
root_path = _get_first_in_list(request.headers.get('X-Forwarded-Prefix'))
|
root_path = _get_first_in_list(request.headers.get("X-Forwarded-Prefix"))
|
||||||
if root_path:
|
if root_path:
|
||||||
request.root_path = root_path
|
request.root_path = root_path
|
||||||
|
|
||||||
|
@ -47,10 +47,10 @@ async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
|||||||
|
|
||||||
async def _setup(*, task_status=trio.TASK_STATUS_IGNORED):
|
async def _setup(*, task_status=trio.TASK_STATUS_IGNORED):
|
||||||
async with open_sdnotify() as sn:
|
async with open_sdnotify() as sn:
|
||||||
await sn.send('STATUS=Starting hub')
|
await sn.send("STATUS=Starting hub")
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
await nursery.start(_run_hub)
|
await nursery.start(_run_hub)
|
||||||
await sn.send('READY=1', 'STATUS=Ready for client requests')
|
await sn.send("READY=1", "STATUS=Ready for client requests")
|
||||||
task_status.started()
|
task_status.started()
|
||||||
# continue running hub and systemd watchdog handler
|
# continue running hub and systemd watchdog handler
|
||||||
|
|
||||||
|
@ -23,12 +23,12 @@ _logger = logging.getLogger(__name__)
|
|||||||
def get_client_ip() -> cptypes.IPAddress:
|
def get_client_ip() -> cptypes.IPAddress:
|
||||||
remote_addr = quart.request.remote_addr
|
remote_addr = quart.request.remote_addr
|
||||||
if not remote_addr:
|
if not remote_addr:
|
||||||
quart.abort(500, 'Missing client address')
|
quart.abort(500, "Missing client address")
|
||||||
try:
|
try:
|
||||||
addr = ipaddress.ip_address(remote_addr)
|
addr = ipaddress.ip_address(remote_addr)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
_logger.warning(f'Invalid client address {remote_addr!r}: {e}')
|
_logger.warning(f"Invalid client address {remote_addr!r}: {e}")
|
||||||
quart.abort(500, 'Invalid client address')
|
quart.abort(500, "Invalid client address")
|
||||||
return addr
|
return addr
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ async def get_client_mac(address: cptypes.IPAddress | None = None) -> cptypes.Ma
|
|||||||
mac = await get_client_mac_if_present(address)
|
mac = await get_client_mac_if_present(address)
|
||||||
if mac is None:
|
if mac is None:
|
||||||
_logger.warning(f"Couldn't find MAC addresss for {address}")
|
_logger.warning(f"Couldn't find MAC addresss for {address}")
|
||||||
quart.abort(404, 'Unknown client')
|
quart.abort(404, "Unknown client")
|
||||||
return mac
|
return mac
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> Non
|
|||||||
quart.abort(500, str(e))
|
quart.abort(500, str(e))
|
||||||
|
|
||||||
if pu:
|
if pu:
|
||||||
_logger.debug(f'User {mac} (with IP {address}) logged in')
|
_logger.debug(f"User {mac} (with IP {address}) logged in")
|
||||||
for msg in pu.serialized:
|
for msg in pu.serialized:
|
||||||
await app.my_hub.broadcast(msg)
|
await app.my_hub.broadcast(msg)
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ async def user_logout(mac: cptypes.MacAddress) -> None:
|
|||||||
except capport.database.NotReadyYet as e:
|
except capport.database.NotReadyYet as e:
|
||||||
quart.abort(500, str(e))
|
quart.abort(500, str(e))
|
||||||
if pu:
|
if pu:
|
||||||
_logger.debug(f'User {mac} logged out')
|
_logger.debug(f"User {mac} logged out")
|
||||||
for msg in pu.serialized:
|
for msg in pu.serialized:
|
||||||
await app.my_hub.broadcast(msg)
|
await app.my_hub.broadcast(msg)
|
||||||
|
|
||||||
@ -92,73 +92,73 @@ async def user_lookup() -> cptypes.MacPublicState:
|
|||||||
|
|
||||||
|
|
||||||
def check_self_origin():
|
def check_self_origin():
|
||||||
origin = quart.request.headers.get('Origin', None)
|
origin = quart.request.headers.get("Origin", None)
|
||||||
if origin is None:
|
if origin is None:
|
||||||
# not a request by a modern browser - probably curl or something similar. don't care.
|
# not a request by a modern browser - probably curl or something similar. don't care.
|
||||||
return
|
return
|
||||||
origin = origin.lower().strip()
|
origin = origin.lower().strip()
|
||||||
if origin == 'none':
|
if origin == "none":
|
||||||
quart.abort(403, 'Origin is none')
|
quart.abort(403, "Origin is none")
|
||||||
origin_parts = origin.split('/')
|
origin_parts = origin.split("/")
|
||||||
# Origin should look like: <scheme>://<hostname> (optionally followed by :<port>)
|
# Origin should look like: <scheme>://<hostname> (optionally followed by :<port>)
|
||||||
if len(origin_parts) < 3:
|
if len(origin_parts) < 3:
|
||||||
quart.abort(400, 'Broken Origin header')
|
quart.abort(400, "Broken Origin header")
|
||||||
if origin_parts[0] != 'https:' and not app.my_config.debug:
|
if origin_parts[0] != "https:" and not app.my_config.debug:
|
||||||
# -> require https in production
|
# -> require https in production
|
||||||
quart.abort(403, 'Non-https Origin not allowed')
|
quart.abort(403, "Non-https Origin not allowed")
|
||||||
origin_host = origin_parts[2]
|
origin_host = origin_parts[2]
|
||||||
host = quart.request.headers.get('Host', None)
|
host = quart.request.headers.get("Host", None)
|
||||||
if host is None:
|
if host is None:
|
||||||
quart.abort(403, 'Missing Host header')
|
quart.abort(403, "Missing Host header")
|
||||||
if host.lower() != origin_host:
|
if host.lower() != origin_host:
|
||||||
quart.abort(403, 'Origin mismatch')
|
quart.abort(403, "Origin mismatch")
|
||||||
|
|
||||||
|
|
||||||
@app.route('/', methods=['GET'])
|
@app.route("/", methods=["GET"])
|
||||||
async def index(missing_accept: bool = False):
|
async def index(missing_accept: bool = False):
|
||||||
state = await user_lookup()
|
state = await user_lookup()
|
||||||
if not state.mac:
|
if not state.mac:
|
||||||
return await render_i18n_template('index_unknown.html', state=state, missing_accept=missing_accept)
|
return await render_i18n_template("index_unknown.html", state=state, missing_accept=missing_accept)
|
||||||
elif state.allowed:
|
elif state.allowed:
|
||||||
return await render_i18n_template('index_active.html', state=state, missing_accept=missing_accept)
|
return await render_i18n_template("index_active.html", state=state, missing_accept=missing_accept)
|
||||||
else:
|
else:
|
||||||
return await render_i18n_template('index_inactive.html', state=state, missing_accept=missing_accept)
|
return await render_i18n_template("index_inactive.html", state=state, missing_accept=missing_accept)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/login', methods=['POST'])
|
@app.route("/login", methods=["POST"])
|
||||||
async def login():
|
async def login():
|
||||||
check_self_origin()
|
check_self_origin()
|
||||||
with trio.fail_after(5.0):
|
with trio.fail_after(5.0):
|
||||||
form = await quart.request.form
|
form = await quart.request.form
|
||||||
if form.get('accept') != '1':
|
if form.get("accept") != "1":
|
||||||
return await index(missing_accept=True)
|
return await index(missing_accept=True)
|
||||||
req_mac = form.get('mac')
|
req_mac = form.get("mac")
|
||||||
if not req_mac:
|
if not req_mac:
|
||||||
quart.abort(400, description='Missing MAC in request form data')
|
quart.abort(400, description="Missing MAC in request form data")
|
||||||
address = get_client_ip()
|
address = get_client_ip()
|
||||||
mac = await get_client_mac(address)
|
mac = await get_client_mac(address)
|
||||||
if str(mac) != req_mac:
|
if str(mac) != req_mac:
|
||||||
quart.abort(403, description="Passed MAC in request form doesn't match client address")
|
quart.abort(403, description="Passed MAC in request form doesn't match client address")
|
||||||
await user_login(address, mac)
|
await user_login(address, mac)
|
||||||
return quart.redirect('/', code=303)
|
return quart.redirect("/", code=303)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/logout', methods=['POST'])
|
@app.route("/logout", methods=["POST"])
|
||||||
async def logout():
|
async def logout():
|
||||||
check_self_origin()
|
check_self_origin()
|
||||||
with trio.fail_after(5.0):
|
with trio.fail_after(5.0):
|
||||||
form = await quart.request.form
|
form = await quart.request.form
|
||||||
req_mac = form.get('mac')
|
req_mac = form.get("mac")
|
||||||
if not req_mac:
|
if not req_mac:
|
||||||
quart.abort(400, description='Missing MAC in request form data')
|
quart.abort(400, description="Missing MAC in request form data")
|
||||||
mac = await get_client_mac()
|
mac = await get_client_mac()
|
||||||
if str(mac) != req_mac:
|
if str(mac) != req_mac:
|
||||||
quart.abort(403, description="Passed MAC in request form doesn't match client address")
|
quart.abort(403, description="Passed MAC in request form doesn't match client address")
|
||||||
await user_logout(mac)
|
await user_logout(mac)
|
||||||
return quart.redirect('/', code=303)
|
return quart.redirect("/", code=303)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/captive-portal', methods=['GET'])
|
@app.route("/api/captive-portal", methods=["GET"])
|
||||||
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
|
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
|
||||||
async def captive_api():
|
async def captive_api():
|
||||||
state = await user_lookup()
|
state = await user_lookup()
|
||||||
|
@ -43,7 +43,7 @@ class Channel:
|
|||||||
_logger.debug(f"{self}: created (server_side={server_side})")
|
_logger.debug(f"{self}: created (server_side={server_side})")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'Channel[0x{id(self):x}]'
|
return f"Channel[0x{id(self):x}]"
|
||||||
|
|
||||||
async def do_handshake(self) -> capport.comm.message.Hello:
|
async def do_handshake(self) -> capport.comm.message.Hello:
|
||||||
try:
|
try:
|
||||||
@ -59,9 +59,8 @@ class Channel:
|
|||||||
peer_hello = (await self.recv_msg()).to_variant()
|
peer_hello = (await self.recv_msg()).to_variant()
|
||||||
if not isinstance(peer_hello, capport.comm.message.Hello):
|
if not isinstance(peer_hello, capport.comm.message.Hello):
|
||||||
raise HubConnectionReadError("Expected Hello as first message")
|
raise HubConnectionReadError("Expected Hello as first message")
|
||||||
auth_succ = \
|
expected_auth = self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)
|
||||||
(peer_hello.authentication ==
|
auth_succ = peer_hello.authentication == expected_auth
|
||||||
self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
|
|
||||||
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
|
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
|
||||||
peer_auth = (await self.recv_msg()).to_variant()
|
peer_auth = (await self.recv_msg()).to_variant()
|
||||||
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
|
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
|
||||||
@ -72,7 +71,7 @@ class Channel:
|
|||||||
|
|
||||||
async def _read(self, num: int) -> bytes:
|
async def _read(self, num: int) -> bytes:
|
||||||
assert num > 0
|
assert num > 0
|
||||||
buf = b''
|
buf = b""
|
||||||
# _logger.debug(f"{self}:_read({num})")
|
# _logger.debug(f"{self}:_read({num})")
|
||||||
while num > 0:
|
while num > 0:
|
||||||
try:
|
try:
|
||||||
@ -92,7 +91,7 @@ class Channel:
|
|||||||
|
|
||||||
async def _recv_raw_msg(self) -> bytes:
|
async def _recv_raw_msg(self) -> bytes:
|
||||||
len_bytes = await self._read(4)
|
len_bytes = await self._read(4)
|
||||||
chunk_size, = struct.unpack('!I', len_bytes)
|
(chunk_size,) = struct.unpack("!I", len_bytes)
|
||||||
chunk = await self._read(chunk_size)
|
chunk = await self._read(chunk_size)
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
|
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
|
||||||
@ -116,7 +115,7 @@ class Channel:
|
|||||||
async def send_msg(self, msg: capport.comm.message.Message):
|
async def send_msg(self, msg: capport.comm.message.Message):
|
||||||
chunk = msg.SerializeToString(deterministic=True)
|
chunk = msg.SerializeToString(deterministic=True)
|
||||||
chunk_size = len(chunk)
|
chunk_size = len(chunk)
|
||||||
len_bytes = struct.pack('!I', chunk_size)
|
len_bytes = struct.pack("!I", chunk_size)
|
||||||
chunk = len_bytes + chunk
|
chunk = len_bytes + chunk
|
||||||
await self._send_raw(chunk)
|
await self._send_raw(chunk)
|
||||||
|
|
||||||
@ -158,7 +157,7 @@ class Connection:
|
|||||||
if msg:
|
if msg:
|
||||||
await self._channel.send_msg(msg)
|
await self._channel.send_msg(msg)
|
||||||
else:
|
else:
|
||||||
await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message())
|
await self._channel.send_msg(capport.comm.message.Ping(payload=b"ping").to_message())
|
||||||
except trio.TooSlowError:
|
except trio.TooSlowError:
|
||||||
_logger.warning(f"{self._channel}: send timed out")
|
_logger.warning(f"{self._channel}: send timed out")
|
||||||
except ConnectionError as e:
|
except ConnectionError as e:
|
||||||
@ -303,7 +302,7 @@ class Hub:
|
|||||||
if is_controller:
|
if is_controller:
|
||||||
state_filename = config.database_file
|
state_filename = config.database_file
|
||||||
else:
|
else:
|
||||||
state_filename = ''
|
state_filename = ""
|
||||||
self.database = capport.database.Database(state_filename=state_filename)
|
self.database = capport.database.Database(state_filename=state_filename)
|
||||||
self._anon_context = ssl.SSLContext()
|
self._anon_context = ssl.SSLContext()
|
||||||
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
|
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
|
||||||
@ -311,14 +310,14 @@ class Hub:
|
|||||||
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||||
# -> AECDH-AES256-SHA
|
# -> AECDH-AES256-SHA
|
||||||
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
|
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
|
||||||
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
|
self._anon_context.set_ciphers("HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0")
|
||||||
self._controllers: dict[str, ControllerConn] = {}
|
self._controllers: dict[str, ControllerConn] = {}
|
||||||
self._established: dict[uuid.UUID, Connection] = {}
|
self._established: dict[uuid.UUID, Connection] = {}
|
||||||
|
|
||||||
async def _accept(self, stream):
|
async def _accept(self, stream):
|
||||||
remotename = stream.socket.getpeername()
|
remotename = stream.socket.getpeername()
|
||||||
if isinstance(remotename, tuple) and len(remotename) == 2:
|
if isinstance(remotename, tuple) and len(remotename) == 2:
|
||||||
remote = f'[{remotename[0]}]:{remotename[1]}'
|
remote = f"[{remotename[0]}]:{remotename[1]}"
|
||||||
else:
|
else:
|
||||||
remote = str(remotename)
|
remote = str(remotename)
|
||||||
try:
|
try:
|
||||||
@ -351,11 +350,11 @@ class Hub:
|
|||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
|
|
||||||
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
|
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
|
||||||
m = hmac.new(self._config.comm_secret.encode('utf8'), digestmod=hashlib.sha256)
|
m = hmac.new(self._config.comm_secret.encode("utf8"), digestmod=hashlib.sha256)
|
||||||
if server_side:
|
if server_side:
|
||||||
m.update(b'server$')
|
m.update(b"server$")
|
||||||
else:
|
else:
|
||||||
m.update(b'client$')
|
m.update(b"client$")
|
||||||
m.update(ssl_binding)
|
m.update(ssl_binding)
|
||||||
return m.digest()
|
return m.digest()
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from .protobuf import message_pb2
|
|||||||
|
|
||||||
|
|
||||||
def _message_to_variant(self: message_pb2.Message) -> typing.Any:
|
def _message_to_variant(self: message_pb2.Message) -> typing.Any:
|
||||||
variant_name = self.WhichOneof('oneof')
|
variant_name = self.WhichOneof("oneof")
|
||||||
if variant_name:
|
if variant_name:
|
||||||
return getattr(self, variant_name)
|
return getattr(self, variant_name)
|
||||||
return None
|
return None
|
||||||
@ -16,14 +16,15 @@ def _make_to_message(oneof_field):
|
|||||||
def to_message(self) -> message_pb2.Message:
|
def to_message(self) -> message_pb2.Message:
|
||||||
msg = message_pb2.Message(**{oneof_field: self})
|
msg = message_pb2.Message(**{oneof_field: self})
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
return to_message
|
return to_message
|
||||||
|
|
||||||
|
|
||||||
def _monkey_patch():
|
def _monkey_patch():
|
||||||
g = globals()
|
g = globals()
|
||||||
g['Message'] = message_pb2.Message
|
g["Message"] = message_pb2.Message
|
||||||
message_pb2.Message.to_variant = _message_to_variant
|
message_pb2.Message.to_variant = _message_to_variant
|
||||||
for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields:
|
for field in message_pb2._MESSAGE.oneofs_by_name["oneof"].fields:
|
||||||
type_name = field.message_type.name
|
type_name = field.message_type.name
|
||||||
field_type = getattr(message_pb2, type_name)
|
field_type = getattr(message_pb2, type_name)
|
||||||
field_type.to_message = _make_to_message(field.name)
|
field_type.to_message = _make_to_message(field.name)
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import google.protobuf.message
|
import google.protobuf.message
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
|
||||||
# manually maintained typehints for protobuf created (and monkey-patched) types
|
# manually maintained typehints for protobuf created (and monkey-patched) types
|
||||||
|
|
||||||
|
|
||||||
class Message(google.protobuf.message.Message):
|
class Message(google.protobuf.message.Message):
|
||||||
hello: Hello
|
hello: Hello
|
||||||
authentication_result: AuthenticationResult
|
authentication_result: AuthenticationResult
|
||||||
@ -19,10 +17,8 @@ class Message(google.protobuf.message.Message):
|
|||||||
ping: Ping | None = None,
|
ping: Ping | None = None,
|
||||||
mac_states: MacStates | None = None,
|
mac_states: MacStates | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ...
|
def to_variant(self) -> Hello | AuthenticationResult | Ping | MacStates: ...
|
||||||
|
|
||||||
|
|
||||||
class Hello(google.protobuf.message.Message):
|
class Hello(google.protobuf.message.Message):
|
||||||
instance_id: bytes
|
instance_id: bytes
|
||||||
hostname: str
|
hostname: str
|
||||||
@ -32,15 +28,13 @@ class Hello(google.protobuf.message.Message):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
instance_id: bytes=b'',
|
instance_id: bytes = b"",
|
||||||
hostname: str='',
|
hostname: str = "",
|
||||||
is_controller: bool = False,
|
is_controller: bool = False,
|
||||||
authentication: bytes=b'',
|
authentication: bytes = b"",
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def to_message(self) -> Message: ...
|
def to_message(self) -> Message: ...
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationResult(google.protobuf.message.Message):
|
class AuthenticationResult(google.protobuf.message.Message):
|
||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
@ -49,22 +43,18 @@ class AuthenticationResult(google.protobuf.message.Message):
|
|||||||
*,
|
*,
|
||||||
success: bool = False,
|
success: bool = False,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def to_message(self) -> Message: ...
|
def to_message(self) -> Message: ...
|
||||||
|
|
||||||
|
|
||||||
class Ping(google.protobuf.message.Message):
|
class Ping(google.protobuf.message.Message):
|
||||||
payload: bytes
|
payload: bytes
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
payload: bytes=b'',
|
payload: bytes = b"",
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def to_message(self) -> Message: ...
|
def to_message(self) -> Message: ...
|
||||||
|
|
||||||
|
|
||||||
class MacStates(google.protobuf.message.Message):
|
class MacStates(google.protobuf.message.Message):
|
||||||
states: list[MacState]
|
states: list[MacState]
|
||||||
|
|
||||||
@ -73,10 +63,8 @@ class MacStates(google.protobuf.message.Message):
|
|||||||
*,
|
*,
|
||||||
states: list[MacState] = [],
|
states: list[MacState] = [],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def to_message(self) -> Message: ...
|
def to_message(self) -> Message: ...
|
||||||
|
|
||||||
|
|
||||||
class MacState(google.protobuf.message.Message):
|
class MacState(google.protobuf.message.Message):
|
||||||
mac_address: bytes
|
mac_address: bytes
|
||||||
last_change: int # Seconds of UTC time since epoch
|
last_change: int # Seconds of UTC time since epoch
|
||||||
@ -86,7 +74,7 @@ class MacState(google.protobuf.message.Message):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
mac_address: bytes=b'',
|
mac_address: bytes = b"",
|
||||||
last_change: int = 0,
|
last_change: int = 0,
|
||||||
allow_until: int = 0,
|
allow_until: int = 0,
|
||||||
allowed: bool = False,
|
allowed: bool = False,
|
||||||
|
@ -39,7 +39,7 @@ class Config:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def load(filename: str | None = None) -> Config:
|
def load(filename: str | None = None) -> Config:
|
||||||
if filename is None:
|
if filename is None:
|
||||||
for name in ('capport.yaml', '/etc/capport.yaml'):
|
for name in ("capport.yaml", "/etc/capport.yaml"):
|
||||||
if os.path.exists(name):
|
if os.path.exists(name):
|
||||||
return Config.load(name)
|
return Config.load(name)
|
||||||
raise RuntimeError("Missing config file")
|
raise RuntimeError("Missing config file")
|
||||||
@ -47,19 +47,19 @@ class Config:
|
|||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise RuntimeError(f"Invalid yaml config data, expected dict: {data!r}")
|
raise RuntimeError(f"Invalid yaml config data, expected dict: {data!r}")
|
||||||
controllers = list(map(str, data.pop('controllers')))
|
controllers = list(map(str, data.pop("controllers")))
|
||||||
config = Config(
|
config = Config(
|
||||||
_source_filename=filename,
|
_source_filename=filename,
|
||||||
controllers=controllers,
|
controllers=controllers,
|
||||||
server_names=data.pop('server-names', []),
|
server_names=data.pop("server-names", []),
|
||||||
comm_secret=str(data.pop('comm-secret')),
|
comm_secret=str(data.pop("comm-secret")),
|
||||||
cookie_secret=str(data.pop('cookie-secret')),
|
cookie_secret=str(data.pop("cookie-secret")),
|
||||||
venue_info_url=str(data.pop('venue-info-url')),
|
venue_info_url=str(data.pop("venue-info-url")),
|
||||||
session_timeout=data.pop('session-timeout', 3600),
|
session_timeout=data.pop("session-timeout", 3600),
|
||||||
api_port=data.pop('api-port', 8000),
|
api_port=data.pop("api-port", 8000),
|
||||||
controller_port=data.pop('controller-port', 5000),
|
controller_port=data.pop("controller-port", 5000),
|
||||||
database_file=str(data.pop('database-file', 'capport.state')),
|
database_file=str(data.pop("database-file", "capport.state")),
|
||||||
debug=data.pop('debug', False)
|
debug=data.pop("debug", False),
|
||||||
)
|
)
|
||||||
if data:
|
if data:
|
||||||
_logger.error(f"Unknown config elements: {list(data.keys())}")
|
_logger.error(f"Unknown config elements: {list(data.keys())}")
|
||||||
|
@ -58,9 +58,9 @@ async def amain(config: capport.config.Config) -> None:
|
|||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
# hub.run loads the statefile from disk before signalling it was "started"
|
# hub.run loads the statefile from disk before signalling it was "started"
|
||||||
await nursery.start(hub.run)
|
await nursery.start(hub.run)
|
||||||
await sn.send('READY=1', 'STATUS=Deploying initial entries to nftables set')
|
await sn.send("READY=1", "STATUS=Deploying initial entries to nftables set")
|
||||||
app.apply_db_entries(hub.database.entries())
|
app.apply_db_entries(hub.database.entries())
|
||||||
await sn.send('STATUS=Kernel fully synchronized')
|
await sn.send("STATUS=Kernel fully synchronized")
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -69,7 +69,7 @@ class CliArguments:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--config', '-c')
|
parser.add_argument("--config", "-c")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
self.config = args.config
|
self.config = args.config
|
||||||
|
@ -23,14 +23,14 @@ class MacAddress:
|
|||||||
raw: bytes
|
raw: bytes
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.raw.hex(':')
|
return self.raw.hex(":")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return repr(str(self))
|
return repr(str(self))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(s: str) -> MacAddress:
|
def parse(s: str) -> MacAddress:
|
||||||
return MacAddress(bytes.fromhex(s.replace(':', '')))
|
return MacAddress(bytes.fromhex(s.replace(":", "")))
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True, order=True)
|
@dataclasses.dataclass(frozen=True, order=True)
|
||||||
@ -40,9 +40,9 @@ class Timestamp:
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
try:
|
try:
|
||||||
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
|
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
|
||||||
return ts.isoformat(sep=' ')
|
return ts.isoformat(sep=" ")
|
||||||
except OSError:
|
except OSError:
|
||||||
return f'epoch@{self.epoch}'
|
return f"epoch@{self.epoch}"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return repr(str(self))
|
return repr(str(self))
|
||||||
@ -85,7 +85,7 @@ class MacPublicState:
|
|||||||
def allowed_remaining_duration(self) -> str:
|
def allowed_remaining_duration(self) -> str:
|
||||||
mm, ss = divmod(self.allowed_remaining, 60)
|
mm, ss = divmod(self.allowed_remaining, 60)
|
||||||
hh, mm = divmod(mm, 60)
|
hh, mm = divmod(mm, 60)
|
||||||
return f'{hh}:{mm:02}:{ss:02}'
|
return f"{hh}:{mm:02}:{ss:02}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allowed_until(self) -> datetime.datetime | None:
|
def allowed_until(self) -> datetime.datetime | None:
|
||||||
@ -95,18 +95,18 @@ class MacPublicState:
|
|||||||
|
|
||||||
def to_rfc8908(self, config: Config) -> quart.Response:
|
def to_rfc8908(self, config: Config) -> quart.Response:
|
||||||
response: dict[str, typing.Any] = {
|
response: dict[str, typing.Any] = {
|
||||||
'user-portal-url': quart.url_for('index', _external=True),
|
"user-portal-url": quart.url_for("index", _external=True),
|
||||||
}
|
}
|
||||||
if config.venue_info_url:
|
if config.venue_info_url:
|
||||||
response['venue-info-url'] = config.venue_info_url
|
response["venue-info-url"] = config.venue_info_url
|
||||||
if self.captive:
|
if self.captive:
|
||||||
response['captive'] = True
|
response["captive"] = True
|
||||||
else:
|
else:
|
||||||
response['captive'] = False
|
response["captive"] = False
|
||||||
response['seconds-remaining'] = self.allowed_remaining
|
response["seconds-remaining"] = self.allowed_remaining
|
||||||
response['can-extend-session'] = True
|
response["can-extend-session"] = True
|
||||||
return quart.Response(
|
return quart.Response(
|
||||||
json.dumps(response),
|
json.dumps(response),
|
||||||
headers={'Cache-Control': 'private'},
|
headers={"Cache-Control": "private"},
|
||||||
content_type='application/captive+json',
|
content_type="application/captive+json",
|
||||||
)
|
)
|
||||||
|
@ -128,7 +128,7 @@ def _serialize_mac_states_as_messages(
|
|||||||
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
|
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
|
||||||
chunk = states.SerializeToString(deterministic=True)
|
chunk = states.SerializeToString(deterministic=True)
|
||||||
chunk_size = len(chunk)
|
chunk_size = len(chunk)
|
||||||
len_bytes = struct.pack('!I', chunk_size)
|
len_bytes = struct.pack("!I", chunk_size)
|
||||||
return len_bytes + chunk
|
return len_bytes + chunk
|
||||||
|
|
||||||
|
|
||||||
@ -204,39 +204,32 @@ class Database:
|
|||||||
return _serialize_mac_states_as_messages(self._macs)
|
return _serialize_mac_states_as_messages(self._macs)
|
||||||
|
|
||||||
def as_json(self) -> dict:
|
def as_json(self) -> dict:
|
||||||
return {
|
return {str(addr): entry.as_json() for addr, entry in self._macs.items()}
|
||||||
str(addr): entry.as_json()
|
|
||||||
for addr, entry in self._macs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _run_statefile(self) -> None:
|
async def _run_statefile(self) -> None:
|
||||||
rx: trio.MemoryReceiveChannel[
|
rx: trio.MemoryReceiveChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]]
|
||||||
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
|
tx: trio.MemorySendChannel[capport.comm.message.MacStates | list[capport.comm.message.MacStates]]
|
||||||
]
|
|
||||||
tx: trio.MemorySendChannel[
|
|
||||||
capport.comm.message.MacStates | list[capport.comm.message.MacStates],
|
|
||||||
]
|
|
||||||
tx, rx = trio.open_memory_channel(64)
|
tx, rx = trio.open_memory_channel(64)
|
||||||
self._send_changes = tx
|
self._send_changes = tx
|
||||||
|
|
||||||
assert self._state_filename
|
assert self._state_filename
|
||||||
filename: str = self._state_filename
|
filename: str = self._state_filename
|
||||||
tmp_filename = f'{filename}.new-{os.getpid()}'
|
tmp_filename = f"{filename}.new-{os.getpid()}"
|
||||||
|
|
||||||
async def resync(all_states: list[capport.comm.message.MacStates]):
|
async def resync(all_states: list[capport.comm.message.MacStates]):
|
||||||
try:
|
try:
|
||||||
async with await trio.open_file(tmp_filename, 'xb') as tf:
|
async with await trio.open_file(tmp_filename, "xb") as tf:
|
||||||
for states in all_states:
|
for states in all_states:
|
||||||
await tf.write(_states_to_chunk(states))
|
await tf.write(_states_to_chunk(states))
|
||||||
os.rename(tmp_filename, filename)
|
os.rename(tmp_filename, filename)
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(tmp_filename):
|
if os.path.exists(tmp_filename):
|
||||||
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
|
_logger.warning(f"Removing (failed) state export file {tmp_filename}")
|
||||||
os.unlink(tmp_filename)
|
os.unlink(tmp_filename)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
async with await trio.open_file(filename, 'ab', buffering=0) as sf:
|
async with await trio.open_file(filename, "ab", buffering=0) as sf:
|
||||||
while True:
|
while True:
|
||||||
update = await rx.receive()
|
update = await rx.receive()
|
||||||
if isinstance(update, list):
|
if isinstance(update, list):
|
||||||
@ -247,10 +240,10 @@ class Database:
|
|||||||
await resync(update)
|
await resync(update)
|
||||||
# now reopen normal statefile and continue appending updates
|
# now reopen normal statefile and continue appending updates
|
||||||
except trio.Cancelled:
|
except trio.Cancelled:
|
||||||
_logger.info('Final sync to disk')
|
_logger.info("Final sync to disk")
|
||||||
with trio.CancelScope(shield=True):
|
with trio.CancelScope(shield=True):
|
||||||
await resync(_serialize_mac_states(self._macs))
|
await resync(_serialize_mac_states(self._macs))
|
||||||
_logger.info('Final sync to disk done')
|
_logger.info("Final sync to disk done")
|
||||||
|
|
||||||
async def _load_statefile(self):
|
async def _load_statefile(self):
|
||||||
if not os.path.exists(self._state_filename):
|
if not os.path.exists(self._state_filename):
|
||||||
@ -259,7 +252,7 @@ class Database:
|
|||||||
# we're going to ignore changes from loading the file
|
# we're going to ignore changes from loading the file
|
||||||
pu = PendingUpdates(self)
|
pu = PendingUpdates(self)
|
||||||
pu._closed = False
|
pu._closed = False
|
||||||
async with await trio.open_file(self._state_filename, 'rb') as sf:
|
async with await trio.open_file(self._state_filename, "rb") as sf:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
len_bytes = await sf.read(4)
|
len_bytes = await sf.read(4)
|
||||||
@ -268,7 +261,7 @@ class Database:
|
|||||||
if len(len_bytes) < 4:
|
if len(len_bytes) < 4:
|
||||||
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
|
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
|
||||||
return
|
return
|
||||||
chunk_size, = struct.unpack('!I', len_bytes)
|
(chunk_size,) = struct.unpack("!I", len_bytes)
|
||||||
chunk = await sf.read(chunk_size)
|
chunk = await sf.read(chunk_size)
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
_logger.error(f"Failed to read next chunk from statefile: {e}")
|
_logger.error(f"Failed to read next chunk from statefile: {e}")
|
||||||
@ -286,7 +279,7 @@ class Database:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors += 1
|
errors += 1
|
||||||
if errors < 5:
|
if errors < 5:
|
||||||
_logger.error(f'Failed to handle state: {e}')
|
_logger.error(f"Failed to handle state: {e}")
|
||||||
|
|
||||||
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
|
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
|
||||||
entry = self._macs.get(mac)
|
entry = self._macs.get(mac)
|
||||||
|
@ -14,20 +14,17 @@ from . import cptypes
|
|||||||
def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None):
|
def print_metric(name: str, mtype: str, value, *, now: int | None = None, help: str | None = None):
|
||||||
# no labels in our names for now, always print help and type
|
# no labels in our names for now, always print help and type
|
||||||
if help:
|
if help:
|
||||||
print(f'# HELP {name} {help}')
|
print(f"# HELP {name} {help}")
|
||||||
print(f'# TYPE {name} {mtype}')
|
print(f"# TYPE {name} {mtype}")
|
||||||
if now:
|
if now:
|
||||||
print(f'{name} {value} {now}')
|
print(f"{name} {value} {now}")
|
||||||
else:
|
else:
|
||||||
print(f'{name} {value}')
|
print(f"{name} {value}")
|
||||||
|
|
||||||
|
|
||||||
async def amain(client_ifname: str):
|
async def amain(client_ifname: str):
|
||||||
ns = capport.utils.nft_set.NftSet()
|
ns = capport.utils.nft_set.NftSet()
|
||||||
captive_allowed_entries: set[cptypes.MacAddress] = {
|
captive_allowed_entries: set[cptypes.MacAddress] = {entry["mac"] for entry in ns.list()}
|
||||||
entry['mac']
|
|
||||||
for entry in ns.list()
|
|
||||||
}
|
|
||||||
seen_allowed_entries: set[cptypes.MacAddress] = set()
|
seen_allowed_entries: set[cptypes.MacAddress] = set()
|
||||||
total_ipv4 = 0
|
total_ipv4 = 0
|
||||||
total_ipv6 = 0
|
total_ipv6 = 0
|
||||||
@ -47,46 +44,46 @@ async def amain(client_ifname: str):
|
|||||||
total_ipv6 += 1
|
total_ipv6 += 1
|
||||||
unique_ipv6.add(mac)
|
unique_ipv6.add(mac)
|
||||||
print_metric(
|
print_metric(
|
||||||
'capport_allowed_macs',
|
"capport_allowed_macs",
|
||||||
'gauge',
|
"gauge",
|
||||||
len(captive_allowed_entries),
|
len(captive_allowed_entries),
|
||||||
help='Number of allowed client mac addresses',
|
help="Number of allowed client mac addresses",
|
||||||
)
|
)
|
||||||
print_metric(
|
print_metric(
|
||||||
'capport_allowed_neigh_macs',
|
"capport_allowed_neigh_macs",
|
||||||
'gauge',
|
"gauge",
|
||||||
len(seen_allowed_entries),
|
len(seen_allowed_entries),
|
||||||
help='Number of allowed client mac addresses seen in neighbor cache',
|
help="Number of allowed client mac addresses seen in neighbor cache",
|
||||||
)
|
)
|
||||||
print_metric(
|
print_metric(
|
||||||
'capport_unique',
|
"capport_unique",
|
||||||
'gauge',
|
"gauge",
|
||||||
len(unique_clients),
|
len(unique_clients),
|
||||||
help='Number of clients (mac addresses) in client network seen in neighbor cache',
|
help="Number of clients (mac addresses) in client network seen in neighbor cache",
|
||||||
)
|
)
|
||||||
print_metric(
|
print_metric(
|
||||||
'capport_unique_ipv4',
|
"capport_unique_ipv4",
|
||||||
'gauge',
|
"gauge",
|
||||||
len(unique_ipv4),
|
len(unique_ipv4),
|
||||||
help='Number of IPv4 clients (unique per mac) in client network seen in neighbor cache',
|
help="Number of IPv4 clients (unique per mac) in client network seen in neighbor cache",
|
||||||
)
|
)
|
||||||
print_metric(
|
print_metric(
|
||||||
'capport_unique_ipv6',
|
"capport_unique_ipv6",
|
||||||
'gauge',
|
"gauge",
|
||||||
len(unique_ipv6),
|
len(unique_ipv6),
|
||||||
help='Number of IPv6 clients (unique per mac) in client network seen in neighbor cache',
|
help="Number of IPv6 clients (unique per mac) in client network seen in neighbor cache",
|
||||||
)
|
)
|
||||||
print_metric(
|
print_metric(
|
||||||
'capport_total_ipv4',
|
"capport_total_ipv4",
|
||||||
'gauge',
|
"gauge",
|
||||||
total_ipv4,
|
total_ipv4,
|
||||||
help='Number of IPv4 addresses seen in neighbor cache',
|
help="Number of IPv4 addresses seen in neighbor cache",
|
||||||
)
|
)
|
||||||
print_metric(
|
print_metric(
|
||||||
'capport_total_ipv6',
|
"capport_total_ipv6",
|
||||||
'gauge',
|
"gauge",
|
||||||
total_ipv6,
|
total_ipv6,
|
||||||
help='Number of IPv6 addresses seen in neighbor cache',
|
help="Number of IPv6 addresses seen in neighbor cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,8 +10,8 @@ def init_logger(config: capport.config.Config):
|
|||||||
if config.debug:
|
if config.debug:
|
||||||
loglevel = logging.DEBUG
|
loglevel = logging.DEBUG
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
|
format="%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s",
|
||||||
datefmt='[%Y-%m-%d %H:%M:%S %z]',
|
datefmt="[%Y-%m-%d %H:%M:%S %z]",
|
||||||
level=loglevel,
|
level=loglevel,
|
||||||
)
|
)
|
||||||
logging.getLogger('hypercorn').propagate = False
|
logging.getLogger("hypercorn").propagate = False
|
||||||
|
@ -35,9 +35,9 @@ class NeighborController:
|
|||||||
route = await self.get_route(address)
|
route = await self.get_route(address)
|
||||||
if route is None:
|
if route is None:
|
||||||
return None
|
return None
|
||||||
index = route.get_attr(route.name2nla('oif'))
|
index = route.get_attr(route.name2nla("oif"))
|
||||||
try:
|
try:
|
||||||
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
|
return self.ip.neigh("get", dst=str(address), ifindex=index, state="none")[0]
|
||||||
except pyroute2.netlink.exceptions.NetlinkError as e:
|
except pyroute2.netlink.exceptions.NetlinkError as e:
|
||||||
if e.code == errno.ENOENT:
|
if e.code == errno.ENOENT:
|
||||||
return None
|
return None
|
||||||
@ -53,7 +53,7 @@ class NeighborController:
|
|||||||
neigh = await self.get_neighbor(address, index=index, flags=flags)
|
neigh = await self.get_neighbor(address, index=index, flags=flags)
|
||||||
if neigh is None:
|
if neigh is None:
|
||||||
return None
|
return None
|
||||||
mac = neigh.get_attr(neigh.name2nla('lladdr'))
|
mac = neigh.get_attr(neigh.name2nla("lladdr"))
|
||||||
if mac is None:
|
if mac is None:
|
||||||
return None
|
return None
|
||||||
return cptypes.MacAddress.parse(mac)
|
return cptypes.MacAddress.parse(mac)
|
||||||
@ -63,7 +63,7 @@ class NeighborController:
|
|||||||
address: cptypes.IPAddress,
|
address: cptypes.IPAddress,
|
||||||
) -> pyroute2.iproute.linux.rtmsg | None:
|
) -> pyroute2.iproute.linux.rtmsg | None:
|
||||||
try:
|
try:
|
||||||
return self.ip.route('get', dst=str(address))[0]
|
return self.ip.route("get", dst=str(address))[0]
|
||||||
except pyroute2.netlink.exceptions.NetlinkError as e:
|
except pyroute2.netlink.exceptions.NetlinkError as e:
|
||||||
if e.code == errno.ENOENT:
|
if e.code == errno.ENOENT:
|
||||||
return None
|
return None
|
||||||
@ -74,16 +74,16 @@ class NeighborController:
|
|||||||
interface: str,
|
interface: str,
|
||||||
) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
|
) -> typing.AsyncGenerator[tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
|
||||||
ifindex = socket.if_nametoindex(interface)
|
ifindex = socket.if_nametoindex(interface)
|
||||||
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
|
unicast_num = pyroute2.netlink.rtnl.rt_type["unicast"]
|
||||||
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
|
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
|
||||||
for family in (socket.AF_INET, socket.AF_INET6):
|
for family in (socket.AF_INET, socket.AF_INET6):
|
||||||
for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family):
|
for neigh in self.ip.neigh("dump", ifindex=ifindex, family=family):
|
||||||
if neigh['ndm_type'] != unicast_num:
|
if neigh["ndm_type"] != unicast_num:
|
||||||
continue
|
continue
|
||||||
mac = neigh.get_attr(neigh.name2nla('lladdr'))
|
mac = neigh.get_attr(neigh.name2nla("lladdr"))
|
||||||
if not mac:
|
if not mac:
|
||||||
continue
|
continue
|
||||||
dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla('dst')))
|
dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla("dst")))
|
||||||
if dst.is_link_local:
|
if dst.is_link_local:
|
||||||
continue
|
continue
|
||||||
yield (cptypes.MacAddress.parse(mac), dst)
|
yield (cptypes.MacAddress.parse(mac), dst)
|
||||||
|
@ -30,26 +30,20 @@ class NftSet:
|
|||||||
timeout: int | float | None = None,
|
timeout: int | float | None = None,
|
||||||
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
|
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
|
||||||
attrs: dict[str, typing.Any] = {
|
attrs: dict[str, typing.Any] = {
|
||||||
'NFTA_SET_ELEM_KEY': dict(
|
"NFTA_SET_ELEM_KEY": dict(
|
||||||
NFTA_DATA_VALUE=mac.raw,
|
NFTA_DATA_VALUE=mac.raw,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if timeout:
|
if timeout:
|
||||||
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
|
attrs["NFTA_SET_ELEM_TIMEOUT"] = int(1000 * timeout)
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
def _bulk_insert(
|
def _bulk_insert(
|
||||||
self,
|
self,
|
||||||
entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]],
|
entries: typing.Sequence[tuple[cptypes.MacAddress, int | float]],
|
||||||
) -> None:
|
) -> None:
|
||||||
ser_entries = [
|
ser_entries = [self._set_elem(mac) for mac, _timeout in entries]
|
||||||
self._set_elem(mac)
|
ser_entries_with_timeout = [self._set_elem(mac, timeout) for mac, timeout in entries]
|
||||||
for mac, _timeout in entries
|
|
||||||
]
|
|
||||||
ser_entries_with_timeout = [
|
|
||||||
self._set_elem(mac, timeout)
|
|
||||||
for mac, timeout in entries
|
|
||||||
]
|
|
||||||
with self._socket.begin() as tx:
|
with self._socket.begin() as tx:
|
||||||
# create doesn't affect existing elements, so:
|
# create doesn't affect existing elements, so:
|
||||||
# make sure entries exists
|
# make sure entries exists
|
||||||
@ -58,8 +52,8 @@ class NftSet:
|
|||||||
pyroute2.netlink.NLM_F_CREATE,
|
pyroute2.netlink.NLM_F_CREATE,
|
||||||
nfgen_family=NFPROTO_INET,
|
nfgen_family=NFPROTO_INET,
|
||||||
attrs=dict(
|
attrs=dict(
|
||||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -68,8 +62,8 @@ class NftSet:
|
|||||||
_nftsocket.NFT_MSG_DELSETELEM,
|
_nftsocket.NFT_MSG_DELSETELEM,
|
||||||
nfgen_family=NFPROTO_INET,
|
nfgen_family=NFPROTO_INET,
|
||||||
attrs=dict(
|
attrs=dict(
|
||||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -79,8 +73,8 @@ class NftSet:
|
|||||||
pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL,
|
pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL,
|
||||||
nfgen_family=NFPROTO_INET,
|
nfgen_family=NFPROTO_INET,
|
||||||
attrs=dict(
|
attrs=dict(
|
||||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -95,10 +89,7 @@ class NftSet:
|
|||||||
self.bulk_insert([(mac, timeout)])
|
self.bulk_insert([(mac, timeout)])
|
||||||
|
|
||||||
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
||||||
ser_entries = [
|
ser_entries = [self._set_elem(mac) for mac in entries]
|
||||||
self._set_elem(mac)
|
|
||||||
for mac in entries
|
|
||||||
]
|
|
||||||
with self._socket.begin() as tx:
|
with self._socket.begin() as tx:
|
||||||
# make sure entries exists
|
# make sure entries exists
|
||||||
tx.put(
|
tx.put(
|
||||||
@ -106,8 +97,8 @@ class NftSet:
|
|||||||
pyroute2.netlink.NLM_F_CREATE,
|
pyroute2.netlink.NLM_F_CREATE,
|
||||||
nfgen_family=NFPROTO_INET,
|
nfgen_family=NFPROTO_INET,
|
||||||
attrs=dict(
|
attrs=dict(
|
||||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -116,8 +107,8 @@ class NftSet:
|
|||||||
_nftsocket.NFT_MSG_DELSETELEM,
|
_nftsocket.NFT_MSG_DELSETELEM,
|
||||||
nfgen_family=NFPROTO_INET,
|
nfgen_family=NFPROTO_INET,
|
||||||
attrs=dict(
|
attrs=dict(
|
||||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -137,20 +128,20 @@ class NftSet:
|
|||||||
_nftsocket.NFT_MSG_GETSETELEM,
|
_nftsocket.NFT_MSG_GETSETELEM,
|
||||||
nfgen_family=NFPROTO_INET,
|
nfgen_family=NFPROTO_INET,
|
||||||
attrs=dict(
|
attrs=dict(
|
||||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
'mac': cptypes.MacAddress(
|
"mac": cptypes.MacAddress(
|
||||||
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
|
elem.get_attr("NFTA_SET_ELEM_KEY").get_attr("NFTA_DATA_VALUE"),
|
||||||
),
|
),
|
||||||
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
|
"timeout": _from_msec(elem.get_attr("NFTA_SET_ELEM_TIMEOUT", None)),
|
||||||
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
|
"expiration": _from_msec(elem.get_attr("NFTA_SET_ELEM_EXPIRATION", None)),
|
||||||
}
|
}
|
||||||
for response in responses
|
for response in responses
|
||||||
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
|
for elem in response.get_attr("NFTA_SET_ELEM_LIST_ELEMENTS", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
def flush(self) -> None:
|
def flush(self) -> None:
|
||||||
@ -158,9 +149,9 @@ class NftSet:
|
|||||||
_nftsocket.NFT_MSG_DELSETELEM,
|
_nftsocket.NFT_MSG_DELSETELEM,
|
||||||
nfgen_family=NFPROTO_INET,
|
nfgen_family=NFPROTO_INET,
|
||||||
attrs=dict(
|
attrs=dict(
|
||||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
NFTA_SET_ELEM_LIST_TABLE="captive_mark",
|
||||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
NFTA_SET_ELEM_LIST_SET="allowed",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
@ -170,7 +161,7 @@ class NftSet:
|
|||||||
pyroute2.netlink.NLM_F_CREATE,
|
pyroute2.netlink.NLM_F_CREATE,
|
||||||
nfgen_family=NFPROTO_INET,
|
nfgen_family=NFPROTO_INET,
|
||||||
attrs=dict(
|
attrs=dict(
|
||||||
NFTA_TABLE_NAME='captive_mark',
|
NFTA_TABLE_NAME="captive_mark",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
tx.put(
|
tx.put(
|
||||||
@ -178,8 +169,8 @@ class NftSet:
|
|||||||
pyroute2.netlink.NLM_F_CREATE,
|
pyroute2.netlink.NLM_F_CREATE,
|
||||||
nfgen_family=NFPROTO_INET,
|
nfgen_family=NFPROTO_INET,
|
||||||
attrs=dict(
|
attrs=dict(
|
||||||
NFTA_SET_TABLE='captive_mark',
|
NFTA_SET_TABLE="captive_mark",
|
||||||
NFTA_SET_NAME='allowed',
|
NFTA_SET_NAME="allowed",
|
||||||
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
|
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
|
||||||
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
|
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
|
||||||
NFTA_SET_KEY_LEN=6, # length of key: mac address
|
NFTA_SET_KEY_LEN=6, # length of key: mac address
|
||||||
|
@ -13,7 +13,7 @@ from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
|
|||||||
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
|
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
|
||||||
|
|
||||||
|
|
||||||
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base)
|
_NlMsgBase = typing.TypeVar("_NlMsgBase", bound=pyroute2.netlink.nlmsg_base)
|
||||||
|
|
||||||
|
|
||||||
# nft uses NESTED for those.. lets do the same
|
# nft uses NESTED for those.. lets do the same
|
||||||
@ -25,16 +25,17 @@ _nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM
|
|||||||
|
|
||||||
def _monkey_patch_pyroute2():
|
def _monkey_patch_pyroute2():
|
||||||
import pyroute2.netlink
|
import pyroute2.netlink
|
||||||
|
|
||||||
# overwrite setdefault on nlmsg_base class hierarchy
|
# overwrite setdefault on nlmsg_base class hierarchy
|
||||||
_orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
|
_orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
|
||||||
|
|
||||||
def _nlmsg_base__setvalue(self, value):
|
def _nlmsg_base__setvalue(self, value):
|
||||||
if not self.header or not self['header'] or not isinstance(value, dict):
|
if not self.header or not self["header"] or not isinstance(value, dict):
|
||||||
return _orig_setvalue(self, value)
|
return _orig_setvalue(self, value)
|
||||||
# merge headers instead of overwriting them
|
# merge headers instead of overwriting them
|
||||||
header = value.pop('header', {})
|
header = value.pop("header", {})
|
||||||
res = _orig_setvalue(self, value)
|
res = _orig_setvalue(self, value)
|
||||||
self['header'].update(header)
|
self["header"].update(header)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None:
|
def overwrite_methods(cls: type[pyroute2.netlink.nlmsg_base]) -> None:
|
||||||
@ -52,20 +53,20 @@ _monkey_patch_pyroute2()
|
|||||||
def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
|
def _build(msg_class: type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
|
||||||
msg = msg_class()
|
msg = msg_class()
|
||||||
for key, value in header.items():
|
for key, value in header.items():
|
||||||
msg['header'][key] = value
|
msg["header"][key] = value
|
||||||
for key, value in fields.items():
|
for key, value in fields.items():
|
||||||
msg[key] = value
|
msg[key] = value
|
||||||
if attrs:
|
if attrs:
|
||||||
attr_list = msg['attrs']
|
attr_list = msg["attrs"]
|
||||||
r_nla_map = msg_class._nlmsg_base__r_nla_map
|
r_nla_map = msg_class._nlmsg_base__r_nla_map
|
||||||
for key, value in attrs.items():
|
for key, value in attrs.items():
|
||||||
if msg_class.prefix:
|
if msg_class.prefix:
|
||||||
key = msg_class.name2nla(key)
|
key = msg_class.name2nla(key)
|
||||||
prime = r_nla_map[key]
|
prime = r_nla_map[key]
|
||||||
nla_class = prime['class']
|
nla_class = prime["class"]
|
||||||
if issubclass(nla_class, pyroute2.netlink.nla):
|
if issubclass(nla_class, pyroute2.netlink.nla):
|
||||||
# support passing nested attributes as dicts of subattributes (or lists of those)
|
# support passing nested attributes as dicts of subattributes (or lists of those)
|
||||||
if prime['nla_array']:
|
if prime["nla_array"]:
|
||||||
value = [
|
value = [
|
||||||
_build(nla_class, attrs=elem)
|
_build(nla_class, attrs=elem)
|
||||||
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
|
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
|
||||||
@ -83,10 +84,7 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
|
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
|
||||||
policy = {
|
policy = {(x | (NFNL_SUBSYS_NFTABLES << 8)): y for (x, y) in self.policy.items()}
|
||||||
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
|
|
||||||
for (x, y) in self.policy.items()
|
|
||||||
}
|
|
||||||
self.register_policy(policy)
|
self.register_policy(policy)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@ -156,7 +154,7 @@ class NFTTransaction:
|
|||||||
# no inner messages were queued... not sending anything
|
# no inner messages were queued... not sending anything
|
||||||
return
|
return
|
||||||
# request ACK on the last message (before END)
|
# request ACK on the last message (before END)
|
||||||
self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
|
self._msgs[-1]["header"]["flags"] |= pyroute2.netlink.NLM_F_ACK
|
||||||
self._msgs.append(
|
self._msgs.append(
|
||||||
# batch end message
|
# batch end message
|
||||||
_build(
|
_build(
|
||||||
|
@ -10,7 +10,7 @@ import trio.socket
|
|||||||
|
|
||||||
|
|
||||||
def _check_watchdog_pid() -> bool:
|
def _check_watchdog_pid() -> bool:
|
||||||
wpid = os.environ.pop('WATCHDOG_PID', None)
|
wpid = os.environ.pop("WATCHDOG_PID", None)
|
||||||
if not wpid:
|
if not wpid:
|
||||||
return True
|
return True
|
||||||
return wpid == str(os.getpid())
|
return wpid == str(os.getpid())
|
||||||
@ -18,16 +18,16 @@ def _check_watchdog_pid() -> bool:
|
|||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
|
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
|
||||||
target = os.environ.pop('NOTIFY_SOCKET', None)
|
target = os.environ.pop("NOTIFY_SOCKET", None)
|
||||||
ns: trio.socket.SocketType | None = None
|
ns: trio.socket.SocketType | None = None
|
||||||
watchdog_usec: int = 0
|
watchdog_usec: int = 0
|
||||||
if target:
|
if target:
|
||||||
if target.startswith('@'):
|
if target.startswith("@"):
|
||||||
# Linux abstract namespace socket
|
# Linux abstract namespace socket
|
||||||
target = '\0' + target[1:]
|
target = "\0" + target[1:]
|
||||||
ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
||||||
await ns.connect(target)
|
await ns.connect(target)
|
||||||
watchdog_usec_s = os.environ.pop('WATCHDOG_USEC', None)
|
watchdog_usec_s = os.environ.pop("WATCHDOG_USEC", None)
|
||||||
if _check_watchdog_pid() and watchdog_usec_s:
|
if _check_watchdog_pid() and watchdog_usec_s:
|
||||||
watchdog_usec = int(watchdog_usec_s)
|
watchdog_usec = int(watchdog_usec_s)
|
||||||
try:
|
try:
|
||||||
@ -52,18 +52,18 @@ class SdNotify:
|
|||||||
|
|
||||||
async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
||||||
assert self.is_connected(), "Watchdog can't run without socket"
|
assert self.is_connected(), "Watchdog can't run without socket"
|
||||||
await self.send('WATCHDOG=1')
|
await self.send("WATCHDOG=1")
|
||||||
task_status.started()
|
task_status.started()
|
||||||
# send every half of the watchdog timeout
|
# send every half of the watchdog timeout
|
||||||
interval = (watchdog_usec / 1e6) / 2.0
|
interval = (watchdog_usec / 1e6) / 2.0
|
||||||
while True:
|
while True:
|
||||||
await trio.sleep(interval)
|
await trio.sleep(interval)
|
||||||
await self.send('WATCHDOG=1')
|
await self.send("WATCHDOG=1")
|
||||||
|
|
||||||
async def send(self, *msg: str) -> None:
|
async def send(self, *msg: str) -> None:
|
||||||
if not self.is_connected():
|
if not self.is_connected():
|
||||||
return
|
return
|
||||||
dgram = '\n'.join(msg).encode('utf-8')
|
dgram = "\n".join(msg).encode("utf-8")
|
||||||
assert self._ns, "not connected" # checked above
|
assert self._ns, "not connected" # checked above
|
||||||
sent = await self._ns.send(dgram)
|
sent = await self._ns.send(dgram)
|
||||||
if sent != len(dgram):
|
if sent != len(dgram):
|
||||||
|
@ -10,9 +10,9 @@ def get_local_timezone():
|
|||||||
global _zoneinfo
|
global _zoneinfo
|
||||||
if not _zoneinfo:
|
if not _zoneinfo:
|
||||||
try:
|
try:
|
||||||
with open('/etc/timezone') as f:
|
with open("/etc/timezone") as f:
|
||||||
key = f.readline().strip()
|
key = f.readline().strip()
|
||||||
_zoneinfo = zoneinfo.ZoneInfo(key)
|
_zoneinfo = zoneinfo.ZoneInfo(key)
|
||||||
except (OSError, zoneinfo.ZoneInfoNotFoundError):
|
except (OSError, zoneinfo.ZoneInfoNotFoundError):
|
||||||
_zoneinfo = zoneinfo.ZoneInfo('UTC')
|
_zoneinfo = zoneinfo.ZoneInfo("UTC")
|
||||||
return _zoneinfo
|
return _zoneinfo
|
||||||
|
Loading…
Reference in New Issue
Block a user