document new repo URI exclusively
This commit is contained in:
parent
14a477cf4e
commit
b7232c4a5b
11
.flake8
11
.flake8
@ -1,11 +0,0 @@
|
||||
[flake8]
|
||||
# E266 too many leading '#' for block comment [ I like marking disabled code blocks with '### ' ]
|
||||
# E402 module level import not at top of file [ usually on purpose. might use individual overrides instead? ]
|
||||
# E701 multiple statements on one line [ still quite readable in short forms ]
|
||||
# E713 test for membership should be ‘not in’ [ disagree: want `not a in x` ]
|
||||
# E714 test for object identity should be 'is not' [ disagree: want `not a is x` ]
|
||||
# W503 line break before binary operator [ gotta pick one way ]
|
||||
extend-ignore = E266,E402,E701,E713,E714,W503
|
||||
max-line-length = 120
|
||||
exclude = *_pb2.py
|
||||
application-import-names = capport
|
9
.gitignore
vendored
9
.gitignore
vendored
@ -1,9 +0,0 @@
|
||||
.vscode
|
||||
*.pyc
|
||||
*.egg-info
|
||||
__pycache__
|
||||
venv
|
||||
capport.yaml
|
||||
custom
|
||||
capport.state
|
||||
capport.state.new-*
|
19
LICENSE
19
LICENSE
@ -1,19 +0,0 @@
|
||||
Copyright (c) 2022 Universität Stuttgart
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
82
README.md
82
README.md
@ -2,85 +2,3 @@
|
||||
|
||||
New URL is now at:
|
||||
https://git-nks-public.tik.uni-stuttgart.de/ac107458/python-capport.git
|
||||
|
||||
|
||||
|
||||
# python Captive Portal
|
||||
|
||||
## Installation
|
||||
|
||||
Either clone repository (and install dependencies either through distribution or as virtualenv with `./setup-venv.sh`) or install as package.
|
||||
|
||||
[`pipx`](https://pypa.github.io/pipx/) (available in debian as package) can be used to install in separate virtual environment:
|
||||
|
||||
pipx install https://git-nks-public.tik.uni-stuttgart.de/net/python-capport
|
||||
|
||||
In production put a reverse proxy in front of the local web ui (on 127.0.0.1:5001), and handle `/static` path either to `src/capport/api/static/` or your customized version of static files.
|
||||
|
||||
See the `contrib` directory for config of other software needed to setup a captive portal.
|
||||
|
||||
## Customization
|
||||
|
||||
Create `custom/templates` and put customized templates (from `src/capport/api/templates`) there.
|
||||
|
||||
Create `i18n/<langcode>` folders to put localized templates into (localized extends must use the full `i18n/.../basetmpl` paths though).
|
||||
|
||||
Requests with a `setlang=<langcode>` query parameter will set the language and try to store the choice in a session cookie.
|
||||
|
||||
## Run
|
||||
|
||||
Run `./start-api.sh` to start the web ui (listens on 127.0.0.1:5001 by default).
|
||||
|
||||
Run `./start-control.sh` to start the "controller" ("enforcement") part; this needs to be run as root (i.e. as `CAP_NET_ADMIN` of the current network namespace).
|
||||
|
||||
The controller expects this nft set to exist:
|
||||
|
||||
```
|
||||
table inet captive_mark {
|
||||
set allowed {
|
||||
type ether_addr
|
||||
flags timeout
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Restarting the controller will push all entries it should contain again, but won't cleanup others.
|
||||
|
||||
## Internals
|
||||
|
||||
### Login/Logout
|
||||
|
||||
This is for an "open" network, i.e. no actual logins required, just an "we accept the ToS" form.
|
||||
|
||||
Designed to work without cookies; CSRF protection implemented by verifying the `Origin` header against the `Host` header (but allowing missing `Origin` header), and also requiring the clients `MAC` address (which an attacker from the same L2 could know, or guess from a non-temporary IPv6 address).
|
||||
|
||||
### HA
|
||||
|
||||
The list of "allowed" clients is stored in a "database"; each instance has the full database, and each time two instances connect to each other, they will send their full database for sync (and also all received updates will be broadcast to all others - but only if they actually led to a change in the database).
|
||||
|
||||
On each node there are two instances: one "controller" (also responsible for deploying the list to the kernel, aka "enforcement"), and the webui (also contains the RFC 8908 API).
|
||||
|
||||
The "controller" also stores updates to disk, and loads it on start.
|
||||
|
||||
This synchronization of the database works because it shouldn't matter in which order "changes" to the database are merged (each change is also just the new state from another database); see the `merge` method in the `MacEntry` class in `src/capport/database.py`.
|
||||
|
||||
#### Protocol
|
||||
|
||||
The controllers should be a full-mesh (be connected to all other controllers), and the webui instances are connected to all controllers (but not to other webui instances).
|
||||
The controllers listen on the fixed TCP port 5000 for a custom "database sync" protocol.
|
||||
|
||||
This protocol is based on an anonymous TLS connection, which then uses a shared secret to verify the connection (not perfect yet; it would be better if python simply supported SRP - https://bugs.python.org/issue11943).
|
||||
|
||||
Then both sides can send protobuf messages; each message is prefixed by its 4-byte length. The basic message is defined in `protobuf/message.proto`.
|
||||
|
||||
### Web-UI
|
||||
|
||||
The ui needs to know the clients mac address to add it to the database. Right now this means that the webui must run on a host connected to the L2 of the clients to see them in the neighbor table (and client connection to the ui must use this L2 connection - the ui doesn't actively query for neighbors, it only looks at the neighbor cache).
|
||||
|
||||
### async
|
||||
|
||||
This project uses the `trio` python library for async IO.
|
||||
|
||||
Only the netlink handling (`ipneigh.py`, `nft_*.py`) uses blocking IO - but that should be ok, as we only make requests to the kernel which should get answered immediately.
|
||||
|
||||
Disk-IO for writing the database to disk is done in a separate thread.
|
||||
|
@ -1,11 +0,0 @@
|
||||
---
|
||||
comm-secret: mysecret
|
||||
cookie-secret: mysecret
|
||||
controllers:
|
||||
- capport-controller1.example.com
|
||||
- capport-controller2.example.com
|
||||
session-timeout: 3600 # in seconds
|
||||
venue-info-url: 'https://example.com'
|
||||
server-names:
|
||||
- localhost
|
||||
- ...
|
@ -1,54 +0,0 @@
|
||||
# Various other parts of a captive portal setup
|
||||
|
||||
Network (HA) setup with IPv4 NAT:
|
||||
- two nodes
|
||||
- shared uplink L2, some transfer network (can be private or public)
|
||||
* virtual addresses on active node for IPv4 and IPv6 to route traffic to
|
||||
- shared downlink L2
|
||||
* virtual addresses on active node for IPv4 and IPv6 as gateway for clients
|
||||
* using `fe80::1` as gateway, but also add a public IPv6 virtual address
|
||||
* connected: private IPv4 prefix (e.g. CGNAT), not routed
|
||||
* connected: public IPv6 prefix (routed to virtual uplink address of nodes)
|
||||
- public IPv4 prefix routed virtual uplink address of nodes to use for NAT
|
||||
* IPv4-traffic from clients will be (S)NATted from this prefix; size depends
|
||||
on number of parallel connections you want to support.
|
||||
- webserver on nodes:
|
||||
* port 8080: receives transparent http redirects from the firewall; should return a temporary redirect to your portal page.
|
||||
* port 80: redirect to https
|
||||
* port 443: reverse-proxy to 127.0.0.1:5001 (the webui backend), but serve `/static` directly from directory (see main README)
|
||||
|
||||
To access the portal page on the clients you'll need a DNS-name; it should point to the virtual addresses. In some ways downlink address is preferred, but you also might want to avoid private addresses - i.e. use the uplink IPv4 address and the downlink IPv6 address.
|
||||
|
||||
Also the management traffic for the virtual address should use the uplink interface if possible (`keepalived` supports this).
|
||||
|
||||
## ISC dhcpd
|
||||
|
||||
See `dhcpd.conf.erb` and `dhcpd6.conf.erb`.
|
||||
|
||||
Note: don't use too large IPv4 pools or dhcpd will take a long time to sync and build up the leases files.
|
||||
|
||||
## Firewall / NAT
|
||||
|
||||
See `nftables.conf.erb` for forwarding rules; if you want traffic shaping as well see `shape_non_whitelisted.sh`.
|
||||
Local policies (ssh access and normal "host protection") are not included in the example.
|
||||
|
||||
You also might want to set a high `net.netfilter.nf_conntrack_max` with sysctl (e.g. `16777216`).
|
||||
|
||||
## Conntrackd
|
||||
|
||||
Active/failover configuration TBD.
|
||||
|
||||
I strongly recommend not to enable any tracking helpers; they often make significant holes into your stateful firewall (i.e. make clients reachable from the outside in ways they didn't actually want).
|
||||
|
||||
## Keepalived (for virtual addresses)
|
||||
|
||||
See `keepalived.conf.erb`.
|
||||
|
||||
## Apache2
|
||||
|
||||
See `apache2.conf` (only contains "interesting" parts, probably won't start that way).
|
||||
Any other webserver configured in a similar way should do just as well.
|
||||
|
||||
## systemd units
|
||||
|
||||
See the `systemd` directory for examples of systemd units.
|
@ -1,43 +0,0 @@
|
||||
Listen 80
|
||||
Listen 443
|
||||
Listen 8080
|
||||
|
||||
<VirtualHost *:8080>
|
||||
ServerName redirect
|
||||
|
||||
Header always set Cache-Control "no-store"
|
||||
# trailing '?' drops request query string:
|
||||
RedirectMatch seeother ^.*$ https://portal.example.com?
|
||||
KeepAlive off
|
||||
</VirtualHost>
|
||||
|
||||
<VirtualHost *:80>
|
||||
ServerName portal.example.com
|
||||
ServerAlias portal-node1.example.com
|
||||
|
||||
Redirect permanent / https://portal.example.com/
|
||||
</VirtualHost>
|
||||
|
||||
<VirtualHost *:443>
|
||||
ServerName portal.example.com
|
||||
ServerAlias portal-node1.example.com
|
||||
|
||||
SSLEngine on
|
||||
SSLCertificateFile "/etc/ssl/certs/portal.example.com-with-chain.crt"
|
||||
SSLCertificateKeyFile "/etc/ssl/private/portal.example.com.key"
|
||||
|
||||
# The static directory of your theme (or the builtin one)
|
||||
Alias /static "/var/lib/python-capport/custom/static"
|
||||
|
||||
Header always set X-Frame-Options DENY
|
||||
Header always set Referrer-Policy same-origin
|
||||
Header always set X-Content-Type-Options nosniff
|
||||
Header always set Strict-Transport-Security "max-age=31556926;"
|
||||
|
||||
RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME}
|
||||
|
||||
ProxyRequests Off
|
||||
ProxyPreserveHost On
|
||||
ProxyPass /static !
|
||||
ProxyPass / http://127.0.0.1:5001/
|
||||
</VirtualHost>
|
@ -1,35 +0,0 @@
|
||||
option domain-name-servers <%= ', '.join(@dns_resolvers_ipv4) %>;
|
||||
option ntp-servers <%= ', '.join(@ntp_servers_ipv4) %>;
|
||||
|
||||
# specify API server URL (RFC8910)
|
||||
option default-url "https://<%= @service_name %>/api/captive-portal";
|
||||
|
||||
default-lease-time 600;
|
||||
max-lease-time 3600;
|
||||
|
||||
authoritative;
|
||||
|
||||
<% if @instances.length == 2 -%>
|
||||
failover peer "dhcp-peer" {
|
||||
<% if @instance_index == 0 %>primary<% else %>secondary<% end %>;
|
||||
address <%= @instances[@instance_index]['external_ipv4'] %>;
|
||||
peer address <%= @instances[1-@instance_index]['external_ipv4'] %>;
|
||||
max-response-delay 60;
|
||||
max-unacked-updates 10;
|
||||
load balance max seconds 3;
|
||||
<% if @instance_index == 0 -%>
|
||||
split 128;
|
||||
mclt 180;
|
||||
<%- end %>
|
||||
}
|
||||
<%- end %>
|
||||
|
||||
subnet <%= @client_ipv4_net %> netmask <%= @client_netmask %> {
|
||||
option routers <%= @client_ipv4_gateway %>;
|
||||
pool {
|
||||
range <%= @client_ipv4_dhcp_from %> <%= @client_ipv4_dhcp_to %>;
|
||||
<% if @instances.length == 2 -%>
|
||||
failover peer "dhcp-peer";
|
||||
<%- end %>
|
||||
}
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
option dhcp6.name-servers <%= ', '.join(@dns_resolvers_ipv6) %>;
|
||||
option dhcp6.sntp-servers <%= ', '.join(@ntp_servers_ipv6) %>;
|
||||
|
||||
# specify API server URL (RFC8910)
|
||||
option dhcp6.v6-captive-portal "https://<%= @service_name %>/api/captive-portal";
|
||||
|
||||
# The delay before information-request refresh
|
||||
# (minimum is 10 minutes, maximum one day, default is to not refresh)
|
||||
# (set to 6 hours)
|
||||
option dhcp6.info-refresh-time 3600;
|
||||
|
||||
subnet6 <%= @client_ipv6 %> {
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
global_defs {
|
||||
vrrp_no_swap
|
||||
checker_no_swap
|
||||
script_user nobody
|
||||
enable_script_security
|
||||
}
|
||||
|
||||
vrrp_instance capport_ipv4_default {
|
||||
state BACKUP
|
||||
|
||||
interface <%= @uplink_interface %>
|
||||
|
||||
virtual_router_id 1
|
||||
|
||||
priority 100
|
||||
|
||||
virtual_ipaddress {
|
||||
<% @uplink_virtual_ipv4 %>
|
||||
<% @client_virtual_ipv4 %> dev <%= @client_interface %>
|
||||
}
|
||||
|
||||
promote_secondaries
|
||||
}
|
||||
|
||||
vrrp_instance capport_ipv6_default {
|
||||
state BACKUP
|
||||
|
||||
interface <%= @uplink_interface %>
|
||||
|
||||
virtual_router_id 2
|
||||
|
||||
priority 100
|
||||
|
||||
virtual_ipaddress {
|
||||
fe80::1:1
|
||||
<%= @uplink_virtual_ipv6 %>
|
||||
fe80::1 dev <%= @client_interface %>
|
||||
<%= @client_virtual_ipv6 %> dev <%= @client_interface %>
|
||||
}
|
||||
|
||||
promote_secondaries
|
||||
}
|
||||
|
||||
vrrp_sync_group capport_default {
|
||||
group {
|
||||
capport_ipv4_default
|
||||
capport_ipv6_default
|
||||
}
|
||||
}
|
@ -1,360 +0,0 @@
|
||||
#!/usr/sbin/nft -f
|
||||
|
||||
# Template notes: most variables should have an obvious meaning, but:
|
||||
# - client_ipv4: private IPv4 network for clients (not routed outside), connected
|
||||
# - client_ipv4_public: public IPv4 network for clients, must be routed to this host
|
||||
# and should be blackholed here.
|
||||
# - client_ipv6: public IPv6 network, must be routed to this host, connected
|
||||
|
||||
# NOTE: mustn't flush full ruleset; need to keep the table `captive_mark` and its set around
|
||||
# DON'T ENABLE THIS:
|
||||
# flush ruleset
|
||||
|
||||
# fully whitelist certain sites, e.g. VPN gateways your users are
|
||||
# allowed to connect to even without accepting the terms.
|
||||
define full_ipv4 = {
|
||||
<%- @whitelist_full_ipv4.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
define full_ipv6 = {
|
||||
<%- @whitelist_full_ipv6.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
|
||||
# whitelist http[s] traffic to certain sites, e.g. your
|
||||
# homepage hosting the terms, OCSP responders, websites
|
||||
# to setup other Wifi configurations (cat.eduroam.org)
|
||||
define http_server_ipv4 = {
|
||||
<%- @whitelist_http_ipv4.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
define http_server_ipv6 = {
|
||||
<%- @whitelist_http_ipv6.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
|
||||
option ntp-servers <%= ', '.join(@ntp_servers_ipv4) %>;
|
||||
option dhcp6.sntp-servers <%= ', '.join(@ntp_servers_ipv6) %>;
|
||||
|
||||
# whitelist your DNS resolvers
|
||||
define dns_server_ipv4 = {
|
||||
<%- @dns_resolvers_ipv4.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
define dns_server_ipv6 = {
|
||||
<%- @dns_resolvers_ipv6.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
|
||||
# whitelist your (and possible other friendly) NTP servers
|
||||
define ntp_server_ipv4 = {
|
||||
<%- @ntp_servers_ipv4.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
define ntp_server_ipv6 = {
|
||||
<%- @ntp_servers_ipv6.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
|
||||
# ports to block traffic for completely
|
||||
## block traffic from clients to certain server ports
|
||||
define backlist_tcp = {
|
||||
25, # SMTP
|
||||
161, # SNMP
|
||||
135, # epmap (netbios/portmapping)
|
||||
137, # TCP netbios-ns
|
||||
138, # TCP netbios-dgm
|
||||
139, # netbios-ssn
|
||||
445, # microsoft-ds (cifs, samba)
|
||||
}
|
||||
define backlist_udp = {
|
||||
25, # SMTP
|
||||
161, # SNMP
|
||||
135, # UDP epmap (netbios/portmapping)
|
||||
137, # netbios-ns
|
||||
138, # netbios-dgm
|
||||
139, # UDP netbios-ssn
|
||||
445, # UDP microsoft-ds (cifs, samba)
|
||||
5353, # mDNS
|
||||
}
|
||||
## block traffic from certain client ports
|
||||
define backlist_udp_source = {
|
||||
162, # SNMP trap
|
||||
}
|
||||
|
||||
# once a client accepted the terms:
|
||||
# * define "good" (whitelisted) traffic with the following lists
|
||||
# * decide in "chain forward" below whether to block other traffic completely or
|
||||
# e.g. shape it to low bandwidth (also see shape_non_whitelisted.sh)
|
||||
define whitelist_tcp = {
|
||||
22, # ssh
|
||||
53, # dns
|
||||
80, # http
|
||||
443, # https
|
||||
3128, # http proxy (squid)
|
||||
8080, # http alt
|
||||
110, # pop3
|
||||
995, # pop3s
|
||||
143, # imap
|
||||
993, # imaps
|
||||
587, # submission
|
||||
465, # submissions
|
||||
1194, # openvpn default
|
||||
}
|
||||
# https://help.webex.com/en-us/article/WBX264/How-Do-I-Allow-Webex-Meetings-Traffic-on-My-Network
|
||||
define whitelist_udp = {
|
||||
53, # dns
|
||||
123, # ntp
|
||||
443, # http/3
|
||||
1194, # openvpn default
|
||||
51820, # wireguard default
|
||||
500, # IPsec isakmp
|
||||
4500, # IPsec ipsec-nat-t
|
||||
10000, # IPSec Cisco NAT-T
|
||||
9000, # Primary Webex Client Media
|
||||
5004, # Webex Client Media
|
||||
}
|
||||
# whitelist traffic to local sites
|
||||
define whitelist_dest_ipv4 = {
|
||||
<%- @local_site_prefixes_ipv4.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
define whitelist_dest_ipv6 = {
|
||||
<%- @local_site_prefixes_ipv6.each do |n| -%>
|
||||
<%= n %>,
|
||||
<%- end -%>
|
||||
}
|
||||
|
||||
# IPv4 HTTP redirect + SNAT
|
||||
table ip nat4 {}
|
||||
flush table ip nat4
|
||||
table ip nat4 {
|
||||
chain prerouting {
|
||||
type nat hook prerouting priority -100;
|
||||
policy accept;
|
||||
# needs to be marked from client interface (bit 0), captive (bit 1), and http dnat (bit 2) - otherwise accept
|
||||
meta mark & 0x00000007 != 0x00000007 accept
|
||||
tcp dport 80 redirect to 8080
|
||||
}
|
||||
|
||||
chain postrouting {
|
||||
type nat hook postrouting priority 100;
|
||||
policy accept;
|
||||
# needs to be marked from client interface (bit 0) - otherwise no NAT
|
||||
meta mark & 0x00000001 != 0x00000001 accept
|
||||
oifname <%= @uplink_interface %> snat to <%= @client_ipv4_public %>
|
||||
}
|
||||
}
|
||||
|
||||
# IPv6 HTTP redirect
|
||||
table ip6 nat6 {}
|
||||
flush table ip6 nat6
|
||||
table ip6 nat6 {
|
||||
chain prerouting {
|
||||
type nat hook prerouting priority -100;
|
||||
policy accept;
|
||||
# needs to be marked from client interface (bit 0), captive (bit 1), and http dnat (bit 2) - otherwise accept
|
||||
meta mark & 0x00000007 != 0x00000007 accept
|
||||
tcp dport 80 redirect to 8080
|
||||
}
|
||||
}
|
||||
|
||||
# ipv4 + ipv6
|
||||
table inet captive_filter {}
|
||||
flush table inet captive_filter
|
||||
table inet captive_filter {
|
||||
chain antispoof_input {
|
||||
# need to accept packet for input and forward!
|
||||
meta mark & 0x00000001 == 0 accept comment "accept from non-client interface"
|
||||
ip saddr { 0.0.0.0, <%= @client_ipv4 %> } return
|
||||
ip6 saddr { fe80::/64, <%= @client_ipv6 %> } return
|
||||
drop
|
||||
}
|
||||
|
||||
# we need the "redirect" decision before DNAT in prerouting:-100
|
||||
chain mark_clients {
|
||||
type filter hook prerouting priority -110;
|
||||
policy accept;
|
||||
meta mark & 0x00000001 == 0 accept comment "accept from non-client interface"
|
||||
jump antispoof_input
|
||||
meta mark & 0x00000002 == 0 accept comment "accept packets from non-captive clients"
|
||||
# now accept all traffic to destinations allowed in captive state, and mark "redirect" packets:
|
||||
jump captive_allowed
|
||||
}
|
||||
|
||||
chain input {
|
||||
type filter hook input priority 0;
|
||||
policy accept;
|
||||
# TODO: limit services available to clients? iptconf might already be enough
|
||||
}
|
||||
|
||||
chain forward_down {
|
||||
# only filter uplink -> client here:
|
||||
iifname != <%= @uplink_interface %> accept
|
||||
oifname != <%= @client_interface %> accept
|
||||
# established connections
|
||||
ct state established,related accept
|
||||
# allow incoming ipv6 ping (ipv4 ping can't work due to NAT)
|
||||
icmpv6 type echo-request accept
|
||||
drop
|
||||
}
|
||||
|
||||
chain antispoof_forward {
|
||||
meta mark & 0x00000001 == 0 goto forward_down comment "handle forwardings not from client interface"
|
||||
ip saddr { <%= @client_ipv4 %> } return
|
||||
ip6 saddr { <%= @client_ipv6 %> } return
|
||||
drop
|
||||
}
|
||||
|
||||
chain captive_allowed_icmp {
|
||||
# allow all pings to servers we allow other kind of traffic
|
||||
icmp type echo-request accept
|
||||
icmpv6 type echo-request accept
|
||||
}
|
||||
|
||||
chain captive_allowed_http {
|
||||
# http + https (but not QUIC)
|
||||
tcp dport { 80, 443 } accept
|
||||
goto captive_allowed_icmp
|
||||
}
|
||||
|
||||
chain captive_allowed_dns {
|
||||
# DNS, DoT, DoH
|
||||
udp dport { 53, 853 } accept
|
||||
tcp dport { 53, 443, 853 } accept
|
||||
goto captive_allowed_icmp
|
||||
}
|
||||
|
||||
chain captive_allowed_ntp {
|
||||
# only NTP
|
||||
udp dport 123 accept
|
||||
goto captive_allowed_icmp
|
||||
}
|
||||
|
||||
chain captive_allowed {
|
||||
# all protocols for fully whitelisted
|
||||
ip daddr $full_ipv4 accept
|
||||
ip6 daddr $full_ipv6 accept
|
||||
|
||||
ip daddr $http_server_ipv4 jump captive_allowed_http
|
||||
ip6 daddr $http_server_ipv6 jump captive_allowed_http
|
||||
|
||||
ip daddr $dns_server_ipv4 jump captive_allowed_dns
|
||||
ip6 daddr $dns_server_ipv6 jump captive_allowed_dns
|
||||
|
||||
ip daddr $ntp_server_ipv4 jump captive_allowed_ntp
|
||||
ip6 daddr $ntp_server_ipv6 jump captive_allowed_ntp
|
||||
|
||||
# mark (new) http clients to redirect to local http server with bit 2
|
||||
tcp dport 80 ct state new meta mark set meta mark | 0x00000004 accept comment "DNAT HTTP"
|
||||
# mark packets to reject in forward with bit 3
|
||||
meta mark set meta mark | 0x00000008 comment "reject in forwarding"
|
||||
}
|
||||
|
||||
# for DNS+NTP
|
||||
ct timeout udp-oneshot {
|
||||
protocol udp;
|
||||
policy = { unreplied: 10, replied: 0 }
|
||||
}
|
||||
|
||||
chain forward_reject {
|
||||
# could reject TCP connections with tcp reset, but ICMP unreachable should be good enough
|
||||
# (and it's also semantically correct):
|
||||
# ip protocol tcp reject with tcp reset
|
||||
# ip6 nexthdr tcp reject with tcp reset
|
||||
|
||||
# but we need to close existing tcp sessions: (when client moves to captive state)
|
||||
ct state != new ip protocol tcp reject with tcp reset
|
||||
ct state != new ip6 nexthdr tcp reject with tcp reset
|
||||
|
||||
# default icmp reject
|
||||
reject with icmpx type admin-prohibited
|
||||
}
|
||||
|
||||
# block some ports always
|
||||
chain blacklist {
|
||||
tcp dport $backlist_tcp goto forward_reject
|
||||
udp dport $backlist_udp goto forward_reject
|
||||
udp sport $backlist_udp_source goto forward_reject
|
||||
}
|
||||
|
||||
# ports we assume are "proper" - still only allowed in non-captive state
|
||||
chain whitelist {
|
||||
ip daddr $whitelist_dest_ipv4 accept
|
||||
ip6 daddr $whitelist_dest_ipv6 accept
|
||||
tcp dport $whitelist_tcp accept
|
||||
udp dport $whitelist_udp accept
|
||||
ip protocol esp accept
|
||||
ip6 nexthdr esp accept
|
||||
icmp type echo-request accept
|
||||
icmpv6 type echo-request accept
|
||||
# icmp related to existing connections, ...
|
||||
ct state established,related accept
|
||||
}
|
||||
|
||||
chain forward {
|
||||
type filter hook forward priority 0;
|
||||
policy drop;
|
||||
jump antispoof_forward
|
||||
# optimize conntrack timeouts for DNS and NTP
|
||||
udp dport { 53, 123 } ct timeout set "udp-oneshot"
|
||||
jump blacklist
|
||||
# reject packets marked for rejection in mark_clients/captive_allowed (bit 3)
|
||||
meta mark & 0x00000008 != 0 goto forward_reject
|
||||
jump whitelist
|
||||
# optional (policy):
|
||||
goto forward_reject comment "drop not-whitelisted traffic completely"
|
||||
# (conntrack) mark connection with bit 4 for shaping
|
||||
ct mark set ct mark | 0x00000010 counter accept comment "accept shaped traffic"
|
||||
}
|
||||
|
||||
chain forward_con_shaped_mark {
|
||||
type filter hook forward priority 10;
|
||||
policy accept;
|
||||
# copy conntrack mark bit 4 to meta mark bit 4
|
||||
ct mark & 0x00000010 == 0 accept comment "non shaped connection"
|
||||
meta mark set meta mark | 0x00000010 counter accept comment "shaped connection"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# NOTE: mustn't flush this table to keep the set around
|
||||
# DON'T ENABLE THIS:
|
||||
# flush table inet captive_mark
|
||||
|
||||
# as table wasn't flushed, at least delete the single chain we're expecting
|
||||
table inet captive_mark {
|
||||
chain prerouting { }
|
||||
}
|
||||
delete chain inet captive_mark prerouting;
|
||||
|
||||
table inet captive_mark {
|
||||
# set isn't recreated, i.e. keeps dynamically added members
|
||||
set allowed {
|
||||
type ether_addr
|
||||
flags timeout
|
||||
}
|
||||
|
||||
chain prerouting {
|
||||
type filter hook prerouting priority -150;
|
||||
policy accept;
|
||||
iifname != <%= @client_interface %> accept
|
||||
# mark packets from client interface with bit 0
|
||||
meta mark set meta mark | 0x00000001
|
||||
ether saddr @allowed accept
|
||||
# mark "captive" clients with bit 1
|
||||
meta mark set meta mark | 0x00000002
|
||||
accept
|
||||
}
|
||||
|
||||
# further existing elements in this table won't be cleared by loading this file!
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
interface <%= client_interface %>
|
||||
{
|
||||
AdvSendAdvert on;
|
||||
AdvDefaultPreference high;
|
||||
AdvSourceLLAddress off;
|
||||
AdvRASrcAddress {
|
||||
fe80::1;
|
||||
};
|
||||
|
||||
RDNSS <%= ' '.join(@dns_resolvers_ipv6) %> {
|
||||
FlushRDNSS off;
|
||||
};
|
||||
|
||||
AdvOtherConfigFlag on;
|
||||
|
||||
# will require radvd > 2.19 (not released yet)
|
||||
# AdvCaptivePortalAPI "https://<%= @service_name %>/api/captive-portal";
|
||||
|
||||
prefix <%= client_ipv6 %>
|
||||
{
|
||||
};
|
||||
};
|
@ -1,31 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
limit_iface() {
|
||||
local dev=$1
|
||||
local shape_rate=$2 # "guaranteed"
|
||||
local shape_ceil=$3 # "upper limit"
|
||||
|
||||
tc qdisc delete dev "${dev}" root 2>/dev/null
|
||||
tc qdisc add dev "${dev}" root handle 1: htb default 0x11
|
||||
# basically no limit for default traffic
|
||||
tc class add dev "${dev}" parent 1: classid 1:11 htb rate 10Gbit ceil 100Gbit quantum 100000
|
||||
# limit "bad" (not whitelisted) traffic
|
||||
tc class add dev "${dev}" parent 1: classid 1:12 htb prio 1 rate "${shape_rate}" ceil "${shape_ceil}"
|
||||
# use "codel" qdisc for both classes, but with larger queue for default traffic
|
||||
tc qdisc add dev "${dev}" parent 1:11 handle 11: codel limit 20000
|
||||
tc qdisc add dev "${dev}" parent 1:12 handle 12: codel
|
||||
|
||||
# sort into bad class based on netfilter mark (if bit 0x10 is set)
|
||||
tc filter add dev "${dev}" parent 1: prio 1 basic match 'meta(nf_mark mask 0x10 eq 0x10)' classid 1:12
|
||||
}
|
||||
|
||||
uplink=$1
|
||||
downlink=$2
|
||||
|
||||
if [ -z "${uplink}" -o -z "${downlink}" ]; then
|
||||
echo >&2 "Missing uplink and downlink interface names"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
limit_iface "${uplink}" "1Mbit" "1Mbit"
|
||||
limit_iface "${downlink}" "1Mbit" "1Mbit"
|
@ -1,18 +0,0 @@
|
||||
[Unit]
|
||||
Description=Captive Portal enforcement service
|
||||
Wants=basic.target
|
||||
After=basic.target network.target
|
||||
ConditionFileIsExecutable=/var/lib/python-capport/start-control.sh
|
||||
ConditionPathIsDirectory=/var/lib/python-capport/venv
|
||||
|
||||
# TODO: start as unprivileged user but with CAP_NET_ADMIN ?
|
||||
[Service]
|
||||
Type=notify
|
||||
WatchdogSec=10
|
||||
ExecStart=/var/lib/python-capport/start-control.sh
|
||||
Restart=always
|
||||
ProtectSystem=full
|
||||
ProtectHome=true
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
@ -1,18 +0,0 @@
|
||||
[Unit]
|
||||
Description=NFT Firewall Shim for Captive Portal
|
||||
Wants=network-pre.target
|
||||
Before=network-pre.target shutdown.target
|
||||
Conflicts=shutdown.target
|
||||
DefaultDependencies=no
|
||||
|
||||
[Service]
|
||||
Type=oneshot
|
||||
RemainAfterExit=yes
|
||||
StandardInput=null
|
||||
ProtectSystem=full
|
||||
ProtectHome=true
|
||||
ExecStart=/usr/sbin/nft -f /etc/nftables.conf
|
||||
ExecReload=/usr/sbin/nft -f /etc/nftables.conf
|
||||
|
||||
[Install]
|
||||
WantedBy=sysinit.target
|
@ -1,15 +0,0 @@
|
||||
[Unit]
|
||||
Description=Captive Portal traffic shaping
|
||||
Wants=basic.target
|
||||
After=basic.target network.target
|
||||
|
||||
[Service]
|
||||
Type=oneshot
|
||||
RemainAfterExit=yes
|
||||
StandardInput=null
|
||||
ProtectSystem=full
|
||||
ProtectHome=true
|
||||
ExecStart=/etc/capport-tc.sh <%= @uplink_interface %> <%= @client_interface %>
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
@ -1,18 +0,0 @@
|
||||
[Unit]
|
||||
Description=Captive Portal web ui service
|
||||
Wants=basic.target
|
||||
After=basic.target network.target
|
||||
ConditionFileIsExecutable=/var/lib/python-capport/start-control.sh
|
||||
ConditionPathIsDirectory=/var/lib/python-capport/venv
|
||||
|
||||
[Service]
|
||||
User=capport
|
||||
Type=notify
|
||||
WatchdogSec=10
|
||||
ExecStart=/var/lib/python-capport/start-api.sh
|
||||
Restart=always
|
||||
ProtectSystem=full
|
||||
ProtectHome=true
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
19
flake8
19
flake8
@ -1,19 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
### check type annotations with mypy
|
||||
|
||||
set -e
|
||||
|
||||
base=$(dirname "$(readlink -f "$0")")
|
||||
cd "${base}"
|
||||
|
||||
if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
|
||||
echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -x ./venv/bin/flake8 ]; then
|
||||
./venv/bin/pip install flake8 flake8-import-order
|
||||
fi
|
||||
|
||||
./venv/bin/flake8 src
|
42
mypy
42
mypy
@ -1,42 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
### check type annotations with mypy
|
||||
|
||||
set -e
|
||||
|
||||
base=$(dirname "$(readlink -f "$0")")
|
||||
cd "${base}"
|
||||
|
||||
if [ ! -d "venv" -o ! -x "venv/bin/python" ]; then
|
||||
echo >&2 "Missing virtualenv in 'venv'; maybe run setup-venv.sh first!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -x ./venv/bin/mypy ]; then
|
||||
./venv/bin/pip install mypy trio-typing[mypy] types-PyYAML types-aiofiles types-colorama types-cryptography types-protobuf types-toml
|
||||
fi
|
||||
|
||||
site_pkgs=$(./venv/bin/python -c 'import site; print(site.getsitepackages()[0])')
|
||||
if [ ! -d "${site_pkgs}/trio_typing" ]; then
|
||||
./venv/bin/pip install trio-typing[mypy]
|
||||
fi
|
||||
if [ ! -d "${site_pkgs}/yaml-stubs" ]; then
|
||||
./venv/bin/pip install types-PyYAML
|
||||
fi
|
||||
if [ ! -d "${site_pkgs}/aiofiles-stubs" ]; then
|
||||
./venv/bin/pip install types-aiofiles
|
||||
fi
|
||||
if [ ! -d "${site_pkgs}/colorama-stubs" ]; then
|
||||
./venv/bin/pip install types-colorama
|
||||
fi
|
||||
if [ ! -d "${site_pkgs}/cryptography-stubs" ]; then
|
||||
./venv/bin/pip install types-cryptography
|
||||
fi
|
||||
if [ ! -d "${site_pkgs}/google-stubs" ]; then
|
||||
./venv/bin/pip install types-protobuf
|
||||
fi
|
||||
if [ ! -d "${site_pkgs}/toml-stubs" ]; then
|
||||
./venv/bin/pip install types-toml
|
||||
fi
|
||||
|
||||
./venv/bin/mypy --install-types src
|
@ -1,10 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$(readlink -f "$0")")"
|
||||
|
||||
rm -rf ../src/capport/comm/protobuf/message_pb2.py
|
||||
mkdir -p ../src/capport/comm/protobuf
|
||||
|
||||
protoc --python_out=../src/capport/comm/protobuf message.proto
|
@ -1,42 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package capport;
|
||||
|
||||
message Message {
|
||||
oneof oneof {
|
||||
Hello hello = 1;
|
||||
AuthenticationResult authentication_result = 2;
|
||||
Ping ping = 3;
|
||||
MacStates mac_states = 10;
|
||||
}
|
||||
}
|
||||
|
||||
// sent by clients and servers as first message
|
||||
message Hello {
|
||||
bytes instance_id = 1;
|
||||
string hostname = 2;
|
||||
bool is_controller = 3;
|
||||
bytes authentication = 4;
|
||||
}
|
||||
|
||||
// tell peer whether hello authentication was good
|
||||
message AuthenticationResult {
|
||||
bool success = 1;
|
||||
}
|
||||
|
||||
message Ping {
|
||||
bytes payload = 1;
|
||||
}
|
||||
|
||||
message MacStates {
|
||||
repeated MacState states = 1;
|
||||
}
|
||||
|
||||
message MacState {
|
||||
bytes mac_address = 1;
|
||||
// Seconds of UTC time since epoch
|
||||
int64 last_change = 2;
|
||||
// Seconds of UTC time since epoch
|
||||
int64 allow_until = 3;
|
||||
bool allowed = 4;
|
||||
}
|
@ -1,14 +0,0 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"wheel"
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.9"
|
||||
# warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
exclude = [
|
||||
'_pb2\.py$', # TOML literal string (single-quotes, no escaping necessary)
|
||||
]
|
@ -1,11 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
self=$(dirname "$(readlink -f "$0")")
|
||||
cd "${self}"
|
||||
|
||||
python3 -m venv venv
|
||||
|
||||
# install cli extras
|
||||
./venv/bin/pip install --upgrade --upgrade-strategy eager -e '.'
|
38
setup.cfg
38
setup.cfg
@ -1,38 +0,0 @@
|
||||
[metadata]
|
||||
name = capport-tik-nks
|
||||
version = 0.0.1
|
||||
author = Stefan Bühler
|
||||
author_email = stefan.buehler@tik.uni-stuttgart.de
|
||||
description = Captive Portal
|
||||
long_description = file: README.md
|
||||
long_description_content_type = text/markdown
|
||||
url = https://git-nks-public.tik.uni-stuttgart.de/net/python-capport
|
||||
project_urls =
|
||||
Bug Tracker = https://git-nks-public.tik.uni-stuttgart.de/net/python-capport/issues
|
||||
classifiers =
|
||||
Programming Language :: Python :: 3
|
||||
License :: OSI Approved :: MIT License
|
||||
Operating System :: OS Independent
|
||||
|
||||
[options]
|
||||
package_dir =
|
||||
= src
|
||||
packages = find:
|
||||
python_requires = >=3.9
|
||||
install_requires =
|
||||
trio
|
||||
quart
|
||||
quart-trio
|
||||
hypercorn[trio]
|
||||
PyYAML
|
||||
protobuf>=4.21
|
||||
pyroute2~=0.7.3
|
||||
|
||||
[options.packages.find]
|
||||
where = src
|
||||
|
||||
[options.entry_points]
|
||||
console_scripts =
|
||||
capport-control = capport.control.run:main
|
||||
capport-stats = capport.stats:main
|
||||
capport-webui = capport.api.hypercorn_run:main
|
6
setup.py
6
setup.py
@ -1,6 +0,0 @@
|
||||
# https://github.com/pypa/setuptools/issues/2816
|
||||
# allow editable install on older pip versions
|
||||
from setuptools import setup
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup()
|
@ -1,11 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .app_cls import MyQuartApp
|
||||
|
||||
app = MyQuartApp(__name__)
|
||||
|
||||
__import__('capport.api.setup')
|
||||
__import__('capport.api.proxy_fix')
|
||||
__import__('capport.api.lang')
|
||||
__import__('capport.api.template_filters')
|
||||
__import__('capport.api.views')
|
@ -1,55 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import os.path
|
||||
import typing
|
||||
|
||||
import jinja2
|
||||
|
||||
import quart.templating
|
||||
|
||||
import quart_trio
|
||||
|
||||
import capport.comm.hub
|
||||
import capport.config
|
||||
import capport.utils.ipneigh
|
||||
|
||||
|
||||
class DispatchingJinjaLoader(quart.templating.DispatchingJinjaLoader):
|
||||
app: MyQuartApp
|
||||
|
||||
def __init__(self, app: MyQuartApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
def _loaders(self) -> typing.Generator[jinja2.BaseLoader, None, None]:
|
||||
if self.app.custom_loader:
|
||||
yield self.app.custom_loader
|
||||
for loader in super()._loaders():
|
||||
yield loader
|
||||
|
||||
|
||||
class MyQuartApp(quart_trio.QuartTrio):
|
||||
my_nc: typing.Optional[capport.utils.ipneigh.NeighborController] = None
|
||||
my_hub: typing.Optional[capport.comm.hub.Hub] = None
|
||||
my_config: capport.config.Config
|
||||
custom_loader: typing.Optional[jinja2.FileSystemLoader] = None
|
||||
|
||||
def __init__(self, import_name: str, **kwargs) -> None:
|
||||
self.my_config = capport.config.Config.load_default_once()
|
||||
kwargs.setdefault('template_folder', os.path.join(os.path.dirname(__file__), 'templates'))
|
||||
cust_templ = os.path.join('custom', 'templates')
|
||||
if os.path.exists(cust_templ):
|
||||
self.custom_loader = jinja2.FileSystemLoader(os.fspath(cust_templ))
|
||||
cust_static = os.path.abspath(os.path.join('custom', 'static'))
|
||||
if os.path.exists(cust_static):
|
||||
static_folder = cust_static
|
||||
else:
|
||||
static_folder = os.path.join(os.path.dirname(__file__), 'static')
|
||||
kwargs.setdefault('static_folder', static_folder)
|
||||
super().__init__(import_name, **kwargs)
|
||||
self.debug = self.my_config.debug
|
||||
self.secret_key = self.my_config.cookie_secret
|
||||
|
||||
def create_global_jinja_loader(self) -> DispatchingJinjaLoader:
|
||||
"""Create and return a global (not blueprint specific) Jinja loader."""
|
||||
return DispatchingJinjaLoader(self)
|
@ -1,38 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hypercorn.config
|
||||
import hypercorn.trio.run
|
||||
import hypercorn.utils
|
||||
|
||||
import capport.config
|
||||
|
||||
|
||||
def run(config: hypercorn.config.Config) -> None:
|
||||
sockets = config.create_sockets()
|
||||
if config.worker_class != 'trio':
|
||||
raise Exception('Invalid worker class received from constructor')
|
||||
|
||||
hypercorn.trio.run.trio_worker(config=config, sockets=sockets)
|
||||
|
||||
for sock in sockets.secure_sockets:
|
||||
sock.close()
|
||||
for sock in sockets.insecure_sockets:
|
||||
sock.close()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
_config = capport.config.Config.load_default_once()
|
||||
|
||||
hypercorn_config = hypercorn.config.Config()
|
||||
hypercorn_config.application_path = 'capport.api.app'
|
||||
hypercorn_config.worker_class = 'trio'
|
||||
hypercorn_config.bind = ["127.0.0.1:5001"]
|
||||
|
||||
if _config.server_names:
|
||||
hypercorn_config.server_names = _config.server_names
|
||||
elif not _config.debug:
|
||||
raise Exception(
|
||||
"production setup requires server-names in config (list of accepted hostnames in http requests)"
|
||||
)
|
||||
|
||||
run(hypercorn_config)
|
@ -1,83 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os.path
|
||||
import re
|
||||
import typing
|
||||
|
||||
import quart
|
||||
|
||||
from .app import app
|
||||
|
||||
_VALID_LANGUAGE_NAMES = re.compile(r'^[-a-z0-9_]+$')
|
||||
|
||||
|
||||
def parse_accept_language(value: str) -> typing.List[str]:
|
||||
value = value.strip()
|
||||
if not value or value == '*':
|
||||
return []
|
||||
tuples = []
|
||||
for entry in value.split(','):
|
||||
attrs = entry.split(';')
|
||||
name = attrs.pop(0).strip().lower()
|
||||
q = 1.0
|
||||
for attr in attrs:
|
||||
if not '=' in attr: continue
|
||||
key, value = attr.split('=', maxsplit=1)
|
||||
if key.strip().lower() == 'q':
|
||||
try:
|
||||
q = float(value.strip())
|
||||
except ValueError:
|
||||
q = 0.0
|
||||
if q >= 0.0:
|
||||
tuples.append((q, name))
|
||||
tuples.sort()
|
||||
have = set()
|
||||
result = []
|
||||
for (_q, name) in tuples:
|
||||
if name in have: continue
|
||||
if name == '*': break
|
||||
have.add(name)
|
||||
if _VALID_LANGUAGE_NAMES.match(name):
|
||||
result.append(name)
|
||||
short_name = name.split('-', maxsplit=1)[0].split('_', maxsplit=1)[0]
|
||||
if not short_name or short_name in have: continue
|
||||
have.add(short_name)
|
||||
result.append(short_name)
|
||||
return result
|
||||
|
||||
|
||||
@app.before_request
|
||||
def detect_language():
|
||||
g = quart.g
|
||||
r = quart.request
|
||||
s = quart.session
|
||||
if 'setlang' in r.args:
|
||||
lang = r.args.get('setlang').strip().lower()
|
||||
if lang and _VALID_LANGUAGE_NAMES.match(lang):
|
||||
if s.get('lang') != lang:
|
||||
s['lang'] = lang
|
||||
g.langs = [lang]
|
||||
return
|
||||
else:
|
||||
# reset language
|
||||
s.pop('lang', None)
|
||||
lang = s.get('lang')
|
||||
if lang:
|
||||
lang = lang.strip().lower()
|
||||
if lang and _VALID_LANGUAGE_NAMES.match(lang):
|
||||
g.langs = [lang]
|
||||
return
|
||||
acc_lang = ','.join(r.headers.get_all('Accept-Language'))
|
||||
g.langs = parse_accept_language(acc_lang)
|
||||
|
||||
|
||||
async def render_i18n_template(template, /, **kwargs) -> str:
|
||||
langs: typing.List[str] = quart.g.langs
|
||||
if not langs:
|
||||
return await quart.render_template(template, **kwargs)
|
||||
names = [
|
||||
os.path.join('i18n', lang, template)
|
||||
for lang in langs
|
||||
]
|
||||
names.append(template)
|
||||
return await quart.render_template(names, **kwargs)
|
@ -1,80 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import typing
|
||||
|
||||
import quart
|
||||
|
||||
import werkzeug
|
||||
from werkzeug.http import parse_list_header
|
||||
|
||||
from .app import app
|
||||
|
||||
|
||||
def _get_first_in_list(value_list: typing.Optional[str], allowed: typing.Sequence[str] = ()) -> typing.Optional[str]:
|
||||
if not value_list:
|
||||
return None
|
||||
values = parse_list_header(value_list)
|
||||
if values and values[0]:
|
||||
if not allowed or values[0] in allowed:
|
||||
return values[0]
|
||||
return None
|
||||
|
||||
|
||||
def local_proxy_fix(request: quart.Request):
|
||||
if not request.remote_addr:
|
||||
return
|
||||
try:
|
||||
addr = ipaddress.ip_address(request.remote_addr)
|
||||
except ValueError:
|
||||
# TODO: accept unix sockets somehow too?
|
||||
return
|
||||
if not addr.is_loopback:
|
||||
return
|
||||
client = _get_first_in_list(request.headers.get('X-Forwarded-For'))
|
||||
if not client:
|
||||
# assume this is always set behind reverse proxies supporting any of the headers
|
||||
return
|
||||
request.remote_addr = client
|
||||
scheme = _get_first_in_list(request.headers.get('X-Forwarded-Proto'), ('http', 'https'))
|
||||
port: typing.Optional[int] = None
|
||||
if scheme:
|
||||
port = 443 if scheme == 'https' else 80
|
||||
request.scheme = scheme
|
||||
host = _get_first_in_list(request.headers.get('X-Forwarded-Host'))
|
||||
port_s: typing.Optional[str]
|
||||
if host:
|
||||
request.host = host
|
||||
if ':' in host and not host.endswith(']'):
|
||||
try:
|
||||
_, port_s = host.rsplit(':', maxsplit=1)
|
||||
port = int(port_s)
|
||||
except ValueError:
|
||||
# ignore invalid port in host header
|
||||
pass
|
||||
port_s = _get_first_in_list(request.headers.get('X-Forwarded-Port'))
|
||||
if port_s:
|
||||
try:
|
||||
port = int(port_s)
|
||||
except ValueError:
|
||||
# ignore invalid port in header
|
||||
pass
|
||||
if port:
|
||||
if request.server and len(request.server) == 2:
|
||||
request.server = (request.server[0], port)
|
||||
root_path = _get_first_in_list(request.headers.get('X-Forwarded-Prefix'))
|
||||
if root_path:
|
||||
request.root_path = root_path
|
||||
|
||||
|
||||
class LocalProxyFixRequestHandler:
|
||||
def __init__(self, orig_handle_request):
|
||||
self._orig_handle_request = orig_handle_request
|
||||
|
||||
async def __call__(self, request: quart.Request) -> typing.Union[quart.Response, werkzeug.Response]:
|
||||
# need to patch request before url_adapter is built
|
||||
local_proxy_fix(request)
|
||||
return await self._orig_handle_request(request)
|
||||
|
||||
|
||||
app.handle_request = LocalProxyFixRequestHandler(app.handle_request) # type: ignore
|
@ -1,63 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import trio
|
||||
|
||||
import capport.comm.hub
|
||||
import capport.comm.message
|
||||
import capport.database
|
||||
import capport.utils.cli
|
||||
import capport.utils.ipneigh
|
||||
from capport.utils.sd_notify import open_sdnotify
|
||||
|
||||
from .app import app
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiHubApp(capport.comm.hub.HubApplication):
|
||||
async def mac_states_changed(
|
||||
self,
|
||||
*,
|
||||
from_peer_id: uuid.UUID,
|
||||
pending_updates: capport.database.PendingUpdates,
|
||||
) -> None:
|
||||
# TODO: support websocket notification updates to clients?
|
||||
pass
|
||||
|
||||
|
||||
async def _run_hub(*, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
||||
try:
|
||||
async with capport.utils.ipneigh.connect() as mync:
|
||||
app.my_nc = mync
|
||||
_logger.info("Running hub for API")
|
||||
myapp = ApiHubApp()
|
||||
myhub = capport.comm.hub.Hub(config=app.my_config, app=myapp, is_controller=False)
|
||||
app.my_hub = myhub
|
||||
await myhub.run(task_status=task_status)
|
||||
finally:
|
||||
app.my_hub = None
|
||||
app.my_nc = None
|
||||
_logger.info("Done running hub for API")
|
||||
await app.shutdown()
|
||||
|
||||
|
||||
async def _setup(*, task_status=trio.TASK_STATUS_IGNORED):
|
||||
async with open_sdnotify() as sn:
|
||||
await sn.send('STATUS=Starting hub')
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(_run_hub)
|
||||
await sn.send('READY=1', 'STATUS=Ready for client requests')
|
||||
task_status.started()
|
||||
# continue running hub and systemd watchdog handler
|
||||
|
||||
|
||||
@app.before_serving
|
||||
async def init():
|
||||
app.debug = app.my_config.debug
|
||||
app.secret_key = app.my_config.cookie_secret
|
||||
capport.utils.cli.init_logger(app.my_config)
|
||||
await app.nursery.start(_setup)
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,16 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import typing
|
||||
|
||||
from .app import app
|
||||
|
||||
|
||||
@app.add_template_test
|
||||
def ip_in_networks(ip_s: str, networks: typing.Iterable[str]) -> bool:
|
||||
ip = ipaddress.ip_address(ip_s)
|
||||
for network in networks:
|
||||
net = ipaddress.ip_network(network)
|
||||
if ip in net:
|
||||
return True
|
||||
return False
|
@ -1,19 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>{% block title %}Captive Portal{% endblock %}</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<script src="{{ url_for('static', filename='bootstrap/bootstrap.bundle.min.js') }}"></script>
|
||||
<link rel="stylesheet" href="{{ url_for('static', filename='bootstrap/bootstrap.min.css') }}">
|
||||
</head>
|
||||
<body class="container">
|
||||
<header class="d-flex justify-content-center py-3">
|
||||
<ul class="nav nav-pills">
|
||||
<li class="nav-item"><a href="#" class="nav-link active" aria-current="page">Home</a></li>
|
||||
</ul>
|
||||
</header>
|
||||
|
||||
{% block content %}{% endblock %}
|
||||
</body>
|
||||
</html>
|
@ -1,16 +0,0 @@
|
||||
{% extends "base.html" %}
|
||||
{% block content %}
|
||||
You already accepted out conditions and are currently granted access to the internet:
|
||||
|
||||
Your current session will last for {{ state.allowed_remaining }} seconds.
|
||||
<form method="POST" action="/login">
|
||||
<input type="hidden" name="accept" value="1">
|
||||
<input type="hidden" name="mac" value="{{ state.mac }}">
|
||||
<button type="submit" class="btn btn-primary mb-3">Extend session</button>
|
||||
</form>
|
||||
<form method="POST" action="/logout">
|
||||
<input type="hidden" name="mac" value="{{ state.mac }}">
|
||||
<button type="submit" class="btn btn-danger mb-3">Terminate session</button>
|
||||
</form>
|
||||
<br>
|
||||
{% endblock %}
|
@ -1,9 +0,0 @@
|
||||
{% extends "base.html" %}
|
||||
{% block content %}
|
||||
To get access to the internet please accept our usage guidelines by clicking this button:
|
||||
<form method="POST" action="/login">
|
||||
<input type="hidden" name="accept" value="1">
|
||||
<input type="hidden" name="mac" value="{{ state.mac }}">
|
||||
<button type="submit" class="btn btn-primary mb-3">Accept</button>
|
||||
</form>
|
||||
{% endblock %}
|
@ -1,5 +0,0 @@
|
||||
{% extends "base.html" %}
|
||||
{% block content %}
|
||||
<p>It seems you're accessing this site from outside the network this captive portal is running for.</p>
|
||||
<p>Your clients IP address is {{ state.address }}</p>
|
||||
{% endblock %}
|
@ -1,166 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import quart
|
||||
|
||||
import trio
|
||||
|
||||
import capport.comm.hub
|
||||
import capport.comm.message
|
||||
import capport.database
|
||||
import capport.utils.cli
|
||||
import capport.utils.ipneigh
|
||||
from capport import cptypes
|
||||
|
||||
from .app import app
|
||||
from .lang import render_i18n_template
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_client_ip() -> cptypes.IPAddress:
|
||||
remote_addr = quart.request.remote_addr
|
||||
if not remote_addr:
|
||||
quart.abort(500, 'Missing client address')
|
||||
try:
|
||||
addr = ipaddress.ip_address(remote_addr)
|
||||
except ValueError as e:
|
||||
_logger.warning(f'Invalid client address {remote_addr!r}: {e}')
|
||||
quart.abort(500, 'Invalid client address')
|
||||
return addr
|
||||
|
||||
|
||||
async def get_client_mac_if_present(
|
||||
address: typing.Optional[cptypes.IPAddress] = None,
|
||||
) -> typing.Optional[cptypes.MacAddress]:
|
||||
assert app.my_nc # for mypy
|
||||
if not address:
|
||||
address = get_client_ip()
|
||||
return await app.my_nc.get_neighbor_mac(address)
|
||||
|
||||
|
||||
async def get_client_mac(address: typing.Optional[cptypes.IPAddress] = None) -> cptypes.MacAddress:
|
||||
mac = await get_client_mac_if_present(address)
|
||||
if mac is None:
|
||||
_logger.warning(f"Couldn't find MAC addresss for {address}")
|
||||
quart.abort(404, 'Unknown client')
|
||||
return mac
|
||||
|
||||
|
||||
async def user_login(address: cptypes.IPAddress, mac: cptypes.MacAddress) -> None:
|
||||
assert app.my_hub # for mypy
|
||||
async with app.my_hub.database.make_changes() as pu:
|
||||
try:
|
||||
pu.login(mac, app.my_config.session_timeout)
|
||||
except capport.database.NotReadyYet as e:
|
||||
quart.abort(500, str(e))
|
||||
|
||||
if pu:
|
||||
_logger.debug(f'User {mac} (with IP {address}) logged in')
|
||||
for msg in pu.serialized:
|
||||
await app.my_hub.broadcast(msg)
|
||||
|
||||
|
||||
async def user_logout(mac: cptypes.MacAddress) -> None:
|
||||
assert app.my_hub # for mypy
|
||||
async with app.my_hub.database.make_changes() as pu:
|
||||
try:
|
||||
pu.logout(mac)
|
||||
except capport.database.NotReadyYet as e:
|
||||
quart.abort(500, str(e))
|
||||
if pu:
|
||||
_logger.debug(f'User {mac} logged out')
|
||||
for msg in pu.serialized:
|
||||
await app.my_hub.broadcast(msg)
|
||||
|
||||
|
||||
async def user_lookup() -> cptypes.MacPublicState:
|
||||
assert app.my_hub # for mypy
|
||||
address = get_client_ip()
|
||||
mac = await get_client_mac_if_present(address)
|
||||
if not mac:
|
||||
return cptypes.MacPublicState.from_missing_mac(address)
|
||||
else:
|
||||
return app.my_hub.database.lookup(address, mac)
|
||||
|
||||
|
||||
# @app.route('/all')
|
||||
# async def route_all():
|
||||
# return app.my_hub.database.as_json()
|
||||
|
||||
|
||||
def check_self_origin():
|
||||
origin = quart.request.headers.get('Origin', None)
|
||||
if origin is None:
|
||||
# not a request by a modern browser - probably curl or something similar. don't care.
|
||||
return
|
||||
origin = origin.lower().strip()
|
||||
if origin == 'none':
|
||||
quart.abort(403, 'Origin is none')
|
||||
origin_parts = origin.split('/')
|
||||
# Origin should look like: <scheme>://<hostname> (optionally followed by :<port>)
|
||||
if len(origin_parts) < 3:
|
||||
quart.abort(400, 'Broken Origin header')
|
||||
if origin_parts[0] != 'https:' and not app.my_config.debug:
|
||||
# -> require https in production
|
||||
quart.abort(403, 'Non-https Origin not allowed')
|
||||
origin_host = origin_parts[2]
|
||||
host = quart.request.headers.get('Host', None)
|
||||
if host is None:
|
||||
quart.abort(403, 'Missing Host header')
|
||||
if host.lower() != origin_host:
|
||||
quart.abort(403, 'Origin mismatch')
|
||||
|
||||
|
||||
@app.route('/', methods=['GET'])
|
||||
async def index(missing_accept: bool = False):
|
||||
state = await user_lookup()
|
||||
if not state.mac:
|
||||
return await render_i18n_template('index_unknown.html', state=state, missing_accept=missing_accept)
|
||||
elif state.allowed:
|
||||
return await render_i18n_template('index_active.html', state=state, missing_accept=missing_accept)
|
||||
else:
|
||||
return await render_i18n_template('index_inactive.html', state=state, missing_accept=missing_accept)
|
||||
|
||||
|
||||
@app.route('/login', methods=['POST'])
|
||||
async def login():
|
||||
check_self_origin()
|
||||
with trio.fail_after(5.0):
|
||||
form = await quart.request.form
|
||||
if form.get('accept') != '1':
|
||||
return await index(missing_accept=True)
|
||||
req_mac = form.get('mac')
|
||||
if not req_mac:
|
||||
quart.abort(400, description='Missing MAC in request form data')
|
||||
address = get_client_ip()
|
||||
mac = await get_client_mac(address)
|
||||
if str(mac) != req_mac:
|
||||
quart.abort(403, description="Passed MAC in request form doesn't match client address")
|
||||
await user_login(address, mac)
|
||||
return quart.redirect('/', code=303)
|
||||
|
||||
|
||||
@app.route('/logout', methods=['POST'])
|
||||
async def logout():
|
||||
check_self_origin()
|
||||
with trio.fail_after(5.0):
|
||||
form = await quart.request.form
|
||||
req_mac = form.get('mac')
|
||||
if not req_mac:
|
||||
quart.abort(400, description='Missing MAC in request form data')
|
||||
mac = await get_client_mac()
|
||||
if str(mac) != req_mac:
|
||||
quart.abort(403, description="Passed MAC in request form doesn't match client address")
|
||||
await user_logout(mac)
|
||||
return quart.redirect('/', code=303)
|
||||
|
||||
|
||||
@app.route('/api/captive-portal', methods=['GET'])
|
||||
# RFC 8908: https://datatracker.ietf.org/doc/html/rfc8908
|
||||
async def captive_api():
|
||||
state = await user_lookup()
|
||||
return state.to_rfc8908(app.my_config)
|
@ -1,453 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import ssl
|
||||
import struct
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
import trio
|
||||
|
||||
import capport.comm.message
|
||||
import capport.database
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..config import Config
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HubConnectionReadError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class HubConnectionClosedError(ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class LoopbackConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Channel:
|
||||
def __init__(self, hub: Hub, transport_stream, server_side: bool):
|
||||
self._hub = hub
|
||||
self._serverside = server_side
|
||||
self._ssl = trio.SSLStream(transport_stream, self._hub._anon_context, server_side=server_side)
|
||||
_logger.debug(f"{self}: created (server_side={server_side})")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Channel[0x{id(self):x}]'
|
||||
|
||||
async def do_handshake(self) -> capport.comm.message.Hello:
|
||||
try:
|
||||
await self._ssl.do_handshake()
|
||||
ssl_binding = self._ssl.get_channel_binding()
|
||||
if not ssl_binding:
|
||||
# binding mustn't be None after successful handshake
|
||||
raise ConnectionError("Missing SSL channel binding")
|
||||
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||
raise ConnectionError(e) from None
|
||||
msg = self._hub._make_hello(ssl_binding, server_side=self._serverside).to_message()
|
||||
await self.send_msg(msg)
|
||||
peer_hello = (await self.recv_msg()).to_variant()
|
||||
if not isinstance(peer_hello, capport.comm.message.Hello):
|
||||
raise HubConnectionReadError("Expected Hello as first message")
|
||||
auth_succ = \
|
||||
(peer_hello.authentication ==
|
||||
self._hub._calc_authentication(ssl_binding, server_side=not self._serverside))
|
||||
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
|
||||
peer_auth = (await self.recv_msg()).to_variant()
|
||||
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
|
||||
raise HubConnectionReadError("Expected AuthenticationResult as second message")
|
||||
if not auth_succ or not peer_auth.success:
|
||||
raise HubConnectionReadError("Authentication failed")
|
||||
return peer_hello
|
||||
|
||||
async def _read(self, num: int) -> bytes:
|
||||
assert num > 0
|
||||
buf = b''
|
||||
# _logger.debug(f"{self}:_read({num})")
|
||||
while num > 0:
|
||||
try:
|
||||
part = await self._ssl.receive_some(num)
|
||||
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||
raise ConnectionError(e) from None
|
||||
# _logger.debug(f"{self}:_read({num}) got part {part!r}")
|
||||
if len(part) == 0:
|
||||
if len(buf) == 0:
|
||||
raise HubConnectionClosedError()
|
||||
raise HubConnectionReadError("Unexpected end of TLS stream")
|
||||
buf += part
|
||||
num -= len(part)
|
||||
if num < 0:
|
||||
raise HubConnectionReadError("TLS receive_some returned too much")
|
||||
return buf
|
||||
|
||||
async def _recv_raw_msg(self) -> bytes:
|
||||
len_bytes = await self._read(4)
|
||||
chunk_size, = struct.unpack('!I', len_bytes)
|
||||
chunk = await self._read(chunk_size)
|
||||
if chunk is None:
|
||||
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
|
||||
return chunk
|
||||
|
||||
async def recv_msg(self) -> capport.comm.message.Message:
|
||||
try:
|
||||
chunk = await self._recv_raw_msg()
|
||||
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||
raise ConnectionError(e) from None
|
||||
msg = capport.comm.message.Message()
|
||||
msg.ParseFromString(chunk)
|
||||
return msg
|
||||
|
||||
async def _send_raw(self, chunk: bytes) -> None:
|
||||
try:
|
||||
await self._ssl.send_all(chunk)
|
||||
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||
raise ConnectionError(e) from None
|
||||
|
||||
async def send_msg(self, msg: capport.comm.message.Message):
|
||||
chunk = msg.SerializeToString(deterministic=True)
|
||||
chunk_size = len(chunk)
|
||||
len_bytes = struct.pack('!I', chunk_size)
|
||||
chunk = len_bytes + chunk
|
||||
await self._send_raw(chunk)
|
||||
|
||||
async def aclose(self):
|
||||
try:
|
||||
await self._ssl.aclose()
|
||||
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
||||
raise ConnectionError(e) from None
|
||||
|
||||
|
||||
class Connection:
|
||||
PING_INTERVAL = 10
|
||||
RECEIVE_TIMEOUT = 15
|
||||
SEND_TIMEOUT = 5
|
||||
|
||||
def __init__(self, hub: Hub, channel: Channel, peer: capport.comm.message.Hello):
|
||||
self._channel = channel
|
||||
self._hub = hub
|
||||
tx: trio.MemorySendChannel
|
||||
rx: trio.MemoryReceiveChannel
|
||||
(tx, rx) = trio.open_memory_channel(64)
|
||||
self._pending_tx = tx
|
||||
self._pending_rx = rx
|
||||
self.peer: capport.comm.message.Hello = peer
|
||||
self.peer_id: uuid.UUID = uuid.UUID(bytes=peer.instance_id)
|
||||
self.closed = trio.Event() # set by Hub._lost_peer
|
||||
_logger.debug(f"{self._channel}: authenticated -> {self.peer_id}")
|
||||
|
||||
async def _sender(self, cancel_scope: trio.CancelScope) -> None:
|
||||
try:
|
||||
msg: typing.Optional[capport.comm.message.Message]
|
||||
while True:
|
||||
msg = None
|
||||
# make sure we send something every PING_INTERVAL
|
||||
with trio.move_on_after(self.PING_INTERVAL):
|
||||
msg = await self._pending_rx.receive()
|
||||
# if send blocks too long we're in trouble
|
||||
with trio.fail_after(self.SEND_TIMEOUT):
|
||||
if msg:
|
||||
await self._channel.send_msg(msg)
|
||||
else:
|
||||
await self._channel.send_msg(capport.comm.message.Ping(payload=b'ping').to_message())
|
||||
except trio.TooSlowError:
|
||||
_logger.warning(f"{self._channel}: send timed out")
|
||||
except ConnectionError as e:
|
||||
_logger.warning(f"{self._channel}: failed sending: {e!r}")
|
||||
except Exception:
|
||||
_logger.exception(f"{self._channel}: failed sending")
|
||||
finally:
|
||||
cancel_scope.cancel()
|
||||
|
||||
async def _receive(self, cancel_scope: trio.CancelScope) -> None:
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
with trio.fail_after(self.RECEIVE_TIMEOUT):
|
||||
msg = await self._channel.recv_msg()
|
||||
except (HubConnectionClosedError, ConnectionResetError):
|
||||
return
|
||||
except trio.TooSlowError:
|
||||
_logger.warning(f"{self._channel}: receive timed out")
|
||||
return
|
||||
await self._hub._received_msg(self.peer_id, msg)
|
||||
except ConnectionError as e:
|
||||
_logger.warning(f"{self._channel}: failed receiving: {e!r}")
|
||||
except Exception:
|
||||
_logger.exception(f"{self._channel}: failed receiving")
|
||||
finally:
|
||||
cancel_scope.cancel()
|
||||
|
||||
async def _inner_run(self) -> None:
|
||||
if self.peer_id == self._hub._instance_id:
|
||||
# connected to ourself, don't need that
|
||||
raise LoopbackConnectionError()
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(self._sender, nursery.cancel_scope)
|
||||
# be nice and wait for new_peer beforce receiving messages
|
||||
# (won't work on failover to a second connection)
|
||||
await nursery.start(self._hub._new_peer, self.peer_id, self)
|
||||
nursery.start_soon(self._receive, nursery.cancel_scope)
|
||||
|
||||
async def send_msg(self, *msgs: capport.comm.message.Message):
|
||||
try:
|
||||
for msg in msgs:
|
||||
await self._pending_tx.send(msg)
|
||||
except trio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
async def _run(self) -> None:
|
||||
try:
|
||||
await self._inner_run()
|
||||
finally:
|
||||
_logger.debug(f"{self._channel}: finished message handling")
|
||||
# basic (non-async) cleanup
|
||||
self._hub._lost_peer(self.peer_id, self)
|
||||
self._pending_tx.close()
|
||||
self._pending_rx.close()
|
||||
# allow 3 seconds for proper cleanup
|
||||
with trio.CancelScope(shield=True, deadline=trio.current_time() + 3):
|
||||
try:
|
||||
await self._channel.aclose()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def run(hub: Hub, transport_stream, server_side: bool) -> None:
|
||||
channel = Channel(hub, transport_stream, server_side)
|
||||
try:
|
||||
with trio.fail_after(5):
|
||||
peer = await channel.do_handshake()
|
||||
except trio.TooSlowError:
|
||||
_logger.warning("Handshake timed out")
|
||||
return
|
||||
conn = Connection(hub, channel, peer)
|
||||
await conn._run()
|
||||
|
||||
|
||||
class ControllerConn:
|
||||
def __init__(self, hub: Hub, hostname: str):
|
||||
self._hub = hub
|
||||
self.hostname = hostname
|
||||
self.loopback = False
|
||||
|
||||
async def _connect(self):
|
||||
_logger.info(f"Connecting to controller at {self.hostname}")
|
||||
with trio.fail_after(5):
|
||||
try:
|
||||
stream = await trio.open_tcp_stream(self.hostname, 5000)
|
||||
except OSError as e:
|
||||
_logger.warning(f"Failed to connect to controller at {self.hostname}: {e}")
|
||||
return
|
||||
try:
|
||||
await Connection.run(self._hub, stream, server_side=False)
|
||||
finally:
|
||||
_logger.info(f"Connection to {self.hostname} closed")
|
||||
|
||||
async def run(self):
|
||||
while True:
|
||||
try:
|
||||
await self._connect()
|
||||
except LoopbackConnectionError:
|
||||
_logger.debug(f"Connection to {self.hostname} reached ourself")
|
||||
self.loopback = True
|
||||
return
|
||||
except trio.TooSlowError:
|
||||
pass
|
||||
# try again later
|
||||
retry_splay = random.random() * 5
|
||||
await trio.sleep(10 + retry_splay)
|
||||
|
||||
|
||||
class HubApplication:
|
||||
async def new_peer(self, *, peer_id: uuid.UUID) -> None:
|
||||
_logger.info(f"New peer {peer_id}")
|
||||
|
||||
def lost_peer(self, *, peer_id: uuid.UUID) -> None:
|
||||
_logger.warning(f"Lost peer {peer_id}")
|
||||
|
||||
async def received_unknown_message(self, *, from_peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
|
||||
_logger.warning(f"Received from {from_peer_id}: {str(msg).strip()}")
|
||||
|
||||
async def received_mac_state(self, *, from_peer_id: uuid.UUID, states: capport.comm.message.MacStates) -> None:
|
||||
if _logger.isEnabledFor(logging.DEBUG):
|
||||
_logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}")
|
||||
|
||||
async def mac_states_changed(
|
||||
self,
|
||||
*,
|
||||
from_peer_id: uuid.UUID,
|
||||
pending_updates: capport.database.PendingUpdates,
|
||||
) -> None:
|
||||
if _logger.isEnabledFor(logging.DEBUG):
|
||||
_logger.debug(f"Received new states from {from_peer_id}: {pending_updates}")
|
||||
|
||||
|
||||
class Hub:
|
||||
def __init__(self, config: Config, app: HubApplication, *, is_controller: bool) -> None:
|
||||
self._config = config
|
||||
self._instance_id = uuid.uuid4()
|
||||
self._hostname = socket.getfqdn()
|
||||
self._app = app
|
||||
self._is_controller = is_controller
|
||||
state_filename: str
|
||||
if is_controller:
|
||||
state_filename = config.database_file
|
||||
else:
|
||||
state_filename = ''
|
||||
self.database = capport.database.Database(state_filename=state_filename)
|
||||
self._anon_context = ssl.SSLContext()
|
||||
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
|
||||
self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||
# -> AECDH-AES256-SHA
|
||||
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
|
||||
self._anon_context.set_ciphers('HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0')
|
||||
self._controllers: dict[str, ControllerConn] = {}
|
||||
self._established: dict[uuid.UUID, Connection] = {}
|
||||
|
||||
async def _accept(self, stream):
|
||||
remotename = stream.socket.getpeername()
|
||||
if isinstance(remotename, tuple) and len(remotename) == 2:
|
||||
remote = f'[{remotename[0]}]:{remotename[1]}'
|
||||
else:
|
||||
remote = str(remotename)
|
||||
try:
|
||||
await Connection.run(self, stream, server_side=True)
|
||||
except LoopbackConnectionError:
|
||||
pass
|
||||
except trio.TooSlowError:
|
||||
pass
|
||||
finally:
|
||||
_logger.debug(f"Connection from {remote} closed")
|
||||
|
||||
async def _listen(self, task_status=trio.TASK_STATUS_IGNORED):
|
||||
await trio.serve_tcp(self._accept, 5000, task_status=task_status)
|
||||
|
||||
async def run(self, *, task_status=trio.TASK_STATUS_IGNORED):
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(self.database.run)
|
||||
if self._is_controller:
|
||||
await nursery.start(self._listen)
|
||||
|
||||
for name in self._config.controllers:
|
||||
conn = ControllerConn(self, name)
|
||||
self._controllers[name] = conn
|
||||
|
||||
task_status.started()
|
||||
|
||||
for conn in self._controllers.values():
|
||||
nursery.start_soon(conn.run)
|
||||
|
||||
await trio.sleep_forever()
|
||||
|
||||
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
|
||||
m = hmac.new(self._config.comm_secret.encode('utf8'), digestmod=hashlib.sha256)
|
||||
if server_side:
|
||||
m.update(b'server$')
|
||||
else:
|
||||
m.update(b'client$')
|
||||
m.update(ssl_binding)
|
||||
return m.digest()
|
||||
|
||||
def _make_hello(self, ssl_binding: bytes, server_side: bool) -> capport.comm.message.Hello:
|
||||
return capport.comm.message.Hello(
|
||||
instance_id=self._instance_id.bytes,
|
||||
hostname=self._hostname,
|
||||
is_controller=self._is_controller,
|
||||
authentication=self._calc_authentication(ssl_binding, server_side),
|
||||
)
|
||||
|
||||
async def _sync_new_connection(self, peer_id: uuid.UUID, conn: Connection) -> None:
|
||||
# send database (and all changes) to peers
|
||||
await self.send(*self.database.serialize(), to=peer_id)
|
||||
|
||||
async def _new_peer(self, peer_id: uuid.UUID, conn: Connection, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
||||
have = self._established.get(peer_id, None)
|
||||
if not have:
|
||||
# peer unknown, "normal start"
|
||||
# no "await" between get above and set here!!!
|
||||
self._established[peer_id] = conn
|
||||
# first wait for app to handle new peer
|
||||
await self._app.new_peer(peer_id=peer_id)
|
||||
task_status.started()
|
||||
await self._sync_new_connection(peer_id, conn)
|
||||
return
|
||||
|
||||
# peer already known - immediately allow receiving messages, then sync connection
|
||||
task_status.started()
|
||||
await self._sync_new_connection(peer_id, conn)
|
||||
# now try to register connection for outgoing messages
|
||||
while True:
|
||||
# recheck whether peer is currently known (due to awaits since last get)
|
||||
have = self._established.get(peer_id, None)
|
||||
if have:
|
||||
# already got a connection, nothing to do as long as it lives
|
||||
await have.closed.wait()
|
||||
else:
|
||||
# make `conn` new outgoing connection for peer
|
||||
# no "await" between get above and set here!!!
|
||||
self._established[peer_id] = conn
|
||||
await self._app.new_peer(peer_id=peer_id)
|
||||
return
|
||||
|
||||
def _lost_peer(self, peer_id: uuid.UUID, conn: Connection):
|
||||
have = self._established.get(peer_id, None)
|
||||
lost = False
|
||||
if have is conn:
|
||||
lost = True
|
||||
self._established.pop(peer_id)
|
||||
conn.closed.set()
|
||||
# only notify if this was the active connection
|
||||
if lost:
|
||||
# even when we failover to another connection we still need to resync
|
||||
# as we don't know which messages might have got lost
|
||||
# -> always trigger lost_peer
|
||||
self._app.lost_peer(peer_id=peer_id)
|
||||
|
||||
async def _received_msg(self, peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
|
||||
variant = msg.to_variant()
|
||||
if isinstance(variant, capport.comm.message.Hello):
|
||||
pass
|
||||
elif isinstance(variant, capport.comm.message.AuthenticationResult):
|
||||
pass
|
||||
elif isinstance(variant, capport.comm.message.Ping):
|
||||
pass
|
||||
elif isinstance(variant, capport.comm.message.MacStates):
|
||||
await self._app.received_mac_state(from_peer_id=peer_id, states=variant)
|
||||
async with self.database.make_changes() as pu:
|
||||
for state in variant.states:
|
||||
pu.received_mac_state(state)
|
||||
if pu:
|
||||
# re-broadcast all received updates to all peers
|
||||
await self.broadcast(*pu.serialized, exclude=peer_id)
|
||||
await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu)
|
||||
else:
|
||||
await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg)
|
||||
|
||||
def peer_is_controller(self, peer_id: uuid.UUID) -> bool:
|
||||
conn = self._established.get(peer_id)
|
||||
if conn:
|
||||
return conn.peer.is_controller
|
||||
return False
|
||||
|
||||
async def send(self, *msgs: capport.comm.message.Message, to: uuid.UUID):
|
||||
conn = self._established.get(to)
|
||||
if conn:
|
||||
await conn.send_msg(*msgs)
|
||||
|
||||
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: typing.Optional[uuid.UUID] = None):
|
||||
async with trio.open_nursery() as nursery:
|
||||
for peer_id, conn in self._established.items():
|
||||
if peer_id == exclude:
|
||||
continue
|
||||
nursery.start_soon(conn.send_msg, *msgs)
|
@ -1,36 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .protobuf import message_pb2
|
||||
|
||||
|
||||
def _message_to_variant(self: message_pb2.Message) -> typing.Any:
|
||||
variant_name = self.WhichOneof('oneof')
|
||||
if variant_name:
|
||||
return getattr(self, variant_name)
|
||||
return None
|
||||
|
||||
|
||||
def _make_to_message(oneof_field):
|
||||
def to_message(self) -> message_pb2.Message:
|
||||
msg = message_pb2.Message(**{oneof_field: self})
|
||||
return msg
|
||||
return to_message
|
||||
|
||||
|
||||
def _monkey_patch():
|
||||
g = globals()
|
||||
g['Message'] = message_pb2.Message
|
||||
message_pb2.Message.to_variant = _message_to_variant
|
||||
for field in message_pb2._MESSAGE.oneofs_by_name['oneof'].fields:
|
||||
type_name = field.message_type.name
|
||||
field_type = getattr(message_pb2, type_name)
|
||||
field_type.to_message = _make_to_message(field.name)
|
||||
g[type_name] = field_type
|
||||
|
||||
|
||||
# also re-exports all message types
|
||||
_monkey_patch()
|
||||
# not a variant of Message, still re-export
|
||||
MacState = message_pb2.MacState
|
@ -1,93 +0,0 @@
|
||||
import google.protobuf.message
|
||||
import typing
|
||||
|
||||
|
||||
# manually maintained typehints for protobuf created (and monkey-patched) types
|
||||
|
||||
|
||||
class Message(google.protobuf.message.Message):
|
||||
hello: Hello
|
||||
authentication_result: AuthenticationResult
|
||||
ping: Ping
|
||||
mac_states: MacStates
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hello: typing.Optional[Hello]=None,
|
||||
authentication_result: typing.Optional[AuthenticationResult]=None,
|
||||
ping: typing.Optional[Ping]=None,
|
||||
mac_states: typing.Optional[MacStates]=None,
|
||||
) -> None: ...
|
||||
|
||||
def to_variant(self) -> typing.Union[Hello, AuthenticationResult, Ping, MacStates]: ...
|
||||
|
||||
|
||||
class Hello(google.protobuf.message.Message):
|
||||
instance_id: bytes
|
||||
hostname: str
|
||||
is_controller: bool
|
||||
authentication: bytes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
instance_id: bytes=b'',
|
||||
hostname: str='',
|
||||
is_controller: bool=False,
|
||||
authentication: bytes=b'',
|
||||
) -> None: ...
|
||||
|
||||
def to_message(self) -> Message: ...
|
||||
|
||||
|
||||
class AuthenticationResult(google.protobuf.message.Message):
|
||||
success: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
success: bool=False,
|
||||
) -> None: ...
|
||||
|
||||
def to_message(self) -> Message: ...
|
||||
|
||||
|
||||
class Ping(google.protobuf.message.Message):
|
||||
payload: bytes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: bytes=b'',
|
||||
) -> None: ...
|
||||
|
||||
def to_message(self) -> Message: ...
|
||||
|
||||
|
||||
class MacStates(google.protobuf.message.Message):
|
||||
states: typing.List[MacState]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
states: typing.List[MacState]=[],
|
||||
) -> None: ...
|
||||
|
||||
def to_message(self) -> Message: ...
|
||||
|
||||
|
||||
class MacState(google.protobuf.message.Message):
|
||||
mac_address: bytes
|
||||
last_change: int # Seconds of UTC time since epoch
|
||||
allow_until: int # Seconds of UTC time since epoch
|
||||
allowed: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mac_address: bytes=b'',
|
||||
last_change: int=0,
|
||||
allow_until: int=0,
|
||||
allowed: bool=False,
|
||||
) -> None: ...
|
@ -1,35 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: message.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x07\x63\x61pport\"\xbc\x01\n\x07Message\x12\x1f\n\x05hello\x18\x01 \x01(\x0b\x32\x0e.capport.HelloH\x00\x12>\n\x15\x61uthentication_result\x18\x02 \x01(\x0b\x32\x1d.capport.AuthenticationResultH\x00\x12\x1d\n\x04ping\x18\x03 \x01(\x0b\x32\r.capport.PingH\x00\x12(\n\nmac_states\x18\n \x01(\x0b\x32\x12.capport.MacStatesH\x00\x42\x07\n\x05oneof\"]\n\x05Hello\x12\x13\n\x0binstance_id\x18\x01 \x01(\x0c\x12\x10\n\x08hostname\x18\x02 \x01(\t\x12\x15\n\ris_controller\x18\x03 \x01(\x08\x12\x16\n\x0e\x61uthentication\x18\x04 \x01(\x0c\"\'\n\x14\x41uthenticationResult\x12\x0f\n\x07success\x18\x01 \x01(\x08\"\x17\n\x04Ping\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\".\n\tMacStates\x12!\n\x06states\x18\x01 \x03(\x0b\x32\x11.capport.MacState\"Z\n\x08MacState\x12\x13\n\x0bmac_address\x18\x01 \x01(\x0c\x12\x13\n\x0blast_change\x18\x02 \x01(\x03\x12\x13\n\x0b\x61llow_until\x18\x03 \x01(\x03\x12\x0f\n\x07\x61llowed\x18\x04 \x01(\x08\x62\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'message_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_MESSAGE._serialized_start=27
|
||||
_MESSAGE._serialized_end=215
|
||||
_HELLO._serialized_start=217
|
||||
_HELLO._serialized_end=310
|
||||
_AUTHENTICATIONRESULT._serialized_start=312
|
||||
_AUTHENTICATIONRESULT._serialized_end=351
|
||||
_PING._serialized_start=353
|
||||
_PING._serialized_end=376
|
||||
_MACSTATES._serialized_start=378
|
||||
_MACSTATES._serialized_end=424
|
||||
_MACSTATE._serialized_start=426
|
||||
_MACSTATE._serialized_end=516
|
||||
# @@protoc_insertion_point(module_scope)
|
@ -1,55 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os.path
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
_cached_config: typing.Optional[Config] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Config:
|
||||
controllers: typing.List[str]
|
||||
server_names: typing.List[str]
|
||||
comm_secret: str
|
||||
cookie_secret: str
|
||||
venue_info_url: typing.Optional[str]
|
||||
session_timeout: int # in seconds
|
||||
database_file: str # empty str: disable database
|
||||
debug: bool
|
||||
|
||||
@staticmethod
|
||||
def load_default_once() -> Config:
|
||||
global _cached_config
|
||||
if not _cached_config:
|
||||
_cached_config = Config.load()
|
||||
return _cached_config
|
||||
|
||||
@staticmethod
|
||||
def load(filename: typing.Optional[str] = None) -> Config:
|
||||
if filename is None:
|
||||
if len(sys.argv) > 0:
|
||||
for name in sys.argv[1:]:
|
||||
if os.path.exists(name):
|
||||
return Config.load(name)
|
||||
for name in ('capport.yaml', '/etc/capport.yaml'):
|
||||
if os.path.exists(name):
|
||||
return Config.load(name)
|
||||
raise RuntimeError("Missing config file")
|
||||
with open(filename) as f:
|
||||
data = yaml.safe_load(f)
|
||||
controllers = list(map(str, data['controllers']))
|
||||
return Config(
|
||||
controllers=controllers,
|
||||
server_names=data.get('server-names', []),
|
||||
comm_secret=str(data['comm-secret']),
|
||||
cookie_secret=str(data['cookie-secret']),
|
||||
venue_info_url=str(data.get('venue-info-url')),
|
||||
session_timeout=data.get('session-timeout', 3600),
|
||||
database_file=str(data['database-file']) if 'database-file' in data else 'capport.state',
|
||||
debug=data.get('debug', False)
|
||||
)
|
@ -1,70 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
import trio
|
||||
|
||||
import capport.comm.hub
|
||||
import capport.comm.message
|
||||
import capport.config
|
||||
import capport.database
|
||||
import capport.utils.cli
|
||||
import capport.utils.nft_set
|
||||
from capport import cptypes
|
||||
from capport.utils.sd_notify import open_sdnotify
|
||||
|
||||
|
||||
class ControlApp(capport.comm.hub.HubApplication):
|
||||
hub: capport.comm.hub.Hub
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.nft_set = capport.utils.nft_set.NftSet()
|
||||
|
||||
async def mac_states_changed(
|
||||
self,
|
||||
*,
|
||||
from_peer_id: uuid.UUID,
|
||||
pending_updates: capport.database.PendingUpdates,
|
||||
) -> None:
|
||||
self.apply_db_entries(pending_updates.changes())
|
||||
|
||||
def apply_db_entries(
|
||||
self,
|
||||
entries: typing.Iterable[typing.Tuple[cptypes.MacAddress, capport.database.MacEntry]],
|
||||
) -> None:
|
||||
# deploy changes to netfilter set
|
||||
inserts = []
|
||||
removals = []
|
||||
now = cptypes.Timestamp.now()
|
||||
for mac, state in entries:
|
||||
rem = state.allowed_remaining(now)
|
||||
if rem > 0:
|
||||
inserts.append((mac, rem))
|
||||
else:
|
||||
removals.append(mac)
|
||||
self.nft_set.bulk_insert(inserts)
|
||||
self.nft_set.bulk_remove(removals)
|
||||
|
||||
|
||||
async def amain(config: capport.config.Config) -> None:
|
||||
async with open_sdnotify() as sn:
|
||||
app = ControlApp()
|
||||
hub = capport.comm.hub.Hub(config=config, app=app, is_controller=True)
|
||||
app.hub = hub
|
||||
async with trio.open_nursery() as nursery:
|
||||
# hub.run loads the statefile from disk before signalling it was "started"
|
||||
await nursery.start(hub.run)
|
||||
await sn.send('READY=1', 'STATUS=Deploying initial entries to nftables set')
|
||||
app.apply_db_entries(hub.database.entries())
|
||||
await sn.send('STATUS=Kernel fully synchronized')
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = capport.config.Config.load_default_once()
|
||||
capport.utils.cli.init_logger(config)
|
||||
try:
|
||||
trio.run(amain, config)
|
||||
except (KeyboardInterrupt, InterruptedError):
|
||||
print()
|
@ -1,112 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import datetime
|
||||
import ipaddress
|
||||
import json
|
||||
import time
|
||||
import typing
|
||||
|
||||
import quart
|
||||
|
||||
import capport.utils.zoneinfo
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .config import Config
|
||||
|
||||
|
||||
IPAddress = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class MacAddress:
|
||||
raw: bytes
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.raw.hex(':')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(str(self))
|
||||
|
||||
@staticmethod
|
||||
def parse(s: str) -> MacAddress:
|
||||
return MacAddress(bytes.fromhex(s.replace(':', '')))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, order=True)
|
||||
class Timestamp:
|
||||
epoch: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
try:
|
||||
ts = datetime.datetime.fromtimestamp(self.epoch, datetime.timezone.utc)
|
||||
return ts.isoformat(sep=' ')
|
||||
except OSError:
|
||||
return f'epoch@{self.epoch}'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(str(self))
|
||||
|
||||
@staticmethod
|
||||
def now() -> Timestamp:
|
||||
now = int(time.time())
|
||||
return Timestamp(epoch=now)
|
||||
|
||||
@staticmethod
|
||||
def from_protobuf(epoch: int) -> typing.Optional[Timestamp]:
|
||||
if epoch:
|
||||
return Timestamp(epoch=epoch)
|
||||
return None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MacPublicState:
|
||||
address: IPAddress
|
||||
mac: typing.Optional[MacAddress]
|
||||
allowed_remaining: int
|
||||
|
||||
@staticmethod
|
||||
def from_missing_mac(address: IPAddress) -> MacPublicState:
|
||||
return MacPublicState(
|
||||
address=address,
|
||||
mac=None,
|
||||
allowed_remaining=0,
|
||||
)
|
||||
|
||||
@property
|
||||
def allowed(self) -> bool:
|
||||
return self.allowed_remaining > 0
|
||||
|
||||
@property
|
||||
def captive(self) -> bool:
|
||||
return not self.allowed
|
||||
|
||||
@property
|
||||
def allowed_remaining_duration(self) -> str:
|
||||
mm, ss = divmod(self.allowed_remaining, 60)
|
||||
hh, mm = divmod(mm, 60)
|
||||
return f'{hh}:{mm:02}:{ss:02}'
|
||||
|
||||
@property
|
||||
def allowed_until(self) -> typing.Optional[datetime.datetime]:
|
||||
zone = capport.utils.zoneinfo.get_local_timezone()
|
||||
now = datetime.datetime.now(zone).replace(microsecond=0)
|
||||
return now + datetime.timedelta(seconds=self.allowed_remaining)
|
||||
|
||||
def to_rfc8908(self, config: Config) -> quart.Response:
|
||||
response: dict[str, typing.Any] = {
|
||||
'user-portal-url': quart.url_for('index', _external=True),
|
||||
}
|
||||
if config.venue_info_url:
|
||||
response['venue-info-url'] = config.venue_info_url
|
||||
if self.captive:
|
||||
response['captive'] = True
|
||||
else:
|
||||
response['captive'] = False
|
||||
response['seconds-remaining'] = self.allowed_remaining
|
||||
response['can-extend-session'] = True
|
||||
return quart.Response(
|
||||
json.dumps(response),
|
||||
headers={'Cache-Control': 'private'},
|
||||
content_type='application/captive+json',
|
||||
)
|
@ -1,389 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import typing
|
||||
|
||||
import google.protobuf.message
|
||||
|
||||
import trio
|
||||
|
||||
import capport.comm.message
|
||||
from capport import cptypes
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MacEntry:
|
||||
# entry can be removed if last_change was some time ago and allow_until wasn't set
|
||||
# or got reached.
|
||||
WAIT_LAST_CHANGE_SECONDS = 60
|
||||
WAIT_ALLOW_UNTIL_PASSED_SECONDS = 10
|
||||
|
||||
# last_change: timestamp of last change (sent by system initiating the change)
|
||||
last_change: cptypes.Timestamp
|
||||
# only if allowed is true and allow_until is set the device can communicate with the internet
|
||||
# allow_until must not go backwards (and not get unset)
|
||||
allow_until: typing.Optional[cptypes.Timestamp]
|
||||
allowed: bool
|
||||
|
||||
@staticmethod
|
||||
def parse_state(msg: capport.comm.message.MacState) -> typing.Tuple[cptypes.MacAddress, MacEntry]:
|
||||
if len(msg.mac_address) < 6:
|
||||
raise Exception("Invalid MacState: mac_address too short")
|
||||
addr = cptypes.MacAddress(raw=msg.mac_address)
|
||||
last_change = cptypes.Timestamp.from_protobuf(msg.last_change)
|
||||
if not last_change:
|
||||
raise Exception(f"Invalid MacState[{addr}]: missing last_change")
|
||||
allow_until = cptypes.Timestamp.from_protobuf(msg.allow_until)
|
||||
return (addr, MacEntry(last_change=last_change, allow_until=allow_until, allowed=msg.allowed))
|
||||
|
||||
def to_state(self, addr: cptypes.MacAddress) -> capport.comm.message.MacState:
|
||||
allow_until = 0
|
||||
if self.allow_until:
|
||||
allow_until = self.allow_until.epoch
|
||||
return capport.comm.message.MacState(
|
||||
mac_address=addr.raw,
|
||||
last_change=self.last_change.epoch,
|
||||
allow_until=allow_until,
|
||||
allowed=self.allowed,
|
||||
)
|
||||
|
||||
def as_json(self) -> dict:
|
||||
allow_until = None
|
||||
if self.allow_until:
|
||||
allow_until = self.allow_until.epoch
|
||||
return dict(
|
||||
last_change=self.last_change.epoch,
|
||||
allow_until=allow_until,
|
||||
allowed=self.allowed,
|
||||
)
|
||||
|
||||
def merge(self, new: MacEntry) -> bool:
|
||||
changed = False
|
||||
if new.last_change > self.last_change:
|
||||
changed = True
|
||||
self.last_change = new.last_change
|
||||
self.allowed = new.allowed
|
||||
elif new.last_change == self.last_change:
|
||||
# same last_change: set allowed if one allowed
|
||||
if new.allowed and not self.allowed:
|
||||
changed = True
|
||||
self.allowed = True
|
||||
# set allow_until to max of both
|
||||
if new.allow_until: # if not set nothing to change in local data
|
||||
if not self.allow_until or self.allow_until < new.allow_until:
|
||||
changed = True
|
||||
self.allow_until = new.allow_until
|
||||
return changed
|
||||
|
||||
def timeout(self) -> cptypes.Timestamp:
|
||||
elc = self.last_change.epoch + self.WAIT_LAST_CHANGE_SECONDS
|
||||
if self.allow_until:
|
||||
eau = self.allow_until.epoch + self.WAIT_ALLOW_UNTIL_PASSED_SECONDS
|
||||
if eau > elc:
|
||||
return cptypes.Timestamp(epoch=eau)
|
||||
return cptypes.Timestamp(epoch=elc)
|
||||
|
||||
# returns 0 if not allowed
|
||||
def allowed_remaining(self, now: typing.Optional[cptypes.Timestamp] = None) -> int:
|
||||
if not self.allowed or not self.allow_until:
|
||||
return 0
|
||||
if not now:
|
||||
now = cptypes.Timestamp.now()
|
||||
assert self.allow_until
|
||||
return max(self.allow_until.epoch - now.epoch, 0)
|
||||
|
||||
def outdated(self, now: typing.Optional[cptypes.Timestamp] = None) -> bool:
|
||||
if not now:
|
||||
now = cptypes.Timestamp.now()
|
||||
return now.epoch > self.timeout().epoch
|
||||
|
||||
|
||||
# might use this to serialize into file - don't need Message variant there
|
||||
def _serialize_mac_states(macs: dict[cptypes.MacAddress, MacEntry]) -> typing.List[capport.comm.message.MacStates]:
|
||||
result: typing.List[capport.comm.message.MacStates] = []
|
||||
current = capport.comm.message.MacStates()
|
||||
for addr, entry in macs.items():
|
||||
state = entry.to_state(addr)
|
||||
current.states.append(state)
|
||||
if len(current.states) >= 1024: # split into messages with 1024 states
|
||||
result.append(current)
|
||||
current = capport.comm.message.MacStates()
|
||||
if len(current.states):
|
||||
result.append(current)
|
||||
return result
|
||||
|
||||
|
||||
def _serialize_mac_states_as_messages(
|
||||
macs: dict[cptypes.MacAddress, MacEntry],
|
||||
) -> typing.List[capport.comm.message.Message]:
|
||||
return [s.to_message() for s in _serialize_mac_states(macs)]
|
||||
|
||||
|
||||
def _states_to_chunk(states: capport.comm.message.MacStates) -> bytes:
|
||||
chunk = states.SerializeToString(deterministic=True)
|
||||
chunk_size = len(chunk)
|
||||
len_bytes = struct.pack('!I', chunk_size)
|
||||
return len_bytes + chunk
|
||||
|
||||
|
||||
class NotReadyYet(Exception):
|
||||
def __init__(self, msg: str, wait: int):
|
||||
self.wait = wait # seconds to wait
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, state_filename: str = None):
|
||||
self._macs: dict[cptypes.MacAddress, MacEntry] = {}
|
||||
self._state_filename = state_filename
|
||||
self._changed_since_last_cleanup = False
|
||||
self._send_changes: typing.Optional[trio.MemorySendChannel[typing.Union[
|
||||
capport.comm.message.MacStates,
|
||||
typing.List[capport.comm.message.MacStates],
|
||||
]]] = None
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_changes(self) -> typing.AsyncGenerator[PendingUpdates, None]:
|
||||
pu = PendingUpdates(self)
|
||||
pu._closed = False
|
||||
yield pu
|
||||
pu._finish()
|
||||
if pu:
|
||||
self._changed_since_last_cleanup = True
|
||||
if self._send_changes:
|
||||
for state in pu.serialized_states:
|
||||
await self._send_changes.send(state)
|
||||
|
||||
def _drop_outdated(self) -> None:
|
||||
done = False
|
||||
while not done:
|
||||
depr: typing.Set[cptypes.MacAddress] = set()
|
||||
now = cptypes.Timestamp.now()
|
||||
done = True
|
||||
for mac, entry in self._macs.items():
|
||||
if entry.outdated(now):
|
||||
depr.add(mac)
|
||||
if len(depr) >= 1024:
|
||||
# clear entries found so far, then try again
|
||||
done = False
|
||||
break
|
||||
if depr:
|
||||
self._changed_since_last_cleanup = True
|
||||
for mac in depr:
|
||||
del self._macs[mac]
|
||||
|
||||
async def run(self, task_status=trio.TASK_STATUS_IGNORED):
|
||||
if self._state_filename:
|
||||
await self._load_statefile()
|
||||
task_status.started()
|
||||
async with trio.open_nursery() as nursery:
|
||||
if self._state_filename:
|
||||
nursery.start_soon(self._run_statefile)
|
||||
while True:
|
||||
await trio.sleep(300) # cleanup every 5 minutes
|
||||
_logger.debug("Running database cleanup")
|
||||
self._drop_outdated()
|
||||
if self._changed_since_last_cleanup:
|
||||
self._changed_since_last_cleanup = False
|
||||
if self._send_changes:
|
||||
states = _serialize_mac_states(self._macs)
|
||||
# trigger a resync
|
||||
await self._send_changes.send(states)
|
||||
|
||||
# for initial handling of all data
|
||||
def entries(self) -> typing.List[typing.Tuple[cptypes.MacAddress, MacEntry]]:
|
||||
return list(self._macs.items())
|
||||
|
||||
# for initial sync with new peer
|
||||
def serialize(self) -> typing.List[capport.comm.message.Message]:
|
||||
return _serialize_mac_states_as_messages(self._macs)
|
||||
|
||||
def as_json(self) -> dict:
|
||||
return {
|
||||
str(addr): entry.as_json()
|
||||
for addr, entry in self._macs.items()
|
||||
}
|
||||
|
||||
async def _run_statefile(self) -> None:
|
||||
rx: trio.MemoryReceiveChannel[typing.Union[
|
||||
capport.comm.message.MacStates,
|
||||
typing.List[capport.comm.message.MacStates],
|
||||
]]
|
||||
tx: trio.MemorySendChannel[typing.Union[
|
||||
capport.comm.message.MacStates,
|
||||
typing.List[capport.comm.message.MacStates],
|
||||
]]
|
||||
tx, rx = trio.open_memory_channel(64)
|
||||
self._send_changes = tx
|
||||
|
||||
assert self._state_filename
|
||||
filename: str = self._state_filename
|
||||
tmp_filename = f'{filename}.new-{os.getpid()}'
|
||||
|
||||
async def resync(all_states: typing.List[capport.comm.message.MacStates]):
|
||||
try:
|
||||
async with await trio.open_file(tmp_filename, 'xb') as tf:
|
||||
for states in all_states:
|
||||
await tf.write(_states_to_chunk(states))
|
||||
os.rename(tmp_filename, filename)
|
||||
finally:
|
||||
if os.path.exists(tmp_filename):
|
||||
_logger.warning(f'Removing (failed) state export file {tmp_filename}')
|
||||
os.unlink(tmp_filename)
|
||||
|
||||
try:
|
||||
while True:
|
||||
async with await trio.open_file(filename, 'ab', buffering=0) as sf:
|
||||
while True:
|
||||
update = await rx.receive()
|
||||
if isinstance(update, list):
|
||||
break
|
||||
await sf.write(_states_to_chunk(update))
|
||||
# got a "list" update - i.e. a resync
|
||||
with trio.CancelScope(shield=True):
|
||||
await resync(update)
|
||||
# now reopen normal statefile and continue appending updates
|
||||
except trio.Cancelled:
|
||||
_logger.info('Final sync to disk')
|
||||
with trio.CancelScope(shield=True):
|
||||
await resync(_serialize_mac_states(self._macs))
|
||||
_logger.info('Final sync to disk done')
|
||||
|
||||
async def _load_statefile(self):
|
||||
if not os.path.exists(self._state_filename):
|
||||
return
|
||||
_logger.info("Loading statefile")
|
||||
# we're going to ignore changes from loading the file
|
||||
pu = PendingUpdates(self)
|
||||
pu._closed = False
|
||||
async with await trio.open_file(self._state_filename, 'rb') as sf:
|
||||
while True:
|
||||
try:
|
||||
len_bytes = await sf.read(4)
|
||||
if not len_bytes:
|
||||
return
|
||||
if len(len_bytes) < 4:
|
||||
_logger.error("Failed to read next chunk from statefile (unexpected EOF)")
|
||||
return
|
||||
chunk_size, = struct.unpack('!I', len_bytes)
|
||||
chunk = await sf.read(chunk_size)
|
||||
except IOError as e:
|
||||
_logger.error(f"Failed to read next chunk from statefile: {e}")
|
||||
return
|
||||
try:
|
||||
states = capport.comm.message.MacStates()
|
||||
states.ParseFromString(chunk)
|
||||
except google.protobuf.message.DecodeError as e:
|
||||
_logger.error(f"Failed to decode chunk from statefile, trying next one: {e}")
|
||||
continue
|
||||
for state in states.states:
|
||||
errors = 0
|
||||
try:
|
||||
pu.received_mac_state(state)
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if errors < 5:
|
||||
_logger.error(f'Failed to handle state: {e}')
|
||||
|
||||
def lookup(self, address: cptypes.IPAddress, mac: cptypes.MacAddress) -> cptypes.MacPublicState:
|
||||
entry = self._macs.get(mac)
|
||||
if entry:
|
||||
allowed_remaining = entry.allowed_remaining()
|
||||
else:
|
||||
allowed_remaining = 0
|
||||
return cptypes.MacPublicState(
|
||||
address=address,
|
||||
mac=mac,
|
||||
allowed_remaining=allowed_remaining,
|
||||
)
|
||||
|
||||
|
||||
class PendingUpdates:
|
||||
def __init__(self, database: Database):
|
||||
self._changes: dict[cptypes.MacAddress, MacEntry] = {}
|
||||
self._database = database
|
||||
self._closed = True
|
||||
self._serialized_states: typing.List[capport.comm.message.MacStates] = []
|
||||
self._serialized: typing.List[capport.comm.message.Message] = []
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._changes)
|
||||
|
||||
def changes(self) -> typing.Iterable[typing.Tuple[cptypes.MacAddress, MacEntry]]:
|
||||
return self._changes.items()
|
||||
|
||||
@property
|
||||
def serialized_states(self) -> typing.List[capport.comm.message.MacStates]:
|
||||
assert self._closed
|
||||
return self._serialized_states
|
||||
|
||||
@property
|
||||
def serialized(self) -> typing.List[capport.comm.message.Message]:
|
||||
assert self._closed
|
||||
return self._serialized
|
||||
|
||||
def _finish(self):
|
||||
if self._closed:
|
||||
raise Exception("Can't change closed PendingUpdates")
|
||||
self._closed = True
|
||||
self._serialized_states = _serialize_mac_states(self._changes)
|
||||
self._serialized = [s.to_message() for s in self._serialized_states]
|
||||
|
||||
def received_mac_state(self, state: capport.comm.message.MacState):
|
||||
if self._closed:
|
||||
raise Exception("Can't change closed PendingUpdates")
|
||||
(addr, new_entry) = MacEntry.parse_state(state)
|
||||
old_entry = self._database._macs.get(addr)
|
||||
if not old_entry:
|
||||
# only redistribute if not outdated
|
||||
if not new_entry.outdated():
|
||||
self._database._macs[addr] = new_entry
|
||||
self._changes[addr] = new_entry
|
||||
elif old_entry.merge(new_entry):
|
||||
if old_entry.outdated():
|
||||
# remove local entry, but still redistribute
|
||||
self._database._macs.pop(addr)
|
||||
self._changes[addr] = old_entry
|
||||
|
||||
def login(self, mac: cptypes.MacAddress, session_timeout: int, *, renew_maximum: float = 0.8):
|
||||
if self._closed:
|
||||
raise Exception("Can't change closed PendingUpdates")
|
||||
now = cptypes.Timestamp.now()
|
||||
allow_until = cptypes.Timestamp(epoch=now.epoch + session_timeout)
|
||||
new_entry = MacEntry(last_change=now, allow_until=allow_until, allowed=True)
|
||||
|
||||
entry = self._database._macs.get(mac)
|
||||
if not entry:
|
||||
self._database._macs[mac] = new_entry
|
||||
self._changes[mac] = new_entry
|
||||
elif entry.allowed_remaining(now) > renew_maximum * session_timeout:
|
||||
# too much time left on clock, not renewing session
|
||||
return
|
||||
elif entry.merge(new_entry):
|
||||
self._changes[mac] = entry
|
||||
elif not entry.allowed_remaining() > 0:
|
||||
# entry should have been updated - can only fail due to `now < entry.last_change`
|
||||
# i.e. out of sync clocks
|
||||
wait = entry.last_change.epoch - now.epoch
|
||||
raise NotReadyYet(f"can't login yet, try again in {wait} seconds", wait)
|
||||
|
||||
def logout(self, mac: cptypes.MacAddress):
|
||||
if self._closed:
|
||||
raise Exception("Can't change closed PendingUpdates")
|
||||
now = cptypes.Timestamp.now()
|
||||
new_entry = MacEntry(last_change=now, allow_until=None, allowed=False)
|
||||
entry = self._database._macs.get(mac)
|
||||
if entry:
|
||||
if entry.merge(new_entry):
|
||||
self._changes[mac] = entry
|
||||
elif entry.allowed_remaining() > 0:
|
||||
# still logged in. can only happen with `now <= entry.last_change`
|
||||
# clocks not necessarily out of sync, but you can't logout in the same second you logged in
|
||||
wait = entry.last_change.epoch - now.epoch + 1
|
||||
raise NotReadyYet(f"can't logout yet, try again in {wait} seconds", wait)
|
@ -1,104 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import sys
|
||||
import typing
|
||||
|
||||
import trio
|
||||
|
||||
import capport.utils.ipneigh
|
||||
import capport.utils.nft_set
|
||||
|
||||
from . import cptypes
|
||||
|
||||
|
||||
def print_metric(
|
||||
name: str,
|
||||
mtype: str,
|
||||
value,
|
||||
*,
|
||||
now: typing.Optional[int] = None,
|
||||
help: typing.Optional[str] = None,
|
||||
):
|
||||
# no labels in our names for now, always print help and type
|
||||
if help:
|
||||
print(f"# HELP {name} {help}")
|
||||
print(f"# TYPE {name} {mtype}")
|
||||
if now:
|
||||
print(f"{name} {value} {now}")
|
||||
else:
|
||||
print(f"{name} {value}")
|
||||
|
||||
|
||||
async def amain(client_ifname: str):
|
||||
ns = capport.utils.nft_set.NftSet()
|
||||
captive_allowed_entries: typing.Set[cptypes.MacAddress] = {
|
||||
entry["mac"] for entry in ns.list()
|
||||
}
|
||||
seen_allowed_entries: typing.Set[cptypes.MacAddress] = set()
|
||||
total_ipv4 = 0
|
||||
total_ipv6 = 0
|
||||
unique_clients = set()
|
||||
unique_ipv4 = set()
|
||||
unique_ipv6 = set()
|
||||
async with capport.utils.ipneigh.connect() as ipn:
|
||||
ipn.ip.strict_check = True
|
||||
async for (mac, addr) in ipn.dump_neighbors(client_ifname):
|
||||
if mac in captive_allowed_entries:
|
||||
seen_allowed_entries.add(mac)
|
||||
unique_clients.add(mac)
|
||||
if isinstance(addr, ipaddress.IPv4Address):
|
||||
total_ipv4 += 1
|
||||
unique_ipv4.add(mac)
|
||||
else:
|
||||
total_ipv6 += 1
|
||||
unique_ipv6.add(mac)
|
||||
print_metric(
|
||||
"capport_allowed_macs",
|
||||
"gauge",
|
||||
len(captive_allowed_entries),
|
||||
help="Number of allowed client mac addresses",
|
||||
)
|
||||
print_metric(
|
||||
"capport_allowed_neigh_macs",
|
||||
"gauge",
|
||||
len(seen_allowed_entries),
|
||||
help="Number of allowed client mac addresses seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
"capport_unique",
|
||||
"gauge",
|
||||
len(unique_clients),
|
||||
help="Number of clients (mac addresses) in client network seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
"capport_unique_ipv4",
|
||||
"gauge",
|
||||
len(unique_ipv4),
|
||||
help="Number of IPv4 clients (unique per mac) in client network seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
"capport_unique_ipv6",
|
||||
"gauge",
|
||||
len(unique_ipv6),
|
||||
help="Number of IPv6 clients (unique per mac) in client network seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
"capport_total_ipv4",
|
||||
"gauge",
|
||||
total_ipv4,
|
||||
help="Number of IPv4 addresses seen in neighbor cache",
|
||||
)
|
||||
print_metric(
|
||||
"capport_total_ipv6",
|
||||
"gauge",
|
||||
total_ipv6,
|
||||
help="Number of IPv6 addresses seen in neighbor cache",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print("Need name of client interface as argument")
|
||||
sys.exit(1)
|
||||
trio.run(amain, sys.argv[1])
|
@ -1,17 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import capport.config
|
||||
|
||||
|
||||
def init_logger(config: capport.config.Config):
|
||||
loglevel = logging.INFO
|
||||
if config.debug:
|
||||
loglevel = logging.DEBUG
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S %z]',
|
||||
level=loglevel,
|
||||
)
|
||||
logging.getLogger('hypercorn').propagate = False
|
@ -1,89 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import errno
|
||||
import ipaddress
|
||||
import socket
|
||||
import typing
|
||||
|
||||
import pyroute2.iproute.linux # type: ignore
|
||||
import pyroute2.netlink.exceptions # type: ignore
|
||||
import pyroute2.netlink.rtnl # type: ignore
|
||||
import pyroute2.netlink.rtnl.ndmsg # type: ignore
|
||||
|
||||
from capport import cptypes
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect():
|
||||
yield NeighborController()
|
||||
|
||||
|
||||
# TODO: run blocking iproute calls in a different thread?
|
||||
class NeighborController:
|
||||
def __init__(self):
|
||||
self.ip = pyroute2.iproute.linux.IPRoute()
|
||||
|
||||
async def get_neighbor(
|
||||
self,
|
||||
address: cptypes.IPAddress,
|
||||
*,
|
||||
index: int = 0, # interface index
|
||||
flags: int = 0,
|
||||
) -> typing.Optional[pyroute2.iproute.linux.ndmsg.ndmsg]:
|
||||
if not index:
|
||||
route = await self.get_route(address)
|
||||
if route is None:
|
||||
return None
|
||||
index = route.get_attr(route.name2nla('oif'))
|
||||
try:
|
||||
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
|
||||
except pyroute2.netlink.exceptions.NetlinkError as e:
|
||||
if e.code == errno.ENOENT:
|
||||
return None
|
||||
raise
|
||||
|
||||
async def get_neighbor_mac(
|
||||
self,
|
||||
address: cptypes.IPAddress,
|
||||
*,
|
||||
index: int = 0, # interface index
|
||||
flags: int = 0,
|
||||
) -> typing.Optional[cptypes.MacAddress]:
|
||||
neigh = await self.get_neighbor(address, index=index, flags=flags)
|
||||
if neigh is None:
|
||||
return None
|
||||
mac = neigh.get_attr(neigh.name2nla('lladdr'))
|
||||
if mac is None:
|
||||
return None
|
||||
return cptypes.MacAddress.parse(mac)
|
||||
|
||||
async def get_route(
|
||||
self,
|
||||
address: cptypes.IPAddress,
|
||||
) -> typing.Optional[pyroute2.iproute.linux.rtmsg]:
|
||||
try:
|
||||
return self.ip.route('get', dst=str(address))[0]
|
||||
except pyroute2.netlink.exceptions.NetlinkError as e:
|
||||
if e.code == errno.ENOENT:
|
||||
return None
|
||||
raise
|
||||
|
||||
async def dump_neighbors(
|
||||
self,
|
||||
interface: str,
|
||||
) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
|
||||
ifindex = socket.if_nametoindex(interface)
|
||||
unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
|
||||
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
|
||||
for family in (socket.AF_INET, socket.AF_INET6):
|
||||
for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family):
|
||||
if neigh['ndm_type'] != unicast_num:
|
||||
continue
|
||||
mac = neigh.get_attr(neigh.name2nla('lladdr'))
|
||||
if not mac:
|
||||
continue
|
||||
dst = ipaddress.ip_address(neigh.get_attr(neigh.name2nla('dst')))
|
||||
if dst.is_link_local:
|
||||
continue
|
||||
yield (cptypes.MacAddress.parse(mac), dst)
|
@ -1,188 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pyroute2.netlink # type: ignore
|
||||
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
|
||||
|
||||
from capport import cptypes
|
||||
|
||||
from .nft_socket import NFTSocket
|
||||
|
||||
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
|
||||
|
||||
|
||||
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]:
|
||||
# to seconds
|
||||
if msecs is None:
|
||||
return None
|
||||
return msecs / 1000.0
|
||||
|
||||
|
||||
class NftSet:
|
||||
def __init__(self):
|
||||
self._socket = NFTSocket()
|
||||
self._socket.bind()
|
||||
|
||||
@staticmethod
|
||||
def _set_elem(
|
||||
mac: cptypes.MacAddress,
|
||||
timeout: typing.Optional[typing.Union[int, float]] = None,
|
||||
) -> _nftsocket.nft_set_elem_list_msg.set_elem:
|
||||
attrs: dict[str, typing.Any] = {
|
||||
'NFTA_SET_ELEM_KEY': dict(
|
||||
NFTA_DATA_VALUE=mac.raw,
|
||||
),
|
||||
}
|
||||
if timeout:
|
||||
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
|
||||
return attrs
|
||||
|
||||
def _bulk_insert(
|
||||
self,
|
||||
entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]],
|
||||
) -> None:
|
||||
ser_entries = [
|
||||
self._set_elem(mac)
|
||||
for mac, _timeout in entries
|
||||
]
|
||||
ser_entries_with_timeout = [
|
||||
self._set_elem(mac, timeout)
|
||||
for mac, timeout in entries
|
||||
]
|
||||
with self._socket.begin() as tx:
|
||||
# create doesn't affect existing elements, so:
|
||||
# make sure entries exists
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWSETELEM,
|
||||
pyroute2.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
# drop entries (would fail if it doesn't exist)
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_DELSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
# now create entries with new timeout value
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWSETELEM,
|
||||
pyroute2.netlink.NLM_F_CREATE | pyroute2.netlink.NLM_F_EXCL,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
|
||||
),
|
||||
)
|
||||
|
||||
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
|
||||
# limit chunk size
|
||||
while len(entries) > 0:
|
||||
self._bulk_insert(entries[:1024])
|
||||
entries = entries[1024:]
|
||||
|
||||
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None:
|
||||
self.bulk_insert([(mac, timeout)])
|
||||
|
||||
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
||||
ser_entries = [
|
||||
self._set_elem(mac)
|
||||
for mac in entries
|
||||
]
|
||||
with self._socket.begin() as tx:
|
||||
# make sure entries exists
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWSETELEM,
|
||||
pyroute2.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
# drop entries (would fail if it doesn't exist)
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_DELSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
|
||||
def bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
||||
# limit chunk size
|
||||
while len(entries) > 0:
|
||||
self._bulk_remove(entries[:1024])
|
||||
entries = entries[1024:]
|
||||
|
||||
def remove(self, mac: cptypes.MacAddress) -> None:
|
||||
self.bulk_remove([mac])
|
||||
|
||||
def list(self) -> list:
|
||||
responses: typing.Iterator[_nftsocket.nft_set_elem_list_msg]
|
||||
responses = self._socket.nft_dump(
|
||||
_nftsocket.NFT_MSG_GETSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
)
|
||||
)
|
||||
return [
|
||||
{
|
||||
'mac': cptypes.MacAddress(
|
||||
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
|
||||
),
|
||||
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
|
||||
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
|
||||
}
|
||||
for response in responses
|
||||
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
|
||||
]
|
||||
|
||||
def flush(self) -> None:
|
||||
self._socket.nft_put(
|
||||
_nftsocket.NFT_MSG_DELSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_ELEM_LIST_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
)
|
||||
)
|
||||
|
||||
def create(self):
|
||||
with self._socket.begin() as tx:
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWTABLE,
|
||||
pyroute2.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_TABLE_NAME='captive_mark',
|
||||
),
|
||||
)
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWSET,
|
||||
pyroute2.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_NAME='allowed',
|
||||
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
|
||||
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
|
||||
NFTA_SET_KEY_LEN=6, # length of key: mac address
|
||||
NFTA_SET_ID=1, # kernel seems to need a set id unique per transaction
|
||||
),
|
||||
)
|
@ -1,189 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
import pyroute2.netlink # type: ignore
|
||||
import pyroute2.netlink.nlsocket # type: ignore
|
||||
from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES # type: ignore
|
||||
from pyroute2.netlink.nfnetlink import nfgen_msg # type: ignore
|
||||
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
|
||||
|
||||
|
||||
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
|
||||
|
||||
|
||||
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base)
|
||||
|
||||
|
||||
# nft uses NESTED for those.. lets do the same
|
||||
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pyroute2.netlink.NLA_F_NESTED
|
||||
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pyroute2.netlink.NLA_F_NESTED
|
||||
# nftable lists always use `1` as list element attr type
|
||||
_nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM
|
||||
|
||||
|
||||
def _monkey_patch_pyroute2():
|
||||
import pyroute2.netlink
|
||||
# overwrite setdefault on nlmsg_base class hierarchy
|
||||
_orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
|
||||
|
||||
def _nlmsg_base__setvalue(self, value):
|
||||
if not self.header or not self['header'] or not isinstance(value, dict):
|
||||
return _orig_setvalue(self, value)
|
||||
# merge headers instead of overwriting them
|
||||
header = value.pop('header', {})
|
||||
res = _orig_setvalue(self, value)
|
||||
self['header'].update(header)
|
||||
return res
|
||||
|
||||
def overwrite_methods(cls: typing.Type) -> None:
|
||||
if cls.setvalue is _orig_setvalue:
|
||||
cls.setvalue = _nlmsg_base__setvalue
|
||||
for subcls in cls.__subclasses__():
|
||||
overwrite_methods(subcls)
|
||||
|
||||
overwrite_methods(pyroute2.netlink.nlmsg_base)
|
||||
|
||||
|
||||
_monkey_patch_pyroute2()
|
||||
|
||||
|
||||
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
|
||||
msg = msg_class()
|
||||
for key, value in header.items():
|
||||
msg['header'][key] = value
|
||||
for key, value in fields.items():
|
||||
msg[key] = value
|
||||
if attrs:
|
||||
attr_list = msg['attrs']
|
||||
r_nla_map = msg_class._nlmsg_base__r_nla_map
|
||||
for key, value in attrs.items():
|
||||
if msg_class.prefix:
|
||||
key = msg_class.name2nla(key)
|
||||
prime = r_nla_map[key]
|
||||
nla_class = prime['class']
|
||||
if issubclass(nla_class, pyroute2.netlink.nla):
|
||||
# support passing nested attributes as dicts of subattributes (or lists of those)
|
||||
if prime['nla_array']:
|
||||
value = [
|
||||
_build(nla_class, attrs=elem)
|
||||
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
|
||||
else elem
|
||||
for elem in value
|
||||
]
|
||||
elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict):
|
||||
value = _build(nla_class, attrs=value)
|
||||
attr_list.append([key, value])
|
||||
return msg
|
||||
|
||||
|
||||
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
|
||||
policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
|
||||
policy = {
|
||||
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
|
||||
for (x, y) in self.policy.items()
|
||||
}
|
||||
self.register_policy(policy)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def begin(self) -> typing.Generator[NFTTransaction, None, None]:
|
||||
tx = NFTTransaction(socket=self)
|
||||
try:
|
||||
yield tx
|
||||
# autocommit when no exception was raised
|
||||
# (only commits if it wasn't aborted)
|
||||
tx.autocommit()
|
||||
finally:
|
||||
# abort does nothing if commit went through
|
||||
tx.abort()
|
||||
|
||||
def nft_put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
||||
with self.begin() as tx:
|
||||
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
|
||||
|
||||
def nft_dump(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
||||
msg_flags |= pyroute2.netlink.NLM_F_DUMP
|
||||
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
|
||||
|
||||
def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
||||
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
|
||||
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
|
||||
msg_flags |= pyroute2.netlink.NLM_F_REQUEST
|
||||
msg = _build(msg_class, attrs=attrs, **fields)
|
||||
return self.nlm_request(msg, msg_type, msg_flags)
|
||||
|
||||
|
||||
class NFTTransaction:
|
||||
def __init__(self, socket: NFTSocket) -> None:
|
||||
self._socket = socket
|
||||
self._closed = False
|
||||
self._msgs: list[nfgen_msg] = [
|
||||
# batch begin message
|
||||
_build(
|
||||
nfgen_msg,
|
||||
res_id=NFNL_SUBSYS_NFTABLES,
|
||||
header=dict(
|
||||
type=0x10, # NFNL_MSG_BATCH_BEGIN
|
||||
flags=pyroute2.netlink.NLM_F_REQUEST,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
def abort(self) -> None:
|
||||
"""
|
||||
Aborts if transaction wasn't already committed or aborted
|
||||
"""
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
|
||||
def autocommit(self) -> None:
|
||||
"""
|
||||
Commits if transaction wasn't already committed or aborted
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self.commit()
|
||||
|
||||
def commit(self) -> None:
|
||||
if self._closed:
|
||||
raise Exception("Transaction already closed")
|
||||
self._closed = True
|
||||
if len(self._msgs) == 1:
|
||||
# no inner messages were queued... not sending anything
|
||||
return
|
||||
# request ACK on the last message (before END)
|
||||
self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
|
||||
self._msgs.append(
|
||||
# batch end message
|
||||
_build(
|
||||
nfgen_msg,
|
||||
res_id=NFNL_SUBSYS_NFTABLES,
|
||||
header=dict(
|
||||
type=0x11, # NFNL_MSG_BATCH_END
|
||||
flags=pyroute2.netlink.NLM_F_REQUEST,
|
||||
),
|
||||
),
|
||||
)
|
||||
for _msg in self._socket.nlm_request_batch(self._msgs):
|
||||
# we should see at most one ACK - real errors get raised anyway
|
||||
pass
|
||||
|
||||
def _put(self, msg: nfgen_msg) -> None:
|
||||
if self._closed:
|
||||
raise Exception("Transaction already closed")
|
||||
self._msgs.append(msg)
|
||||
|
||||
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
||||
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
|
||||
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
|
||||
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!
|
||||
header = dict(
|
||||
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
|
||||
flags=msg_flags,
|
||||
)
|
||||
msg = _build(msg_class, attrs=attrs, header=header, **fields)
|
||||
self._put(msg)
|
@ -1,70 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import socket
|
||||
import typing
|
||||
|
||||
import trio
|
||||
import trio.socket
|
||||
|
||||
|
||||
def _check_watchdog_pid() -> bool:
|
||||
wpid = os.environ.pop('WATCHDOG_PID', None)
|
||||
if not wpid:
|
||||
return True
|
||||
return wpid == str(os.getpid())
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def open_sdnotify() -> typing.AsyncGenerator[SdNotify, None]:
|
||||
target = os.environ.pop('NOTIFY_SOCKET', None)
|
||||
ns: typing.Optional[trio.socket.SocketType] = None
|
||||
watchdog_usec: int = 0
|
||||
if target:
|
||||
if target.startswith('@'):
|
||||
# Linux abstract namespace socket
|
||||
target = '\0' + target[1:]
|
||||
ns = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
||||
await ns.connect(target)
|
||||
watchdog_usec_s = os.environ.pop('WATCHDOG_USEC', None)
|
||||
if _check_watchdog_pid() and watchdog_usec_s:
|
||||
watchdog_usec = int(watchdog_usec_s)
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
sn = SdNotify(_ns=ns)
|
||||
if watchdog_usec:
|
||||
await nursery.start(sn._run_watchdog, watchdog_usec)
|
||||
yield sn
|
||||
# stop watchdoch
|
||||
nursery.cancel_scope.cancel()
|
||||
finally:
|
||||
if ns:
|
||||
ns.close()
|
||||
|
||||
|
||||
class SdNotify:
|
||||
def __init__(self, *, _ns: typing.Optional[trio.socket.SocketType]) -> None:
|
||||
self._ns = _ns
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return not (self._ns is None)
|
||||
|
||||
async def _run_watchdog(self, watchdog_usec: int, *, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
||||
assert self.is_connected(), "Watchdog can't run without socket"
|
||||
await self.send('WATCHDOG=1')
|
||||
task_status.started()
|
||||
# send every half of the watchdog timeout
|
||||
interval = (watchdog_usec/1e6) / 2.0
|
||||
while True:
|
||||
await trio.sleep(interval)
|
||||
await self.send('WATCHDOG=1')
|
||||
|
||||
async def send(self, *msg: str) -> None:
|
||||
if not self.is_connected():
|
||||
return
|
||||
dgram = '\n'.join(msg).encode('utf-8')
|
||||
assert self._ns, "not connected" # checked above
|
||||
sent = await self._ns.send(dgram)
|
||||
if sent != len(dgram):
|
||||
raise OSError("Sent incomplete datagram to NOTIFY_SOCKET")
|
@ -1,19 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import zoneinfo
|
||||
|
||||
|
||||
_zoneinfo: typing.Optional[zoneinfo.ZoneInfo] = None
|
||||
|
||||
|
||||
def get_local_timezone():
|
||||
global _zoneinfo
|
||||
if not _zoneinfo:
|
||||
try:
|
||||
with open('/etc/timezone') as f:
|
||||
key = f.readline().strip()
|
||||
_zoneinfo = zoneinfo.ZoneInfo(key)
|
||||
except (OSError, zoneinfo.ZoneInfoNotFoundError):
|
||||
_zoneinfo = zoneinfo.ZoneInfo('UTC')
|
||||
return _zoneinfo
|
@ -1,8 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
base=$(dirname "$(readlink -f "$0")")
|
||||
cd "${base}"
|
||||
|
||||
exec ./venv/bin/capport-webui "$@"
|
@ -1,8 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
base=$(dirname "$(readlink -f "$0")")
|
||||
cd "${base}"
|
||||
|
||||
exec ./venv/bin/capport-control "$@"
|
@ -1,29 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -e
|
||||
|
||||
base=$(dirname "$(readlink -f "$0")")
|
||||
cd "${base}"
|
||||
|
||||
instance=$1
|
||||
ifname=$2
|
||||
|
||||
if [ -z "${instance}" -o -z "${ifname}" ]; then
|
||||
echo >&2 "Syntax: $0 instancename clientifname"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
targetname="/var/lib/prometheus/node-exporter/capport-${instance}.prom"
|
||||
tmpname="${targetname}.$$"
|
||||
|
||||
if [ -f "/run/netns/${instance}" ];then
|
||||
_run_in_ns="/usr/sbin/ns-enter ${instance} -- "
|
||||
else
|
||||
_run_in_ns=""
|
||||
fi
|
||||
if ${_run_in_ns} ${base}/stats.sh "${ifname}" > "${tmpname}"; then
|
||||
mv "${tmpname}" "${targetname}"
|
||||
else
|
||||
rm "${tmpname}"
|
||||
exit 1
|
||||
fi
|
Loading…
Reference in New Issue
Block a user