224 lines
8.7 KiB
Python
224 lines
8.7 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import typing
|
|
|
|
import pyroute2.netlink # type: ignore
|
|
import pyroute2.netlink.nlsocket # type: ignore
|
|
from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES # type: ignore
|
|
from pyroute2.netlink.nfnetlink import nfgen_msg # type: ignore
|
|
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
|
|
|
|
|
|
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
|
|
|
|
|
|
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base)
|
|
|
|
|
|
# nft uses NESTED for those.. lets do the same
|
|
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pyroute2.netlink.NLA_F_NESTED
|
|
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pyroute2.netlink.NLA_F_NESTED
|
|
# nftable lists always use `1` as list element attr type
|
|
_nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM
|
|
|
|
|
|
def _monkey_patch_pyroute2():
|
|
import pyroute2.netlink
|
|
# overwrite setdefault on nlmsg_base class hierarchy
|
|
_orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
|
|
|
|
def _nlmsg_base__setvalue(self, value):
|
|
if not self.header or not self['header'] or not isinstance(value, dict):
|
|
return _orig_setvalue(self, value)
|
|
# merge headers instead of overwriting them
|
|
header = value.pop('header', {})
|
|
res = _orig_setvalue(self, value)
|
|
self['header'].update(header)
|
|
return res
|
|
|
|
def overwrite_methods(cls: typing.Type) -> None:
|
|
if cls.setvalue is _orig_setvalue:
|
|
cls.setvalue = _nlmsg_base__setvalue
|
|
for subcls in cls.__subclasses__():
|
|
overwrite_methods(subcls)
|
|
|
|
overwrite_methods(pyroute2.netlink.nlmsg_base)
|
|
|
|
|
|
_monkey_patch_pyroute2()
|
|
|
|
|
|
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
|
|
msg = msg_class()
|
|
for key, value in header.items():
|
|
msg['header'][key] = value
|
|
for key, value in fields.items():
|
|
msg[key] = value
|
|
if attrs:
|
|
attr_list = msg['attrs']
|
|
r_nla_map = msg_class._nlmsg_base__r_nla_map
|
|
for key, value in attrs.items():
|
|
if msg_class.prefix:
|
|
key = msg_class.name2nla(key)
|
|
prime = r_nla_map[key]
|
|
nla_class = prime['class']
|
|
if issubclass(nla_class, pyroute2.netlink.nla):
|
|
# support passing nested attributes as dicts of subattributes (or lists of those)
|
|
if prime['nla_array']:
|
|
value = [
|
|
_build(nla_class, attrs=elem)
|
|
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
|
|
else elem
|
|
for elem in value
|
|
]
|
|
elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict):
|
|
value = _build(nla_class, attrs=value)
|
|
attr_list.append([key, value])
|
|
return msg
|
|
|
|
|
|
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
|
|
policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
|
|
policy = {
|
|
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
|
|
for (x, y) in self.policy.items()
|
|
}
|
|
self.register_policy(policy)
|
|
|
|
@contextlib.contextmanager
|
|
def begin(self) -> typing.Generator[NFTTransaction, None, None]:
|
|
try:
|
|
tx = NFTTransaction(socket=self)
|
|
yield tx
|
|
# autocommit when no exception was raised
|
|
# (only commits if it wasn't aborted)
|
|
tx.autocommit()
|
|
finally:
|
|
# abort does nothing if commit went through
|
|
tx.abort()
|
|
|
|
def nft_put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
|
with self.begin() as tx:
|
|
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
|
|
|
|
def nft_dump(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
|
msg_flags |= pyroute2.netlink.NLM_F_DUMP
|
|
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
|
|
|
|
def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
|
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
|
|
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
|
|
msg_flags |= pyroute2.netlink.NLM_F_REQUEST
|
|
msg = _build(msg_class, attrs=attrs, **fields)
|
|
return self.nlm_request(msg, msg_type, msg_flags)
|
|
|
|
|
|
class NFTTransaction:
|
|
def __init__(self, socket: NFTSocket) -> None:
|
|
self._socket = socket
|
|
self._data = b''
|
|
self._seqnum = self._socket.addr_pool.alloc()
|
|
self._closed = False
|
|
# neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK
|
|
# at the end of an transaction to make sure it worked.
|
|
# we could use a different sequence number for all changes, and wait for an ACK for each of them
|
|
# (but we'd also need to check for errors on the BEGIN sequence number).
|
|
# the other solution: use the same sequence number for all messages in the batch, and add ACK
|
|
# only to the final message (before END) - if we get the ACK we known all other messages before
|
|
# worked out.
|
|
self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None
|
|
begin_msg = _build(
|
|
nfgen_msg,
|
|
res_id=NFNL_SUBSYS_NFTABLES,
|
|
header=dict(
|
|
type=0x10, # NFNL_MSG_BATCH_BEGIN
|
|
flags=pyroute2.netlink.NLM_F_REQUEST,
|
|
sequence_number=self._seqnum,
|
|
),
|
|
)
|
|
begin_msg.encode()
|
|
self._data += begin_msg.data
|
|
|
|
def abort(self) -> None:
|
|
"""
|
|
Aborts if transaction wasn't already committed or aborted
|
|
"""
|
|
if not self._closed:
|
|
self._closed = True
|
|
# unused seqnum
|
|
self._socket.addr_pool.free(self._seqnum)
|
|
|
|
def autocommit(self) -> None:
|
|
"""
|
|
Commits if transaction wasn't already committed or aborted
|
|
"""
|
|
if self._closed:
|
|
return
|
|
self.commit()
|
|
|
|
def commit(self) -> None:
|
|
if self._closed:
|
|
raise Exception("Transaction already closed")
|
|
if not self._final_msg:
|
|
# no inner messages were queued... just abort transaction
|
|
self.abort()
|
|
return
|
|
self._closed = True
|
|
# request ACK only on the last message (before END)
|
|
self._final_msg['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
|
|
self._final_msg.encode()
|
|
self._data += self._final_msg.data
|
|
self._final_msg = None
|
|
# batch end
|
|
end_msg = _build(
|
|
nfgen_msg,
|
|
res_id=NFNL_SUBSYS_NFTABLES,
|
|
header=dict(
|
|
type=0x11, # NFNL_MSG_BATCH_END
|
|
flags=pyroute2.netlink.NLM_F_REQUEST,
|
|
sequence_number=self._seqnum,
|
|
),
|
|
)
|
|
end_msg.encode()
|
|
self._data += end_msg.data
|
|
# need to create backlog for our sequence number
|
|
with self._socket.lock[self._seqnum]:
|
|
self._socket.backlog[self._seqnum] = []
|
|
# send
|
|
self._socket.sendto(self._data, (0, 0))
|
|
try:
|
|
for _msg in self._socket.get(msg_seq=self._seqnum):
|
|
# we should see at most one ACK - real errors get raised anyway
|
|
pass
|
|
finally:
|
|
with self._socket.lock[0]:
|
|
# clear messages from "seq 0" queue - because if there
|
|
# was an error in our backlog, it got raised and the
|
|
# remaining messages moved to 0
|
|
self._socket.backlog[0] = []
|
|
|
|
def _put(self, msg: nfgen_msg) -> None:
|
|
if self._closed:
|
|
raise Exception("Transaction already closed")
|
|
if self._final_msg:
|
|
# previous message wasn't the final one, encode it without ACK
|
|
self._final_msg.encode()
|
|
self._data += self._final_msg.data
|
|
self._final_msg = msg
|
|
|
|
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
|
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
|
|
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
|
|
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!
|
|
header = dict(
|
|
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
|
|
flags=msg_flags,
|
|
sequence_number=self._seqnum,
|
|
)
|
|
msg = _build(msg_class, attrs=attrs, header=header, **fields)
|
|
self._put(msg)
|