2
0

document new repo URI exclusively

This commit is contained in:
Kilian Krause 2024-06-19 14:19:19 +02:00
parent 14a477cf4e
commit b7232c4a5b
63 changed files with 0 additions and 3537 deletions

11
.flake8
View File

@ -1,11 +0,0 @@
[flake8]
# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ]
# E402 module level import not at top of file [ usually on purpose. might use individual overrides instead? ]
# E701 multiple statements on one line [ still quite readable in short forms ]
# E713 test for membership should be not in [ disagree: want `not a in x` ]
# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ]
# W503 line break before binary operator [ gotta pick one way ]
extend-ignore = E266,E402,E701,E713,E714,W503
max-line-length = 120
exclude = *_pb2.py
application-import-names = capport

9
.gitignore vendored
View File

@ -1,9 +0,0 @@
.vscode
*.pyc
*.egg-info
__pycache__
venv
capport.yaml
custom
capport.state
capport.state.new-*

19
LICENSE
View File

@ -1,19 +0,0 @@
Copyright (c) 2022 Universität Stuttgart
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -2,85 +2,3 @@
New URL is now at:
https://git-nks-public.tik.uni-stuttgart.de/ac107458/python-capport.git
# python Captive Portal
## Installation
Either clone repository (and install dependencies either through distribution or as virtualenv with `./setup-venv.sh`) or install as package.
[`pipx`](https://pypa.github.io/pipx/) (available in debian as package) can be used to install in separate virtual environment:
pipx install https://git-nks-public.tik.uni-stuttgart.de/net/python-capport
In production put a reverse proxy in front of the local web ui (on 127.0.0.1:5001), and handle `/static` path either to `src/capport/api/static/` or your customized version of static files.
See the `contrib` directory for config of other software needed to setup a captive portal.
## Customization
Create `custom/templates` and put customized templates (from `src/capport/api/templates`) there.
Create `i18n/<langcode>` folders to put localized templates into (localized extends must use the full `i18n/.../basetmpl` paths though).
Requests with a `setlang=<langcode>` query parameter will set the language and try to store the choice in a session cookie.
## Run
Run `./start-api.sh` to start the web ui (listens on 127.0.0.1:5001 by default).
Run `./start-control.sh` to start the "controller" ("enforcement") part; this needs to be run as root (i.e. as `CAP_NET_ADMIN` of the current network namespace).
The controller expects this nft set to exist:
```
table inet captive_mark {
set allowed {
type ether_addr
flags timeout
}
}
```
Restarting the controller will push all entries it should contain again, but won't cleanup others.
## Internals
### Login/Logout
This is for an "open" network, i.e. no actual logins required, just an "we accept the ToS" form.
Designed to work without cookies; CSRF protection implemented by verifying the `Origin` header against the `Host` header (but allowing missing `Origin` header), and also requiring the clients `MAC` address (which an attacker from the same L2 could know, or guess from a non-temporary IPv6 address).
### HA
The list of "allowed" clients is stored in a "database"; each instance has the full database, and each time two instances connect to each other, they will send their full database for sync (and also all received updates will be broadcast to all others - but only if they actually led to a change in the database).
On each node there are two instances: one "controller" (also responsible for deploying the list to the kernel, aka "enforcement"), and the webui (also contains the RFC 8908 API).
The "controller" also stores updates to disk, and loads it on start.
This synchronization of the database works because it shouldn't matter in which order "changes" to the database are merged (each change is also just the new state from another database); see the `merge` method in the `MacEntry` class in `src/capport/database.py`.
#### Protocol
The controllers should be a full-mesh (be connected to all other controllers), and the webui instances are connected to all controllers (but not to other webui instances).
The controllers listen on the fixed TCP port 5000 for a custom "database sync" protocol.
This protocol is based on an anonymous TLS connection, which then uses a shared secret to verify the connection (not perfect yet; it would be better if python simply supported SRP - https://bugs.python.org/issue11943).
Then both sides can send protobuf messages; each message is prefixed by its 4-byte length. The basic message is defined in `protobuf/message.proto`.
### Web-UI
The ui needs to know the clients mac address to add it to the database. Right now this means that the webui must run on a host connected to the L2 of the clients to see them in the neighbor table (and client connection to the ui must use this L2 connection - the ui doesn't actively query for neighbors, it only looks at the neighbor cache).
### async
This project uses the `trio` python library for async IO.
Only the netlink handling (`ipneigh.py`, `nft_*.py`) uses blocking IO - but that should be ok, as we only make requests to the kernel which should get answered immediately.
Disk-IO for writing the database to disk is done in a separate thread.

View File

@ -1,11 +0,0 @@
---
comm-secret: mysecret
cookie-secret: mysecret
controllers:
- capport-controller1.example.com
- capport-controller2.example.com
session-timeout: 3600 # in seconds
venue-info-url: 'https://example.com'
server-names:
- localhost
- ...

View File

@ -1,54 +0,0 @@
# Various other parts of a captive portal setup
Network (HA) setup with IPv4 NAT:
- two nodes
- shared uplink L2, some transfer network (can be private or public)
* virtual addresses on active node for IPv4 and IPv6 to route traffic to
- shared downlink L2
* virtual addresses on active node for IPv4 and IPv6 as gateway for clients
* using `fe80::1` as gateway, but also add a public IPv6 virtual address
* connected: private IPv4 prefix (e.g. CGNAT), not routed
* connected: public IPv6 prefix (routed to virtual uplink address of nodes)
- public IPv4 prefix routed virtual uplink address of nodes to use for NAT
* IPv4-traffic from clients will be (S)NATted from this prefix; size depends
on number of parallel connections you want to support.
- webserver on nodes:
* port 8080: receives transparent http redirects from the firewall; should return a temporary redirect to your portal page.
* port 80: redirect to https
* port 443: reverse-proxy to 127.0.0.1:5001 (the webui backend), but serve `/static` directly from directory (see main README)
To access the portal page on the clients you'll need a DNS-name; it should point to the virtual addresses. In some ways downlink address is preferred, but you also might want to avoid private addresses - i.e. use the uplink IPv4 address and the downlink IPv6 address.
Also the management traffic for the virtual address should use the uplink interface if possible (`keepalived` supports this).
## ISC dhcpd
See `dhcpd.conf.erb` and `dhcpd6.conf.erb`.
Note: don't use too large IPv4 pools or dhcpd will take a long time to sync and build up the leases files.
## Firewall / NAT
See `nftables.conf.erb` for forwarding rules; if you want traffic shaping as well see `shape_non_whitelisted.sh`.
Local policies (ssh access and normal "host protection") are not included in the example.
You also might want to set a high `net.netfilter.nf_conntrack_max` with sysctl (e.g. `16777216`).
## Conntrackd
Active/failover configuration TBD.
I strongly recommend not to enable any tracking helpers; they often make significant holes into your stateful firewall (i.e. make clients reachable from the outside in ways they didn't actually want).
## Keepalived (for virtual addresses)
See `keepalived.conf.erb`.
## Apache2
See `apache2.conf` (only contains "interesting" parts, probably won't start that way).
Any other webserver configured in a similar way should do just as well.
## systemd units
See the `systemd` directory for examples of systemd units.

View File

@ -1,43 +0,0 @@
Listen 80
Listen 443
Listen 8080
<VirtualHost *:8080>
ServerName redirect
Header always set Cache-Control "no-store"
# trailing '?' drops request query string:
RedirectMatch seeother ^.*$ https://portal.example.com?
KeepAlive off
</VirtualHost>
<VirtualHost *:80>
ServerName portal.example.com
ServerAlias portal-node1.example.com
Redirect permanent / https://portal.example.com/
</VirtualHost>
<VirtualHost *:443>
ServerName portal.example.com
ServerAlias portal-node1.example.com
SSLEngine on
SSLCertificateFile "/etc/ssl/certs/portal.example.com-with-chain.crt"
SSLCertificateKeyFile "/etc/ssl/private/portal.example.com.key"
# The static directory of your theme (or the builtin one)
Alias /static "/var/lib/python-capport/custom/static"
Header always set X-Frame-Options DENY
Header always set Referrer-Policy same-origin
Header always set X-Content-Type-Options nosniff
Header always set Strict-Transport-Security "max-age=31556926;"
RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
ProxyRequests Off
ProxyPreserveHost On
ProxyPass /static !
ProxyPass / http://127.0.0.1:5001/
</VirtualHost>

View File

@ -1,35 +0,0 @@
option domain-name-servers <%= ', '.join(@dns_resolvers_ipv4) %>;
option ntp-servers <%= ', '.join(@ntp_servers_ipv4) %>;
# specify API server URL (RFC8910)
option default-url "https://<%= @service_name %>/api/captive-portal";
default-lease-time 600;
max-lease-time 3600;
authoritative;
<% if @instances.length == 2 -%>
failover peer "dhcp-peer" {
<% if @instance_index == 0 %>primary<% else %>secondary<% end %>;
address <%= @instances[@instance_index]['external_ipv4'] %>;
peer address <%= @instances[1-@instance_index]['external_ipv4'] %>;
max-response-delay 60;
max-unacked-updates 10;
load balance max seconds 3;
<% if @instance_index == 0 -%>
split 128;
mclt 180;
<%- end %>
}
<%- end %>
subnet <%= @client_ipv4_net %> netmask <%= @client_netmask %> {
option routers <%= @client_ipv4_gateway %>;
pool {
range <%= @client_ipv4_dhcp_from %> <%= @client_ipv4_dhcp_to %>;
<% if @instances.length == 2 -%>
failover peer "dhcp-peer";
<%- end %>
}
}

View File

@ -1,13 +0,0 @@
option dhcp6.name-servers <%= ', '.join(@dns_resolvers_ipv6) %>;
option dhcp6.sntp-servers <%= ', '.join(@ntp_servers_ipv6) %>;
# specify API server URL (RFC8910)
option dhcp6.v6-captive-portal "https://<%= @service_name %>/api/captive-portal";
# The delay before information-request refresh
# (minimum is 10 minutes, maximum one day, default is to not refresh)
# (set to 6 hours)
option dhcp6.info-refresh-time 3600;
subnet6 <%= @client_ipv6 %> {
}

View File

@ -1,49 +0,0 @@
global_defs {
vrrp_no_swap
checker_no_swap
script_user nobody
enable_script_security
}
vrrp_instance capport_ipv4_default {
state BACKUP
interface <%= @uplink_interface %>
virtual_router_id 1
priority 100
virtual_ipaddress {
<% @uplink_virtual_ipv4 %>
<% @client_virtual_ipv4 %> dev <%= @client_interface %>
}
promote_secondaries
}
vrrp_instance capport_ipv6_default {
state BACKUP
interface <%= @uplink_interface %>
virtual_router_id 2
priority 100
virtual_ipaddress {
fe80::1:1
<%= @uplink_virtual_ipv6 %>
fe80::1 dev <%= @client_interface %>
<%= @client_virtual_ipv6 %> dev <%= @client_interface %>
}
promote_secondaries
}
vrrp_sync_group capport_default {
group {
capport_ipv4_default
capport_ipv6_default
}
}

View File

@ -1,360 +0,0 @@
#!/usr/sbin/nft -f
# Template notes: most variables should have an obvious meaning, but:
# - client_ipv4: private IPv4 network for clients (not routed outside), connected
# - client_ipv4_public: public IPv4 network for clients, must be routed to this host
# and should be blackholed here.
# - client_ipv6: public IPv6 network, must be routed to this host, connected
# NOTE: mustn't flush full ruleset; need to keep the table `captive_mark` and its set around
# DON'T ENABLE THIS:
# flush ruleset
# fully whitelist certain sites, e.g. VPN gateways your users are
# allowed to connect to even without accepting the terms.
define full_ipv4 = {
<%- @whitelist_full_ipv4.each do |n| -%>
<%= n %>,
<%- end -%>
}
define full_ipv6 = {
<%- @whitelist_full_ipv6.each do |n| -%>
<%= n %>,
<%- end -%>
}
# whitelist http[s] traffic to certain sites, e.g. your
# homepage hosting the terms, OCSP responders, websites
# to setup other Wifi configurations (cat.eduroam.org)
define http_server_ipv4 = {
<%- @whitelist_http_ipv4.each do |n| -%>
<%= n %>,
<%- end -%>
}
define http_server_ipv6 = {
<%- @whitelist_http_ipv6.each do |n| -%>
<%= n %>,
<%- end -%>
}
option ntp-servers <%= ', '.join(@ntp_servers_ipv4) %>;
option dhcp6.sntp-servers <%= ', '.join(@ntp_servers_ipv6) %>;
# whitelist your DNS resolvers
define dns_server_ipv4 = {
<%- @dns_resolvers_ipv4.each do |n| -%>
<%= n %>,
<%- end -%>
}
define dns_server_ipv6 = {
<%- @dns_resolvers_ipv6.each do |n| -%>
<%= n %>,
<%- end -%>
}
# whitelist your (and possible other friendly) NTP servers
define ntp_server_ipv4 = {
<%- @ntp_servers_ipv4.each do |n| -%>
<%= n %>,
<%- end -%>
}
define ntp_server_ipv6 = {
<%- @ntp_servers_ipv6.each do |n| -%>
<%= n %>,
<%- end -%>
}
# ports to block traffic for completely
## block traffic from clients to certain server ports
define backlist_tcp = {
25, # SMTP
161, # SNMP
135, # epmap (netbios/portmapping)
137, # TCP netbios-ns
138, # TCP netbios-dgm
139, # netbios-ssn
445, # microsoft-ds (cifs, samba)
}
define backlist_udp = {
25, # SMTP
161, # SNMP
135, # UDP epmap (netbios/portmapping)
137, # netbios-ns
138, # netbios-dgm
139, # UDP netbios-ssn
445, # UDP microsoft-ds (cifs, samba)
5353, # mDNS
}
## block traffic from certain client ports
define backlist_udp_source = {
162, # SNMP trap
}
# once a client accepted the terms:
# * define "good" (whitelisted) traffic with the following lists
# * decide in "chain forward" below whether to block other traffic completely or
# e.g. shape it to low bandwidth (also see shape_non_whitelisted.sh)
define whitelist_tcp = {
22, # ssh
53, # dns
80, # http
443, # https
3128, # http proxy (squid)
8080, # http alt
110, # pop3
995, # pop3s
143, # imap
993, # imaps
587, # submission
465, # submissions
1194, # openvpn default
}
# https://help.webex.com/en-us/article/WBX264/How-Do-I-Allow-Webex-Meetings-Traffic-on-My-Network
define whitelist_udp = {
53, # dns
123, # ntp
443, # http/3
1194, # openvpn default
51820, # wireguard default
500, # IPsec isakmp
4500, # IPsec ipsec-nat-t
10000, # IPSec Cisco NAT-T
9000, # Primary Webex Client Media
5004, # Webex Client Media
}
# whitelist traffic to local sites
define whitelist_dest_ipv4 = {
<%- @local_site_prefixes_ipv4.each do |n| -%>
<%= n %>,
<%- end -%>
}
define whitelist_dest_ipv6 = {
<%- @local_site_prefixes_ipv6.each do |n| -%>
<%= n %>,
<%- end -%>
}
# IPv4 HTTP redirect + SNAT
table ip nat4 {}
flush table ip nat4
table ip nat4 {
chain prerouting {
type nat hook prerouting priority -100;
policy accept;
# needs to be marked from client interface (bit 0), captive (bit 1), and http dnat (bit 2) - otherwise accept
meta mark & 0x00000007 != 0x00000007 accept
tcp dport 80 redirect to 8080
}
chain postrouting {
type nat hook postrouting priority 100;
policy accept;
# needs to be marked from client interface (bit 0) - otherwise no NAT
meta mark & 0x00000001 != 0x00000001 accept
oifname <%= @uplink_interface %> snat to <%= @client_ipv4_public %>
}
}
# IPv6 HTTP redirect
table ip6 nat6 {}
flush table ip6 nat6
table ip6 nat6 {
chain prerouting {
type nat hook prerouting priority -100;
policy accept;
# needs to be marked from client interface (bit 0), captive (bit 1), and http dnat (bit 2) - otherwise accept
meta mark & 0x00000007 != 0x00000007 accept
tcp dport 80 redirect to 8080
}
}
# ipv4 + ipv6
table inet captive_filter {}
flush table inet captive_filter
table inet captive_filter {
chain antispoof_input {
# need to accept packet for input and forward!
meta mark & 0x00000001 == 0 accept comment "accept from non-client interface"
ip saddr { 0.0.0.0, <%= @client_ipv4 %> } return
ip6 saddr { fe80::/64, <%= @client_ipv6 %> } return
drop
}
# we need the "redirect" decision before DNAT in prerouting:-100
chain mark_clients {
type filter hook prerouting priority -110;
policy accept;
meta mark & 0x00000001 == 0 accept comment "accept from non-client interface"
jump antispoof_input
meta mark & 0x00000002 == 0 accept comment "accept packets from non-captive clients"
# now accept all traffic to destinations allowed in captive state, and mark "redirect" packets:
jump captive_allowed
}
chain input {
type filter hook input priority 0;
policy accept;
# TODO: limit services available to clients? iptconf might already be enough
}
chain forward_down {
# only filter uplink -> client here:
iifname != <%= @uplink_interface %> accept
oifname != <%= @client_interface %> accept
# established connections
ct state established,related accept
# allow incoming ipv6 ping (ipv4 ping can't work due to NAT)
icmpv6 type echo-request accept
drop
}
chain antispoof_forward {
meta mark & 0x00000001 == 0 goto forward_down comment "handle forwardings not from client interface"
ip saddr { <%= @client_ipv4 %> } return
ip6 saddr { <%= @client_ipv6 %> } return
drop
}
chain captive_allowed_icmp {
# allow all pings to servers we allow other kind of traffic
icmp type echo-request accept
icmpv6 type echo-request accept
}
chain captive_allowed_http {
# http + https (but not QUIC)
tcp dport { 80, 443 } accept
goto captive_allowed_icmp
}
chain captive_allowed_dns {
# DNS, DoT, DoH
udp dport { 53, 853 } accept
tcp dport { 53, 443, 853 } accept
goto captive_allowed_icmp
}
chain captive_allowed_ntp {
# only NTP
udp dport 123 accept
goto captive_allowed_icmp
}
chain captive_allowed {
# all protocols for fully whitelisted
ip daddr $full_ipv4 accept
ip6 daddr $full_ipv6 accept
ip daddr $http_server_ipv4 jump captive_allowed_http
ip6 daddr $http_server_ipv6 jump captive_allowed_http
ip daddr $dns_server_ipv4 jump captive_allowed_dns
ip6 daddr $dns_server_ipv6 jump captive_allowed_dns
ip daddr $ntp_server_ipv4 jump captive_allowed_ntp
ip6 daddr $ntp_server_ipv6 jump captive_allowed_ntp
# mark (new) http clients to redirect to local http server with bit 2
tcp dport 80 ct state new meta mark set meta mark | 0x00000004 accept comment "DNAT HTTP"
# mark packets to reject in forward with bit 3
meta mark set meta mark | 0x00000008 comment "reject in forwarding"
}
# for DNS+NTP
ct timeout udp-oneshot {
protocol udp;
policy = { unreplied: 10, replied: 0 }
}
chain forward_reject {
# could reject TCP connections with tcp reset, but ICMP unreachable should be good enough
# (and it's also semantically correct):
# ip protocol tcp reject with tcp reset
# ip6 nexthdr tcp reject with tcp reset
# but we need to close existing tcp sessions: (when client moves to captive state)
ct state != new ip protocol tcp reject with tcp reset
ct state != new ip6 nexthdr tcp reject with tcp reset
# default icmp reject
reject with icmpx type admin-prohibited
}
# block some ports always
chain blacklist {
tcp dport $backlist_tcp goto forward_reject
udp dport $backlist_udp goto forward_reject
udp sport $backlist_udp_source goto forward_reject
}
# ports we assume are "proper" - still only allowed in non-captive state
chain whitelist {
ip daddr $whitelist_dest_ipv4 accept
ip6 daddr $whitelist_dest_ipv6 accept
tcp dport $whitelist_tcp accept
udp dport $whitelist_udp accept
ip protocol esp accept
ip6 nexthdr esp accept
icmp type echo-request accept
icmpv6 type echo-request accept
# icmp related to existing connections, ...
ct state established,related accept
}
chain forward {
type filter hook forward priority 0;
policy drop;
jump antispoof_forward
# optimize conntrack timeouts for DNS and NTP
udp dport { 53, 123 } ct timeout set "udp-oneshot"
jump blacklist
# reject packets marked for rejection in mark_clients/captive_allowed (bit 3)
meta mark & 0x00000008 != 0 goto forward_reject
jump whitelist
# optional (policy):
goto forward_reject comment "drop not-whitelisted traffic completely"
# (conntrack) mark connection with bit 4 for shaping
ct mark set ct mark | 0x00000010 counter accept comment "accept shaped traffic"
}
chain forward_con_shaped_mark {
type filter hook forward priority 10;
policy accept;
# copy conntrack mark bit 4 to meta mark bit 4
ct mark & 0x00000010 == 0 accept comment "non shaped connection"
meta mark set meta mark | 0x00000010 counter accept comment "shaped connection"
}
}
# NOTE: mustn't flush this table to keep the set around
# DON'T ENABLE THIS:
# flush table inet captive_mark
# as table wasn't flushed, at least delete the single chain we're expecting
table inet captive_mark {
chain prerouting { }
}
delete chain inet captive_mark prerouting;
table inet captive_mark {
# set isn't recreated, i.e. keeps dynamically added members
set allowed {
type ether_addr
flags timeout
}
chain prerouting {
type filter hook prerouting priority -150;
policy accept;
iifname != <%= @client_interface %> accept
# mark packets from client interface with bit 0
meta mark set meta mark | 0x00000001
ether saddr @allowed accept
# mark "captive" clients with bit 1
meta mark set meta mark | 0x00000002
accept
}
# further existing elements in this table won't be cleared by loading this file!
}

View File

@ -1,22 +0,0 @@
interface <%= client_interface %>
{
AdvSendAdvert on;
AdvDefaultPreference high;
AdvSourceLLAddress off;
AdvRASrcAddress {
fe80::1;
};
RDNSS <%= ' '.join(@dns_resolvers_ipv6) %> {
FlushRDNSS off;
};
AdvOtherConfigFlag on;
# will require radvd > 2.19 (not released yet)
# AdvCaptivePortalAPI "https://<%= @service_name %>/api/captive-portal";
prefix <%= client_ipv6 %>
{
};
};

View File

@ -1,31 +0,0 @@
#!/bin/bash
limit_iface() {
local dev=$1
local shape_rate=$2 # "guaranteed"
local shape_ceil=$3 # "upper limit"
tc qdisc delete dev "${dev}" root 2>/dev/null
tc qdisc add dev "${dev}" root handle 1: htb default 0x11
# basically no limit for default traffic
tc class add dev "${dev}" parent 1: classid 1:11 htb rate 10Gbit ceil 100Gbit quantum 100000
# limit "bad" (not whitelisted) traffic
tc class add dev "${dev}" parent 1: classid 1:12 htb prio 1 rate "${shape_rate}" ceil "${shape_ceil}"
# use "codel" qdisc for both classes, but with larger queue for default traffic
tc qdisc add dev "${dev}" parent 1:11 handle 11: codel limit 20000
tc qdisc add dev "${dev}" parent 1:12 handle 12: codel
# sort into bad class based on netfilter mark (if bit 0x10 is set)
tc filter add dev "${dev}" parent 1: prio 1 basic match 'meta(nf_mark mask 0x10 eq 0x10)' classid 1:12
}
uplink=$1
downlink=$2
if [ -z "${uplink}" -o -z "${downlink}" ]; then
echo >&2 "Missing uplink and downlink interface names"
exit 1
fi
limit_iface "${uplink}" "1Mbit" "1Mbit"
limit_iface "${downlink}" "1Mbit" "1Mbit"

View File

@ -1,18 +0,0 @@
[Unit]
Description=Captive Portal enforcement service
Wants=basic.target
After=basic.target network.target
ConditionFileIsExecutable=/var/lib/python-capport/start-control.sh
ConditionPathIsDirectory=/var/lib/python-capport/venv
# TODO: start as unprivileged user but with CAP_NET_ADMIN ?
[Service]
Type=notify
WatchdogSec=10
ExecStart=/var/lib/python-capport/start-control.sh
Restart=always
ProtectSystem=full
ProtectHome=true
[Install]
WantedBy=multi-user.target

View File

@ -1,18 +0,0 @@
[Unit]
Description=NFT Firewall Shim for Captive Portal
Wants=network-pre.target
Before=network-pre.target shutdown.target
Conflicts=shutdown.target
DefaultDependencies=no
[Service]
Type=oneshot
RemainAfterExit=yes
StandardInput=null
ProtectSystem=full
ProtectHome=true
ExecStart=/usr/sbin/nft -f /etc/nftables.conf
ExecReload=/usr/sbin/nft -f /etc/nftables.conf
[Install]
WantedBy=sysinit.target

View File

@ -1,15 +0,0 @@
[Unit]
Description=Captive Portal traffic shaping
Wants=basic.target
After=basic.target network.target
[Service]
Type=oneshot
RemainAfterExit=yes
StandardInput=null
ProtectSystem=full
ProtectHome=true
ExecStart=/etc/capport-tc.sh <%= @uplink_interface %> <%= @client_interface %>
[Install]
WantedBy=multi-user.target

View File

@ -1,18 +0,0 @@
[Unit]
Description=Captive Portal web ui service
Wants=basic.target
After=basic.target network.target
ConditionFileIsExecutable=/var/lib/python-capport/start-control.sh
ConditionPathIsDirectory=/var/lib/python-capport/venv
[Service]
User=capport
Type=notify
WatchdogSec=10
ExecStart=/var/lib/python-capport/start-api.sh
Restart=always
ProtectSystem=full
ProtectHome=true
[Install]
WantedBy=multi-user.target

19
flake8
View File

@ -1,19 +0,0 @@
#!/bin/sh
### check type annotations with mypy
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!"
exit 1
fi
if [ ! -x ./venv/bin/flake8 ]; then
./venv/bin/pip install flake8 flake8-import-order
fi
./venv/bin/flake8 src

42
mypy
View File

@ -1,42 +0,0 @@
#!/bin/sh
### check type annotations with mypy
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!"
exit 1
fi
if [ ! -x ./venv/bin/mypy ]; then
./venv/bin/pip install mypy trio-typing[mypy] types-PyYAML types-aiofiles types-colorama types-cryptography types-protobuf types-toml
fi
site_pkgs=$(./venv/bin/python -c 'import site; print(site.getsitepackages()[0])')
if [ ! -d "${site_pkgs}/trio_typing" ]; then
./venv/bin/pip install trio-typing[mypy]
fi
if [ ! -d "${site_pkgs}/yaml-stubs" ]; then
./venv/bin/pip install types-PyYAML
fi
if [ ! -d "${site_pkgs}/aiofiles-stubs" ]; then
./venv/bin/pip install types-aiofiles
fi
if [ ! -d "${site_pkgs}/colorama-stubs" ]; then
./venv/bin/pip install types-colorama
fi
if [ ! -d "${site_pkgs}/cryptography-stubs" ]; then
./venv/bin/pip install types-cryptography
fi
if [ ! -d "${site_pkgs}/google-stubs" ]; then
./venv/bin/pip install types-protobuf
fi
if [ ! -d "${site_pkgs}/toml-stubs" ]; then
./venv/bin/pip install types-toml
fi
./venv/bin/mypy --install-types src

View File

@ -1,10 +0,0 @@
#!/bin/bash
set -e
cd "$(dirname "$(readlink -f "$0")")"
rm -rf ../src/capport/comm/protobuf/message_pb2.py
mkdir -p ../src/capport/comm/protobuf
protoc --python_out=../src/capport/comm/protobuf message.proto

View File

@ -1,42 +0,0 @@
syntax = "proto3";
package capport;
message Message {
oneof oneof {
Hello hello = 1;
AuthenticationResult authentication_result = 2;
Ping ping = 3;
MacStates mac_states = 10;
}
}
// sent by clients and servers as first message
message Hello {
bytes instance_id = 1;
string hostname = 2;
bool is_controller = 3;
bytes authentication = 4;
}
// tell peer whether hello authentication was good
message AuthenticationResult {
bool success = 1;
}
message Ping {
bytes payload = 1;
}
message MacStates {
repeated MacState states = 1;
}
message MacState {
bytes mac_address = 1;
// Seconds of UTC time since epoch
int64 last_change = 2;
// Seconds of UTC time since epoch
int64 allow_until = 3;
bool allowed = 4;
}

View File

@ -1,14 +0,0 @@
[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"
[tool.mypy]
python_version = "3.9"
# warn_return_any = true
warn_unused_configs = true
exclude = [
'_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary)
]

View File

@ -1,11 +0,0 @@
#!/bin/bash
set -e
self=$(dirname "$(readlink -f "$0")")
cd "${self}"
python3 -m venv venv
# install cli extras
./venv/bin/pip install --upgrade --upgrade-strategy eager -e '.'

View File

@ -1,38 +0,0 @@
[metadata]
name = capport-tik-nks
version = 0.0.1
author = Stefan Bühler
author_email = stefan.buehler@tik.uni-stuttgart.de
description = Captive Portal
long_description = file: README.md
long_description_content_type = text/markdown
url = https://git-nks-public.tik.uni-stuttgart.de/net/python-capport
project_urls =
Bug Tracker = https://git-nks-public.tik.uni-stuttgart.de/net/python-capport/issues
classifiers =
Programming Language :: Python :: 3
License :: OSI Approved :: MIT License
Operating System :: OS Independent
[options]
package_dir =
= src
packages = find:
python_requires = >=3.9
install_requires =
trio
quart
quart-trio
hypercorn[trio]
PyYAML
protobuf>=4.21
pyroute2~=0.7.3
[options.packages.find]
where = src
[options.entry_points]
console_scripts =
capport-control = capport.control.run:main
capport-stats = capport.stats:main
capport-webui = capport.api.hypercorn_run:main

View File

@ -1,6 +0,0 @@
# https://github.com/pypa/setuptools/issues/2816
# allow editable install on older pip versions
from setuptools import setup
if __name__ == "__main__":
setup()

View File

@ -1,11 +0,0 @@
from __future__ import annotations
from .app_cls import MyQuartApp
app = MyQuartApp(__name__)
__import__('capport.api.setup')
__import__('capport.api.proxy_fix')
__import__('capport.api.lang')
__import__('capport.api.template_filters')
__import__('capport.api.views')

View File

@ -1,55 +0,0 @@
from __future__ import annotations
import os
import os.path
import typing
import jinja2
import quart.templating
import quart_trio
import capport.comm.hub
import capport.config
import capport.utils.ipneigh
class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader):
app: MyQuartApp
def __init__(self, app: MyQuartApp) -> None:
super().__init__(app)
def _loaders(self) -> typing.Generator[jinja2.BaseLoader, None, None]:
if self.app.custom_loader:
yield self.app.custom_loader
for loader in super()._loaders():
yield loader
class MyQuartApp(quart_trio.QuartTrio):
my_nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None
my_hub: typing.Optional[capport.comm.hub.Hub] = None
my_config: capport.config.Config
custom_loader: typing.Optional[jinja2.FileSystemLoader] = None
def __init__(self, import_name: str, **kwargs) -> None:
self.my_config = capport.config.Config.load_default_once()
kwargs.setdefault('template_folder', os.path.join(os.path.dirname(__file__), 'templates'))
cust_templ = os.path.join('custom', 'templates')
if os.path.exists(cust_templ):
self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ))
cust_static = os.path.abspath(os.path.join('custom', 'static'))
if os.path.exists(cust_static):
static_folder = cust_static
else:
static_folder = os.path.join(os.path.dirname(__file__), 'static')
kwargs.setdefault('static_folder', static_folder)
super().__init__(import_name, **kwargs)
self.debug = self.my_config.debug
self.secret_key = self.my_config.cookie_secret
def create_global_jinja_loader(self) -> DispatchingJinjaLoader:
"""Create and return a global (not blueprint specific) Jinja loader."""
return DispatchingJinjaLoader(self)

View File

@ -1,38 +0,0 @@
from __future__ import annotations
import hypercorn.config
import hypercorn.trio.run
import hypercorn.utils
import capport.config
def run(config: hypercorn.config.Config) -> None:
sockets = config.create_sockets()
if config.worker_class != 'trio':
raise Exception('Invalid worker class received from constructor')
hypercorn.trio.run.trio_worker(config=config, sockets=sockets)
for sock in sockets.secure_sockets:
sock.close()
for sock in sockets.insecure_sockets:
sock.close()
def main() -> None:
_config = capport.config.Config.load_default_once()
hypercorn_config = hypercorn.config.Config()
hypercorn_config.application_path = 'capport.api.app'
hypercorn_config.worker_class = 'trio'
hypercorn_config.bind = ["127.0.0.1:5001"]
if _config.server_names:
hypercorn_config.server_names = _config.server_names
elif not _config.debug:
raise Exception(
"production setup requires server-names in config (list of accepted hostnames in http requests)"
)
run(hypercorn_config)

View File

@ -1,83 +0,0 @@
from __future__ import annotations
import os.path
import re
import typing
import quart
from .app import app
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$')
def parse_accept_language(value: str) -> typing.List[str]:
value = value.strip()
if not value or value == '*':
return []
tuples = []
for entry in value.split(','):
attrs = entry.split(';')
name = attrs.pop(0).strip().lower()
q = 1.0
for attr in attrs:
if not '=' in attr: continue
key, value = attr.split('=', maxsplit=1)
if key.strip().lower() == 'q':
try:
q = float(value.strip())
except ValueError:
q = 0.0
if q >= 0.0:
tuples.append((q, name))
tuples.sort()
have = set()
result = []
for (_q, name) in tuples:
if name in have: continue
if name == '*': break
have.add(name)
if _VALID_LANGUAGE_NAMES.match(name):
result.append(name)
short_name = name.split('-', maxsplit=1)[0].split('_', maxsplit=1)[0]
if not short_name or short_name in have: continue
have.add(short_name)
result.append(short_name)
return result
@app.before_request
def detect_language():
g = quart.g
r = quart.request
s = quart.session
if 'setlang' in r.args:
lang = r.args.get('setlang').strip().lower()
if lang and _VALID_LANGUAGE_NAMES.match(lang):
if s.get('lang') != lang:
s['lang'] = lang
g.langs = [lang]
return
else:
# reset language
s.pop('lang', None)
lang = s.get('lang')
if lang:
lang = lang.strip().lower()
if lang and _VALID_LANGUAGE_NAMES.match(lang):
g.langs = [lang]
return
acc_lang = ','.join(r.headers.get_all('Accept-Language'))
g.langs = parse_accept_language(acc_lang)
async def render_i18n_template(template, /, **kwargs) -> str:
langs: typing.List[str] = quart.g.langs
if not langs:
return await quart.render_template(template, **kwargs)
names = [
os.path.join('i18n', lang, template)
for lang in langs
]
names.append(template)
return await quart.render_template(names, **kwargs)

View File

@ -1,80 +0,0 @@
from __future__ import annotations
import ipaddress
import typing
import quart
import werkzeug
from werkzeug.http import parse_list_header
from .app import app
def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str] = ()) -> typing.Optional[str]:
if not value_list:
return None
values = parse_list_header(value_list)
if values and values[0]:
if not allowed or values[0] in allowed:
return values[0]
return None
def local_proxy_fix(request: quart.Request):
if not request.remote_addr:
return
try:
addr = ipaddress.ip_address(request.remote_addr)
except ValueError:
# TODO: accept unix sockets somehow too?
return
if not addr.is_loopback:
return
client = _get_first_in_list(request.headers.get('X-Forwarded-For'))
if not client:
# assume this is always set behind reverse proxies supporting any of the headers
return
request.remote_addr = client
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https'))
port: typing.Optional[int] = None
if scheme:
port = 443 if scheme == 'https' else 80
request.scheme = scheme
host = _get_first_in_list(request.headers.get('X-Forwarded-Host'))
port_s: typing.Optional[str]
if host:
request.host = host
if ':' in host and not host.endswith(']'):
try:
_, port_s = host.rsplit(':', maxsplit=1)
port = int(port_s)
except ValueError:
# ignore invalid port in host header
pass
port_s = _get_first_in_list(request.headers.get('X-Forwarded-Port'))
if port_s:
try:
port = int(port_s)
except ValueError:
# ignore invalid port in header
pass
if port:
if request.server and len(request.server) == 2:
request.server = (request.server[0], port)
root_path = _get_first_in_list(request.headers.get('X-Forwarded-Prefix'))
if root_path:
request.root_path = root_path
class LocalProxyFixRequestHandler:
def __init__(self, orig_handle_request):
self._orig_handle_request = orig_handle_request
async def __call__(self, request: quart.Request) -> typing.Union[quart.Response, werkzeug.Response]:
# need to patch request before url_adapter is built
local_proxy_fix(request)
return await self._orig_handle_request(request)
app.handle_request = LocalProxyFixRequestHandler(app.handle_request) # type: ignore

View File

@ -1,63 +0,0 @@
from __future__ import annotations
import logging
import uuid
import trio
import capport.comm.hub
import capport.comm.message
import capport.database
import capport.utils.cli
import capport.utils.ipneigh
from capport.utils.sd_notify import open_sdnotify
from .app import app
_logger = logging.getLogger(__name__)
class ApiHubApp(capport.comm.hub.HubApplication):
async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
# TODO: support websocket notification updates to clients?
pass
async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
try:
async with capport.utils.ipneigh.connect() as mync:
app.my_nc = mync
_logger.info("Running hub for API")
myapp = ApiHubApp()
myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp, is_controller=False)
app.my_hub = myhub
await myhub.run(task_status=task_status)
finally:
app.my_hub = None
app.my_nc = None
_logger.info("Done running hub for API")
await app.shutdown()
async def _setup(*, task_status=trio.TASK_STATUS_IGNORED):
async with open_sdnotify() as sn:
await sn.send('STATUS=Starting hub')
async with trio.open_nursery() as nursery:
await nursery.start(_run_hub)
await sn.send('READY=1', 'STATUS=Ready for client requests')
task_status.started()
# continue running hub and systemd watchdog handler
@app.before_serving
async def init():
app.debug = app.my_config.debug
app.secret_key = app.my_config.cookie_secret
capport.utils.cli.init_logger(app.my_config)
await app.nursery.start(_setup)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,16 +0,0 @@
from __future__ import annotations
import ipaddress
import typing
from .app import app
@app.add_template_test
def ip_in_networks(ip_s: str, networks: typing.Iterable[str]) -> bool:
ip = ipaddress.ip_address(ip_s)
for network in networks:
net = ipaddress.ip_network(network)
if ip in net:
return True
return False

