From 1acb693e2b810819fdc0fbc513298c16b625f1ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Wed, 22 Mar 2023 13:50:34 +0100 Subject: [PATCH] nft_socket.NFTTransaction: use pyroute2 nlm_request_batch instead of hack. fixes leak of sequence ids. --- src/capport/utils/nft_socket.py | 90 ++++++++++----------------------- 1 file changed, 28 insertions(+), 62 deletions(-) diff --git a/src/capport/utils/nft_socket.py b/src/capport/utils/nft_socket.py index f2cae4b..ffd927a 100644 --- a/src/capport/utils/nft_socket.py +++ b/src/capport/utils/nft_socket.py @@ -120,28 +120,18 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket): class NFTTransaction: def __init__(self, socket: NFTSocket) -> None: self._socket = socket - self._data = b'' - self._seqnum = self._socket.addr_pool.alloc() self._closed = False - # neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK - # at the end of an transaction to make sure it worked. - # we could use a different sequence number for all changes, and wait for an ACK for each of them - # (but we'd also need to check for errors on the BEGIN sequence number). - # the other solution: use the same sequence number for all messages in the batch, and add ACK - # only to the final message (before END) - if we get the ACK we known all other messages before - # worked out. - self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None - begin_msg = _build( - nfgen_msg, - res_id=NFNL_SUBSYS_NFTABLES, - header=dict( - type=0x10, # NFNL_MSG_BATCH_BEGIN - flags=pyroute2.netlink.NLM_F_REQUEST, - sequence_number=self._seqnum, + self._msgs: list[nfgen_msg] = [ + # batch begin message + _build( + nfgen_msg, + res_id=NFNL_SUBSYS_NFTABLES, + header=dict( + type=0x10, # NFNL_MSG_BATCH_BEGIN + flags=pyroute2.netlink.NLM_F_REQUEST, + ), ), - ) - begin_msg.encode() - self._data += begin_msg.data + ] def abort(self) -> None: """ @@ -149,8 +139,6 @@ class NFTTransaction: """ if not self._closed: self._closed = True - # unused seqnum - self._socket.addr_pool.free(self._seqnum) def autocommit(self) -> None: """ @@ -163,52 +151,31 @@ class NFTTransaction: def commit(self) -> None: if self._closed: raise Exception("Transaction already closed") - if not self._final_msg: - # no inner messages were queued... just abort transaction - self.abort() - return self._closed = True - # request ACK only on the last message (before END) - self._final_msg['header']['flags'] |= pyroute2.netlink.NLM_F_ACK - self._final_msg.encode() - self._data += self._final_msg.data - self._final_msg = None - # batch end - end_msg = _build( - nfgen_msg, - res_id=NFNL_SUBSYS_NFTABLES, - header=dict( - type=0x11, # NFNL_MSG_BATCH_END - flags=pyroute2.netlink.NLM_F_REQUEST, - sequence_number=self._seqnum, + if len(self._msgs) == 1: + # no inner messages were queued... not sending anything + return + # request ACK on the last message (before END) + self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK + self._msgs.append( + # batch end message + _build( + nfgen_msg, + res_id=NFNL_SUBSYS_NFTABLES, + header=dict( + type=0x11, # NFNL_MSG_BATCH_END + flags=pyroute2.netlink.NLM_F_REQUEST, + ), ), ) - end_msg.encode() - self._data += end_msg.data - # need to create backlog for our sequence number - with self._socket.lock[self._seqnum]: - self._socket.backlog[self._seqnum] = [] - # send - self._socket.sendto(self._data, (0, 0)) - try: - for _msg in self._socket.get(msg_seq=self._seqnum): - # we should see at most one ACK - real errors get raised anyway - pass - finally: - with self._socket.lock[0]: - # clear messages from "seq 0" queue - because if there - # was an error in our backlog, it got raised and the - # remaining messages moved to 0 - self._socket.backlog[0] = [] + for _msg in self._socket.nlm_request_batch(self._msgs): + # we should see at most one ACK - real errors get raised anyway + pass def _put(self, msg: nfgen_msg) -> None: if self._closed: raise Exception("Transaction already closed") - if self._final_msg: - # previous message wasn't the final one, encode it without ACK - self._final_msg.encode() - self._data += self._final_msg.data - self._final_msg = msg + self._msgs.append(msg) def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None: msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type] @@ -217,7 +184,6 @@ class NFTTransaction: header = dict( type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type, flags=msg_flags, - sequence_number=self._seqnum, ) msg = _build(msg_class, attrs=attrs, header=header, **fields) self._put(msg)