3
0

initial commit

This commit is contained in:
2022-04-04 19:21:51 +02:00
commit d1050d2ee4
34 changed files with 2223 additions and 0 deletions

View File

17
src/capport/utils/cli.py Normal file
View File

@@ -0,0 +1,17 @@
from __future__ import annotations
import logging
import capport.config
def init_logger(config: capport.config.Config):
loglevel = logging.INFO
if config.debug:
loglevel = logging.DEBUG
logging.basicConfig(
format='%(asctime)s [%(name)-25s] [%(levelname)-8s] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S %z]',
level=loglevel,
)
logging.getLogger('hypercorn').propagate = False

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import contextlib
import errno
import typing
import pr2modules.iproute.linux
import pr2modules.netlink.exceptions
from capport import cptypes
@contextlib.asynccontextmanager
async def connect():
yield NeighborController()
# TODO: run blocking iproute calls in a different thread?
class NeighborController:
def __init__(self):
self.ip = pr2modules.iproute.linux.IPRoute()
async def get_neighbor(
self,
address: cptypes.IPAddress,
*,
index: int=0, # interface index
flags: int=0,
) -> typing.Optional[pr2modules.iproute.linux.ndmsg.ndmsg]:
if not index:
route = await self.get_route(address)
if route is None:
return None
index = route.get_attr(route.name2nla('oif'))
try:
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
except pr2modules.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
raise
async def get_neighbor_mac(
self,
address: cptypes.IPAddress,
*,
index: int=0, # interface index
flags: int=0,
) -> typing.Optional[cptypes.MacAddress]:
neigh = await self.get_neighbor(address, index=index, flags=flags)
if neigh is None:
return None
mac = neigh.get_attr(neigh.name2nla('lladdr'))
return cptypes.MacAddress.parse(mac)
async def get_route(
self,
address: cptypes.IPAddress,
) -> typing.Optional[pr2modules.iproute.linux.rtmsg]:
try:
return self.ip.route('get', dst=str(address))[0]
except pr2modules.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT:
return None
raise

View File