View File

@ -1,19 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>{% block title %}Captive Portal{% endblock %}</title>
<meta name="viewport" content="width=device-width, initial-scale=1" />
<script src="{{ url_for('static', filename='bootstrap/bootstrap.bundle.min.js') }}"></script>
<link rel="stylesheet" href="{{ url_for('static', filename='bootstrap/bootstrap.min.css') }}">
</head>
<body class="container">
<header class="d-flex justify-content-center py-3">
<ul class="nav nav-pills">
<li class="nav-item"><a href="#" class="nav-link active" aria-current="page">Home</a></li>
</ul>
</header>
{% block content %}{% endblock %}
</body>
</html>

View File

@ -1,16 +0,0 @@
{% extends "base.html" %}
{% block content %}
You already accepted out conditions and are currently granted access to the internet:
Your current session will last for {{ state.allowed_remaining }} seconds.
<form method="POST" action="/login">
<input type="hidden" name="accept" value="1">
<input type="hidden" name="mac" value="{{ state.mac }}">
<button type="submit" class="btn btn-primary mb-3">Extend session</button>
</form>
<form method="POST" action="/logout">
<input type="hidden" name="mac" value="{{ state.mac }}">
<button type="submit" class="btn btn-danger mb-3">Terminate session</button>
</form>
<br>
{% endblock %}

