nft_socket.NFTTransaction: use pyroute2 nlm_request_batch instead of hack. fixes leak of sequence ids.
This commit is contained in:
parent
3d74447fd9
commit
1acb693e2b
@ -120,28 +120,18 @@ class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
|
|||||||
class NFTTransaction:
|
class NFTTransaction:
|
||||||
def __init__(self, socket: NFTSocket) -> None:
|
def __init__(self, socket: NFTSocket) -> None:
|
||||||
self._socket = socket
|
self._socket = socket
|
||||||
self._data = b''
|
|
||||||
self._seqnum = self._socket.addr_pool.alloc()
|
|
||||||
self._closed = False
|
self._closed = False
|
||||||
# neither NFNL_MSG_BATCH_BEGIN nor NFNL_MSG_BATCH_END supports ACK, but we need an ACK
|
self._msgs: list[nfgen_msg] = [
|
||||||
# at the end of an transaction to make sure it worked.
|
# batch begin message
|
||||||
# we could use a different sequence number for all changes, and wait for an ACK for each of them
|
_build(
|
||||||
# (but we'd also need to check for errors on the BEGIN sequence number).
|
|
||||||
# the other solution: use the same sequence number for all messages in the batch, and add ACK
|
|
||||||
# only to the final message (before END) - if we get the ACK we known all other messages before
|
|
||||||
# worked out.
|
|
||||||
self._final_msg: typing.Optional[_nftsocket.nft_gen_msg] = None
|
|
||||||
begin_msg = _build(
|
|
||||||
nfgen_msg,
|
nfgen_msg,
|
||||||
res_id=NFNL_SUBSYS_NFTABLES,
|
res_id=NFNL_SUBSYS_NFTABLES,
|
||||||
header=dict(
|
header=dict(
|
||||||
type=0x10, # NFNL_MSG_BATCH_BEGIN
|
type=0x10, # NFNL_MSG_BATCH_BEGIN
|
||||||
flags=pyroute2.netlink.NLM_F_REQUEST,
|
flags=pyroute2.netlink.NLM_F_REQUEST,
|
||||||
sequence_number=self._seqnum,
|
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
begin_msg.encode()
|
]
|
||||||
self._data += begin_msg.data
|
|
||||||
|
|
||||||
def abort(self) -> None:
|
def abort(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -149,8 +139,6 @@ class NFTTransaction:
|
|||||||
"""
|
"""
|
||||||
if not self._closed:
|
if not self._closed:
|
||||||
self._closed = True
|
self._closed = True
|
||||||
# unused seqnum
|
|
||||||
self._socket.addr_pool.free(self._seqnum)
|
|
||||||
|
|
||||||
def autocommit(self) -> None:
|
def autocommit(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -163,52 +151,31 @@ class NFTTransaction:
|
|||||||
def commit(self) -> None:
|
def commit(self) -> None:
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise Exception("Transaction already closed")
|
raise Exception("Transaction already closed")
|
||||||
if not self._final_msg:
|
|
||||||
# no inner messages were queued... just abort transaction
|
|
||||||
self.abort()
|
|
||||||
return
|
|
||||||
self._closed = True
|
self._closed = True
|
||||||
# request ACK only on the last message (before END)
|
if len(self._msgs) == 1:
|
||||||
self._final_msg['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
|
# no inner messages were queued... not sending anything
|
||||||
self._final_msg.encode()
|
return
|
||||||
self._data += self._final_msg.data
|
# request ACK on the last message (before END)
|
||||||
self._final_msg = None
|
self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
|
||||||
# batch end
|
self._msgs.append(
|
||||||
end_msg = _build(
|
# batch end message
|
||||||
|
_build(
|
||||||
nfgen_msg,
|
nfgen_msg,
|
||||||
res_id=NFNL_SUBSYS_NFTABLES,
|
res_id=NFNL_SUBSYS_NFTABLES,
|
||||||
header=dict(
|
header=dict(
|
||||||
type=0x11, # NFNL_MSG_BATCH_END
|
type=0x11, # NFNL_MSG_BATCH_END
|
||||||
flags=pyroute2.netlink.NLM_F_REQUEST,
|
flags=pyroute2.netlink.NLM_F_REQUEST,
|
||||||
sequence_number=self._seqnum,
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
end_msg.encode()
|
for _msg in self._socket.nlm_request_batch(self._msgs):
|
||||||
self._data += end_msg.data
|
|
||||||
# need to create backlog for our sequence number
|
|
||||||
with self._socket.lock[self._seqnum]:
|
|
||||||
self._socket.backlog[self._seqnum] = []
|
|
||||||
# send
|
|
||||||
self._socket.sendto(self._data, (0, 0))
|
|
||||||
try:
|
|
||||||
for _msg in self._socket.get(msg_seq=self._seqnum):
|
|
||||||
# we should see at most one ACK - real errors get raised anyway
|
# we should see at most one ACK - real errors get raised anyway
|
||||||
pass
|
pass
|
||||||
finally:
|
|
||||||
with self._socket.lock[0]:
|
|
||||||
# clear messages from "seq 0" queue - because if there
|
|
||||||
# was an error in our backlog, it got raised and the
|
|
||||||
# remaining messages moved to 0
|
|
||||||
self._socket.backlog[0] = []
|
|
||||||
|
|
||||||
def _put(self, msg: nfgen_msg) -> None:
|
def _put(self, msg: nfgen_msg) -> None:
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise Exception("Transaction already closed")
|
raise Exception("Transaction already closed")
|
||||||
if self._final_msg:
|
self._msgs.append(msg)
|
||||||
# previous message wasn't the final one, encode it without ACK
|
|
||||||
self._final_msg.encode()
|
|
||||||
self._data += self._final_msg.data
|
|
||||||
self._final_msg = msg
|
|
||||||
|
|
||||||
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
|
||||||
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
|
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
|
||||||
@ -217,7 +184,6 @@ class NFTTransaction:
|
|||||||
header = dict(
|
header = dict(
|
||||||
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
|
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
|
||||||
flags=msg_flags,
|
flags=msg_flags,
|
||||||
sequence_number=self._seqnum,
|
|
||||||
)
|
)
|
||||||
msg = _build(msg_class, attrs=attrs, header=header, **fields)
|
msg = _build(msg_class, attrs=attrs, header=header, **fields)
|
||||||
self._put(msg)
|
self._put(msg)
|
||||||
|
Loading…
Reference in New Issue
Block a user