@@ -0,0 +1,181 @@
from __future__ import annotations
import typing
import pr2modules.netlink
from capport import cptypes
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
from .nft_socket import NFTSocket
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNIX"
def _from_msec(msecs: typing.Optional[int]) -> typing.Optional[float]:
# to seconds
if msecs is None:
return None
return msecs / 1000.0
class NftSet:
def __init__(self):
self._socket = NFTSocket()
self._socket.bind()
@staticmethod
def _set_elem(mac: cptypes.MacAddress, timeout: typing.Optional[typing.Union[int, float]]=None) -> _nftsocket.nft_set_elem_list_msg.set_elem:
attrs: typing.Dict[str, typing.Any] = {
'NFTA_SET_ELEM_KEY': dict(
NFTA_DATA_VALUE=mac.raw,
),
}
if timeout:
attrs['NFTA_SET_ELEM_TIMEOUT'] = int(1000*timeout)
return attrs
def _bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
ser_entries = [
self._set_elem(mac)
for mac, _timeout in entries
]
ser_entries_with_timeout = [
self._set_elem(mac, timeout)
for mac, timeout in entries
]
with self._socket.begin() as tx:
# create doesn't affect existing elements, so:
# make sure entries exists
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# drop entries (would fail if it doesn't exist)
tx.put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# now create entries with new timeout value
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE|pr2modules.netlink.NLM_F_EXCL,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries_with_timeout,
),
)
def bulk_insert(self, entries: typing.Sequence[typing.Tuple[cptypes.MacAddress, typing.Union[int, float]]]) -> None:
# limit chunk size
while len(entries) > 0:
self._bulk_insert(entries[:1024])
entries = entries[1024:]
def insert(self, mac: cptypes.MacAddress, timeout: typing.Union[int, float]) -> None:
self.bulk_insert([(mac, timeout)])
def _bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
ser_entries = [
self._set_elem(mac)
for mac in entries
]
with self._socket.begin() as tx:
# make sure entries exists
tx.put(
_nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
# drop entries (would fail if it doesn't exist)
tx.put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
NFTA_SET_ELEM_LIST_ELEMENTS=ser_entries,
),
)
def bulk_remove(self, entries: typing.Sequence[cptypes.MacAddress]) -> None:
# limit chunk size
while len(entries) > 0:
self._bulk_remove(entries[:1024])
entries = entries[1024:]
def remove(self, mac: cptypes.MacAddress) -> None:
self.bulk_remove([mac])
def list(self) -> list:
responses: typing.Iterator[_nftsocket.nft_set_elem_list_msg]
responses = self._socket.nft_dump(
_nftsocket.NFT_MSG_GETSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
)
return [
{
'mac': cptypes.MacAddress(
elem.get_attr('NFTA_SET_ELEM_KEY').get_attr('NFTA_DATA_VALUE'),
),
'timeout': _from_msec(elem.get_attr('NFTA_SET_ELEM_TIMEOUT', None)),
'expiration': _from_msec(elem.get_attr('NFTA_SET_ELEM_EXPIRATION', None)),
}
for response in responses
for elem in response.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS', [])
]
def flush(self) -> None:
self._socket.nft_put(
_nftsocket.NFT_MSG_DELSETELEM,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_ELEM_LIST_SET='allowed',
)
)
def create(self):
with self._socket.begin() as tx:
tx.put(
_nftsocket.NFT_MSG_NEWTABLE,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_TABLE_NAME='captive_mark',
),
)
tx.put(
_nftsocket.NFT_MSG_NEWSET,
pr2modules.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET,
attrs=dict(
NFTA_SET_TABLE='captive_mark',
NFTA_SET_NAME='allowed',
NFTA_SET_FLAGS=0x10, # NFT_SET_TIMEOUT
NFTA_SET_KEY_TYPE=9, # nft type for "type ether_addr" - only relevant for userspace nft
NFTA_SET_KEY_LEN=6, # length of key: mac address
NFTA_SET_ID=1, # kernel seems to need a set id unique per transaction
),
)

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
import contextlib
import typing
import threading
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket
import pr2modules.netlink
import pr2modules.netlink.nlsocket
from pr2modules.netlink.nfnetlink import nfgen_msg
from pr2modules.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pr2modules.netlink.nlmsg_base)
# nft uses NESTED for those.. lets do the same
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pr2modules.netlink.NLA_F_NESTED
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pr2modules.netlink.NLA_F_NESTED
def _monkey_patch_pyroute2():
import pr2modules.netlink
# overwrite setdefault on nlmsg_base class hierarchy
_orig_setvalue = pr2modules.netlink.nlmsg_base.setvalue
def _nlmsg_base__setvalue(self, value):
if not self.header or not self['header'] or not isinstance(value, dict):
return _orig_setvalue(self, value)
header = value.pop('header', {})
res = _orig_setvalue(self, value)
self['header'].update(header)
return res
def overwrite_methods(cls: typing.Type) -> None:
if cls.setvalue is _orig_setvalue:
cls.setvalue = _nlmsg_base__setvalue
for subcls in cls.__subclasses__():
overwrite_methods(subcls)
overwrite_methods(pr2modules.netlink.nlmsg_base)
_monkey_patch_pyroute2()
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: typing.Dict={}, **fields) -> _NlMsgBase:
msg = msg_class()
for key, value in header.items():
msg['header'][key] = value
for key, value in fields.items():
msg[key] = value
if attrs:
attr_list = msg['attrs']
r_nla_map = msg_class._nlmsg_base__r_nla_map
for key, value in attrs.items():
if msg_class.prefix:
key = msg_class.name2nla(key)
prime = r_nla_map[key]
nla_class = prime['class']
if issubclass(nla_class, pr2modules.netlink.nla):
# support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']:
value = [
_build(nla_class, attrs=elem) if not isinstance(elem, pr2modules.netlink.nlmsg_base) and isinstance(elem, dict) else elem
for elem in value
]
elif not isinstance(value, pr2modules.netlink.nlmsg_base) and isinstance(value, dict):
value = _build(nla_class, attrs=value)
attr_list.append([key, value])
return msg
class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket):
policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None:
super().__init__(family=pr2modules.netlink.NETLINK_NETFILTER)
policy = {
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
for (x, y) in self.policy.items()
}
self.register_policy(policy)
@contextlib.contextmanager
def begin(self) -> typing.Generator[NFTTransaction, None, None]:
try:
tx = NFTTransaction(socket=self)
yield tx
# autocommit when no exception was raised
# (only commits if it wasn't aborted)
tx.autocommit()
finally:
# abort does nothing if commit went through
tx.abort()
def nft_put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
with self.begin() as tx:
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_flags |= pr2modules.netlink.NLM_F_DUMP
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pr2modules.netlink.NLM_F_REQUEST
msg = _build(msg_class, attrs=attrs, **fields)
return self.nlm_request(msg, msg_type, msg_flags)
class NFTTransaction:
def __init__(self, socket: NFTSocket) -> None:
self._socket = socket
self._data = b''
self._seqnum = self._socket.addr_pool.alloc()
self._closed = False
# neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK
# at the end of an transaction to make sure it worked.
# we could use a different sequence number for all changes, and wait for an ACK for each of them
# (but we'd also need to check for errors on the BEGIN sequence number).
# the other solution: use the same sequence number for all messages in the batch, and add ACK
# only to the final message (before END) - if we get the ACK we known all other messages before
# worked out.
self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None
begin_msg = _build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x10, # NFNL_MSG_BATCH_BEGIN
flags=pr2modules.netlink.NLM_F_REQUEST,
sequence_number=self._seqnum,
),
)
begin_msg.encode()
self._data += begin_msg.data
def abort(self) -> None:
"""
Aborts if transaction wasn't already committed or aborted
"""
if not self._closed:
self._closed = True
# unused seqnum
self._socket.addr_pool.free(self._seqnum)
def autocommit(self) -> None:
"""
Commits if transaction wasn't already committed or aborted
"""
if self._closed:
return
self.commit()
def commit(self) -> None:
if self._closed:
raise Exception("Transaction already closed")
if not self._final_msg:
# no inner messages were queued... just abort transaction
self.abort()
return
self._closed = True
# request ACK only on the last message (before END)
self._final_msg['header']['flags'] |= pr2modules.netlink.NLM_F_ACK
self._final_msg.encode()
self._data += self._final_msg.data
self._final_msg = None
# batch end
end_msg = _build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x11, # NFNL_MSG_BATCH_END
flags=pr2modules.netlink.NLM_F_REQUEST,
sequence_number=self._seqnum,
),
)
end_msg.encode()
self._data += end_msg.data
# need to create backlog for our sequence number
with self._socket.lock[self._seqnum]:
self._socket.backlog[self._seqnum] = []
# send
self._socket.sendto(self._data, (0, 0))
try:
for _msg in self._socket.get(msg_seq=self._seqnum):
# we should see at most one ACK - real errors get raised anyway
pass
finally:
with self._socket.lock[0]:
# clear messages from "seq 0" queue - because if there
# was an error in our backlog, it got raised and the
# remaining messages moved to 0
self._socket.backlog[0] = []
def _put(self, msg: nfgen_msg) -> None:
if self._closed:
raise Exception("Transaction already closed")
if self._final_msg:
# previous message wasn't the final one, encode it without ACK
self._final_msg.encode()
self._data += self._final_msg.data
self._final_msg = msg
def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pr2modules.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pr2modules.netlink.NLM_F_ACK # make sure ACK is not set!
header = dict(
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
flags=msg_flags,
sequence_number=self._seqnum,
)
msg = _build(msg_class, attrs=attrs, header=header, **fields)
self._put(msg)