View File

@ -1,9 +0,0 @@
{% extends "base.html" %}
{% block content %}
To get access to the internet please accept our usage guidelines by clicking this button:
<form method="POST" action="/login">
<input type="hidden" name="accept" value="1">
<input type="hidden" name="mac" value="{{ state.mac }}">
<button type="submit" class="btn btn-primary mb-3">Accept</button>
</form>
{% endblock %}

View File

@ -1,5 +0,0 @@
{% extends "base.html" %}
{% block content %}
<p>It seems you're accessing this site from outside the network this captive portal is running for.</p>
<p>Your clients IP address is {{ state.address }}</p>
{% endblock %}

View File

@ -1,166 +0,0 @@
from __future__ import annotations
import ipaddress
import logging
import typing
import quart
import trio
import capport.comm.hub
import capport.comm.message
import capport.database
import capport.utils.cli
import capport.utils.ipneigh
from capport import cptypes
from .app import app
from .lang import render_i18n_template
_logger = logging.getLogger(__name__)
def get_client_ip() -> cptypes.IPAddress:
remote_addr = quart.request.remote_addr
if not remote_addr:
quart.abort(500, 'Missing client address')
try:
addr = ipaddress.ip_address(remote_addr)
except ValueError as e:
_logger.warning(f'Invalid client address {remote_addr!r}: {e}')
quart.abort(500, 'Invalid client address')
return addr
async def get_client_mac_if_present(
address: typing.Optional[cptypes.IPAddress] = None,
) -> typing.Optional[cptypes.MacAddress]:
assert app.my_nc # for mypy
if not address:
address = get_client_ip()
return await app.my_nc.get_neighbor_mac(address)
async def get_client_mac(address: typing.Optional[cptypes.IPAddress] = None) -> cptypes.MacAddress:
mac = await get_client_mac_if_present(address)
if mac is None:
_logger.warning(f"Couldn't find MAC addresss for {address}")
quart.abort(404, 'Unknown client')
return mac
async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None:
assert app.my_hub # for mypy
async with app.my_hub.database.make_changes() as pu:
try:
pu.login(mac, app.my_config.session_timeout)
except capport.database.NotReadyYet as e:
quart.abort(500, str(e))
if pu:
_logger.debug(f'User {mac} (with IP {address}) logged in')
for msg in pu.serialized:
await app.my_hub.broadcast(msg)
async def user_logout(mac: cptypes.MacAddress) -> None:
assert app.my_hub # for mypy
async with app.my_hub.database.make_changes() as pu:
try:
pu.logout(mac)
except capport.database.NotReadyYet as e:
quart.abort(500, str(e))
if pu:
_logger.debug(f'User {mac} logged out')
for msg in pu.serialized:
await app.my_hub.broadcast(msg)
async def user_lookup() -> cptypes.MacPublicState:
assert app.my_hub # for mypy
address = get_client_ip()
mac = await get_client_mac_if_present(address)
if not mac:
return cptypes.MacPublicState.from_missing_mac(address)
else:
return app.my_hub.database.lookup(address, mac)
# @app.route('/all')
# async def route_all():
# return app.my_hub.database.as_json()
def check_self_origin():
origin = quart.request.headers.get('Origin', None)
if origin is None:
# not a request by a modern browser - probably curl or something similar. don't care.
return
origin = origin.lower().strip()
if origin == 'none':
quart.abort(403, 'Origin is none')
origin_parts = origin.split('/')
# Origin should look like: <scheme>://<hostname> (optionally followed by :<port>)
if len(origin_parts) < 3:
quart.abort(400, 'Broken Origin header')
if origin_parts[0] != 'https:' and not app.my_config.debug:
# -> require https in production
quart.abort(403, 'Non-https Origin not allowed')
origin_host = origin_parts[2]
host = quart.request.headers.get('Host', None)
if host is None:
quart.abort(403, 'Missing Host header')
if host.lower() != origin_host:
quart.abort(403, 'Origin mismatch')
@app.route('/', methods=['GET'])
async def index(missing_accept: bool = False):
state = await user_lookup()
if not state.mac:
return await render_i18n_template('index_unknown.html', state=state, missing_accept=missing_accept)
elif state.allowed:
return await render_i18n_template('index_active.html', state=state, missing_accept=missing_accept)
else:
return await render_i18n_template('index_inactive.html', state=state, missing_accept=missing_accept)
@app.route('/login', methods=['POST'])
async def login():
check_self_origin()
with trio.fail_after(5.0):
form = await quart.request.form
if form.get('accept') != '1':
return await index(missing_accept=True)
req_mac = form.get('mac')
if not req_mac:
quart.abort(400, description='Missing MAC in request form data')
address = get_client_ip()
mac = await get_client_mac(address)
if str(mac) != req_mac:
quart.abort(403, description="Passed MAC in request form doesn't match client address")
await user_login(address, mac)
return quart.redirect('/', code=303)
@app.route('/logout', methods=['POST'])
async def logout():
check_self_origin()
with trio.fail_after(5.0):
form = await quart.request.form
req_mac = form.get('mac')
if not req_mac:
quart.abort(400, description='Missing MAC in request form data')
mac = await get_client_mac()
if str(mac) != req_mac:
quart.abort(403, description="Passed MAC in request form doesn't match client address")
await user_logout(mac)
return quart.redirect('/', code=303)
@app.route('/api/captive-portal', methods=['GET'])
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
async def captive_api():
state = await user_lookup()
return state.to_rfc8908(app.my_config)

