2
0

update pyroute2 handling (pr2modules is gone)

This commit is contained in:
Stefan Bühler 2023-01-12 12:17:49 +01:00
parent 65116375b3
commit f9c9b98868
4 changed files with 42 additions and 41 deletions

View File

@ -26,7 +26,7 @@ install_requires =
hypercorn[trio] hypercorn[trio]
PyYAML PyYAML
protobuf>=4.21 protobuf>=4.21
pyroute2 pyroute2~=0.7.3
[options.packages.find] [options.packages.find]
where = src where = src

View File

@ -6,10 +6,10 @@ import ipaddress
import socket import socket
import typing import typing
import pr2modules.iproute.linux import pyroute2.iproute.linux
import pr2modules.netlink.exceptions import pyroute2.netlink.exceptions
import pr2modules.netlink.rtnl import pyroute2.netlink.rtnl
import pr2modules.netlink.rtnl.ndmsg import pyroute2.netlink.rtnl.ndmsg
from capport import cptypes from capport import cptypes
@ -21,7 +21,7 @@ async def connect():
# TODO: run blocking iproute calls in a different thread? # TODO: run blocking iproute calls in a different thread?
class NeighborController: class NeighborController:
def __init__(self): def __init__(self):
self.ip = pr2modules.iproute.linux.IPRoute() self.ip = pyroute2.iproute.linux.IPRoute()
async def get_neighbor( async def get_neighbor(
self, self,
@ -29,7 +29,7 @@ class NeighborController:
*, *,
index: int=0, # interface index index: int=0, # interface index
flags: int=0, flags: int=0,
) -> typing.Optional[pr2modules.iproute.linux.ndmsg.ndmsg]: ) -> typing.Optional[pyroute2.iproute.linux.ndmsg.ndmsg]:
if not index: if not index:
route = await self.get_route(address) route = await self.get_route(address)
if route is None: if route is None:
@ -37,7 +37,7 @@ class NeighborController:
index = route.get_attr(route.name2nla('oif')) index = route.get_attr(route.name2nla('oif'))
try: try:
return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0] return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0]
except pr2modules.netlink.exceptions.NetlinkError as e: except pyroute2.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT: if e.code == errno.ENOENT:
return None return None
raise raise
@ -58,17 +58,17 @@ class NeighborController:
async def get_route( async def get_route(
self, self,
address: cptypes.IPAddress, address: cptypes.IPAddress,
) -> typing.Optional[pr2modules.iproute.linux.rtmsg]: ) -> typing.Optional[pyroute2.iproute.linux.rtmsg]:
try: try:
return self.ip.route('get', dst=str(address))[0] return self.ip.route('get', dst=str(address))[0]
except pr2modules.netlink.exceptions.NetlinkError as e: except pyroute2.netlink.exceptions.NetlinkError as e:
if e.code == errno.ENOENT: if e.code == errno.ENOENT:
return None return None
raise raise
async def dump_neighbors(self, interface: str) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]: async def dump_neighbors(self, interface: str) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]:
ifindex = socket.if_nametoindex(interface) ifindex = socket.if_nametoindex(interface)
unicast_num = pr2modules.netlink.rtnl.rt_type['unicast'] unicast_num = pyroute2.netlink.rtnl.rt_type['unicast']
# ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET) # ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET)
for family in (socket.AF_INET, socket.AF_INET6): for family in (socket.AF_INET, socket.AF_INET6):
for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family): for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family):

View File

