initial commit
This commit is contained in:
0
src/capport/utils/__init__.py
Normal file
0
src/capport/utils/__init__.py
Normal file
17
src/capport/utils/cli.py
Normal file
17
src/capport/utils/cli.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import capport.config
|
||||
|
||||
|
||||
def init_logger(config: capport.config.Config):
|
||||
loglevel = logging.INFO
|
||||
if config.debug:
|
||||
loglevel = logging.DEBUG
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S %z]',
|
||||
level=loglevel,
|
||||
)
|
||||
logging.getLogger('hypercorn').propagate = False
|
63
src/capport/utils/ipneigh.py
Normal file
63
src/capport/utils/ipneigh.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import errno
|
||||
import typing
|
||||
|
||||
import pr2modules.iproute.linux
|
||||
import pr2modules.netlink.exceptions
|
||||
from capport import cptypes
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect():
|
||||
yield NeighborController()
|
||||
|
||||
|
||||
# TODO: run blocking iproute calls in a different thread?
|
||||
class NeighborController:
|
||||
def __init__(self):
|
||||
self.ip = pr2modules.iproute.linux.IPRoute()
|
||||
|
||||
async def get_neighbor(
|
||||
self,
|
||||
address: cptypes.IPAddress,
|
||||
*,
|
||||
index: int=0, # interface index
|
||||
flags: int=0,
|
||||
) -> typing.Optional[pr2modules.iproute.linux.ndmsg.ndmsg]:
|
||||
if not index:
|
||||
route = await self.get_route(address)
|
||||
if route is None:
|
||||
return None
|
||||
index = route.get_attr(route.name2nla('oif'))
|
||||
try:
|
||||
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
|
||||
except pr2modules.netlink.exceptions.NetlinkError as e:
|
||||
if e.code == errno.ENOENT:
|
||||
return None
|
||||
raise
|
||||
|
||||
async def get_neighbor_mac(
|
||||
self,
|
||||
address: cptypes.IPAddress,
|
||||
*,
|
||||
index: int=0, # interface index
|
||||
flags: int=0,
|
||||
) -> typing.Optional[cptypes.MacAddress]:
|
||||
neigh = await self.get_neighbor(address, index=index, flags=flags)
|
||||
if neigh is None:
|
||||
return None
|
||||
mac = neigh.get_attr(neigh.name2nla('lladdr'))
|
||||
return cptypes.MacAddress.parse(mac)
|
||||
|
||||
async def get_route(
|
||||
self,
|
||||
address: cptypes.IPAddress,
|
||||
) -> typing.Optional[pr2modules.iproute.linux.rtmsg]:
|
||||
try:
|
||||
return self.ip.route('get', dst=str(address))[0]
|
||||
except pr2modules.netlink.exceptions.NetlinkError as e:
|
||||
if e.code == errno.ENOENT:
|
||||
return None
|
||||
raise
|
181
src/capport/utils/nft_set.py
Normal file
181
src/capport/utils/nft_set.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pr2modules.netlink
|
||||
from capport import cptypes
|
||||
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
|
||||
|
||||
from .nft_socket import NFTSocket
|
||||
|
||||
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
|
||||
|
||||
|
||||
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]:
|
||||
# to seconds
|
||||
if msecs is None:
|
||||
return None
|
||||
return msecs / 1000.0
|
||||
|
||||
|
||||
class NftSet:
|
||||
def __init__(self):
|
||||
self._socket = NFTSocket()
|
||||
self._socket.bind()
|
||||
|
||||
@staticmethod
|
||||
def _set_elem(mac: cptypes.MacAddress, timeout: typing.Optional[typing.Union[int, float]]=None) -> _nftsocket.nft_set_elem_list_msg.set_elem:
|
||||
attrs: typing.Dict[str, typing.Any] = {
|
||||
'NFTA_SET_ELEM_KEY': dict(
|
||||
NFTA_DATA_VALUE=mac.raw,
|
||||
),
|
||||
}
|
||||
if timeout:
|
||||
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
|
||||
return attrs
|
||||
|
||||
def _bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
|
||||
ser_entries = [
|
||||
self._set_elem(mac)
|
||||
for mac, _timeout in entries
|
||||
]
|
||||
ser_entries_with_timeout = [
|
||||
self._set_elem(mac, timeout)
|
||||
for mac, timeout in entries
|
||||
]
|
||||
with self._socket.begin() as tx:
|
||||
# create doesn't affect existing elements, so:
|
||||
# make sure entries exists
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWSETELEM,
|
||||
pr2modules.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
# drop entries (would fail if it doesn't exist)
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_DELSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
# now create entries with new timeout value
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWSETELEM,
|
||||
pr2modules.netlink.NLM_F_CREATE|pr2modules.netlink.NLM_F_EXCL,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
|
||||
),
|
||||
)
|
||||
|
||||
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
|
||||
# limit chunk size
|
||||
while len(entries) > 0:
|
||||
self._bulk_insert(entries[:1024])
|
||||
entries = entries[1024:]
|
||||
|
||||
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None:
|
||||
self.bulk_insert([(mac, timeout)])
|
||||
|
||||
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
||||
ser_entries = [
|
||||
self._set_elem(mac)
|
||||
for mac in entries
|
||||
]
|
||||
with self._socket.begin() as tx:
|
||||
# make sure entries exists
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWSETELEM,
|
||||
pr2modules.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
# drop entries (would fail if it doesn't exist)
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_DELSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
|
||||
),
|
||||
)
|
||||
|
||||
def bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
|
||||
# limit chunk size
|
||||
while len(entries) > 0:
|
||||
self._bulk_remove(entries[:1024])
|
||||
entries = entries[1024:]
|
||||
|
||||
def remove(self, mac: cptypes.MacAddress) -> None:
|
||||
self.bulk_remove([mac])
|
||||
|
||||
def list(self) -> list:
|
||||
responses: typing.Iterator[_nftsocket.nft_set_elem_list_msg]
|
||||
responses = self._socket.nft_dump(
|
||||
_nftsocket.NFT_MSG_GETSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
)
|
||||
)
|
||||
return [
|
||||
{
|
||||
'mac': cptypes.MacAddress(
|
||||
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
|
||||
),
|
||||
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
|
||||
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
|
||||
}
|
||||
for response in responses
|
||||
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
|
||||
]
|
||||
|
||||
def flush(self) -> None:
|
||||
self._socket.nft_put(
|
||||
_nftsocket.NFT_MSG_DELSETELEM,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_ELEM_LIST_SET='allowed',
|
||||
)
|
||||
)
|
||||
|
||||
def create(self):
|
||||
with self._socket.begin() as tx:
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWTABLE,
|
||||
pr2modules.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_TABLE_NAME='captive_mark',
|
||||
),
|
||||
)
|
||||
tx.put(
|
||||
_nftsocket.NFT_MSG_NEWSET,
|
||||
pr2modules.netlink.NLM_F_CREATE,
|
||||
nfgen_family=NFPROTO_INET,
|
||||
attrs=dict(
|
||||
NFTA_SET_TABLE='captive_mark',
|
||||
NFTA_SET_NAME='allowed',
|
||||
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
|
||||
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
|
||||
NFTA_SET_KEY_LEN=6, # length of key: mac address
|
||||
NFTA_SET_ID=1, # kernel seems to need a set id unique per transaction
|
||||
),
|
||||
)
|
217
src/capport/utils/nft_socket.py
Normal file
217
src/capport/utils/nft_socket.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
import threading
|
||||
|
||||
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
|
||||
import pr2modules.netlink
|
||||
import pr2modules.netlink.nlsocket
|
||||
from pr2modules.netlink.nfnetlink import nfgen_msg
|
||||
from pr2modules.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES
|
||||
|
||||
|
||||
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
|
||||
|
||||
|
||||
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pr2modules.netlink.nlmsg_base)
|
||||
|
||||
|
||||
# nft uses NESTED for those.. lets do the same
|
||||
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pr2modules.netlink.NLA_F_NESTED
|
||||
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pr2modules.netlink.NLA_F_NESTED
|
||||
|
||||
|
||||
def _monkey_patch_pyroute2():
|
||||
import pr2modules.netlink
|
||||
# overwrite setdefault on nlmsg_base class hierarchy
|
||||
_orig_setvalue = pr2modules.netlink.nlmsg_base.setvalue
|
||||
|
||||
def _nlmsg_base__setvalue(self, value):
|
||||
if not self.header or not self['header'] or not isinstance(value, dict):
|
||||
return _orig_setvalue(self, value)
|
||||
header = value.pop('header', {})
|
||||
res = _orig_setvalue(self, value)
|
||||
self['header'].update(header)
|
||||
return res
|
||||
|
||||
def overwrite_methods(cls: typing.Type) -> None:
|
||||
if cls.setvalue is _orig_setvalue:
|
||||
cls.setvalue = _nlmsg_base__setvalue
|
||||
for subcls in cls.__subclasses__():
|
||||
overwrite_methods(subcls)
|
||||
|
||||
overwrite_methods(pr2modules.netlink.nlmsg_base)
|
||||
_monkey_patch_pyroute2()
|
||||
|
||||
|
||||
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: typing.Dict={}, **fields) -> _NlMsgBase:
|
||||
msg = msg_class()
|
||||
for key, value in header.items():
|
||||
msg['header'][key] = value
|
||||
for key, value in fields.items():
|
||||
msg[key] = value
|
||||
if attrs:
|
||||
attr_list = msg['attrs']
|
||||
r_nla_map = msg_class._nlmsg_base__r_nla_map
|
||||
for key, value in attrs.items():
|
||||
if msg_class.prefix:
|
||||
key = msg_class.name2nla(key)
|
||||
prime = r_nla_map[key]
|
||||
nla_class = prime['class']
|
||||
if issubclass(nla_class, pr2modules.netlink.nla):
|
||||
# support passing nested attributes as dicts of subattributes (or lists of those)
|
||||
if prime['nla_array']:
|
||||
value = [
|
||||
_build(nla_class, attrs=elem) if not isinstance(elem, pr2modules.netlink.nlmsg_base) and isinstance(elem, dict) else elem
|
||||
for elem in value
|
||||
]
|
||||
elif not isinstance(value, pr2modules.netlink.nlmsg_base) and isinstance(value, dict):
|
||||
value = _build(nla_class, attrs=value)
|
||||
attr_list.append([key, value])
|
||||
return msg
|
||||
|
||||
|
||||
class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket):
|
||||
policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(family=pr2modules.netlink.NETLINK_NETFILTER)
|
||||
policy = {
|
||||
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
|
||||
for (x, y) in self.policy.items()
|
||||
}
|
||||
self.register_policy(policy)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def begin(self) -> typing.Generator[NFTTransaction, None, None]:
|
||||
try:
|
||||
tx = NFTTransaction(socket=self)
|
||||
yield tx
|
||||
# autocommit when no exception was raised
|
||||
# (only commits if it wasn't aborted)
|
||||
tx.autocommit()
|
||||
finally:
|
||||
# abort does nothing if commit went through
|
||||
tx.abort()
|
||||
|
||||
def nft_put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
|
||||
with self.begin() as tx:
|
||||
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
|
||||
|
||||
def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
|
||||
msg_flags |= pr2modules.netlink.NLM_F_DUMP
|
||||
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
|
||||
|
||||
def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
|
||||
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
|
||||
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
|
||||
msg_flags |= pr2modules.netlink.NLM_F_REQUEST
|
||||
msg = _build(msg_class, attrs=attrs, **fields)
|
||||
return self.nlm_request(msg, msg_type, msg_flags)
|
||||
|
||||
|
||||
class NFTTransaction:
|
||||
def __init__(self, socket: NFTSocket) -> None:
|
||||
self._socket = socket
|
||||
self._data = b''
|
||||
self._seqnum = self._socket.addr_pool.alloc()
|
||||
self._closed = False
|
||||
# neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK
|
||||
# at the end of an transaction to make sure it worked.
|
||||
# we could use a different sequence number for all changes, and wait for an ACK for each of them
|
||||
# (but we'd also need to check for errors on the BEGIN sequence number).
|
||||
# the other solution: use the same sequence number for all messages in the batch, and add ACK
|
||||
# only to the final message (before END) - if we get the ACK we known all other messages before
|
||||
# worked out.
|
||||
self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None
|
||||
begin_msg = _build(
|
||||
nfgen_msg,
|
||||
res_id=NFNL_SUBSYS_NFTABLES,
|
||||
header=dict(
|
||||
type=0x10, # NFNL_MSG_BATCH_BEGIN
|
||||
flags=pr2modules.netlink.NLM_F_REQUEST,
|
||||
sequence_number=self._seqnum,
|
||||
),
|
||||
)
|
||||
begin_msg.encode()
|
||||
self._data += begin_msg.data
|
||||
|
||||
def abort(self) -> None:
|
||||
"""
|
||||
Aborts if transaction wasn't already committed or aborted
|
||||
"""
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
# unused seqnum
|
||||
self._socket.addr_pool.free(self._seqnum)
|
||||
|
||||
def autocommit(self) -> None:
|
||||
"""
|
||||
Commits if transaction wasn't already committed or aborted
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self.commit()
|
||||
|
||||
def commit(self) -> None:
|
||||
if self._closed:
|
||||
raise Exception("Transaction already closed")
|
||||
if not self._final_msg:
|
||||
# no inner messages were queued... just abort transaction
|
||||
self.abort()
|
||||
return
|
||||
self._closed = True
|
||||
# request ACK only on the last message (before END)
|
||||
self._final_msg['header']['flags'] |= pr2modules.netlink.NLM_F_ACK
|
||||
self._final_msg.encode()
|
||||
self._data += self._final_msg.data
|
||||
self._final_msg = None
|
||||
# batch end
|
||||
end_msg = _build(
|
||||
nfgen_msg,
|
||||
res_id=NFNL_SUBSYS_NFTABLES,
|
||||
header=dict(
|
||||
type=0x11, # NFNL_MSG_BATCH_END
|
||||
flags=pr2modules.netlink.NLM_F_REQUEST,
|
||||
sequence_number=self._seqnum,
|
||||
),
|
||||
)
|
||||
end_msg.encode()
|
||||
self._data += end_msg.data
|
||||
# need to create backlog for our sequence number
|
||||
with self._socket.lock[self._seqnum]:
|
||||
self._socket.backlog[self._seqnum] = []
|
||||
# send
|
||||
self._socket.sendto(self._data, (0, 0))
|
||||
try:
|
||||
for _msg in self._socket.get(msg_seq=self._seqnum):
|
||||
# we should see at most one ACK - real errors get raised anyway
|
||||
pass
|
||||
finally:
|
||||
with self._socket.lock[0]:
|
||||
# clear messages from "seq 0" queue - because if there
|
||||
# was an error in our backlog, it got raised and the
|
||||
# remaining messages moved to 0
|
||||
self._socket.backlog[0] = []
|
||||
|
||||
def _put(self, msg: nfgen_msg) -> None:
|
||||
if self._closed:
|
||||
raise Exception("Transaction already closed")
|
||||
if self._final_msg:
|
||||
# previous message wasn't the final one, encode it without ACK
|
||||
self._final_msg.encode()
|
||||
self._data += self._final_msg.data
|
||||
self._final_msg = msg
|
||||
|
||||
def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
|
||||
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
|
||||
msg_flags |= pr2modules.netlink.NLM_F_REQUEST # always set REQUEST
|
||||
msg_flags &= ~pr2modules.netlink.NLM_F_ACK # make sure ACK is not set!
|
||||
header = dict(
|
||||
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
|
||||
flags=msg_flags,
|
||||
sequence_number=self._seqnum,
|
||||
)
|
||||
msg = _build(msg_class, attrs=attrs, header=header, **fields)
|
||||
self._put(msg)
|
Reference in New Issue
Block a user