View File

@ -1,453 +0,0 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import random
import socket
import ssl
import struct
import typing
import uuid
import trio
import capport.comm.message
import capport.database
if typing.TYPE_CHECKING:
from ..config import Config
_logger = logging.getLogger(__name__)
class HubConnectionReadError(ConnectionError):
pass
class HubConnectionClosedError(ConnectionError):
pass
class LoopbackConnectionError(Exception):
pass
class Channel:
def __init__(self, hub: Hub, transport_stream, server_side: bool):
self._hub = hub
self._serverside = server_side
self._ssl = trio.SSLStream(transport_stream, self._hub._anon_context, server_side=server_side)
_logger.debug(f"{self}: created (server_side={server_side})")
def __repr__(self) -> str:
return f'Channel[0x{id(self):x}]'
async def do_handshake(self) -> capport.comm.message.Hello:
try:
await self._ssl.do_handshake()
ssl_binding = self._ssl.get_channel_binding()
if not ssl_binding:
# binding mustn't be None after successful handshake
raise ConnectionError("Missing SSL channel binding")
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
msg = self._hub._make_hello(ssl_binding, server_side=self._serverside).to_message()
await self.send_msg(msg)
peer_hello = (await self.recv_msg()).to_variant()
if not isinstance(peer_hello, capport.comm.message.Hello):
raise HubConnectionReadError("Expected Hello as first message")
auth_succ = \
(peer_hello.authentication ==
self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
peer_auth = (await self.recv_msg()).to_variant()
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
raise HubConnectionReadError("Expected AuthenticationResult as second message")
if not auth_succ or not peer_auth.success:
raise HubConnectionReadError("Authentication failed")
return peer_hello
async def _read(self, num: int) -> bytes:
assert num > 0
buf = b''
# _logger.debug(f"{self}:_read({num})")
while num > 0:
try:
part = await self._ssl.receive_some(num)
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
# _logger.debug(f"{self}:_read({num}) got part {part!r}")
if len(part) == 0:
if len(buf) == 0:
raise HubConnectionClosedError()
raise HubConnectionReadError("Unexpected end of TLS stream")
buf += part
num -= len(part)
if num < 0:
raise HubConnectionReadError("TLS receive_some returned too much")
return buf
async def _recv_raw_msg(self) -> bytes:
len_bytes = await self._read(4)
chunk_size, = struct.unpack('!I', len_bytes)
chunk = await self._read(chunk_size)
if chunk is None:
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
return chunk
async def recv_msg(self) -> capport.comm.message.Message:
try:
chunk = await self._recv_raw_msg()
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
msg = capport.comm.message.Message()
msg.ParseFromString(chunk)
return msg
async def _send_raw(self, chunk: bytes) -> None:
try:
await self._ssl.send_all(chunk)
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
async def send_msg(self, msg: capport.comm.message.Message):
chunk = msg.SerializeToString(deterministic=True)
chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size)
chunk = len_bytes + chunk
await self._send_raw(chunk)
async def aclose(self):
try:
await self._ssl.aclose()
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
raise ConnectionError(e) from None
class Connection:
PING_INTERVAL = 10
RECEIVE_TIMEOUT = 15
SEND_TIMEOUT = 5
def __init__(self, hub: Hub, channel: Channel, peer: capport.comm.message.Hello):
self._channel = channel
self._hub = hub
tx: trio.MemorySendChannel
rx: trio.MemoryReceiveChannel
(tx, rx) = trio.open_memory_channel(64)
self._pending_tx = tx
self._pending_rx = rx
self.peer: capport.comm.message.Hello = peer
self.peer_id: uuid.UUID = uuid.UUID(bytes=peer.instance_id)
self.closed = trio.Event() # set by Hub._lost_peer
_logger.debug(f"{self._channel}: authenticated -> {self.peer_id}")
async def _sender(self, cancel_scope: trio.CancelScope) -> None:
try:
msg: typing.Optional[capport.comm.message.Message]
while True:
msg = None
# make sure we send something every PING_INTERVAL
with trio.move_on_after(self.PING_INTERVAL):
msg = await self._pending_rx.receive()
# if send blocks too long we're in trouble
with trio.fail_after(self.SEND_TIMEOUT):
if msg:
await self._channel.send_msg(msg)
else:
await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message())
except trio.TooSlowError:
_logger.warning(f"{self._channel}: send timed out")
except ConnectionError as e:
_logger.warning(f"{self._channel}: failed sending: {e!r}")
except Exception:
_logger.exception(f"{self._channel}: failed sending")
finally:
cancel_scope.cancel()
async def _receive(self, cancel_scope: trio.CancelScope) -> None:
try:
while True:
try:
with trio.fail_after(self.RECEIVE_TIMEOUT):
msg = await self._channel.recv_msg()
except (HubConnectionClosedError, ConnectionResetError):
return
except trio.TooSlowError:
_logger.warning(f"{self._channel}: receive timed out")
return
await self._hub._received_msg(self.peer_id, msg)
except ConnectionError as e:
_logger.warning(f"{self._channel}: failed receiving: {e!r}")
except Exception:
_logger.exception(f"{self._channel}: failed receiving")
finally:
cancel_scope.cancel()
async def _inner_run(self) -> None:
if self.peer_id == self._hub._instance_id:
# connected to ourself, don't need that
raise LoopbackConnectionError()
async with trio.open_nursery() as nursery:
nursery.start_soon(self._sender, nursery.cancel_scope)
# be nice and wait for new_peer beforce receiving messages
# (won't work on failover to a second connection)
await nursery.start(self._hub._new_peer, self.peer_id, self)
nursery.start_soon(self._receive, nursery.cancel_scope)
async def send_msg(self, *msgs: capport.comm.message.Message):
try:
for msg in msgs:
await self._pending_tx.send(msg)
except trio.ClosedResourceError:
pass
async def _run(self) -> None:
try:
await self._inner_run()
finally:
_logger.debug(f"{self._channel}: finished message handling")
# basic (non-async) cleanup
self._hub._lost_peer(self.peer_id, self)
self._pending_tx.close()
self._pending_rx.close()
# allow 3 seconds for proper cleanup
with trio.CancelScope(shield=True, deadline=trio.current_time() + 3):
try:
await self._channel.aclose()
except OSError:
pass
@staticmethod
async def run(hub: Hub, transport_stream, server_side: bool) -> None:
channel = Channel(hub, transport_stream, server_side)
try:
with trio.fail_after(5):
peer = await channel.do_handshake()
except trio.TooSlowError:
_logger.warning("Handshake timed out")
return
conn = Connection(hub, channel, peer)
await conn._run()
class ControllerConn:
def __init__(self, hub: Hub, hostname: str):
self._hub = hub
self.hostname = hostname
self.loopback = False
async def _connect(self):
_logger.info(f"Connecting to controller at {self.hostname}")
with trio.fail_after(5):
try:
stream = await trio.open_tcp_stream(self.hostname, 5000)
except OSError as e:
_logger.warning(f"Failed to connect to controller at {self.hostname}: {e}")
return
try:
await Connection.run(self._hub, stream, server_side=False)
finally:
_logger.info(f"Connection to {self.hostname} closed")
async def run(self):
while True:
try:
await self._connect()
except LoopbackConnectionError:
_logger.debug(f"Connection to {self.hostname} reached ourself")
self.loopback = True
return
except trio.TooSlowError:
pass
# try again later
retry_splay = random.random() * 5
await trio.sleep(10 + retry_splay)
class HubApplication:
async def new_peer(self, *, peer_id: uuid.UUID) -> None:
_logger.info(f"New peer {peer_id}")
def lost_peer(self, *, peer_id: uuid.UUID) -> None:
_logger.warning(f"Lost peer {peer_id}")
async def received_unknown_message(self, *, from_peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
_logger.warning(f"Received from {from_peer_id}: {str(msg).strip()}")
async def received_mac_state(self, *, from_peer_id: uuid.UUID, states: capport.comm.message.MacStates) -> None:
if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}")
async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
if _logger.isEnabledFor(logging.DEBUG):
_logger.debug(f"Received new states from {from_peer_id}: {pending_updates}")
class Hub:
def __init__(self, config: Config, app: HubApplication, *, is_controller: bool) -> None:
self._config = config
self._instance_id = uuid.uuid4()
self._hostname = socket.getfqdn()
self._app = app
self._is_controller = is_controller
state_filename: str
if is_controller:
state_filename = config.database_file
else:
state_filename = ''
self.database = capport.database.Database(state_filename=state_filename)
self._anon_context = ssl.SSLContext()
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
# -> AECDH-AES256-SHA
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
self._controllers: dict[str, ControllerConn] = {}
self._established: dict[uuid.UUID, Connection] = {}
async def _accept(self, stream):
remotename = stream.socket.getpeername()
if isinstance(remotename, tuple) and len(remotename) == 2:
remote = f'[{remotename[0]}]:{remotename[1]}'
else:
remote = str(remotename)
try:
await Connection.run(self, stream, server_side=True)
except LoopbackConnectionError:
pass
except trio.TooSlowError:
pass
finally:
_logger.debug(f"Connection from {remote} closed")
async def _listen(self, task_status=trio.TASK_STATUS_IGNORED):
await trio.serve_tcp(self._accept, 5000, task_status=task_status)
async def run(self, *, task_status=trio.TASK_STATUS_IGNORED):
async with trio.open_nursery() as nursery:
await nursery.start(self.database.run)
if self._is_controller:
await nursery.start(self._listen)
for name in self._config.controllers:
conn = ControllerConn(self, name)
self._controllers[name] = conn
task_status.started()
for conn in self._controllers.values():
nursery.start_soon(conn.run)
await trio.sleep_forever()
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
m = hmac.new(self._config.comm_secret.encode('utf8'), digestmod=hashlib.sha256)
if server_side:
m.update(b'server$')
else:
m.update(b'client$')
m.update(ssl_binding)
return m.digest()
def _make_hello(self, ssl_binding: bytes, server_side: bool) -> capport.comm.message.Hello:
return capport.comm.message.Hello(
instance_id=self._instance_id.bytes,
hostname=self._hostname,
is_controller=self._is_controller,
authentication=self._calc_authentication(ssl_binding, server_side),
)
async def _sync_new_connection(self, peer_id: uuid.UUID, conn: Connection) -> None:
# send database (and all changes) to peers
await self.send(*self.database.serialize(), to=peer_id)
async def _new_peer(self, peer_id: uuid.UUID, conn: Connection, task_status=trio.TASK_STATUS_IGNORED) -> None:
have = self._established.get(peer_id, None)
if not have:
# peer unknown, "normal start"
# no "await" between get above and set here!!!
self._established[peer_id] = conn
# first wait for app to handle new peer
await self._app.new_peer(peer_id=peer_id)
task_status.started()
await self._sync_new_connection(peer_id, conn)
return
# peer already known - immediately allow receiving messages, then sync connection
task_status.started()
await self._sync_new_connection(peer_id, conn)
# now try to register connection for outgoing messages
while True:
# recheck whether peer is currently known (due to awaits since last get)
have = self._established.get(peer_id, None)
if have:
# already got a connection, nothing to do as long as it lives
await have.closed.wait()
else:
# make `conn` new outgoing connection for peer
# no "await" between get above and set here!!!
self._established[peer_id] = conn
await self._app.new_peer(peer_id=peer_id)
return
def _lost_peer(self, peer_id: uuid.UUID, conn: Connection):
have = self._established.get(peer_id, None)
lost = False
if have is conn:
lost = True
self._established.pop(peer_id)
conn.closed.set()
# only notify if this was the active connection
if lost:
# even when we failover to another connection we still need to resync
# as we don't know which messages might have got lost
# -> always trigger lost_peer
self._app.lost_peer(peer_id=peer_id)
async def _received_msg(self, peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
variant = msg.to_variant()
if isinstance(variant, capport.comm.message.Hello):
pass
elif isinstance(variant, capport.comm.message.AuthenticationResult):
pass
elif isinstance(variant, capport.comm.message.Ping):
pass
elif isinstance(variant, capport.comm.message.MacStates):
await self._app.received_mac_state(from_peer_id=peer_id, states=variant)
async with self.database.make_changes() as pu:
for state in variant.states:
pu.received_mac_state(state)
if pu:
# re-broadcast all received updates to all peers
await self.broadcast(*pu.serialized, exclude=peer_id)
await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu)
else:
await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg)
def peer_is_controller(self, peer_id: uuid.UUID) -> bool:
conn = self._established.get(peer_id)
if conn:
return conn.peer.is_controller
return False
async def send(self, *msgs: capport.comm.message.Message, to: uuid.UUID):
conn = self._established.get(to)
if conn:
await conn.send_msg(*msgs)
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID] = None):
async with trio.open_nursery() as nursery:
for peer_id, conn in self._established.items():
if peer_id == exclude:
continue
nursery.start_soon(conn.send_msg, *msgs)

