diff --git a/setup.cfg b/setup.cfg index 444a919..b475dba 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ install_requires = hypercorn[trio] PyYAML protobuf>=4.21 - pyroute2 + pyroute2~=0.7.3 [options.packages.find] where = src diff --git a/src/capport/utils/ipneigh.py b/src/capport/utils/ipneigh.py index 0a2df3c..797003a 100644 --- a/src/capport/utils/ipneigh.py +++ b/src/capport/utils/ipneigh.py @@ -6,10 +6,10 @@ import ipaddress import socket import typing -import pr2modules.iproute.linux -import pr2modules.netlink.exceptions -import pr2modules.netlink.rtnl -import pr2modules.netlink.rtnl.ndmsg +import pyroute2.iproute.linux +import pyroute2.netlink.exceptions +import pyroute2.netlink.rtnl +import pyroute2.netlink.rtnl.ndmsg from capport import cptypes @@ -21,7 +21,7 @@ async def connect(): # TODO: run blocking iproute calls in a different thread? class NeighborController: def __init__(self): - self.ip = pr2modules.iproute.linux.IPRoute() + self.ip = pyroute2.iproute.linux.IPRoute() async def get_neighbor( self, @@ -29,7 +29,7 @@ class NeighborController: *, index: int=0, # interface index flags: int=0, - ) -> typing.Optional[pr2modules.iproute.linux.ndmsg.ndmsg]: + ) -> typing.Optional[pyroute2.iproute.linux.ndmsg.ndmsg]: if not index: route = await self.get_route(address) if route is None: @@ -37,7 +37,7 @@ class NeighborController: index = route.get_attr(route.name2nla('oif')) try: return self.ip.neigh('get', dst=str(address), ifindex=index, state='none')[0] - except pr2modules.netlink.exceptions.NetlinkError as e: + except pyroute2.netlink.exceptions.NetlinkError as e: if e.code == errno.ENOENT: return None raise @@ -58,17 +58,17 @@ class NeighborController: async def get_route( self, address: cptypes.IPAddress, - ) -> typing.Optional[pr2modules.iproute.linux.rtmsg]: + ) -> typing.Optional[pyroute2.iproute.linux.rtmsg]: try: return self.ip.route('get', dst=str(address))[0] - except pr2modules.netlink.exceptions.NetlinkError as e: + except pyroute2.netlink.exceptions.NetlinkError as e: if e.code == errno.ENOENT: return None raise async def dump_neighbors(self, interface: str) -> typing.AsyncGenerator[typing.Tuple[cptypes.MacAddress, cptypes.IPAddress], None]: ifindex = socket.if_nametoindex(interface) - unicast_num = pr2modules.netlink.rtnl.rt_type['unicast'] + unicast_num = pyroute2.netlink.rtnl.rt_type['unicast'] # ip.neigh doesn't support AF_UNSPEC (as it is 0 and evaluates to `False` and gets forced to AF_INET) for family in (socket.AF_INET, socket.AF_INET6): for neigh in self.ip.neigh('dump', ifindex=ifindex, family=family): diff --git a/src/capport/utils/nft_set.py b/src/capport/utils/nft_set.py index 7b3fcfd..0503ad4 100644 --- a/src/capport/utils/nft_set.py +++ b/src/capport/utils/nft_set.py @@ -2,9 +2,9 @@ from __future__ import annotations import typing -import pr2modules.netlink +import pyroute2.netlink from capport import cptypes -from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket +from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket from .nft_socket import NFTSocket @@ -48,7 +48,7 @@ class NftSet: # make sure entries exists tx.put( _nftsocket.NFT_MSG_NEWSETELEM, - pr2modules.netlink.NLM_F_CREATE, + pyroute2.netlink.NLM_F_CREATE, nfgen_family=NFPROTO_INET, attrs=dict( NFTA_SET_TABLE='captive_mark', @@ -69,7 +69,7 @@ class NftSet: # now create entries with new timeout value tx.put( _nftsocket.NFT_MSG_NEWSETELEM, - pr2modules.netlink.NLM_F_CREATE|pr2modules.netlink.NLM_F_EXCL, + pyroute2.netlink.NLM_F_CREATE|pyroute2.netlink.NLM_F_EXCL, nfgen_family=NFPROTO_INET, attrs=dict( NFTA_SET_TABLE='captive_mark', @@ -96,7 +96,7 @@ class NftSet: # make sure entries exists tx.put( _nftsocket.NFT_MSG_NEWSETELEM, - pr2modules.netlink.NLM_F_CREATE, + pyroute2.netlink.NLM_F_CREATE, nfgen_family=NFPROTO_INET, attrs=dict( NFTA_SET_TABLE='captive_mark', @@ -160,7 +160,7 @@ class NftSet: with self._socket.begin() as tx: tx.put( _nftsocket.NFT_MSG_NEWTABLE, - pr2modules.netlink.NLM_F_CREATE, + pyroute2.netlink.NLM_F_CREATE, nfgen_family=NFPROTO_INET, attrs=dict( NFTA_TABLE_NAME='captive_mark', @@ -168,7 +168,7 @@ class NftSet: ) tx.put( _nftsocket.NFT_MSG_NEWSET, - pr2modules.netlink.NLM_F_CREATE, + pyroute2.netlink.NLM_F_CREATE, nfgen_family=NFPROTO_INET, attrs=dict( NFTA_SET_TABLE='captive_mark', diff --git a/src/capport/utils/nft_socket.py b/src/capport/utils/nft_socket.py index c517dd7..b5e7092 100644 --- a/src/capport/utils/nft_socket.py +++ b/src/capport/utils/nft_socket.py @@ -4,34 +4,35 @@ import contextlib import typing import threading -from pr2modules.netlink.nfnetlink import nftsocket as _nftsocket -import pr2modules.netlink -import pr2modules.netlink.nlsocket -from pr2modules.netlink.nfnetlink import nfgen_msg -from pr2modules.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES +from pyroute2.netlink.nfnetlink import nftsocket as _nftsocket +import pyroute2.netlink +import pyroute2.netlink.nlsocket +from pyroute2.netlink.nfnetlink import nfgen_msg +from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES NFPROTO_INET: int = 1 # nfgen_family "ipv4+ipv6"; strace decodes this as "AF_UNSPEC" -_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pr2modules.netlink.nlmsg_base) +_NlMsgBase = typing.TypeVar('_NlMsgBase', bound=pyroute2.netlink.nlmsg_base) # nft uses NESTED for those.. lets do the same -_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pr2modules.netlink.NLA_F_NESTED -_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pr2modules.netlink.NLA_F_NESTED +_nftsocket.nft_set_elem_list_msg.set_elem.data_attributes.nla_flags = pyroute2.netlink.NLA_F_NESTED +_nftsocket.nft_set_elem_list_msg.set_elem.nla_flags = pyroute2.netlink.NLA_F_NESTED # nftable lists always use `1` as list element attr type _nftsocket.nft_set_elem_list_msg.set_elem.header_type = 1 # NFTA_LIST_ELEM def _monkey_patch_pyroute2(): - import pr2modules.netlink + import pyroute2.netlink # overwrite setdefault on nlmsg_base class hierarchy - _orig_setvalue = pr2modules.netlink.nlmsg_base.setvalue + _orig_setvalue = pyroute2.netlink.nlmsg_base.setvalue def _nlmsg_base__setvalue(self, value): if not self.header or not self['header'] or not isinstance(value, dict): return _orig_setvalue(self, value) + # merge headers instead of overwriting them header = value.pop('header', {}) res = _orig_setvalue(self, value) self['header'].update(header) @@ -43,7 +44,7 @@ def _monkey_patch_pyroute2(): for subcls in cls.__subclasses__(): overwrite_methods(subcls) - overwrite_methods(pr2modules.netlink.nlmsg_base) + overwrite_methods(pyroute2.netlink.nlmsg_base) _monkey_patch_pyroute2() @@ -61,24 +62,24 @@ def _build(msg_class: typing.Type[_NlMsgBase], /, attrs: typing.Dict={}, header: key = msg_class.name2nla(key) prime = r_nla_map[key] nla_class = prime['class'] - if issubclass(nla_class, pr2modules.netlink.nla): + if issubclass(nla_class, pyroute2.netlink.nla): # support passing nested attributes as dicts of subattributes (or lists of those) if prime['nla_array']: value = [ - _build(nla_class, attrs=elem) if not isinstance(elem, pr2modules.netlink.nlmsg_base) and isinstance(elem, dict) else elem + _build(nla_class, attrs=elem) if not isinstance(elem, pyroute2.netlink.nlmsg_base) and isinstance(elem, dict) else elem for elem in value ] - elif not isinstance(value, pr2modules.netlink.nlmsg_base) and isinstance(value, dict): + elif not isinstance(value, pyroute2.netlink.nlmsg_base) and isinstance(value, dict): value = _build(nla_class, attrs=value) attr_list.append([key, value]) return msg -class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket): +class NFTSocket(pyroute2.netlink.nlsocket.NetlinkSocket): policy: typing.Dict[int, typing.Type[_nftsocket.nft_gen_msg]] = _nftsocket.NFTSocket.policy def __init__(self) -> None: - super().__init__(family=pr2modules.netlink.NETLINK_NETFILTER) + super().__init__(family=pyroute2.netlink.NETLINK_NETFILTER) policy = { (x | (NFNL_SUBSYS_NFTABLES << 8)): y for (x, y) in self.policy.items() @@ -102,13 +103,13 @@ class NFTSocket(pr2modules.netlink.nlsocket.NetlinkSocket): tx.put(msg_type, msg_flags, attrs=attrs, **fields) def nft_dump(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: - msg_flags |= pr2modules.netlink.NLM_F_DUMP + msg_flags |= pyroute2.netlink.NLM_F_DUMP return self.nft_get(msg_type, msg_flags, attrs=attrs, **fields) def nft_get(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: msg_class: typing.Type[_nftsocket.nft_gen_msg] = self.policy[msg_type] msg_type = (NFNL_SUBSYS_NFTABLES << 8) | msg_type - msg_flags |= pr2modules.netlink.NLM_F_REQUEST + msg_flags |= pyroute2.netlink.NLM_F_REQUEST msg = _build(msg_class, attrs=attrs, **fields) return self.nlm_request(msg, msg_type, msg_flags) @@ -132,7 +133,7 @@ class NFTTransaction: res_id=NFNL_SUBSYS_NFTABLES, header=dict( type=0x10, # NFNL_MSG_BATCH_BEGIN - flags=pr2modules.netlink.NLM_F_REQUEST, + flags=pyroute2.netlink.NLM_F_REQUEST, sequence_number=self._seqnum, ), ) @@ -165,7 +166,7 @@ class NFTTransaction: return self._closed = True # request ACK only on the last message (before END) - self._final_msg['header']['flags'] |= pr2modules.netlink.NLM_F_ACK + self._final_msg['header']['flags'] |= pyroute2.netlink.NLM_F_ACK self._final_msg.encode() self._data += self._final_msg.data self._final_msg = None @@ -175,7 +176,7 @@ class NFTTransaction: res_id=NFNL_SUBSYS_NFTABLES, header=dict( type=0x11, # NFNL_MSG_BATCH_END - flags=pr2modules.netlink.NLM_F_REQUEST, + flags=pyroute2.netlink.NLM_F_REQUEST, sequence_number=self._seqnum, ), ) @@ -208,8 +209,8 @@ class NFTTransaction: def put(self, msg_type: int, msg_flags: int=0, /, *, attrs: typing.Dict={}, **fields) -> None: msg_class: typing.Type[_nftsocket.nft_gen_msg] = self._socket.policy[msg_type] - msg_flags |= pr2modules.netlink.NLM_F_REQUEST # always set REQUEST - msg_flags &= ~pr2modules.netlink.NLM_F_ACK # make sure ACK is not set! + msg_flags |= pyroute2.netlink.NLM_F_REQUEST # always set REQUEST + msg_flags &= ~pyroute2.netlink.NLM_F_ACK # make sure ACK is not set! header = dict( type=(NFNL_SUBSYS_NFTABLES << 8) | msg_type, flags=msg_flags,