@ -2,9 +2,9 @@ from __future__ import annotations
import typing import typing
import pr2modules.netlink import pyroute2.netlink
from capport import cptypes from capport import cptypes
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket
from .nft_socket import NFTSocket from .nft_socket import NFTSocket
@ -48,7 +48,7 @@ class NftSet:
# make sure entries exists # make sure entries exists
tx.put( tx.put(
_nftsocket.NFT_MSG_NEWSETELEM, _nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE, pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_TABLE='captive_mark', NFTA_SET_TABLE='captive_mark',
@ -69,7 +69,7 @@ class NftSet:
# now create entries with new timeout value # now create entries with new timeout value
tx.put( tx.put(
_nftsocket.NFT_MSG_NEWSETELEM, _nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE|pr2modules.netlink.NLM_F_EXCL, pyroute2.netlink.NLM_F_CREATE|pyroute2.netlink.NLM_F_EXCL,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_TABLE='captive_mark', NFTA_SET_TABLE='captive_mark',
@ -96,7 +96,7 @@ class NftSet:
# make sure entries exists # make sure entries exists
tx.put( tx.put(
_nftsocket.NFT_MSG_NEWSETELEM, _nftsocket.NFT_MSG_NEWSETELEM,
pr2modules.netlink.NLM_F_CREATE, pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_TABLE='captive_mark', NFTA_SET_TABLE='captive_mark',
@ -160,7 +160,7 @@ class NftSet:
with self._socket.begin() as tx: with self._socket.begin() as tx:
tx.put( tx.put(
_nftsocket.NFT_MSG_NEWTABLE, _nftsocket.NFT_MSG_NEWTABLE,
pr2modules.netlink.NLM_F_CREATE, pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_TABLE_NAME='captive_mark', NFTA_TABLE_NAME='captive_mark',
@ -168,7 +168,7 @@ class NftSet:
) )
tx.put( tx.put(
_nftsocket.NFT_MSG_NEWSET, _nftsocket.NFT_MSG_NEWSET,
pr2modules.netlink.NLM_F_CREATE, pyroute2.netlink.NLM_F_CREATE,
nfgen_family=NFPROTO_INET, nfgen_family=NFPROTO_INET,
attrs=dict( attrs=dict(
NFTA_SET_TABLE='captive_mark', NFTA_SET_TABLE='captive_mark',

View File

@ -4,34 +4,35 @@ import contextlib
import typing import typing
import threading import threading
from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket
import pr2modules.netlink import pyroute2.netlink
import pr2modules.netlink.nlsocket import pyroute2.netlink.nlsocket
from pr2modules.netlink.nfnetlink import nfgen_msg from pyroute2.netlink.nfnetlink import nfgen_msg
from pr2modules.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES
NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC" NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC"
_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pr2modules.netlink.nlmsg_base) _NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base)
# nft uses NESTED for those.. lets do the same # nft uses NESTED for those.. lets do the same
_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pr2modules.netlink.NLA_F_NESTED _nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pyroute2.netlink.NLA_F_NESTED
_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pr2modules.netlink.NLA_F_NESTED _nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pyroute2.netlink.NLA_F_NESTED
# nftable lists always use `1` as list element attr type # nftable lists always use `1` as list element attr type
_nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM _nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM
def _monkey_patch_pyroute2(): def _monkey_patch_pyroute2():
import pr2modules.netlink import pyroute2.netlink
# overwrite setdefault on nlmsg_base class hierarchy # overwrite setdefault on nlmsg_base class hierarchy
_orig_setvalue = pr2modules.netlink.nlmsg_base.setvalue _orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue
def _nlmsg_base__setvalue(self, value): def _nlmsg_base__setvalue(self, value):
if not self.header or not self['header'] or not isinstance(value, dict): if not self.header or not self['header'] or not isinstance(value, dict):
return _orig_setvalue(self, value) return _orig_setvalue(self, value)
# merge headers instead of overwriting them
header = value.pop('header', {}) header = value.pop('header', {})
res = _orig_setvalue(self, value) res = _orig_setvalue(self, value)
self['header'].update(header) self['header'].update(header)
@ -43,7 +44,7 @@ def _monkey_patch_pyroute2():
for subcls in cls.__subclasses__(): for subcls in cls.__subclasses__():
overwrite_methods(subcls) overwrite_methods(subcls)
overwrite_methods(pr2modules.netlink.nlmsg_base) overwrite_methods(pyroute2.netlink.nlmsg_base)
_monkey_patch_pyroute2() _monkey_patch_pyroute2()
@ -61,24 +62,24 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header:
key = msg_class.name2nla(key) key = msg_class.name2nla(key)
prime = r_nla_map[key] prime = r_nla_map[key]
nla_class = prime['class'] nla_class = prime['class']
if issubclass(nla_class, pr2modules.netlink.nla): if issubclass(nla_class, pyroute2.netlink.nla):
# support passing nested attributes as dicts of subattributes (or lists of those) # support passing nested attributes as dicts of subattributes (or lists of those)
if prime['nla_array']: if prime['nla_array']:
value = [ value = [
_build(nla_class, attrs=elem) if not isinstance(elem, pr2modules.netlink.nlmsg_base) and isinstance(elem, dict) else elem _build(nla_class, attrs=elem) if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict) else elem
for elem in value for elem in value
] ]
elif not isinstance(value, pr2modules.netlink.nlmsg_base) and isinstance(value, dict): elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict):
value = _build(nla_class, attrs=value) value = _build(nla_class, attrs=value)
attr_list.append([key, value]) attr_list.append([key, value])
return msg return msg
class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket): class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket):
policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(family=pr2modules.netlink.NETLINK_NETFILTER) super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER)
policy = { policy = {
(x | (NFNL_SUBSYS_NFTABLES << 8)): y (x | (NFNL_SUBSYS_NFTABLES << 8)): y
for (x, y) in self.policy.items() for (x, y) in self.policy.items()
@ -102,13 +103,13 @@ class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket):
tx.put(msg_type, msg_flags, attrs=attrs, **fields) tx.put(msg_type, msg_flags, attrs=attrs, **fields)
def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_flags |= pr2modules.netlink.NLM_F_DUMP msg_flags |= pyroute2.netlink.NLM_F_DUMP
return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields) return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields)
def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type] msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type]
msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type
msg_flags |= pr2modules.netlink.NLM_F_REQUEST msg_flags |= pyroute2.netlink.NLM_F_REQUEST
msg = _build(msg_class, attrs=attrs, **fields) msg = _build(msg_class, attrs=attrs, **fields)
return self.nlm_request(msg, msg_type, msg_flags) return self.nlm_request(msg, msg_type, msg_flags)
@ -132,7 +133,7 @@ class NFTTransaction:
res_id=NFNL_SUBSYS_NFTABLES, res_id=NFNL_SUBSYS_NFTABLES,
header=dict( header=dict(
type=0x10, # NFNL_MSG_BATCH_BEGIN type=0x10, # NFNL_MSG_BATCH_BEGIN
flags=pr2modules.netlink.NLM_F_REQUEST, flags=pyroute2.netlink.NLM_F_REQUEST,
sequence_number=self._seqnum, sequence_number=self._seqnum,
), ),
) )
@ -165,7 +166,7 @@ class NFTTransaction:
return return
self._closed = True self._closed = True
# request ACK only on the last message (before END) # request ACK only on the last message (before END)
self._final_msg['header']['flags'] |= pr2modules.netlink.NLM_F_ACK self._final_msg['header']['flags'] |= pyroute2.netlink.NLM_F_ACK
self._final_msg.encode() self._final_msg.encode()
self._data += self._final_msg.data self._data += self._final_msg.data
self._final_msg = None self._final_msg = None
@ -175,7 +176,7 @@ class NFTTransaction:
res_id=NFNL_SUBSYS_NFTABLES, res_id=NFNL_SUBSYS_NFTABLES,
header=dict( header=dict(
type=0x11, # NFNL_MSG_BATCH_END type=0x11, # NFNL_MSG_BATCH_END
flags=pr2modules.netlink.NLM_F_REQUEST, flags=pyroute2.netlink.NLM_F_REQUEST,
sequence_number=self._seqnum, sequence_number=self._seqnum,
), ),
) )
@ -208,8 +209,8 @@ class NFTTransaction:
def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None:
msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type] msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type]
msg_flags |= pr2modules.netlink.NLM_F_REQUEST # always set REQUEST msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST
msg_flags &= ~pr2modules.netlink.NLM_F_ACK # make sure ACK is not set! msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set!
header = dict( header = dict(
type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type, type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type,
flags=msg_flags, flags=msg_flags,