View File

@ -1,36 +0,0 @@
from __future__ import annotations
import typing
from .protobuf import message_pb2
def _message_to_variant(self: message_pb2.Message) -> typing.Any:
variant_name = self.WhichOneof('oneof')
if variant_name:
return getattr(self, variant_name)
return None
def _make_to_message(oneof_field):
def to_message(self) -> message_pb2.Message:
msg = message_pb2.Message(**{oneof_field: self})
return msg
return to_message
def _monkey_patch():
g = globals()
g['Message'] = message_pb2.Message
message_pb2.Message.to_variant = _message_to_variant
for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields:
type_name = field.message_type.name
field_type = getattr(message_pb2, type_name)
field_type.to_message = _make_to_message(field.name)
g[type_name] = field_type
# also re-exports all message types
_monkey_patch()
# not a variant of Message, still re-export
MacState = message_pb2.MacState

View File

@ -1,93 +0,0 @@
import google.protobuf.message
import typing
# manually maintained typehints for protobuf created (and monkey-patched) types
class Message(google.protobuf.message.Message):
hello: Hello
authentication_result: AuthenticationResult
ping: Ping
mac_states: MacStates
def __init__(
self,
*,
hello: typing.Optional[Hello]=None,
authentication_result: typing.Optional[AuthenticationResult]=None,
ping: typing.Optional[Ping]=None,
mac_states: typing.Optional[MacStates]=None,
) -> None: ...
def to_variant(self) -> typing.Union[Hello, AuthenticationResult, Ping, MacStates]: ...
class Hello(google.protobuf.message.Message):
instance_id: bytes
hostname: str
is_controller: bool
authentication: bytes
def __init__(
self,
*,
instance_id: bytes=b'',
hostname: str='',
is_controller: bool=False,
authentication: bytes=b'',
) -> None: ...
def to_message(self) -> Message: ...
class AuthenticationResult(google.protobuf.message.Message):
success: bool
def __init__(
self,
*,
success: bool=False,
) -> None: ...
def to_message(self) -> Message: ...
class Ping(google.protobuf.message.Message):
payload: bytes
def __init__(
self,
*,
payload: bytes=b'',
) -> None: ...
def to_message(self) -> Message: ...
class MacStates(google.protobuf.message.Message):
states: typing.List[MacState]
def __init__(
self,
*,
states: typing.List[MacState]=[],
) -> None: ...
def to_message(self) -> Message: ...
class MacState(google.protobuf.message.Message):
mac_address: bytes
last_change: int # Seconds of UTC time since epoch
allow_until: int # Seconds of UTC time since epoch
allowed: bool
def __init__(
self,
*,
mac_address: bytes=b'',
last_change: int=0,
allow_until: int=0,
allowed: bool=False,
) -> None: ...

