from __future__ import annotations import contextlib import typing import pyroute2.netlink # type: ignore import pyroute2.netlink.nlsocket # type: ignore from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES # type: ignore from pyroute2.netlink.nfnetlink import nfgen_msg # type: ignore from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC" _NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base) # nft uses NESTED for those.. lets do the same _nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pyroute2.netlink.NLA_F_NESTED _nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pyroute2.netlink.NLA_F_NESTED # nftable lists always use `1` as list element attr type _nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM def _monkey_patch_pyroute2(): import pyroute2.netlink # overwrite setdefault on nlmsg_base class hierarchy _orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue def _nlmsg_base__setvalue(self, value): if not self.header or not self['header'] or not isinstance(value, dict): return _orig_setvalue(self, value) # merge headers instead of overwriting them header = value.pop('header', {}) res = _orig_setvalue(self, value) self['header'].update(header) return res def overwrite_methods(cls: typing.Type) -> None: if cls.setvalue is _orig_setvalue: cls.setvalue = _nlmsg_base__setvalue for subcls in cls.__subclasses__(): overwrite_methods(subcls) overwrite_methods(pyroute2.netlink.nlmsg_base) _monkey_patch_pyroute2() def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase: msg = msg_class() for key, value in header.items(): msg['header'][key] = value for key, value in fields.items(): msg[key] = value if attrs: attr_list = msg['attrs'] r_nla_map = msg_class._nlmsg_base__r_nla_map for key, value in attrs.items(): if msg_class.prefix: key = msg_class.name2nla(key) prime = r_nla_map[key] nla_class = prime['class'] if issubclass(nla_class, pyroute2.netlink.nla): # support passing nested attributes as dicts of subattributes (or lists of those) if prime['nla_array']: value = [ _build(nla_class, attrs=elem) if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict) else elem for elem in value ] elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict): value = _build(nla_class, attrs=value) attr_list.append([key, value]) return msg class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket): policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy def __init__(self) -> None: super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER) policy = { (x | (NFNL_SUBSYS_NFTABLES << 8)): y for (x, y) in self.policy.items() } self.register_policy(policy) @contextlib.contextmanager def begin(self) -> typing.Generator[NFTTransaction, None, None]: try: tx = NFTTransaction(socket=self) yield tx # autocommit when no exception was raised # (only commits if it wasn't aborted) tx.autocommit() finally: # abort does nothing if commit went through tx.abort() def nft_put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: with self.begin() as tx: tx.put(msg_type, msg_flags, attrs=attrs, **fields) def nft_dump(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: msg_flags |= pyroute2.netlink.NLM_F_DUMP return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields) def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type] msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type msg_flags |= pyroute2.netlink.NLM_F_REQUEST msg = _build(msg_class, attrs=attrs, **fields) return self.nlm_request(msg, msg_type, msg_flags) class NFTTransaction: def __init__(self, socket: NFTSocket) -> None: self._socket = socket self._data = b'' self._seqnum = self._socket.addr_pool.alloc() self._closed = False # neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK # at the end of an transaction to make sure it worked. # we could use a different sequence number for all changes, and wait for an ACK for each of them # (but we'd also need to check for errors on the BEGIN sequence number). # the other solution: use the same sequence number for all messages in the batch, and add ACK # only to the final message (before END) - if we get the ACK we known all other messages before # worked out. self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None begin_msg = _build( nfgen_msg, res_id=NFNL_SUBSYS_NFTABLES, header=dict( type=0x10, # NFNL_MSG_BATCH_BEGIN flags=pyroute2.netlink.NLM_F_REQUEST, sequence_number=self._seqnum, ), ) begin_msg.encode() self._data += begin_msg.data def abort(self) -> None: """ Aborts if transaction wasn't already committed or aborted """ if not self._closed: self._closed = True # unused seqnum self._socket.addr_pool.free(self._seqnum) def autocommit(self) -> None: """ Commits if transaction wasn't already committed or aborted """ if self._closed: return self.commit() def commit(self) -> None: if self._closed: raise Exception("Transaction already closed") if not self._final_msg: # no inner messages were queued... just abort transaction self.abort() return self._closed = True # request ACK only on the last message (before END) self._final_msg['header']['flags'] |= pyroute2.netlink.NLM_F_ACK self._final_msg.encode() self._data += self._final_msg.data self._final_msg = None # batch end end_msg = _build( nfgen_msg, res_id=NFNL_SUBSYS_NFTABLES, header=dict( type=0x11, # NFNL_MSG_BATCH_END flags=pyroute2.netlink.NLM_F_REQUEST, sequence_number=self._seqnum, ), ) end_msg.encode() self._data += end_msg.data # need to create backlog for our sequence number with self._socket.lock[self._seqnum]: self._socket.backlog[self._seqnum] = [] # send self._socket.sendto(self._data, (0, 0)) try: for _msg in self._socket.get(msg_seq=self._seqnum): # we should see at most one ACK - real errors get raised anyway pass finally: with self._socket.lock[0]: # clear messages from "seq 0" queue - because if there # was an error in our backlog, it got raised and the # remaining messages moved to 0 self._socket.backlog[0] = [] def _put(self, msg: nfgen_msg) -> None: if self._closed: raise Exception("Transaction already closed") if self._final_msg: # previous message wasn't the final one, encode it without ACK self._final_msg.encode() self._data += self._final_msg.data self._final_msg = msg def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type] msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set! header = dict( type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type, flags=msg_flags, sequence_number=self._seqnum, ) msg = _build(msg_class, attrs=attrs, header=header, **fields) self._put(msg)