3
0

190 lines
7.0 KiB
Python

from __future__ import annotations
import contextlib
import typing
import pyroute2.netlink # type: ignore
import pyroute2.netlink.nlsocket # type: ignore
from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES # type: ignore
from pyroute2.netlink.nfnetlink import nfgen_msg # type: ignore
from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket # type: ignore
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base)
# nft uses NESTED for those.. lets do the same
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pyroute2.netlink.NLA_F_NESTED
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pyroute2.netlink.NLA_F_NESTED
# nftable lists always use `1` as list element attr type
_nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM
def _monkey_patch_pyroute2():
import pyroute2.netlink
# overwrite setdefault on nlmsg_base class hierarchy
_orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
def _nlmsg_base__setvalue(self, value):
if not self.header or not self['header'] or not isinstance(value, dict):
return _orig_setvalue(self, value)
# merge headers instead of overwriting them
header = value.pop('header', {})
res = _orig_setvalue(self, value)
self['header'].update(header)
return res
def overwrite_methods(cls: typing.Type) -> None:
if cls.setvalue is _orig_setvalue:
cls.setvalue = _nlmsg_base__setvalue
for subcls in cls.__subclasses__():
overwrite_methods(subcls)
overwrite_methods(pyroute2.netlink.nlmsg_base)
_monkey_patch_pyroute2()
def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: dict = {}, header: dict = {}, **fields) -> _NlMsgBase:
msg = msg_class()
for key, value in header.items():
msg['header'][key] = value
for key, value in fields.items():
msg[key] = value
if attrs:
attr_list = msg['attrs']
r_nla_map = msg_class._nlmsg_base__r_nla_map
for key, value in attrs.items():
if msg_class.prefix:
key = msg_class.name2nla(key)
prime = r_nla_map[key]
nla_class = prime['class']
if issubclass(nla_class, pyroute2.netlink.nla):
# support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']:
value = [
_build(nla_class, attrs=elem)
if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict)
else elem
for elem in value
]
elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict):
value = _build(nla_class, attrs=value)
attr_list.append([key, value])
return msg
class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
policy: dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None:
super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
policy = {
(x | (NFNL_SUBSYS_NFTABLES << 8)): y
for (x, y) in self.policy.items()
}
self.register_policy(policy)
@contextlib.contextmanager
def begin(self) -> typing.Generator[NFTTransaction, None, None]:
tx = NFTTransaction(socket=self)
try:
yield tx
# autocommit when no exception was raised
# (only commits if it wasn't aborted)
tx.autocommit()
finally:
# abort does nothing if commit went through
tx.abort()
def nft_put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
with self.begin() as tx:
tx.put(msg_type, msg_flags, attrs=attrs, **fields)
def nft_dump(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_flags |= pyroute2.netlink.NLM_F_DUMP
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pyroute2.netlink.NLM_F_REQUEST
msg = _build(msg_class, attrs=attrs, **fields)
return self.nlm_request(msg, msg_type, msg_flags)
class NFTTransaction:
def __init__(self, socket: NFTSocket) -> None:
self._socket = socket
self._closed = False
self._msgs: list[nfgen_msg] = [
# batch begin message
_build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x10, # NFNL_MSG_BATCH_BEGIN
flags=pyroute2.netlink.NLM_F_REQUEST,
),
),
]
def abort(self) -> None:
"""
Aborts if transaction wasn't already committed or aborted
"""
if not self._closed:
self._closed = True
def autocommit(self) -> None:
"""
Commits if transaction wasn't already committed or aborted
"""
if self._closed:
return
self.commit()
def commit(self) -> None:
if self._closed:
raise Exception("Transaction already closed")
self._closed = True
if len(self._msgs) == 1:
# no inner messages were queued... not sending anything
return
# request ACK on the last message (before END)
self._msgs[-1]['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
self._msgs.append(
# batch end message
_build(
nfgen_msg,
res_id=NFNL_SUBSYS_NFTABLES,
header=dict(
type=0x11, # NFNL_MSG_BATCH_END
flags=pyroute2.netlink.NLM_F_REQUEST,
),
),
)
for _msg in self._socket.nlm_request_batch(self._msgs):
# we should see at most one ACK - real errors get raised anyway
pass
def _put(self, msg: nfgen_msg) -> None:
if self._closed:
raise Exception("Transaction already closed")
self._msgs.append(msg)
def put(self, msg_type: int, msg_flags: int = 0, /, *, attrs: dict = {}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!
header = dict(
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
flags=msg_flags,
)
msg = _build(msg_class, attrs=attrs, header=header, **fields)
self._put(msg)