View File

@ -1,35 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: message.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x07\x63\x61pport\"\xbc\x01\n\x07Message\x12\x1f\n\x05hello\x18\x01 \x01(\x0b\x32\x0e.capport.HelloH\x00\x12>\n\x15\x61uthentication_result\x18\x02 \x01(\x0b\x32\x1d.capport.AuthenticationResultH\x00\x12\x1d\n\x04ping\x18\x03 \x01(\x0b\x32\r.capport.PingH\x00\x12(\n\nmac_states\x18\n \x01(\x0b\x32\x12.capport.MacStatesH\x00\x42\x07\n\x05oneof\"]\n\x05Hello\x12\x13\n\x0binstance_id\x18\x01 \x01(\x0c\x12\x10\n\x08hostname\x18\x02 \x01(\t\x12\x15\n\ris_controller\x18\x03 \x01(\x08\x12\x16\n\x0e\x61uthentication\x18\x04 \x01(\x0c\"\'\n\x14\x41uthenticationResult\x12\x0f\n\x07success\x18\x01 \x01(\x08\"\x17\n\x04Ping\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\".\n\tMacStates\x12!\n\x06states\x18\x01 \x03(\x0b\x32\x11.capport.MacState\"Z\n\x08MacState\x12\x13\n\x0bmac_address\x18\x01 \x01(\x0c\x12\x13\n\x0blast_change\x18\x02 \x01(\x03\x12\x13\n\x0b\x61llow_until\x18\x03 \x01(\x03\x12\x0f\n\x07\x61llowed\x18\x04 \x01(\x08\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'message_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_MESSAGE._serialized_start=27
_MESSAGE._serialized_end=215
_HELLO._serialized_start=217
_HELLO._serialized_end=310
_AUTHENTICATIONRESULT._serialized_start=312
_AUTHENTICATIONRESULT._serialized_end=351
_PING._serialized_start=353
_PING._serialized_end=376
_MACSTATES._serialized_start=378
_MACSTATES._serialized_end=424
_MACSTATE._serialized_start=426
_MACSTATE._serialized_end=516
# @@protoc_insertion_point(module_scope)

View File

@ -1,55 +0,0 @@
from __future__ import annotations
import dataclasses
import os.path
import sys
import typing
import yaml
_cached_config: typing.Optional[Config] = None
@dataclasses.dataclass
class Config:
controllers: typing.List[str]
server_names: typing.List[str]
comm_secret: str
cookie_secret: str
venue_info_url: typing.Optional[str]
session_timeout: int # in seconds
database_file: str # empty str: disable database
debug: bool
@staticmethod
def load_default_once() -> Config:
global _cached_config
if not _cached_config:
_cached_config = Config.load()
return _cached_config
@staticmethod
def load(filename: typing.Optional[str] = None) -> Config:
if filename is None:
if len(sys.argv) > 0:
for name in sys.argv[1:]:
if os.path.exists(name):
return Config.load(name)
for name in ('capport.yaml', '/etc/capport.yaml'):
if os.path.exists(name):
return Config.load(name)
raise RuntimeError("Missing config file")
with open(filename) as f:
data = yaml.safe_load(f)
controllers = list(map(str, data['controllers']))
return Config(
controllers=controllers,
server_names=data.get('server-names', []),
comm_secret=str(data['comm-secret']),
cookie_secret=str(data['cookie-secret']),
venue_info_url=str(data.get('venue-info-url')),
session_timeout=data.get('session-timeout', 3600),
database_file=str(data['database-file']) if 'database-file' in data else 'capport.state',
debug=data.get('debug', False)
)

View File

@ -1,70 +0,0 @@
from __future__ import annotations
import typing
import uuid
import trio
import capport.comm.hub
import capport.comm.message
import capport.config
import capport.database
import capport.utils.cli
import capport.utils.nft_set
from capport import cptypes
from capport.utils.sd_notify import open_sdnotify
class ControlApp(capport.comm.hub.HubApplication):
hub: capport.comm.hub.Hub
def __init__(self) -> None:
super().__init__()
self.nft_set = capport.utils.nft_set.NftSet()
async def mac_states_changed(
self,
*,
from_peer_id: uuid.UUID,
pending_updates: capport.database.PendingUpdates,
) -> None:
self.apply_db_entries(pending_updates.changes())
def apply_db_entries(
self,
entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]],
) -> None:
# deploy changes to netfilter set
inserts = []
removals = []
now = cptypes.Timestamp.now()
for mac, state in entries:
rem = state.allowed_remaining(now)
if rem > 0:
inserts.append((mac, rem))
else:
removals.append(mac)
self.nft_set.bulk_insert(inserts)
self.nft_set.bulk_remove(removals)
async def amain(config: capport.config.Config) -> None:
async with open_sdnotify() as sn:
app = ControlApp()
hub = capport.comm.hub.Hub(config=config, app=app, is_controller=True)
app.hub = hub
async with trio.open_nursery() as nursery:
# hub.run loads the statefile from disk before signalling it was "started"
await nursery.start(hub.run)
await sn.send('READY=1', 'STATUS=Deploying initial entries to nftables set')
app.apply_db_entries(hub.database.entries())
await sn.send('STATUS=Kernel fully synchronized')
def main() -> None:
config = capport.config.Config.load_default_once()
capport.utils.cli.init_logger(config)
try:
trio.run(amain, config)
except (KeyboardInterrupt, InterruptedError):
print()

View File

@ -1,112 +0,0 @@
from __future__ import annotations
import dataclasses
import datetime
import ipaddress
import json
import time
import typing
import quart
import capport.utils.zoneinfo
if typing.TYPE_CHECKING:
from .config import Config
IPAddress = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@dataclasses.dataclass(frozen=True)
class MacAddress:
raw: bytes
def __str__(self) -> str:
return self.raw.hex(':')
def __repr__(self) -> str:
return repr(str(self))
@staticmethod
def parse(s: str) -> MacAddress:
return MacAddress(bytes.fromhex(s.replace(':', '')))
@dataclasses.dataclass(frozen=True, order=True)
class Timestamp:
epoch: int
def __str__(self) -> str:
try:
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
return ts.isoformat(sep=' ')
except OSError:
return f'epoch@{self.epoch}'
def __repr__(self) -> str:
return repr(str(self))
@staticmethod
def now() -> Timestamp:
now = int(time.time())
return Timestamp(epoch=now)
@staticmethod
def from_protobuf(epoch: int) -> typing.Optional[Timestamp]:
if epoch:
return Timestamp(epoch=epoch)
return None
@dataclasses.dataclass
class MacPublicState:
address: IPAddress
mac: typing.Optional[MacAddress]
allowed_remaining: int
@staticmethod
def from_missing_mac(address: IPAddress) -> MacPublicState:
return MacPublicState(
address=address,
mac=None,
allowed_remaining=0,
)
@property
def allowed(self) -> bool:
return self.allowed_remaining > 0
@property
def captive(self) -> bool:
return not self.allowed
@property
def allowed_remaining_duration(self) -> str:
mm, ss = divmod(self.allowed_remaining, 60)
hh, mm = divmod(mm, 60)
return f'{hh}:{mm:02}:{ss:02}'
@property
def allowed_until(self) -> typing.Optional[datetime.datetime]:
zone = capport.utils.zoneinfo.get_local_timezone()
now = datetime.datetime.now(zone).replace(microsecond=0)
return now + datetime.timedelta(seconds=self.allowed_remaining)
def to_rfc8908(self, config: Config) -> quart.Response:
response: dict[str, typing.Any] = {
'user-portal-url': quart.url_for('index', _external=True),
}
if config.venue_info_url:
response['venue-info-url'] = config.venue_info_url
if self.captive:
response['captive'] = True
else:
response['captive'] = False
response['seconds-remaining'] = self.allowed_remaining
response['can-extend-session'] = True
return quart.Response(
json.dumps(response),
headers={'Cache-Control': 'private'},
content_type='application/captive+json',
)

View File

@ -1,389 +0,0 @@
from __future__ import annotations
import contextlib
import dataclasses
import logging
import os
import struct
import typing
import google.protobuf.message
import trio
import capport.comm.message
from capport import cptypes
_logger = logging.getLogger(__name__)
@dataclasses.dataclass
class MacEntry:
# entry can be removed if last_change was some time ago and allow_until wasn't set
# or got reached.
WAIT_LAST_CHANGE_SECONDS = 60
WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10
# last_change: timestamp of last change (sent by system initiating the change)
last_change: cptypes.Timestamp
# only if allowed is true and allow_until is set the device can communicate with the internet
# allow_until must not go backwards (and not get unset)
allow_until: typing.Optional[cptypes.Timestamp]
allowed: bool
@staticmethod
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
if len(msg.mac_address) < 6:
raise Exception("Invalid MacState: mac_address too short")
addr = cptypes.MacAddress(raw=msg.mac_address)
last_change = cptypes.Timestamp.from_protobuf(msg.last_change)
if not last_change:
raise Exception(f"Invalid MacState[{addr}]: missing last_change")
allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until)
return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed))
def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState:
allow_until = 0
if self.allow_until:
allow_until = self.allow_until.epoch
return capport.comm.message.MacState(
mac_address=addr.raw,
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def as_json(self) -> dict:
allow_until = None
if self.allow_until:
allow_until = self.allow_until.epoch
return dict(
last_change=self.last_change.epoch,
allow_until=allow_until,
allowed=self.allowed,
)
def merge(self, new: MacEntry) -> bool:
changed = False
if new.last_change > self.last_change:
changed = True
self.last_change = new.last_change
self.allowed = new.allowed
elif new.last_change == self.last_change:
# same last_change: set allowed if one allowed
if new.allowed and not self.allowed:
changed = True
self.allowed = True
# set allow_until to max of both
if new.allow_until: # if not set nothing to change in local data
if not self.allow_until or self.allow_until < new.allow_until:
changed = True
self.allow_until = new.allow_until
return changed
def timeout(self) -> cptypes.Timestamp:
elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS
if self.allow_until:
eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS
if eau > elc:
return cptypes.Timestamp(epoch=eau)
return cptypes.Timestamp(epoch=elc)
# returns 0 if not allowed
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int:
if not self.allowed or not self.allow_until:
return 0
if not now:
now = cptypes.Timestamp.now()
assert self.allow_until
return max(self.allow_until.epoch - now.epoch, 0)
def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool:
if not now:
now = cptypes.Timestamp.now()
return now.epoch > self.timeout().epoch
# might use this to serialize into file - don't need Message variant there
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
result: typing.List[capport.comm.message.MacStates] = []
current = capport.comm.message.MacStates()
for addr, entry in macs.items():
state = entry.to_state(addr)
current.states.append(state)
if len(current.states) >= 1024: # split into messages with 1024 states
result.append(current)
current = capport.comm.message.MacStates()
if len(current.states):
result.append(current)
return result
def _serialize_mac_states_as_messages(
macs: dict[cptypes.MacAddress, MacEntry],
) -> typing.List[capport.comm.message.Message]:
return [s.to_message() for s in _serialize_mac_states(macs)]
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
chunk = states.SerializeToString(deterministic=True)
chunk_size = len(chunk)
len_bytes = struct.pack('!I', chunk_size)
return len_bytes + chunk
class NotReadyYet(Exception):
def __init__(self, msg: str, wait: int):
self.wait = wait # seconds to wait
super().__init__(msg)
class Database:
def __init__(self, state_filename: str = None):
self._macs: dict[cptypes.MacAddress, MacEntry] = {}
self._state_filename = state_filename
self._changed_since_last_cleanup = False
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]] = None
@contextlib.asynccontextmanager
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
pu = PendingUpdates(self)
pu._closed = False
yield pu
pu._finish()
if pu:
self._changed_since_last_cleanup = True
if self._send_changes:
for state in pu.serialized_states:
await self._send_changes.send(state)
def _drop_outdated(self) -> None:
done = False
while not done:
depr: typing.Set[cptypes.MacAddress] = set()
now = cptypes.Timestamp.now()
done = True
for mac, entry in self._macs.items():
if entry.outdated(now):
depr.add(mac)
if len(depr) >= 1024:
# clear entries found so far, then try again
done = False
break
if depr:
self._changed_since_last_cleanup = True
for mac in depr:
del self._macs[mac]
async def run(self, task_status=trio.TASK_STATUS_IGNORED):
if self._state_filename:
await self._load_statefile()
task_status.started()
async with trio.open_nursery() as nursery:
if self._state_filename:
nursery.start_soon(self._run_statefile)
while True:
await trio.sleep(300) # cleanup every 5 minutes
_logger.debug("Running database cleanup")
self._drop_outdated()
if self._changed_since_last_cleanup:
self._changed_since_last_cleanup = False
if self._send_changes:
states = _serialize_mac_states(self._macs)
# trigger a resync
await self._send_changes.send(states)
# for initial handling of all data
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]:
return list(self._macs.items())
# for initial sync with new peer
def serialize(self) -> typing.List[capport.comm.message.Message]:
return _serialize_mac_states_as_messages(self._macs)
def as_json(self) -> dict:
return {
str(addr): entry.as_json()
for addr, entry in self._macs.items()
}
async def _run_statefile(self) -> None:
rx: trio.MemoryReceiveChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
tx: trio.MemorySendChannel[typing.Union[
capport.comm.message.MacStates,
typing.List[capport.comm.message.MacStates],
]]
tx, rx = trio.open_memory_channel(64)
self._send_changes = tx
assert self._state_filename
filename: str = self._state_filename
tmp_filename = f'{filename}.new-{os.getpid()}'
async def resync(all_states: typing.List[capport.comm.message.MacStates]):
try:
async with await trio.open_file(tmp_filename, 'xb') as tf:
for states in all_states:
await tf.write(_states_to_chunk(states))
os.rename(tmp_filename, filename)
finally:
if os.path.exists(tmp_filename):
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
os.unlink(tmp_filename)
try:
while True:
async with await trio.open_file(filename, 'ab', buffering=0) as sf:
while True:
update = await rx.receive()
if isinstance(update, list):
break
await sf.write(_states_to_chunk(update))
# got a "list" update - i.e. a resync
with trio.CancelScope(shield=True):
await resync(update)
# now reopen normal statefile and continue appending updates
except trio.Cancelled:
_logger.info('Final sync to disk')
with trio.CancelScope(shield=True):
await resync(_serialize_mac_states(self._macs))
_logger.info('Final sync to disk done')
async def _load_statefile(self):
if not os.path.exists(self._state_filename):
return
_logger.info("Loading statefile")
# we're going to ignore changes from loading the file
pu = PendingUpdates(self)
pu._closed = False
async with await trio.open_file(self._state_filename, 'rb') as sf:
while True:
try:
len_bytes = await sf.read(4)
if not len_bytes:
return
if len(len_bytes) < 4:
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
return
chunk_size, = struct.unpack('!I', len_bytes)
chunk = await sf.read(chunk_size)
except IOError as e:
_logger.error(f"Failed to read next chunk from statefile: {e}")
return
try:
states = capport.comm.message.MacStates()
states.ParseFromString(chunk)
except google.protobuf.message.DecodeError as e:
_logger.error(f"Failed to decode chunk from statefile, trying next one: {e}")
continue
for state in states.states:
errors = 0
try:
pu.received_mac_state(state)
except Exception as e:
errors += 1
if errors < 5:
_logger.error(f'Failed to handle state: {e}')
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
entry = self._macs.get(mac)
if entry:
allowed_remaining = entry.allowed_remaining()
else:
allowed_remaining = 0
return cptypes.MacPublicState(
address=address,
mac=mac,
allowed_remaining=allowed_remaining,
)
class PendingUpdates:
def __init__(self, database: Database):
self._changes: dict[cptypes.MacAddress, MacEntry] = {}
self._database = database
self._closed = True
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
self._serialized: typing.List[capport.comm.message.Message] = []
def __bool__(self) -> bool:
return bool(self._changes)
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]:
return self._changes.items()
@property
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]:
assert self._closed
return self._serialized_states
@property
def serialized(self) -> typing.List[capport.comm.message.Message]:
assert self._closed
return self._serialized
def _finish(self):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
self._closed = True
self._serialized_states = _serialize_mac_states(self._changes)
self._serialized = [s.to_message() for s in self._serialized_states]
def received_mac_state(self, state: capport.comm.message.MacState):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
(addr, new_entry) = MacEntry.parse_state(state)
old_entry = self._database._macs.get(addr)
if not old_entry:
# only redistribute if not outdated
if not new_entry.outdated():
self._database._macs[addr] = new_entry
self._changes[addr] = new_entry
elif old_entry.merge(new_entry):
if old_entry.outdated():
# remove local entry, but still redistribute
self._database._macs.pop(addr)
self._changes[addr] = old_entry
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float = 0.8):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
now = cptypes.Timestamp.now()
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
entry = self._database._macs.get(mac)
if not entry:
self._database._macs[mac] = new_entry
self._changes[mac] = new_entry
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
# too much time left on clock, not renewing session
return
elif entry.merge(new_entry):
self._changes[mac] = entry
elif not entry.allowed_remaining() > 0:
# entry should have been updated - can only fail due to `now < entry.last_change`
# i.e. out of sync clocks
wait = entry.last_change.epoch - now.epoch
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
def logout(self, mac: cptypes.MacAddress):
if self._closed:
raise Exception("Can't change closed PendingUpdates")
now = cptypes.Timestamp.now()
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
entry = self._database._macs.get(mac)
if entry:
if entry.merge(new_entry):
self._changes[mac] = entry
elif entry.allowed_remaining() > 0:
# still logged in. can only happen with `now <= entry.last_change`
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
wait = entry.last_change.epoch - now.epoch + 1
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)

View File

@ -1,104 +0,0 @@
from __future__ import annotations
import ipaddress
import sys
import typing
import trio
import capport.utils.ipneigh
import capport.utils.nft_set
from . import cptypes
def print_metric(
name: str,
mtype: str,
value,
*,
now: typing.Optional[int] = None,
help: typing.Optional[str] = None,
):
# no labels in our names for now, always print help and type
if help:
print(f"# HELP {name} {help}")
print(f"# TYPE {name} {mtype}")
if now:
print(f"{name} {value} {now}")
else:
print(f"{name} {value}")
async def amain(client_ifname: str):
ns = capport.utils.nft_set.NftSet()
captive_allowed_entries: typing.Set[cptypes.MacAddress] = {
entry["mac"] for entry in ns.list()
}
seen_allowed_entries: typing.Set[cptypes.MacAddress] = set()
total_ipv4 = 0
total_ipv6 = 0
unique_clients = set()
unique_ipv4 = set()
unique_ipv6 = set()
async with capport.utils.ipneigh.connect() as ipn:
ipn.ip.strict_check = True
async for (mac, addr) in ipn.dump_neighbors(client_ifname):
if mac in captive_allowed_entries:
seen_allowed_entries.add(mac)
unique_clients.add(mac)
if isinstance(addr, ipaddress.IPv4Address):
total_ipv4 += 1
unique_ipv4.add(mac)
else:
total_ipv6 += 1
unique_ipv6.add(mac)
print_metric(
"capport_allowed_macs",
"gauge",
len(captive_allowed_entries),
help="Number of allowed client mac addresses",
)
print_metric(
"capport_allowed_neigh_macs",
"gauge",
len(seen_allowed_entries),
help="Number of allowed client mac addresses seen in neighbor cache",
)
print_metric(
"capport_unique",
"gauge",
len(unique_clients),
help="Number of clients (mac addresses) in client network seen in neighbor cache",
)
print_metric(
"capport_unique_ipv4",
"gauge",
len(unique_ipv4),
help="Number of IPv4 clients (unique per mac) in client network seen in neighbor cache",
)
print_metric(
"capport_unique_ipv6",
"gauge",
len(unique_ipv6),
help="Number of IPv6 clients (unique per mac) in client network seen in neighbor cache",
)
print_metric(
"capport_total_ipv4",
"gauge",
total_ipv4,
help="Number of IPv4 addresses seen in neighbor cache",
)
print_metric(
"capport_total_ipv6",
"gauge",
total_ipv6,
help="Number of IPv6 addresses seen in neighbor cache",
)
def main():
if len(sys.argv) != 2:
print("Need name of client interface as argument")
sys.exit(1)
trio.run(amain, sys.argv[1])

View File

@ -1,17 +0,0 @@
from __future__ import annotations
import logging
import capport.config
def init_logger(config: capport.config.Config):
loglevel = logging.INFO
if config.debug:
loglevel = logging.DEBUG
logging.basicConfig(
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S %z]',
level=loglevel,
)
logging.getLogger('hypercorn').propagate = False

View File

@ -1,89 +0,0 @@
from __future__ import annotations
import contextlib
import errno
import ipaddress
import socket
import typing
import pyroute2.iproute.linux # type: ignore
import pyroute2.netlink.exceptions # type: ignore
import pyroute2.netlink.rtnl # type: ignore
import pyroute2.netlink.rtnl.ndmsg # type: ignore
from capport import cptypes
@contextlib.asynccontextmanager
async def connect():
yield NeighborController()
# TODO: run blocking iproute calls in a different thread?
class NeighborController:
def __init__(self):
self.ip = pyroute2.iproute.linux.IPRoute()
async def get_neighbor(
self,
address: cptypes.IPAddress,
*,
index: int = 0, # interface index
flags: int = 0,
) -> typing.Optional[pyroute2.iproute.linux.ndmsg.ndmsg]:
if not index:
route = await self.get_route(address)
if route is None:
return None
index = route.get_attr(route.name2nla('oif'))
try:
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
except pyroute2.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
raise
async def get_neighbor_mac(
self,
address: cptypes.IPAddress,
*,
index: int = 0, # interface index
flags: int = 0,
) -> typing.Optional[cptypes.MacAddress]:
neigh = await self.get_neighbor(address, index=index, flags=flags)
if neigh is None:
return None
mac = neigh.get_attr(neigh.name2nla('lladdr'))
if mac is None:
return None
return cptypes.MacAddress.parse(mac)
async def get_route(
self,
address: cptypes.IPAddress,
) -> typing.Optional[pyroute2.iproute.linux.rtmsg]:
try:
return self.ip.route('get', dst=str(address))[0]
except pyroute2.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
raise
async def dump_neighbors(
self,
interface: str,
) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
ifindex = socket.if_nametoindex(interface)
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
for family in (socket.AF_INET, socket.AF_INET6):
for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family):
if neigh['ndm_type'] != unicast_num:
continue
mac = neigh.get_attr(neigh.name2nla('lladdr'))
if not mac:
continue
dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla('dst')))
if dst.is_link_local:
continue
yield (cptypes.MacAddress.parse(mac), dst)

View File

@ -1,188 +0,0 @@
from __future__ import annotations
import typing
import pyroute2.netlink # type: ignore
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
from capport import cptypes
from .nft_socket import NFTSocket
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]:
# to seconds
if msecs is None:
return None
return msecs / 1000.0
class NftSet:
def __init__(self):
self._socket = NFTSocket()
self._socket.bind()
@staticmethod
def _set_elem(
mac: cptypes.MacAddress,
timeout: typing.Optional[typing.Union[int, float]] = None,
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict(
NFTA_DATA_VALUE=mac.raw,
),
}
if timeout:
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
return attrs
def _bulk_insert(
self,
entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]],
) -> None:
ser_entries = [
self._set_elem(mac)
for mac, _timeout in entries
]
ser_entries_with_timeout = [
self._set_elem(mac, timeout)
for mac, timeout in entries
]
with self._socket.begin() as tx:
# create doesn't affect existing elements, so:
# make sure entries exists
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# drop entries (would fail if it doesn't exist)
tx.put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# now create entries with new timeout value
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
),
)
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
# limit chunk size
while len(entries) > 0:
self._bulk_insert(entries[:1024])
entries = entries[1024:]
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None:
self.bulk_insert([(mac, timeout)])
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
ser_entries = [
self._set_elem(mac)
for mac in entries
]
with self._socket.begin() as tx:
# make sure entries exists
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# drop entries (would fail if it doesn't exist)
tx.put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
def bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
# limit chunk size
while len(entries) > 0:
self._bulk_remove(entries[:1024])
entries = entries[1024:]
def remove(self, mac: cptypes.MacAddress) -> None:
self.bulk_remove([mac])
def list(self) -> list:
responses: typing.Iterator[_nftsocket.nft_set_elem_list_msg]
responses = self._socket.nft_dump(
_nftsocket.NFT_MSG_GETSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
)
return [
{
'mac': cptypes.MacAddress(
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
),
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
}
for response in responses
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
]
def flush(self) -> None:
self._socket.nft_put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
)
def create(self):
with self._socket.begin() as tx:
tx.put(
_nftsocket.NFT_MSG_NEWTABLE,
pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_TABLE_NAME='captive_mark',
),
)
tx.put(
_nftsocket.NFT_MSG_NEWSET,
pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_NAME='allowed',
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
NFTA_SET_KEY_LEN=6, # length of key: mac address
NFTA_SET_ID=1, # kernel seems to need a set id unique per transaction
),
)

View File

@ -1,189 +0,0 @@
from __future__ import annotations
import contextlib
import typing
import pyroute2.netlink # type: ignore
import pyroute2.netlink.nlsocket # type: ignore
from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES # type: ignore
from pyroute2.netlink.nfnetlink import nfgen_msg # type: ignore
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base)
# nft uses NESTED for those.. lets do the same
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pyroute2.netlink.NLA_F_NESTED
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pyroute2.netlink.NLA_F_NESTED
# nftable lists always use `1` as list element attr type
_nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM
def _monkey_patch_pyroute2():
import pyroute2.netlink
# overwrite setdefault on nlmsg_base class hierarchy
_orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
def _nlmsg_base__setvalue(self, value):
if not self.header or not self['header'] or not isinstance(value, dict):
return _orig_setvalue(self, value)
# merge headers instead of overwriting them
header = value.pop('header', {})
res = _orig_setvalue(self, value)
self['header'].update(header)
return res
def overwrite_methods(cls: typing.Type) -> None:
if cls.setvalue is _orig_setvalue:
cls.setvalue = _nlmsg_base__setvalue
for subcls in cls.__subclasses__():
overwrite_methods(subcls)
overwrite_methods(pyroute2.netlink.nlmsg_base)
_monkey_patch_pyroute2()
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
msg = msg_class()
for key, value in header.items():
msg['header'][key] = value
for key, value in fields.items():
msg[key] = value
if attrs:
attr_list = msg['attrs']
r_nla_map = msg_class._nlmsg_base__r_nla_map
for key, value in attrs.items():
if msg_class.prefix:
key = msg_class.name2nla(key)
prime = r_nla_map[key]
nla_class = prime['class']
if issubclass(nla_class, pyroute2.netlink.nla):
# support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']:
value = [
_build(nla_class, attrs=elem)
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
else elem
for elem in value
]
elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict):
value = _build(nla_class, attrs=value)
attr_list.append([key, value])
return msg
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None:
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
policy = {
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
for (x, y) in self.policy.items()
}
self.register_policy(policy)
@contextlib.contextmanager
def begin(self) -> typing.Generator[NFTTransaction, None, None]:
tx = NFTTransaction(socket=self)
try:
yield tx
# autocommit when no exception was raised
# (only commits if it wasn't aborted)
tx.autocommit()
finally:
# abort does nothing if commit went through
tx.abort()
def nft_put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
with self.begin() as tx:
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
def nft_dump(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_flags |= pyroute2.netlink.NLM_F_DUMP
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pyroute2.netlink.NLM_F_REQUEST
msg = _build(msg_class, attrs=attrs, **fields)
return self.nlm_request(msg, msg_type, msg_flags)
class NFTTransaction:
def __init__(self, socket: NFTSocket) -> None:
self._socket = socket
self._closed = False
self._msgs: list[nfgen_msg] = [
# batch begin message
_build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x10, # NFNL_MSG_BATCH_BEGIN
flags=pyroute2.netlink.NLM_F_REQUEST,
),
),
]
def abort(self) -> None:
"""
Aborts if transaction wasn't already committed or aborted
"""
if not self._closed:
self._closed = True
def autocommit(self) -> None:
"""
Commits if transaction wasn't already committed or aborted
"""
if self._closed:
return
self.commit()
def commit(self) -> None:
if self._closed:
raise Exception("Transaction already closed")
self._closed = True
if len(self._msgs) == 1:
# no inner messages were queued... not sending anything
return
# request ACK on the last message (before END)
self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
self._msgs.append(
# batch end message
_build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x11, # NFNL_MSG_BATCH_END
flags=pyroute2.netlink.NLM_F_REQUEST,
),
),
)
for _msg in self._socket.nlm_request_batch(self._msgs):
# we should see at most one ACK - real errors get raised anyway
pass
def _put(self, msg: nfgen_msg) -> None:
if self._closed:
raise Exception("Transaction already closed")
self._msgs.append(msg)
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!
header = dict(
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
flags=msg_flags,
)
msg = _build(msg_class, attrs=attrs, header=header, **fields)
self._put(msg)

View File

@ -1,70 +0,0 @@
from __future__ import annotations
import contextlib
import os
import socket
import typing
import trio
import trio.socket
def _check_watchdog_pid() -> bool:
wpid = os.environ.pop('WATCHDOG_PID', None)
if not wpid:
return True
return wpid == str(os.getpid())
@contextlib.asynccontextmanager
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
target = os.environ.pop('NOTIFY_SOCKET', None)
ns: typing.Optional[trio.socket.SocketType] = None
watchdog_usec: int = 0
if target:
if target.startswith('@'):
# Linux abstract namespace socket
target = '\0' + target[1:]
ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
await ns.connect(target)
watchdog_usec_s = os.environ.pop('WATCHDOG_USEC', None)
if _check_watchdog_pid() and watchdog_usec_s:
watchdog_usec = int(watchdog_usec_s)
try:
async with trio.open_nursery() as nursery:
sn = SdNotify(_ns=ns)
if watchdog_usec:
await nursery.start(sn._run_watchdog, watchdog_usec)
yield sn
# stop watchdoch
nursery.cancel_scope.cancel()
finally:
if ns:
ns.close()
class SdNotify:
def __init__(self, *, _ns: typing.Optional[trio.socket.SocketType]) -> None:
self._ns = _ns
def is_connected(self) -> bool:
return not (self._ns is None)
async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None:
assert self.is_connected(), "Watchdog can't run without socket"
await self.send('WATCHDOG=1')
task_status.started()
# send every half of the watchdog timeout
interval = (watchdog_usec/1e6) / 2.0
while True:
await trio.sleep(interval)
await self.send('WATCHDOG=1')
async def send(self, *msg: str) -> None:
if not self.is_connected():
return
dgram = '\n'.join(msg).encode('utf-8')
assert self._ns, "not connected" # checked above
sent = await self._ns.send(dgram)
if sent != len(dgram):
raise OSError("Sent incomplete datagram to NOTIFY_SOCKET")

View File

@ -1,19 +0,0 @@
from __future__ import annotations
import typing
import zoneinfo
_zoneinfo: typing.Optional[zoneinfo.ZoneInfo] = None
def get_local_timezone():
global _zoneinfo
if not _zoneinfo:
try:
with open('/etc/timezone') as f:
key = f.readline().strip()
_zoneinfo = zoneinfo.ZoneInfo(key)
except (OSError, zoneinfo.ZoneInfoNotFoundError):
_zoneinfo = zoneinfo.ZoneInfo('UTC')
return _zoneinfo

View File

@ -1,8 +0,0 @@
#!/bin/bash
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
exec ./venv/bin/capport-webui "$@"

View File

@ -1,8 +0,0 @@
#!/bin/bash
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
exec ./venv/bin/capport-control "$@"

View File

@ -1,29 +0,0 @@
#!/bin/sh
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
instance=$1
ifname=$2
if [ -z "${instance}" -o -z "${ifname}" ]; then
echo >&2 "Syntax: $0 instancename clientifname"
exit 1
fi
targetname="/var/lib/prometheus/node-exporter/capport-${instance}.prom"
tmpname="${targetname}.$$"
if [ -f "/run/netns/${instance}" ];then
_run_in_ns="/usr/sbin/ns-enter ${instance} -- "
else
_run_in_ns=""
fi
if ${_run_in_ns} ${base}/stats.sh "${ifname}" > "${tmpname}"; then
mv "${tmpname}" "${targetname}"
else
rm "${tmpname}"
exit 1
fi

View File

@ -1,8 +0,0 @@
#!/bin/bash
set -e
base=$(dirname "$(readlink -f "$0")")
cd "${base}"
exec ./venv/bin/capport-stats "$@"