From f1a9a74c479fde5fcfce781bbf2bdae075ff7b33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Tue, 14 Jan 2025 17:24:26 +0100 Subject: [PATCH 01/18] WIP: add an async dhcp client --- pyroute2/dhcp/cli.py | 105 ++++++++ pyroute2/dhcp/client.py | 333 ++++++++++++++++++------ pyroute2/dhcp/constants/__init__.py | 0 pyroute2/dhcp/constants/bootp.py | 35 +++ pyroute2/dhcp/constants/dhcp.py | 130 +++++++++ pyroute2/dhcp/dhcp4socket.py | 79 ++++-- pyroute2/dhcp/fsm.py | 62 +++++ pyroute2/dhcp/hooks.py | 30 +++ pyroute2/dhcp/leases.py | 154 +++++++++++ pyroute2/dhcp/messages.py | 45 ++++ pyroute2/dhcp/timers.py | 64 +++++ pyroute2/ext/rawsocket.py | 37 +-- setup.cfg | 2 +- tests/test_linux/conftest.py | 6 +- tests/test_linux/fixtures/dnsmasq.py | 155 +++++++++++ tests/test_linux/fixtures/interfaces.py | 68 +++++ tests/test_linux/test_raw/test_dhcp.py | 110 ++++---- 17 files changed, 1245 insertions(+), 170 deletions(-) create mode 100644 pyroute2/dhcp/cli.py create mode 100644 pyroute2/dhcp/constants/__init__.py create mode 100644 pyroute2/dhcp/constants/bootp.py create mode 100644 pyroute2/dhcp/constants/dhcp.py create mode 100644 pyroute2/dhcp/fsm.py create mode 100644 pyroute2/dhcp/hooks.py create mode 100644 pyroute2/dhcp/leases.py create mode 100644 pyroute2/dhcp/messages.py create mode 100644 pyroute2/dhcp/timers.py create mode 100644 tests/test_linux/fixtures/dnsmasq.py create mode 100644 tests/test_linux/fixtures/interfaces.py diff --git a/pyroute2/dhcp/cli.py b/pyroute2/dhcp/cli.py new file mode 100644 index 000000000..86d3f73aa --- /dev/null +++ b/pyroute2/dhcp/cli.py @@ -0,0 +1,105 @@ +import asyncio +import logging +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from importlib import import_module +from typing import Any + +from pyroute2.dhcp.client import AsyncDHCPClient +from pyroute2.dhcp.hooks import ConfigureIP, Hook +from pyroute2.dhcp.leases import Lease + + +def importable(name: str) -> Any: + '''Imports anything by name. Used by the argument parser.''' + module_name, obj_name = name.rsplit('.', 1) + module = import_module(module_name) + return getattr(module, obj_name) + + +def get_psr() -> ArgumentParser: + psr = ArgumentParser( + description='pyroute2 DHCP client', + formatter_class=ArgumentDefaultsHelpFormatter, + ) + psr.add_argument( + 'interface', help='The interface to request an address for.' + ) + psr.add_argument( + '--lease-type', + help='Class to use for leases. ' + 'Must be a subclass of `pyroute2.dhcp.leases.Lease`.', + type=importable, + default='pyroute2.dhcp.leases.JSONFileLease', + metavar='dotted.name', + ) + psr.add_argument( + '--hook', + help='Hooks to load. ' + 'These are used to run async python code when, ' + 'for example, renewing or expiring a lease.', + nargs='+', + type=importable, + default=[ConfigureIP], + metavar='dotted.name', + ) + psr.add_argument( + '-x', + '--exit-on-lease', + help='Exit as soon as getting a lease.', + default=False, + action='store_true', + ) + psr.add_argument( + '--log-level', + help='Logging level to use.', + choices=('DEBUG', 'INFO', 'WARNING', 'ERROR'), + default='INFO', + ) + return psr + + +async def main(): + psr = get_psr() + args = psr.parse_args() + logging.basicConfig( + level=args.log_level, + format='%(asctime)s %(levelname)s [%(name)s:%(funcName)s] %(message)s', + ) + + if not issubclass(args.lease_type, Lease): + psr.error(f'{args.lease_type!r} must be a Lease subclass') + + # Check hooks are subclasses of Hook + for i in args.hook: + if not issubclass(i, Hook): + psr.error(f'{i!r} must be a Hook subclass') + + acli = AsyncDHCPClient( + interface=args.interface, + lease_type=args.lease_type, + # Instantiate hooks + hooks=[i() for i in args.hook], + ) + + # Open the socket, read existing lease, etc + async with acli: + # Bootstrap the client by sending a DISCOVER or a REQUEST + await acli.bootstrap() + if args.exit_on_lease: + # Wait until we're bound once, then exit + await acli.bound.wait() + else: + # Wait until the client is stopped otherwise + await acli._stopped.wait() + + +def run(): + # for the setup.cfg entrypoint + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + run() diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 7b62bd813..4bb37e30a 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -1,82 +1,261 @@ -import json -import select -import sys - -from pyroute2.dhcp import ( - BOOTREQUEST, - DHCPACK, - DHCPDISCOVER, - DHCPOFFER, - DHCPREQUEST, -) +import asyncio +from logging import getLogger +from typing import ClassVar, Iterable + +from pyroute2.dhcp import fsm, messages +from pyroute2.dhcp.constants import dhcp from pyroute2.dhcp.dhcp4msg import dhcp4msg -from pyroute2.dhcp.dhcp4socket import DHCP4Socket - - -def req(s, poll, msg, expect): - do_req = True - xid = None - - while True: - # get transaction id - if do_req: - xid = s.put(msg)['xid'] - # wait for response - events = poll.poll(2) - for fd, event in events: - response = s.get() - if response['xid'] != xid: - do_req = False - continue - if response['options']['message_type'] != expect: - raise Exception("DHCP protocol error") - return response - do_req = True - - -def action(ifname): - s = DHCP4Socket(ifname) - poll = select.poll() - poll.register(s, select.POLLIN | select.POLLPRI) - - # DISCOVER - discover = dhcp4msg( - { - 'op': BOOTREQUEST, - 'chaddr': s.l2addr, - 'options': { - 'message_type': DHCPDISCOVER, - 'parameter_list': [1, 3, 6, 12, 15, 28], - }, - } - ) - reply = req(s, poll, discover, expect=DHCPOFFER) - - # REQUEST - request = dhcp4msg( - { - 'op': BOOTREQUEST, - 'chaddr': s.l2addr, - 'options': { - 'message_type': DHCPREQUEST, - 'requested_ip': reply['yiaddr'], - 'server_id': reply['options']['server_id'], - 'parameter_list': [1, 3, 6, 12, 15, 28], - }, - } - ) - reply = req(s, poll, request, expect=DHCPACK) - s.close() - return reply +from pyroute2.dhcp.dhcp4socket import AsyncDHCP4Socket +from pyroute2.dhcp.hooks import Hook +from pyroute2.dhcp.leases import JSONFileLease, Lease +from pyroute2.dhcp.timers import Timers + +LOG = getLogger(__name__) + + +class AsyncDHCPClient: + '''A simple async DHCP client based on pyroute2.''' + + DEFAULT_PARAMETERS: ClassVar[tuple[dhcp.Parameter]] = [ + dhcp.Parameter.SUBNET_MASK, + dhcp.Parameter.ROUTER, + dhcp.Parameter.DOMAIN_NAME_SERVER, + dhcp.Parameter.DOMAIN_NAME, + ] + + def __init__( + self, + interface: str, + lease_type: type[Lease] = JSONFileLease, + hooks: Iterable[Hook] = (), + requested_parameters: Iterable[dhcp.Parameter] = (), + ): + self.interface = interface + self.lease_type = lease_type + self.hooks = hooks + self._sock: AsyncDHCP4Socket = AsyncDHCP4Socket(self.interface) + self._state: fsm.fsm.State | None = None + self._lease: Lease | None = None + self.requested_parameters = list( + requested_parameters + if requested_parameters + else self.DEFAULT_PARAMETERS + ) + self._stopped = asyncio.Event() + self._sendq: asyncio.Queue[dhcp4msg | None] = asyncio.Queue() + self._send_task: asyncio.Task | None = None + self.bound = asyncio.Event() + self.timers = Timers() + + async def _renew(self): + '''Called when the renewal timer, as defined in the lease, expires.''' + assert self.lease, 'cannot renew without an existing lease' + LOG.info('Renewal timer expired') + # TODO: send only to server that gave us the current lease + self.timers._reset_timer('renewal') + await self.transition( + to=fsm.State.RENEWING, + send=messages.request( + requested_ip=self.lease.ip, + server_id=self.lease.server_id, + parameter_list=self.requested_parameters, + ), + ) + + async def _rebind(self): + assert self.lease, 'cannot rebind without an existing lease' + LOG.info('Rebinding timer expired') + self.timers._reset_timer('rebinding') + await self.transition( + to=fsm.State.REBINDING, + send=messages.request( + requested_ip=self.lease.ip, + server_id=self.lease.server_id, + parameter_list=self.requested_parameters, + ), + ) + + async def _expire_lease(self): + LOG.info('Lease expired') + self.timers._reset_timer('expiration') + self.state = fsm.State.INIT + # FIXME: call hooks in a non blocking way (maybe call_soon ?) + for i in self.hooks: + await i.unbound(self.lease) + self._lease = None + await self.bootstrap() + + @property + def lease(self) -> Lease | None: + return self._lease + @lease.setter + def lease(self, value: Lease): + '''Set a fresh lease; only call this when a server grants one.''' + self._lease = value + self.timers.arm( + lease=self.lease, + renewal=self._renew, + rebinding=self._rebind, + expiration=self._expire_lease, + ) + self.lease.dump() + + @property + def state(self) -> fsm.State | None: + return self._state + + @state.setter + def state(self, value: fsm.State | None): + if value and self._state and value not in fsm.TRANSITIONS[self._state]: + raise ValueError( + f'Cannot transition from {self._state} to {value}' + ) + LOG.info('%s -> %s', self.state, value) + self._state = value + + async def _run(self): + packet_to_send: dhcp4msg | None = None + # TODO: refactor, this is even more awkward than select() + while True: + wait_for_stopped = asyncio.Task( + self._stopped.wait(), name='wait until stopped' + ) + wait_for_received_packet = asyncio.Task( + self._sock.get(), name='wait for received packet' + ) + wait_for_packet_to_send = asyncio.Task( + self._sendq.get(), name='wait for packet to send' + ) + tasks = ( + wait_for_stopped, + wait_for_received_packet, + wait_for_packet_to_send, + ) + done, pending = await asyncio.wait( + tasks, + timeout=5, # TODO interval + return_when=asyncio.FIRST_COMPLETED, + ) + for i in pending: + i.cancel() + + if wait_for_packet_to_send in done: + packet_to_send = wait_for_packet_to_send.result() + + if wait_for_received_packet in done: + received_packet = wait_for_received_packet.result() + msg_type = dhcp.MessageType( + received_packet['options']['message_type'] + ) + LOG.info('Received %s', msg_type.name) + if received_packet.get('xid') != self.xid: + LOG.error('Missing or wrong xid, discarding') + else: + handler_name = f'{msg_type.name.lower()}_received' + handler = getattr(self, handler_name, None) + if not handler: + LOG.debug('%r messages are not handled', msg_type.name) + else: + if await handler(received_packet): + packet_to_send = None + + if packet_to_send: + packet_to_send['xid'] = self.xid + LOG.debug( + 'Sending %s', + packet_to_send['options']['message_type'].name, + ) + await self._sock.put(packet_to_send) + + if wait_for_stopped in done and wait_for_stopped.result() is True: + return + + async def transition(self, to: fsm.State, send: dhcp4msg | None = None): + self.state = to + await self._sendq.put(send) + + @fsm.state_guard(fsm.State.INIT, fsm.State.INIT_REBOOT) + async def bootstrap(self): + '''Send a `DISCOVER` or a `REQUEST`, + + depending on whether we're initializing or rebooting. + ''' + match self.state: + case fsm.State.INIT: + # send discover + await self.transition( + to=fsm.State.SELECTING, + send=messages.discover( + parameter_list=self.requested_parameters + ), + ) + case fsm.State.INIT_REBOOT: + assert self.lease, 'cannot init_reboot without a lease' + # send request for lease + await self.transition( + to=fsm.State.REBOOTING, + send=messages.request( + requested_ip=self.lease.ip, + server_id=self.lease.server_id, + parameter_list=self.requested_parameters, + ), + ) + + @fsm.state_guard( + fsm.State.REQUESTING, + fsm.State.REBOOTING, + fsm.State.REBINDING, + fsm.State.RENEWING, + ) + async def ack_received(self, pkt: dhcp4msg): + self.lease = self.lease_type(ack=pkt, interface=self.interface) + LOG.info( + 'Got lease for %s from %s', self.lease.ip, self.lease.server_id + ) + await self.transition(to=fsm.State.BOUND) + self.bound.set() + # FIXME: call hooks in a non blocking way (maybe call_soon ?) + for i in self.hooks: + await i.bound(self.lease) + return True -def run(): - if len(sys.argv) > 1: - ifname = sys.argv[1] - else: - ifname = 'eth0' - print(json.dumps(action(ifname), indent=4)) + @fsm.state_guard(fsm.State.SELECTING) + async def offer_received(self, pkt: dhcp4msg): + await self.transition( + to=fsm.State.REQUESTING, + send=messages.request( + requested_ip=pkt['yiaddr'], + server_id=pkt['options']['server_id'], + parameter_list=self.requested_parameters, + ), + ) + return True + async def __aenter__(self): + self._lease = self.lease_type.load(self.interface) + if self.lease: + # TODO check lease is not expired + self.state = fsm.State.INIT_REBOOT + else: + LOG.debug('No current lease') + self.state = fsm.State.INIT + await self._sock.__aenter__() + self._send_task = asyncio.Task(self._run()) + self.xid = self._sock.xid_pool.alloc() + return self -if __name__ == '__main__': - run() + async def __aexit__(self, *_): + self.timers.cancel() + if self.lease: + await self._sendq.put( + messages.release( + requested_ip=self.lease.ip, server_id=self.lease.server_id + ) + ) + self._stopped.set() + await self._send_task + await self._sock.__aexit__() + self.xid = None + self.state = None + self.bound.clear() diff --git a/pyroute2/dhcp/constants/__init__.py b/pyroute2/dhcp/constants/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyroute2/dhcp/constants/bootp.py b/pyroute2/dhcp/constants/bootp.py new file mode 100644 index 000000000..f8a0c14fd --- /dev/null +++ b/pyroute2/dhcp/constants/bootp.py @@ -0,0 +1,35 @@ +from enum import IntEnum + + +class MessageType(IntEnum): + BOOTREQUEST = 1 # Client to server + BOOTREPLY = 2 # Server to client + + +class HardwareType(IntEnum): + ETHERNET = 1 # Ethernet (10Mb) + EXPERIMENTAL_ETHERNET = 2 + AMATEUR_RADIO = 3 + TOKEN_RING = 4 + FDDI = 8 + ATM = 19 + WIRELESS_IEEE_802_11 = 20 + + +class Flag(IntEnum): # TODO: use enum.Flag + UNICAST = 0x0000 # Unicast response requested + BROADCAST = 0x8000 # Broadcast response requested + + +class Option(IntEnum): + PAD = 0 # Padding (no operation) + SUBNET_MASK = 1 # Subnet mask + ROUTER = 3 # Router address + DNS_SERVER = 6 # Domain name server + HOSTNAME = 12 # Hostname + BOOTFILE_SIZE = 13 # Boot file size + DOMAIN_NAME = 15 # Domain name + IP_ADDRESS_LEASE_TIME = 51 # DHCP lease time + MESSAGE_TYPE = 53 # DHCP message type (extended from BOOTP) + SERVER_IDENTIFIER = 54 # DHCP server identifier + END = 255 # End of options diff --git a/pyroute2/dhcp/constants/dhcp.py b/pyroute2/dhcp/constants/dhcp.py new file mode 100644 index 000000000..e5f8dadad --- /dev/null +++ b/pyroute2/dhcp/constants/dhcp.py @@ -0,0 +1,130 @@ +from enum import IntEnum + + +class MessageType(IntEnum): + DISCOVER = 1 + OFFER = 2 + REQUEST = 3 + DECLINE = 4 + ACK = 5 + NAK = 6 + RELEASE = 7 + INFORM = 8 + + +class Option(IntEnum): + SUBNET_MASK = 1 + TIME_OFFSET = 2 + ROUTER = 3 + DNS_SERVER = 6 + HOST_NAME = 12 + DOMAIN_NAME = 15 + BROADCAST_ADDRESS = 28 + REQUESTED_IP_ADDRESS = 50 + IP_ADDRESS_LEASE_TIME = 51 + DHCP_MESSAGE_TYPE = 53 + SERVER_IDENTIFIER = 54 + PARAMETER_REQUEST_LIST = 55 + MAX_MSG_SIZE = 57 + RENEWAL_TIME = 58 + REBINDING_TIME = 59 + VENDOR_CLASS_IDENTIFIER = 60 + CLIENT_IDENTIFIER = 61 + PADDING = 255 + + +class Parameter(IntEnum): + SUBNET_MASK = 1 # Subnet Mask + TIME_OFFSET = 2 # Time Offset + ROUTER = 3 # Router + TIME_SERVER = 4 # Time Server + NAME_SERVER = 5 # Name Server + DOMAIN_NAME_SERVER = 6 # Domain Name Server (DNS) + LOG_SERVER = 7 # Log Server + COOKIE_SERVER = 8 # Cookie Server + LPR_SERVER = 9 # Line Printer Server + IMPRESS_SERVER = 10 # Impress Server + RESOURCE_LOCATION_SERVER = 11 # Resource Location Server + HOST_NAME = 12 # Host Name + BOOT_FILE_SIZE = 13 # Boot File Size + MERIT_DUMP_FILE = 14 # Merit Dump File + DOMAIN_NAME = 15 # Domain Name + SWAP_SERVER = 16 # Swap Server + ROOT_PATH = 17 # Root Path + EXTENSIONS_PATH = 18 # Extensions Path + IP_FORWARDING = 19 # IP Forwarding Enable/Disable + NON_LOCAL_SOURCE_ROUTING = 20 # Non-Local Source Routing Enable/Disable + POLICY_FILTER = 21 # Policy Filter + MAX_DATAGRAM_REASSEMBLY = 22 # Maximum Datagram Reassembly Size + DEFAULT_TTL = 23 # Default IP Time-to-Live + PATH_MTU_AGING_TIMEOUT = 24 # Path MTU Aging Timeout + PATH_MTU_PLATEAU_TABLE = 25 # Path MTU Plateau Table + INTERFACE_MTU = 26 # Interface MTU + ALL_SUBNETS_LOCAL = 27 # All Subnets Are Local + BROADCAST_ADDRESS = 28 # Broadcast Address + PERFORM_MASK_DISCOVERY = 29 # Perform Mask Discovery + MASK_SUPPLIER = 30 # Mask Supplier + PERFORM_ROUTER_DISCOVERY = 31 # Perform Router Discovery + ROUTER_SOLICITATION_ADDRESS = 32 # Router Solicitation Address + STATIC_ROUTE = 33 # Static Route + TRAILER_ENCAPSULATION = 34 # Trailer Encapsulation + ARP_CACHE_TIMEOUT = 35 # ARP Cache Timeout + ETHERNET_ENCAPSULATION = 36 # Ethernet Encapsulation + TCP_DEFAULT_TTL = 37 # TCP Default TTL + TCP_KEEPALIVE_INTERVAL = 38 # TCP Keepalive Interval + TCP_KEEPALIVE_GARBAGE = 39 # TCP Keepalive Garbage + NIS_DOMAIN = 40 # Network Information Service Domain + NIS_SERVERS = 41 # NIS Servers + NTP_SERVERS = 42 # NTP Servers + VENDOR_SPECIFIC_INFORMATION = 43 # Vendor Specific Information + NETBIOS_NAME_SERVER = 44 # NetBIOS over TCP/IP Name Server + NETBIOS_DDG_SERVER = 45 # NetBIOS Datagram Distribution Server + NETBIOS_NODE_TYPE = 46 # NetBIOS Node Type + NETBIOS_SCOPE = 47 # NetBIOS Scope + X_WINDOW_FONT_SERVER = 48 # X Window System Font Server + X_WINDOW_DISPLAY_MANAGER = 49 # X Window System Display Manager + REQUESTED_IP_ADDRESS = 50 # Requested IP Address + IP_ADDRESS_LEASE_TIME = 51 # IP Address Lease Time + OPTION_OVERLOAD = 52 # Option Overload + DHCP_MESSAGE_TYPE = 53 # DHCP Message Type + SERVER_IDENTIFIER = 54 # Server Identifier + PARAMETER_REQUEST_LIST = 55 # Parameter Request List + MESSAGE = 56 # Message + MAX_DHCP_MESSAGE_SIZE = 57 # Maximum DHCP Message Size + RENEWAL_TIME_VALUE = 58 # Renewal (T1) Time Value + REBINDING_TIME_VALUE = 59 # Rebinding (T2) Time Value + CLASS_IDENTIFIER = 60 # Vendor Class Identifier + CLIENT_IDENTIFIER = 61 # Client Identifier + NETWARE_IP_DOMAIN = 62 # NetWare/IP Domain Name + NETWARE_IP_OPTION = 63 # NetWare/IP Option + NIS_PLUS_DOMAIN = 64 # NIS+ Domain + NIS_PLUS_SERVERS = 65 # NIS+ Servers + TFTP_SERVER_NAME = 66 # TFTP Server Name + BOOTFILE_NAME = 67 # Bootfile Name + MOBILE_IP_HOME_AGENT = 68 # Mobile IP Home Agent + SMTP_SERVER = 69 # Simple Mail Transport Protocol Server + POP3_SERVER = 70 # Post Office Protocol v3 Server + NNTP_SERVER = 71 # Network News Transport Protocol Server + DEFAULT_WWW_SERVER = 72 # Default World Wide Web Server + DEFAULT_FINGER_SERVER = 73 # Default Finger Server + DEFAULT_IRC_SERVER = 74 # Default Internet Relay Chat Server + STREETTALK_SERVER = 75 # StreetTalk Server + STDA_SERVER = 76 # StreetTalk Directory Assistance Server + USER_CLASS_INFORMATION = 77 # User Class Information + SLP_DIRECTORY_AGENT = 78 # SLP Directory Agent + SLP_SERVICE_SCOPE = 79 # SLP Service Scope + RAPID_COMMIT = 80 # Rapid Commit + CLIENT_FQDN = 81 # Fully Qualified Domain Name + RELAY_AGENT_INFORMATION = 82 # Relay Agent Information + INTERNET_STORAGE_NAME_SERVICE = 83 # ISNS + NDS_SERVERS = 85 # Novell Directory Services Servers + NDS_TREE_NAME = 86 # Novell Directory Services Tree Name + NDS_CONTEXT = 87 # Novell Directory Services Context + BCMCS_CONTROLLER_DOMAIN = 88 # BCMCS Controller Domain Name List + AUTHENTICATION = 90 + CLIENT_SYSTEM_ARCHITECTURE_TYPE = 93 + CLIENT_NETWORK_INTERFACE_IDENTIFIER = 94 + CLASSLESS_STATIC_ROUTE = 121 + DOMAIN_SEARCH = 119 + PRIVATE_CLASSIC_ROUTE_MS = 249 + PRIVATE_PROXY_AUTODISCOVERY = 252 diff --git a/pyroute2/dhcp/dhcp4socket.py b/pyroute2/dhcp/dhcp4socket.py index 9b2286b08..82e36e57d 100644 --- a/pyroute2/dhcp/dhcp4socket.py +++ b/pyroute2/dhcp/dhcp4socket.py @@ -4,13 +4,22 @@ ''' +import asyncio +import logging +import socket + from pyroute2.common import AddrPool from pyroute2.dhcp.dhcp4msg import dhcp4msg -from pyroute2.ext.rawsocket import RawSocket +from pyroute2.ext.rawsocket import AsyncRawSocket from pyroute2.protocols import ethmsg, ip4msg, udp4_pseudo_header, udpmsg +LOG = logging.getLogger(__name__) + +UDP_HEADER_SIZE = 8 +IPV4_HEADER_SIZE = 20 -def listen_udp_port(port=68): + +def listen_udp_port(port: int = 68) -> list[list[int]]: # pre-scripted BPF code that matches UDP port bpf_code = [ [40, 0, 0, 12], @@ -28,7 +37,7 @@ def listen_udp_port(port=68): return bpf_code -class DHCP4Socket(RawSocket): +class AsyncDHCP4Socket(AsyncRawSocket): ''' Parameters: @@ -45,31 +54,33 @@ class DHCP4Socket(RawSocket): not provided, DHCP4Socket generates it for outgoing messages. ''' - def __init__(self, ifname, port=68): - RawSocket.__init__(self, ifname, listen_udp_port(port)) + def __init__(self, ifname, port: int = 68): + AsyncRawSocket.__init__(self, ifname, listen_udp_port(port)) self.port = port # Create xid pool # # Every allocated xid will be released automatically after 1024 # alloc() calls, there is no need to call free(). Minimal xid == 16 - self.xid_pool = AddrPool(minaddr=16, release=1024) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def put(self, msg=None, dport=67): + self.xid_pool = AddrPool( + minaddr=16, release=1024 + ) # TODO : maybe it should be in the client and not here ? + self.aio_loop = asyncio.get_running_loop() + + async def put( + self, + msg: dhcp4msg, + eth_dst: str = 'ff:ff:ff:ff:ff:ff', + ip_dst: str = '255.255.255.255', + dport: int = 67, + ) -> dhcp4msg: ''' Put DHCP message. Parameters: * msg -- dhcp4msg instance + * eth_dst -- dest MAC address + * ip_dst -- dest IP address * dport -- DHCP server port - If `msg` is not provided, it is constructed as default - BOOTREQUEST + DHCPDISCOVER. - Examples:: sock.put(dhcp4msg({'op': BOOTREQUEST, @@ -79,53 +90,67 @@ def put(self, msg=None, dport=67): 'requested_ip': '172.16.101.2', 'server_id': '172.16.101.1'}})) - The method returns dhcp4msg that was sent, so one can get from - there `xid` (transaction id) and other details. + The method returns the sent dhcp4msg, so one can get from + there the `xid` (transaction id) and other details. ''' # DHCP layer - dhcp = msg or dhcp4msg({'chaddr': self.l2addr}) + dhcp = msg # dhcp transaction id if dhcp['xid'] is None: dhcp['xid'] = self.xid_pool.alloc() + # auto add src addr + if dhcp['chaddr'] is None: + dhcp['chaddr'] = self.l2addr + data = dhcp.encode().buf + dhcp_payload_size = len(data) # UDP layer udp = udpmsg( - {'sport': self.port, 'dport': dport, 'len': 8 + len(data)} + { + 'sport': self.port, + 'dport': dport, + 'len': UDP_HEADER_SIZE + dhcp_payload_size, + } ) + # Pseudo UDP header, only for checksum purposes udph = udp4_pseudo_header( - {'dst': '255.255.255.255', 'len': 8 + len(data)} + {'dst': ip_dst, 'len': UDP_HEADER_SIZE + dhcp_payload_size} ) udp['csum'] = self.csum(udph.encode().buf + udp.encode().buf + data) udp.reset() # IPv4 layer ip4 = ip4msg( - {'len': 20 + 8 + len(data), 'proto': 17, 'dst': '255.255.255.255'} + { + 'len': IPV4_HEADER_SIZE + UDP_HEADER_SIZE + dhcp_payload_size, + 'proto': socket.IPPROTO_UDP, + 'dst': ip_dst, + } ) ip4['csum'] = self.csum(ip4.encode().buf) ip4.reset() # MAC layer eth = ethmsg( - {'dst': 'ff:ff:ff:ff:ff:ff', 'src': self.l2addr, 'type': 0x800} + {'dst': eth_dst, 'src': self.l2addr, 'type': socket.ETHERTYPE_IP} ) data = eth.encode().buf + ip4.encode().buf + udp.encode().buf + data - self.send(data) + await self.aio_loop.sock_sendall(self, data) dhcp.reset() return dhcp - def get(self): + async def get(self) -> dhcp4msg: ''' Get the next incoming packet from the socket and try to decode it as IPv4 DHCP. No analysis is done here, only MAC/IPv4/UDP headers are stripped out, and the rest is interpreted as DHCP. ''' - (data, addr) = self.recvfrom(4096) + data, _ = await self.aio_loop.sock_recvfrom(self, 4096) eth = ethmsg(buf=data).decode() ip4 = ip4msg(buf=data, offset=eth.offset).decode() udp = udpmsg(buf=data, offset=ip4.offset).decode() diff --git a/pyroute2/dhcp/fsm.py b/pyroute2/dhcp/fsm.py new file mode 100644 index 000000000..80c45aaca --- /dev/null +++ b/pyroute2/dhcp/fsm.py @@ -0,0 +1,62 @@ +'''DHCP client state machine helpers.''' + +from enum import StrEnum, auto +from logging import getLogger +from typing import TYPE_CHECKING, Final + +if TYPE_CHECKING: + from .client import AsyncDHCPClient + + +LOG = getLogger(__name__) + + +class State(StrEnum): + '''DHCP client states. + + see + http://www.tcpipguide.com/free/t_DHCPGeneralOperationandClientFiniteStateMachine.htm + ''' + + INIT = auto() + INIT_REBOOT = auto() + REBOOTING = auto() + REQUESTING = auto() + SELECTING = auto() + BOUND = auto() + RENEWING = auto() + REBINDING = auto() + + +# allowed transitions between states +TRANSITIONS: Final[dict[State, set[State]]] = { + State.INIT_REBOOT: {State.REBOOTING}, + State.REBOOTING: {State.INIT, State.BOUND}, + State.INIT: {State.SELECTING}, + State.SELECTING: {State.REQUESTING, State.INIT}, + State.REQUESTING: {State.BOUND, State.INIT}, + State.BOUND: {State.INIT, State.RENEWING, State.REBINDING}, + State.RENEWING: {State.BOUND, State.INIT, State.REBINDING}, + State.REBINDING: {State.BOUND, State.INIT}, +} + + +def state_guard(*states: State): + '''Decorator that prevents a method from running + + if the associated instance is not in one of the given States.''' + + def decorator(meth): + async def wrapper(self: 'AsyncDHCPClient', *args, **kwargs): + if self.state not in states: + LOG.debug( + 'Ignoring call to %r in %s state', + meth.__name__, + self.state, + ) + return False + return await meth(self, *args, **kwargs) + + return wrapper + + return decorator diff --git a/pyroute2/dhcp/hooks.py b/pyroute2/dhcp/hooks.py new file mode 100644 index 000000000..cd5d50f9d --- /dev/null +++ b/pyroute2/dhcp/hooks.py @@ -0,0 +1,30 @@ +'''Hooks called by the DHCP client when bound, a leases expires, etc.''' + +from logging import getLogger + +from pyroute2.dhcp.leases import Lease + +LOG = getLogger(__name__) + + +class Hook: + '''Base class for pyroute2 dhcp client hooks.''' + + def __init__(self, **settings): + pass + + async def bound(self, lease: Lease): + '''Called when the client gets a lease.''' + pass + + async def unbound(self, lease: Lease): + '''Called when a leases expires.''' + pass + + +class ConfigureIP(Hook): + async def bound(self, lease: Lease): + LOG.info('STUB: add %s to %s', lease.ip, lease.interface) + + async def unbound(self, lease: Lease): + LOG.info('STUB: remove %s from %s', lease.ip, lease.interface) diff --git a/pyroute2/dhcp/leases.py b/pyroute2/dhcp/leases.py new file mode 100644 index 000000000..0ebc9bd31 --- /dev/null +++ b/pyroute2/dhcp/leases.py @@ -0,0 +1,154 @@ +'''Lease classes used by the dhcp client.''' + +import abc +import json +from dataclasses import asdict, dataclass, field +from datetime import datetime +from logging import getLogger +from pathlib import Path +from typing import Self + +from pyroute2.dhcp.dhcp4msg import dhcp4msg + +LOG = getLogger(__name__) + + +def _now() -> float: + '''The current timestamp.''' + return datetime.now().timestamp() + + +@dataclass +class Lease(abc.ABC): + '''Represents a lease obtained through DHCP.''' + + # The DHCP ack sent by the server which allocated this lease + ack: dhcp4msg + # Name of the interface for which this lease was requested + interface: str + # Timestamp of when this lease was obtained + obtained: float = field(default_factory=_now) + + def _seconds_til_timer(self, timer_name: str) -> float | None: + '''The number of seconds to wait until the given timer expires. + + The value is fetched from options as `{timer_name}_time`. + (lease -> lease_time, renewal -> renewal_time, ...) + ''' + try: + delta: int = self.ack['options'][f'{timer_name}_time'] + return self.obtained + delta - _now() + except KeyError: + return None + + @property + def expiration_in(self) -> float | None: + return self._seconds_til_timer('lease') + + @property + def renewal_in(self) -> float | None: + '''The amount of seconds before we have to renew the lease. + + Can be negative if it's past due, + or `None` if the server didn't give a renewal time. + ''' + return self._seconds_til_timer('renewal') + + @property + def rebinding_in(self) -> float | None: + '''The amount of seconds before we have to rebind the lease. + + Can be negative if it's past due, + or `None` if the server didn't give a rebinding time. + ''' + return self._seconds_til_timer('rebinding') + + @property + def ip(self) -> str: + '''The IP address assigned to the client.''' + return self.ack['yiaddr'] + + @property + def subnet_mask(self) -> str: + '''The subnet mask assigned to the client.''' + return self.ack['options']['subnet_mask'] + + @property + def routers(self) -> str: + return self.ack['options']['router'] + + @property + def name_servers(self) -> str: # XXX: list ? + return self.ack['options']['name_server'] + + @property + def server_id(self) -> str: + '''The IP address of the server which allocated this lease.''' + return self.ack['options']['server_id'] + + @abc.abstractmethod + def dump(self) -> None: + '''Write a lease, i.e. to disk or to stdout.''' + pass + + @classmethod + @abc.abstractmethod + def load(cls, interface: str) -> 'Self | None': + '''Load an existing lease for an interface, if it exists.''' + pass + + +class JSONStdoutLease(Lease): + '''Just prints the lease to stdout when the client gets a new one.''' + + def dump(self) -> None: + """Writes the lease as json to stdout.""" + print(json.dumps(asdict(self), indent=2)) + + @classmethod + def load(cls, interface: str) -> None: + '''Does not do anything.''' + return None + + +class JSONFileLease(Lease): + '''Write and load the lease from a JSON file in the working directory.''' + + @classmethod + def _get_lease_dir(cls) -> Path: + '''Where to store the lease file, i.e. the working directory.''' + return Path.cwd() + + @classmethod + def _get_path(cls, interface: str) -> Path: + '''The lease file, named after the interface.''' + return ( + cls._get_lease_dir().joinpath(interface).with_suffix('.lease.json') + ) + + def dump(self) -> None: + '''Dump the lease to a file. + + The lease file is named after the interface + and written in the working directory. + ''' + lease_path = self._get_path(self.interface) + LOG.info('Writing lease for %s to %s', self.interface, lease_path) + with lease_path.open('wt') as lf: + json.dump(asdict(self), lf) + + @classmethod + def load(cls, interface: str) -> 'JSONFileLease | None': + '''Load the lease from a file. + + The lease file is named after the interface + and read from the working directory. + ''' + lease_path = cls._get_path(interface) + try: + with lease_path.open('rt') as lf: + LOG.info('Loading lease for %s from %s', interface, lease_path) + return cls(**json.load(lf)) + except FileNotFoundError: + LOG.info('No existing lease at %s for %s', lease_path, interface) + return None diff --git a/pyroute2/dhcp/messages.py b/pyroute2/dhcp/messages.py new file mode 100644 index 000000000..5994bda25 --- /dev/null +++ b/pyroute2/dhcp/messages.py @@ -0,0 +1,45 @@ +"""Helper functions to build dhcp client messages.""" + +from pyroute2.dhcp.constants import bootp, dhcp +from pyroute2.dhcp.dhcp4msg import dhcp4msg + + +def discover(parameter_list: list[dhcp.Parameter]) -> dhcp4msg: + return dhcp4msg( + { + 'op': bootp.MessageType.BOOTREQUEST, + 'options': { + 'message_type': dhcp.MessageType.DISCOVER, + 'parameter_list': parameter_list, + }, + } + ) + + +def request( + requested_ip: str, server_id: str, parameter_list: list[dhcp.Parameter] +) -> dhcp4msg: + return dhcp4msg( + { + 'op': bootp.MessageType.BOOTREQUEST, + 'options': { + 'message_type': dhcp.MessageType.REQUEST, + 'requested_ip': requested_ip, + 'server_id': server_id, + 'parameter_list': parameter_list, + }, + } + ) + + +def release(requested_ip: str, server_id: str) -> dhcp4msg: + return dhcp4msg( + { + 'op': bootp.MessageType.BOOTREQUEST, + 'options': { + 'message_type': dhcp.MessageType.RELEASE, + 'requested_ip': requested_ip, + 'server_id': server_id, + }, + } + ) diff --git a/pyroute2/dhcp/timers.py b/pyroute2/dhcp/timers.py new file mode 100644 index 000000000..65d62ee85 --- /dev/null +++ b/pyroute2/dhcp/timers.py @@ -0,0 +1,64 @@ +'''Timers to manage lease rebinding, renewal & expiration.''' + +import asyncio +import dataclasses +from logging import getLogger +from typing import Awaitable, Callable + +from pyroute2.dhcp.leases import Lease + +LOG = getLogger(__name__) + + +@dataclasses.dataclass +class Timers: + '''Manage callbacks associated with DHCP leases.''' + + renewal: asyncio.TimerHandle | None = None + rebinding: asyncio.TimerHandle | None = None + expiration: asyncio.TimerHandle | None = None + + def cancel(self): + '''Cancel all current timers.''' + for timer_name in ('renewal', 'rebinding', 'expiration'): + self._reset_timer(timer_name) + + def _reset_timer(self, timer_name: str): + '''Cancel a timer and set it to None.''' + if timer := getattr(self, timer_name): + timer: asyncio.TimerHandle + if not timer.cancelled(): + # FIXME: how do we know a timer wasn't cancelled ? + # this causes spurious logs + LOG.debug('Canceling %s timer', timer_name) + timer.cancel() + setattr(self, timer_name, None) + + def arm(self, lease: Lease, **callbacks: Callable[[], Awaitable[None]]): + '''Reset & arm timers from a `Lease`. + + `callbacks` must be async callables with no arguments + that will be called when the associated timer expires. + ''' + self.cancel() + loop = asyncio.get_running_loop() + + for timer_name, async_callback in callbacks.items(): + self._reset_timer(timer_name) + lease_time = getattr(lease, f'{timer_name}_in') + if not lease_time: + LOG.debug('Lease does not set a %s time', timer_name) + continue + if lease_time < 0.0: + LOG.debug('Lease %s is in the past', timer_name) + continue + LOG.info('Scheduling lease %s in %.2fs', timer_name, lease_time) + '''FIXME: calling async_callback() causes a + "coroutine was never awaited" warning. + But deferring its call in a lambda causes the callback to be + the same for all timers, since we're in a loop. + ''' + timer = loop.call_later( + lease_time, asyncio.create_task, async_callback() + ) + setattr(self, timer_name, timer) diff --git a/pyroute2/ext/rawsocket.py b/pyroute2/ext/rawsocket.py index 902c24147..bf08b9305 100644 --- a/pyroute2/ext/rawsocket.py +++ b/pyroute2/ext/rawsocket.py @@ -11,7 +11,7 @@ ) from socket import AF_PACKET, SOCK_RAW, SOL_SOCKET, errno, error, htons, socket -from pyroute2.iproute.linux import IPRoute +from pyroute2.iproute.linux import AsyncIPRoute ETH_P_ALL = 3 SO_ATTACH_FILTER = 26 @@ -34,14 +34,14 @@ class sock_fprog(Structure): _fields_ = [('len', c_ushort), ('filter', c_void_p)] -def compile_bpf(code): +def compile_bpf(code: list[int]): ProgramType = sock_filter * len(code) program = ProgramType(*[sock_filter(*line) for line in code]) sfp = sock_fprog(len(code), addressof(program[0])) return string_at(addressof(sfp), sizeof(sfp)), program -class RawSocket(socket): +class AsyncRawSocket(socket): ''' This raw socket binds to an interface and optionally installs a BPF filter. @@ -55,28 +55,37 @@ class RawSocket(socket): fprog = None - def __init__(self, ifname, bpf=None): - self.ifname = ifname + async def __aexit__(self, *_): + self.close() + + async def __aenter__(self): # lookup the interface details - with IPRoute() as ip: - for link in ip.get_links(): - if link.get_attr('IFLA_IFNAME') == ifname: + async with AsyncIPRoute() as ip: + async for link in await ip.get_links(): + if link.get_attr('IFLA_IFNAME') == self.ifname: break else: raise IOError(2, 'Link not found') - self.l2addr = link.get_attr('IFLA_ADDRESS') - self.ifindex = link['index'] + self.l2addr: str = link.get_attr('IFLA_ADDRESS') + self.ifindex: int = link['index'] # bring up the socket socket.__init__(self, AF_PACKET, SOCK_RAW, htons(ETH_P_ALL)) + socket.setblocking(self, False) socket.bind(self, (self.ifname, ETH_P_ALL)) - if bpf: + if self.bpf: self.clear_buffer() - fstring, self.fprog = compile_bpf(bpf) + fstring, self.fprog = compile_bpf(self.bpf) socket.setsockopt(self, SOL_SOCKET, SO_ATTACH_FILTER, fstring) else: + # FIXME: should be async self.clear_buffer(remove_total_filter=True) + return self + + def __init__(self, ifname: str, bpf: list[list[int]] | None = None): + self.ifname = ifname + self.bpf = bpf - def clear_buffer(self, remove_total_filter=False): + def clear_buffer(self, remove_total_filter: bool = False): # there is a window of time after the socket has been created and # before bind/attaching a filter where packets can be queued onto the # socket buffer @@ -86,7 +95,6 @@ def clear_buffer(self, remove_total_filter=False): # before setting the desired filter total_fstring, prog = compile_bpf(total_filter) socket.setsockopt(self, SOL_SOCKET, SO_ATTACH_FILTER, total_fstring) - self.setblocking(0) while True: try: self.recvfrom(0) @@ -99,7 +107,6 @@ def clear_buffer(self, remove_total_filter=False): break else: raise - self.setblocking(1) if remove_total_filter: # total_fstring ignored socket.setsockopt( diff --git a/setup.cfg b/setup.cfg index a31c73a36..12fb5a2a8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,5 +40,5 @@ console_scripts = ss2 = pyroute2.netlink.diag.ss2:run [psutil] pyroute2-cli = pyroute2.ndb.cli:run pyroute2-decoder = pyroute2.decoder.main:run - pyroute2-dhcp-client = pyroute2.dhcp.client:run + pyroute2-dhcp-client = pyroute2.dhcp.cli:run pyroute2-test-platform = pyroute2.config.test_platform:run diff --git a/tests/test_linux/conftest.py b/tests/test_linux/conftest.py index df00a9c68..ece1de2e4 100644 --- a/tests/test_linux/conftest.py +++ b/tests/test_linux/conftest.py @@ -2,6 +2,8 @@ from uuid import uuid4 import pytest +from fixtures.dnsmasq import dnsmasq, dnsmasq_options # noqa: F401 +from fixtures.interfaces import dhcp_range, veth_pair # noqa: F401 from pr2test.context_manager import NDBContextManager, SpecContextManager from utils import require_user @@ -10,7 +12,7 @@ from pyroute2.wiset import COUNT config.nlm_generator = True -pytest_plugins = "pytester" +pytest_plugins = 'pytester' @pytest.fixture @@ -62,7 +64,7 @@ def wiset_sock(request): if request.param is None: yield None else: - before_count = COUNT["count"] + before_count = COUNT['count'] with IPSet() as sock: yield sock assert before_count == COUNT['count'] diff --git a/tests/test_linux/fixtures/dnsmasq.py b/tests/test_linux/fixtures/dnsmasq.py new file mode 100644 index 000000000..db4a53146 --- /dev/null +++ b/tests/test_linux/fixtures/dnsmasq.py @@ -0,0 +1,155 @@ +import asyncio +from argparse import ArgumentParser +from dataclasses import dataclass +from ipaddress import IPv4Address +from shutil import which +from typing import AsyncGenerator, ClassVar, Literal + +import pytest +import pytest_asyncio +from fixtures.interfaces import DHCPRangeConfig + + +@dataclass +class DnsmasqOptions: + '''Options for the dnsmasq server.''' + + range_start: IPv4Address + range_end: IPv4Address + interface: str + lease_time: str = '12h' + + def __iter__(self): + opts = ( + f'--interface={self.interface}', + f'--dhcp-range={self.range_start},' + f'{self.range_end},{self.lease_time}', + ) + return iter(opts) + + +class DnsmasqFixture: + '''Runs the dnsmasq server as an async context manager.''' + + DNSMASQ_PATH: ClassVar[str | None] = which('dnsmasq') + + def __init__(self, options: DnsmasqOptions) -> None: + self.options = options + self.stdout: list[bytes] = [] + self.stderr: list[bytes] = [] + self.process: asyncio.subprocess.Process | None = None + self.output_poller: asyncio.Task | None = None + + async def _read_output(self, name: Literal['stdout', 'stderr']): + '''Read stdout or stderr until the process exits.''' + stream = getattr(self.process, name) + output = getattr(self, name) + while line := await stream.readline(): + output.append(line) + + async def _read_outputs(self): + '''Read stdout & stderr until the process exits.''' + assert self.process + await asyncio.gather( + self._read_output('stderr'), self._read_output('stdout') + ) + + def _get_base_cmdline_options(self) -> tuple[str]: + '''The base commandline options for dnsmasq.''' + return ( + '--keep-in-foreground', # self explanatory + '--no-resolv', # don't mess w/ resolv.conf + '--log-facility=-', # log to stdout + '--no-hosts', # don't read /etc/hosts + '--bind-interfaces', # don't bind on wildcard + '--no-ping', # don't ping to check if ips are attributed + ) + + def get_cmdline_options(self) -> tuple[str]: + '''All commandline options passed to dnsmasq.''' + return (*self._get_base_cmdline_options(), *self.options) + + async def __aenter__(self): + '''Start the dnsmasq process and start polling its output.''' + assert self.DNSMASQ_PATH + self.process = await asyncio.create_subprocess_exec( + self.DNSMASQ_PATH, + *self.get_cmdline_options(), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + self.output_poller = asyncio.Task(self._read_outputs()) + return self + + async def __aexit__(self, *_): + if self.process: + if self.process.returncode is None: + self.process.terminate() + await self.process.wait() + await self.output_poller + + +@pytest.fixture +def dnsmasq_options( + veth_pair: tuple[str, str], dhcp_range: DHCPRangeConfig +) -> DnsmasqOptions: + '''dnsmasq options useful for test purposes.''' + return DnsmasqOptions( + range_start=dhcp_range.range_start, + range_end=dhcp_range.range_end, + interface=veth_pair[0], + ) + + +@pytest_asyncio.fixture +async def dnsmasq( + dnsmasq_options: DnsmasqOptions, +) -> AsyncGenerator[DnsmasqFixture, None]: + '''A dnsmasq instance running for the duration of the test.''' + async with DnsmasqFixture(options=dnsmasq_options) as dnsf: + yield dnsf + + +def get_psr() -> ArgumentParser: + psr = ArgumentParser() + psr.add_argument('interface', help='Interface to listen on') + psr.add_argument( + '--range-start', + type=IPv4Address, + default=IPv4Address('192.168.186.10'), + help='Start of the DHCP client range.', + ) + psr.add_argument( + '--range-end', + type=IPv4Address, + default=IPv4Address('192.168.186.100'), + help='End of the DHCP client range.', + ) + psr.add_argument( + '--lease-time', + default='2m', + help='DHCP lease time (minimum 2 minutes according to man)', + ) + return psr + + +async def main(): + '''Commandline entrypoint to start dnsmasq the same way the fixture does. + + Useful for debugging. + ''' + args = get_psr().parse_args() + opts = DnsmasqOptions(**args.__dict__) + read_lines: int = 0 + async with DnsmasqFixture(opts) as dnsm: + # quick & dirty stderr polling + while True: + if len(dnsm.stderr) > read_lines: + read_lines += len(lines := dnsm.stderr[read_lines:]) + print(*(i.decode().strip() for i in lines), sep='\n') + else: + await asyncio.sleep(0.2) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/tests/test_linux/fixtures/interfaces.py b/tests/test_linux/fixtures/interfaces.py new file mode 100644 index 000000000..c276b2cbc --- /dev/null +++ b/tests/test_linux/fixtures/interfaces.py @@ -0,0 +1,68 @@ +import asyncio +import random +from ipaddress import IPv4Address, IPv4Interface +from typing import AsyncGenerator, NamedTuple + +import pytest +import pytest_asyncio + + +class DHCPRangeConfig(NamedTuple): + range_start: IPv4Address + range_end: IPv4Address + gw: IPv4Interface + + +async def ip(*args: str): + '''Call `ip` in a subprocess.''' + proc = await asyncio.create_subprocess_exec('ip', *args) + stdout, stderr = await proc.communicate() + assert proc.returncode == 0, stderr + return stdout + + +@pytest.fixture +def dhcp_range() -> DHCPRangeConfig: + ''' 'An IPv4 DHCP range configuration.''' + rangeidx = random.randint(1, 254) + return DHCPRangeConfig( + range_start=IPv4Address(f'10.{rangeidx}.0.10'), + range_end=IPv4Address(f'10.{rangeidx}.0.20'), + gw=IPv4Interface(f'10.{rangeidx}.0.1/16'), + ) + + +class VethPair(NamedTuple): + '''A pair of veth interfaces.''' + + server: str + client: str + + +@pytest_asyncio.fixture +async def veth_pair( + dhcp_range: DHCPRangeConfig, +) -> AsyncGenerator[VethPair, None]: + '''Fixture that creates a temporary veth pair.''' + # FIXME: use pyroute2 + # TODO: /proc/sys/net/ipv4/conf/{interface}/accept_local ? + idx = random.randint(0, 999) + server_ifname = f'dhcptest{idx}-srv' + client_ifname = f'dhcptest{idx}-cli' + try: + await ip( + 'link', + 'add', + server_ifname, + 'type', + 'veth', + 'peer', + 'name', + client_ifname, + ) + await ip('addr', 'add', str(dhcp_range.gw), 'dev', server_ifname) + await ip('link', 'set', server_ifname, 'up') + await ip('link', 'set', client_ifname, 'up') + yield VethPair(server_ifname, client_ifname) + finally: + await ip('link', 'del', server_ifname) diff --git a/tests/test_linux/test_raw/test_dhcp.py b/tests/test_linux/test_raw/test_dhcp.py index 9bc68f89c..225ee55ad 100644 --- a/tests/test_linux/test_raw/test_dhcp.py +++ b/tests/test_linux/test_raw/test_dhcp.py @@ -1,65 +1,79 @@ -import collections +import asyncio import json -import subprocess +from ipaddress import IPv4Address +from pathlib import Path import pytest +from fixtures.dnsmasq import DnsmasqFixture +from fixtures.interfaces import VethPair from pr2test.marks import require_root -from pyroute2 import NDB -from pyroute2.common import dqn2int, hexdump, hexload -from pyroute2.dhcp import client +from pyroute2.dhcp import client, fsm +from pyroute2.dhcp.constants import bootp, dhcp +from pyroute2.dhcp.leases import JSONFileLease pytestmark = [require_root()] -@pytest.fixture -def ctx(): - ndb = NDB() - index = 0 - ifname = '' - # get a DHCP default route, if exists - with ndb.routes.dump() as dump: - dump.select_records(proto=16, dst='') - for route in dump: - index = route.oif - ifname = ndb.interfaces[index]['ifname'] - break - yield collections.namedtuple('Context', ['ndb', 'index', 'ifname'])( - ndb, index, ifname - ) - ndb.close() +@pytest.mark.asyncio +async def test_get_lease( + dnsmasq: DnsmasqFixture, + veth_pair: VethPair, + tmpdir: str, + monkeypatch: pytest.MonkeyPatch, +): + """The client can get a lease and write it to a file.""" + work_dir = Path(tmpdir) + # Patch JSONFileLease so leases get written to the temp dir + # instead of whatever the working directory is + monkeypatch.setattr(JSONFileLease, "_get_lease_dir", lambda: work_dir) + # boot up the dhcp client and wait for a lease + async with client.AsyncDHCPClient(veth_pair.client) as cli: + await cli.bootstrap() + await asyncio.wait_for(cli.bound.wait(), timeout=5) + assert cli.state == fsm.State.BOUND + lease = cli.lease + assert lease.ack["xid"] == cli.xid -def _do_test_client_module(ctx): - if ctx.index == 0: - pytest.skip('no DHCP interfaces detected') - - response = client.action(ctx.ifname) - options = response['options'] - router = response['options']['router'][0] - prefixlen = dqn2int(response['options']['subnet_mask']) - address = response['yiaddr'] - l2addr = response['chaddr'] - - # convert addresses like 96:0:1:45:fa:6c into 96:00:01:45:fa:6c + # check the obtained lease + assert lease.interface == veth_pair.client + assert lease.ack["op"] == bootp.MessageType.BOOTREPLY + assert lease.ack["options"]["message_type"] == dhcp.MessageType.ACK assert ( - hexdump(hexload(l2addr)) == ctx.ndb.interfaces[ctx.ifname]['address'] + dnsmasq.options.range_start + <= IPv4Address(lease.ip) + <= dnsmasq.options.range_end ) - assert router == ctx.ndb.routes['default']['gateway'] - assert options['lease_time'] > 0 - assert prefixlen > 0 - assert address is not None - return response - + assert lease.ack["chaddr"] + # TODO: check chaddr matches veth_pair.client's MAC -def test_client_module(ctx): - _do_test_client_module(ctx) + # check the lease was written to disk and can be loaded + expected_lease_file = JSONFileLease._get_path(lease.interface) + assert expected_lease_file.is_file() + json_lease = json.loads(expected_lease_file.read_bytes()) + assert isinstance(json_lease, dict) + assert JSONFileLease(**json_lease) == lease -def test_client_console(ctx): - response_from_module = json.loads(json.dumps(_do_test_client_module(ctx))) - client = subprocess.run( - ['pyroute2-dhcp-client', ctx.ifname], stdout=subprocess.PIPE +@pytest.mark.asyncio +async def test_client_console(dnsmasq: DnsmasqFixture, veth_pair: VethPair): + """The commandline client can get a lease, print it to stdout and exit.""" + process = await asyncio.create_subprocess_exec( + 'pyroute2-dhcp-client', + veth_pair.client, + '--lease-type', + 'pyroute2.dhcp.leases.JSONStdoutLease', + '--exit-on-lease', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, _ = await asyncio.wait_for(process.communicate(), timeout=5) + assert process.returncode == 0 + json_lease = json.loads(stdout) + assert json_lease["interface"] == veth_pair.client + assert ( + dnsmasq.options.range_start + <= IPv4Address(json_lease["ack"]["yiaddr"]) + <= dnsmasq.options.range_end ) - response_from_console = json.loads(client.stdout) - assert response_from_module == response_from_console From 353efde57824cef959cb6376788c92f11ff581d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Wed, 15 Jan 2025 09:33:00 +0100 Subject: [PATCH 02/18] add StrEnum and async sock recv compat for python < 3.11 --- pyroute2/compat.py | 14 ++++++++++++++ pyroute2/dhcp/dhcp4socket.py | 2 +- pyroute2/dhcp/fsm.py | 4 +++- 3 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 pyroute2/compat.py diff --git a/pyroute2/compat.py b/pyroute2/compat.py new file mode 100644 index 000000000..6d5b6a0b8 --- /dev/null +++ b/pyroute2/compat.py @@ -0,0 +1,14 @@ +'''Compatibility with older but supported Python versions''' + +try: + from enum import StrEnum +except ImportError: + # StrEnum appeared in python 3.11 + + from enum import Enum + + class StrEnum(str, Enum): + '''Same as enum, but members are also strings.''' + + +__all__ = ('StrEnum',) diff --git a/pyroute2/dhcp/dhcp4socket.py b/pyroute2/dhcp/dhcp4socket.py index 82e36e57d..8209edb02 100644 --- a/pyroute2/dhcp/dhcp4socket.py +++ b/pyroute2/dhcp/dhcp4socket.py @@ -150,7 +150,7 @@ async def get(self) -> dhcp4msg: only MAC/IPv4/UDP headers are stripped out, and the rest is interpreted as DHCP. ''' - data, _ = await self.aio_loop.sock_recvfrom(self, 4096) + data = await self.aio_loop.sock_recv(self, 4096) eth = ethmsg(buf=data).decode() ip4 = ip4msg(buf=data, offset=eth.offset).decode() udp = udpmsg(buf=data, offset=ip4.offset).decode() diff --git a/pyroute2/dhcp/fsm.py b/pyroute2/dhcp/fsm.py index 80c45aaca..74a90875e 100644 --- a/pyroute2/dhcp/fsm.py +++ b/pyroute2/dhcp/fsm.py @@ -1,9 +1,11 @@ '''DHCP client state machine helpers.''' -from enum import StrEnum, auto +from enum import auto from logging import getLogger from typing import TYPE_CHECKING, Final +from pyroute2.compat import StrEnum + if TYPE_CHECKING: from .client import AsyncDHCPClient From 3d009de5615f22a5f8b8d2ef58f0d5a40608db6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Wed, 15 Jan 2025 10:04:12 +0100 Subject: [PATCH 03/18] remove typing.Self to restore compat with older pythons --- pyroute2/dhcp/leases.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyroute2/dhcp/leases.py b/pyroute2/dhcp/leases.py index 0ebc9bd31..031a6f576 100644 --- a/pyroute2/dhcp/leases.py +++ b/pyroute2/dhcp/leases.py @@ -6,7 +6,6 @@ from datetime import datetime from logging import getLogger from pathlib import Path -from typing import Self from pyroute2.dhcp.dhcp4msg import dhcp4msg @@ -93,7 +92,7 @@ def dump(self) -> None: @classmethod @abc.abstractmethod - def load(cls, interface: str) -> 'Self | None': + def load(cls, interface: str) -> 'Lease | None': '''Load an existing lease for an interface, if it exists.''' pass From 3f34271761c39c858e6373e28a5bf992a79e39af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Wed, 15 Jan 2025 15:04:16 +0100 Subject: [PATCH 04/18] dnsmasq fixture: ensure consistent output language (LANG=C) --- tests/test_linux/fixtures/dnsmasq.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_linux/fixtures/dnsmasq.py b/tests/test_linux/fixtures/dnsmasq.py index db4a53146..6e07a8c88 100644 --- a/tests/test_linux/fixtures/dnsmasq.py +++ b/tests/test_linux/fixtures/dnsmasq.py @@ -77,6 +77,7 @@ async def __aenter__(self): *self.get_cmdline_options(), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + env={'LANG': 'C'}, ) self.output_poller = asyncio.Task(self._read_outputs()) return self From e2ce512a27e3e05a35baf2988558668d5b835870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Wed, 15 Jan 2025 15:04:54 +0100 Subject: [PATCH 05/18] dhcp client: indent lease files --- pyroute2/dhcp/leases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyroute2/dhcp/leases.py b/pyroute2/dhcp/leases.py index 031a6f576..c70173a3d 100644 --- a/pyroute2/dhcp/leases.py +++ b/pyroute2/dhcp/leases.py @@ -134,7 +134,7 @@ def dump(self) -> None: lease_path = self._get_path(self.interface) LOG.info('Writing lease for %s to %s', self.interface, lease_path) with lease_path.open('wt') as lf: - json.dump(asdict(self), lf) + json.dump(asdict(self), lf, indent=2) @classmethod def load(cls, interface: str) -> 'JSONFileLease | None': From 14eaf41c7130c2adea4fb426697e6df7dde49190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Wed, 15 Jan 2025 15:09:06 +0100 Subject: [PATCH 06/18] dhcp client: split send/recv into separate tasks --- pyroute2/dhcp/client.py | 97 ++++++++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 41 deletions(-) diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 4bb37e30a..4047be0bf 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -16,12 +16,12 @@ class AsyncDHCPClient: '''A simple async DHCP client based on pyroute2.''' - DEFAULT_PARAMETERS: ClassVar[tuple[dhcp.Parameter]] = [ + DEFAULT_PARAMETERS: ClassVar[tuple[dhcp.Parameter, ...]] = ( dhcp.Parameter.SUBNET_MASK, dhcp.Parameter.ROUTER, dhcp.Parameter.DOMAIN_NAME_SERVER, dhcp.Parameter.DOMAIN_NAME, - ] + ) def __init__( self, @@ -34,7 +34,7 @@ def __init__( self.lease_type = lease_type self.hooks = hooks self._sock: AsyncDHCP4Socket = AsyncDHCP4Socket(self.interface) - self._state: fsm.fsm.State | None = None + self._state: fsm.State | None = None self._lease: Lease | None = None self.requested_parameters = list( requested_parameters @@ -43,7 +43,8 @@ def __init__( ) self._stopped = asyncio.Event() self._sendq: asyncio.Queue[dhcp4msg | None] = asyncio.Queue() - self._send_task: asyncio.Task | None = None + self._sender_task: asyncio.Task | None = None + self._receiver_task: asyncio.Task | None = None self.bound = asyncio.Event() self.timers = Timers() @@ -92,14 +93,15 @@ def lease(self) -> Lease | None: @lease.setter def lease(self, value: Lease): '''Set a fresh lease; only call this when a server grants one.''' + self._lease = value self.timers.arm( - lease=self.lease, + lease=self._lease, renewal=self._renew, rebinding=self._rebind, expiration=self._expire_lease, ) - self.lease.dump() + self._lease.dump() @property def state(self) -> fsm.State | None: @@ -114,34 +116,47 @@ def state(self, value: fsm.State | None): LOG.info('%s -> %s', self.state, value) self._state = value - async def _run(self): - packet_to_send: dhcp4msg | None = None - # TODO: refactor, this is even more awkward than select() - while True: - wait_for_stopped = asyncio.Task( - self._stopped.wait(), name='wait until stopped' - ) - wait_for_received_packet = asyncio.Task( - self._sock.get(), name='wait for received packet' - ) + def _make_wait_stopped_task(self) -> asyncio.Task: + return asyncio.Task(self._stopped.wait(), name='wait until stopped') + + async def _send_forever(self): + packet_to_send = None + wait_til_stopped = self._make_wait_stopped_task() + interval = 5 # TODO make dynamic ? + while not wait_til_stopped.done(): wait_for_packet_to_send = asyncio.Task( self._sendq.get(), name='wait for packet to send' ) - tasks = ( - wait_for_stopped, - wait_for_received_packet, - wait_for_packet_to_send, - ) done, pending = await asyncio.wait( - tasks, - timeout=5, # TODO interval + (wait_til_stopped, wait_for_packet_to_send), return_when=asyncio.FIRST_COMPLETED, + timeout=interval, ) - for i in pending: - i.cancel() - if wait_for_packet_to_send in done: - packet_to_send = wait_for_packet_to_send.result() + if packet_to_send := wait_for_packet_to_send.result(): + packet_to_send['xid'] = self.xid + elif wait_for_packet_to_send in pending: + wait_for_packet_to_send.cancel() + + if packet_to_send: + LOG.debug( + 'Sending %s', + packet_to_send['options']['message_type'].name, + ) + await self._sock.put(packet_to_send) + + async def _recv_forever(self) -> None: + wait_til_stopped = self._make_wait_stopped_task() + + while not wait_til_stopped.done(): + wait_for_received_packet = asyncio.Task( + coro=self._sock.get(), + name=f'wait for DHCP packet on {self.interface}', + ) + done, pending = await asyncio.wait( + (wait_til_stopped, wait_for_received_packet), + return_when=asyncio.FIRST_COMPLETED, + ) if wait_for_received_packet in done: received_packet = wait_for_received_packet.result() @@ -157,19 +172,10 @@ async def _run(self): if not handler: LOG.debug('%r messages are not handled', msg_type.name) else: - if await handler(received_packet): - packet_to_send = None - - if packet_to_send: - packet_to_send['xid'] = self.xid - LOG.debug( - 'Sending %s', - packet_to_send['options']['message_type'].name, - ) - await self._sock.put(packet_to_send) + await handler(received_packet) - if wait_for_stopped in done and wait_for_stopped.result() is True: - return + elif wait_for_received_packet in pending: + wait_for_received_packet.cancel() async def transition(self, to: fsm.State, send: dhcp4msg | None = None): self.state = to @@ -241,7 +247,15 @@ async def __aenter__(self): LOG.debug('No current lease') self.state = fsm.State.INIT await self._sock.__aenter__() - self._send_task = asyncio.Task(self._run()) + + self._receiver_task = asyncio.Task( + self._recv_forever(), + name=f'Listen for incoming DHCP packets on {self.interface}', + ) + self._sender_task = asyncio.Task( + self._send_forever(), + name=f'Send outgoing DHCP packets on {self.interface}', + ) self.xid = self._sock.xid_pool.alloc() return self @@ -254,7 +268,8 @@ async def __aexit__(self, *_): ) ) self._stopped.set() - await self._send_task + await self._sender_task + await self._receiver_task await self._sock.__aexit__() self.xid = None self.state = None From 2176061e5e73217b863a8734d1b4553fdd483730 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Wed, 15 Jan 2025 15:26:15 +0100 Subject: [PATCH 07/18] add ETHERTYPE_IP compat for python < 3.12 --- pyroute2/compat.py | 7 +++++++ pyroute2/dhcp/dhcp4socket.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pyroute2/compat.py b/pyroute2/compat.py index 6d5b6a0b8..d33e6a9d9 100644 --- a/pyroute2/compat.py +++ b/pyroute2/compat.py @@ -11,4 +11,11 @@ class StrEnum(str, Enum): '''Same as enum, but members are also strings.''' +try: + from socket import ETHERTYPE_IP +except ImportError: + # ETHERTYPE_* are new in python 3.12 + ETHERTYPE_IP = 0x800 + + __all__ = ('StrEnum',) diff --git a/pyroute2/dhcp/dhcp4socket.py b/pyroute2/dhcp/dhcp4socket.py index 8209edb02..d0af5cc06 100644 --- a/pyroute2/dhcp/dhcp4socket.py +++ b/pyroute2/dhcp/dhcp4socket.py @@ -9,6 +9,7 @@ import socket from pyroute2.common import AddrPool +from pyroute2.compat import ETHERTYPE_IP from pyroute2.dhcp.dhcp4msg import dhcp4msg from pyroute2.ext.rawsocket import AsyncRawSocket from pyroute2.protocols import ethmsg, ip4msg, udp4_pseudo_header, udpmsg @@ -135,7 +136,7 @@ async def put( # MAC layer eth = ethmsg( - {'dst': eth_dst, 'src': self.l2addr, 'type': socket.ETHERTYPE_IP} + {'dst': eth_dst, 'src': self.l2addr, 'type': ETHERTYPE_IP} ) data = eth.encode().buf + ip4.encode().buf + udp.encode().buf + data From e1d2a4f46d3d3973aa8e50cd9f12ca3672fb777b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Wed, 15 Jan 2025 15:42:31 +0100 Subject: [PATCH 08/18] typing fixes and compat --- pyroute2/dhcp/timers.py | 1 - pyroute2/ext/rawsocket.py | 2 +- tests/test_linux/fixtures/dnsmasq.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pyroute2/dhcp/timers.py b/pyroute2/dhcp/timers.py index 65d62ee85..baab58496 100644 --- a/pyroute2/dhcp/timers.py +++ b/pyroute2/dhcp/timers.py @@ -26,7 +26,6 @@ def cancel(self): def _reset_timer(self, timer_name: str): '''Cancel a timer and set it to None.''' if timer := getattr(self, timer_name): - timer: asyncio.TimerHandle if not timer.cancelled(): # FIXME: how do we know a timer wasn't cancelled ? # this causes spurious logs diff --git a/pyroute2/ext/rawsocket.py b/pyroute2/ext/rawsocket.py index bf08b9305..e1f730e0e 100644 --- a/pyroute2/ext/rawsocket.py +++ b/pyroute2/ext/rawsocket.py @@ -34,7 +34,7 @@ class sock_fprog(Structure): _fields_ = [('len', c_ushort), ('filter', c_void_p)] -def compile_bpf(code: list[int]): +def compile_bpf(code: list[list[int]]): ProgramType = sock_filter * len(code) program = ProgramType(*[sock_filter(*line) for line in code]) sfp = sock_fprog(len(code), addressof(program[0])) diff --git a/tests/test_linux/fixtures/dnsmasq.py b/tests/test_linux/fixtures/dnsmasq.py index 6e07a8c88..888b59676 100644 --- a/tests/test_linux/fixtures/dnsmasq.py +++ b/tests/test_linux/fixtures/dnsmasq.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from ipaddress import IPv4Address from shutil import which -from typing import AsyncGenerator, ClassVar, Literal +from typing import AsyncGenerator, ClassVar, Literal, Optional import pytest import pytest_asyncio @@ -31,7 +31,7 @@ def __iter__(self): class DnsmasqFixture: '''Runs the dnsmasq server as an async context manager.''' - DNSMASQ_PATH: ClassVar[str | None] = which('dnsmasq') + DNSMASQ_PATH: ClassVar[Optional[str]] = which('dnsmasq') def __init__(self, options: DnsmasqOptions) -> None: self.options = options From 52937a2704731cba381b79ed21b082c686a8cf0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Wed, 15 Jan 2025 16:03:36 +0100 Subject: [PATCH 09/18] dhcp client: replace match/case with if/else --- pyroute2/dhcp/client.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 4047be0bf..6fc23127d 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -187,26 +187,26 @@ async def bootstrap(self): depending on whether we're initializing or rebooting. ''' - match self.state: - case fsm.State.INIT: - # send discover - await self.transition( - to=fsm.State.SELECTING, - send=messages.discover( - parameter_list=self.requested_parameters - ), - ) - case fsm.State.INIT_REBOOT: - assert self.lease, 'cannot init_reboot without a lease' - # send request for lease - await self.transition( - to=fsm.State.REBOOTING, - send=messages.request( - requested_ip=self.lease.ip, - server_id=self.lease.server_id, - parameter_list=self.requested_parameters, - ), - ) + if self.state is fsm.State.INIT: + # send discover + await self.transition( + to=fsm.State.SELECTING, + send=messages.discover( + parameter_list=self.requested_parameters + ), + ) + elif self.state is fsm.State.INIT_REBOOT: + assert self.lease, 'cannot init_reboot without a lease' + # send request for lease + await self.transition( + to=fsm.State.REBOOTING, + send=messages.request( + requested_ip=self.lease.ip, + server_id=self.lease.server_id, + parameter_list=self.requested_parameters, + ), + ) + # the decorator prevents the needs for an else @fsm.state_guard( fsm.State.REQUESTING, From b6b752a7c862da35c966943eef35d70adf2949d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Wed, 15 Jan 2025 16:16:08 +0100 Subject: [PATCH 10/18] dhcp client: use Optional everywhere for py39 compat --- pyroute2/dhcp/client.py | 20 ++++++++++---------- pyroute2/dhcp/leases.py | 13 +++++++------ pyroute2/dhcp/timers.py | 8 ++++---- pyroute2/ext/rawsocket.py | 3 ++- tests/test_linux/fixtures/dnsmasq.py | 4 ++-- 5 files changed, 25 insertions(+), 23 deletions(-) diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 6fc23127d..235b77ce5 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -1,6 +1,6 @@ import asyncio from logging import getLogger -from typing import ClassVar, Iterable +from typing import ClassVar, Iterable, Optional from pyroute2.dhcp import fsm, messages from pyroute2.dhcp.constants import dhcp @@ -34,17 +34,17 @@ def __init__( self.lease_type = lease_type self.hooks = hooks self._sock: AsyncDHCP4Socket = AsyncDHCP4Socket(self.interface) - self._state: fsm.State | None = None - self._lease: Lease | None = None + self._state: Optional[fsm.State] = None + self._lease: Optional[Lease] = None self.requested_parameters = list( requested_parameters if requested_parameters else self.DEFAULT_PARAMETERS ) self._stopped = asyncio.Event() - self._sendq: asyncio.Queue[dhcp4msg | None] = asyncio.Queue() - self._sender_task: asyncio.Task | None = None - self._receiver_task: asyncio.Task | None = None + self._sendq: asyncio.Queue[Optional[dhcp4msg]] = asyncio.Queue() + self._sender_task: Optional[asyncio.Task] = None + self._receiver_task: Optional[asyncio.Task] = None self.bound = asyncio.Event() self.timers = Timers() @@ -87,7 +87,7 @@ async def _expire_lease(self): await self.bootstrap() @property - def lease(self) -> Lease | None: + def lease(self) -> Optional[Lease]: return self._lease @lease.setter @@ -104,11 +104,11 @@ def lease(self, value: Lease): self._lease.dump() @property - def state(self) -> fsm.State | None: + def state(self) -> Optional[fsm.State]: return self._state @state.setter - def state(self, value: fsm.State | None): + def state(self, value: Optional[fsm.State]): if value and self._state and value not in fsm.TRANSITIONS[self._state]: raise ValueError( f'Cannot transition from {self._state} to {value}' @@ -177,7 +177,7 @@ async def _recv_forever(self) -> None: elif wait_for_received_packet in pending: wait_for_received_packet.cancel() - async def transition(self, to: fsm.State, send: dhcp4msg | None = None): + async def transition(self, to: fsm.State, send: Optional[dhcp4msg] = None): self.state = to await self._sendq.put(send) diff --git a/pyroute2/dhcp/leases.py b/pyroute2/dhcp/leases.py index c70173a3d..d9cd322eb 100644 --- a/pyroute2/dhcp/leases.py +++ b/pyroute2/dhcp/leases.py @@ -6,6 +6,7 @@ from datetime import datetime from logging import getLogger from pathlib import Path +from typing import Optional from pyroute2.dhcp.dhcp4msg import dhcp4msg @@ -28,7 +29,7 @@ class Lease(abc.ABC): # Timestamp of when this lease was obtained obtained: float = field(default_factory=_now) - def _seconds_til_timer(self, timer_name: str) -> float | None: + def _seconds_til_timer(self, timer_name: str) -> Optional[float]: '''The number of seconds to wait until the given timer expires. The value is fetched from options as `{timer_name}_time`. @@ -41,11 +42,11 @@ def _seconds_til_timer(self, timer_name: str) -> float | None: return None @property - def expiration_in(self) -> float | None: + def expiration_in(self) -> Optional[float]: return self._seconds_til_timer('lease') @property - def renewal_in(self) -> float | None: + def renewal_in(self) -> Optional[float]: '''The amount of seconds before we have to renew the lease. Can be negative if it's past due, @@ -54,7 +55,7 @@ def renewal_in(self) -> float | None: return self._seconds_til_timer('renewal') @property - def rebinding_in(self) -> float | None: + def rebinding_in(self) -> Optional[float]: '''The amount of seconds before we have to rebind the lease. Can be negative if it's past due, @@ -92,7 +93,7 @@ def dump(self) -> None: @classmethod @abc.abstractmethod - def load(cls, interface: str) -> 'Lease | None': + def load(cls, interface: str) -> 'Optional[Lease]': '''Load an existing lease for an interface, if it exists.''' pass @@ -137,7 +138,7 @@ def dump(self) -> None: json.dump(asdict(self), lf, indent=2) @classmethod - def load(cls, interface: str) -> 'JSONFileLease | None': + def load(cls, interface: str) -> 'Optional[JSONFileLease]': '''Load the lease from a file. The lease file is named after the interface diff --git a/pyroute2/dhcp/timers.py b/pyroute2/dhcp/timers.py index baab58496..e7142ba1a 100644 --- a/pyroute2/dhcp/timers.py +++ b/pyroute2/dhcp/timers.py @@ -3,7 +3,7 @@ import asyncio import dataclasses from logging import getLogger -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Optional from pyroute2.dhcp.leases import Lease @@ -14,9 +14,9 @@ class Timers: '''Manage callbacks associated with DHCP leases.''' - renewal: asyncio.TimerHandle | None = None - rebinding: asyncio.TimerHandle | None = None - expiration: asyncio.TimerHandle | None = None + renewal: Optional[asyncio.TimerHandle] = None + rebinding: Optional[asyncio.TimerHandle] = None + expiration: Optional[asyncio.TimerHandle] = None def cancel(self): '''Cancel all current timers.''' diff --git a/pyroute2/ext/rawsocket.py b/pyroute2/ext/rawsocket.py index e1f730e0e..cad4489cf 100644 --- a/pyroute2/ext/rawsocket.py +++ b/pyroute2/ext/rawsocket.py @@ -10,6 +10,7 @@ string_at, ) from socket import AF_PACKET, SOCK_RAW, SOL_SOCKET, errno, error, htons, socket +from typing import Optional from pyroute2.iproute.linux import AsyncIPRoute @@ -81,7 +82,7 @@ async def __aenter__(self): self.clear_buffer(remove_total_filter=True) return self - def __init__(self, ifname: str, bpf: list[list[int]] | None = None): + def __init__(self, ifname: str, bpf: Optional[list[list[int]]] = None): self.ifname = ifname self.bpf = bpf diff --git a/tests/test_linux/fixtures/dnsmasq.py b/tests/test_linux/fixtures/dnsmasq.py index 888b59676..af695bd6b 100644 --- a/tests/test_linux/fixtures/dnsmasq.py +++ b/tests/test_linux/fixtures/dnsmasq.py @@ -37,8 +37,8 @@ def __init__(self, options: DnsmasqOptions) -> None: self.options = options self.stdout: list[bytes] = [] self.stderr: list[bytes] = [] - self.process: asyncio.subprocess.Process | None = None - self.output_poller: asyncio.Task | None = None + self.process: Optional[asyncio.subprocess.Process] = None + self.output_poller: Optional[asyncio.Task] = None async def _read_output(self, name: Literal['stdout', 'stderr']): '''Read stdout or stderr until the process exits.''' From e0566b0540de89638b173d5a602949c3872f7f5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Thu, 16 Jan 2025 17:28:50 +0100 Subject: [PATCH 11/18] improve dnsmasq fixture & handling of timeouts in dhcp client tests --- tests/test_linux/fixtures/dnsmasq.py | 8 +++---- tests/test_linux/test_raw/test_dhcp.py | 33 ++++++++++++++++---------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/test_linux/fixtures/dnsmasq.py b/tests/test_linux/fixtures/dnsmasq.py index af695bd6b..e453d2ad8 100644 --- a/tests/test_linux/fixtures/dnsmasq.py +++ b/tests/test_linux/fixtures/dnsmasq.py @@ -35,8 +35,8 @@ class DnsmasqFixture: def __init__(self, options: DnsmasqOptions) -> None: self.options = options - self.stdout: list[bytes] = [] - self.stderr: list[bytes] = [] + self.stdout: list[str] = [] + self.stderr: list[str] = [] self.process: Optional[asyncio.subprocess.Process] = None self.output_poller: Optional[asyncio.Task] = None @@ -45,7 +45,7 @@ async def _read_output(self, name: Literal['stdout', 'stderr']): stream = getattr(self.process, name) output = getattr(self, name) while line := await stream.readline(): - output.append(line) + output.append(line.decode().strip()) async def _read_outputs(self): '''Read stdout & stderr until the process exits.''' @@ -147,7 +147,7 @@ async def main(): while True: if len(dnsm.stderr) > read_lines: read_lines += len(lines := dnsm.stderr[read_lines:]) - print(*(i.decode().strip() for i in lines), sep='\n') + print(*lines, sep='\n') else: await asyncio.sleep(0.2) diff --git a/tests/test_linux/test_raw/test_dhcp.py b/tests/test_linux/test_raw/test_dhcp.py index 225ee55ad..129877153 100644 --- a/tests/test_linux/test_raw/test_dhcp.py +++ b/tests/test_linux/test_raw/test_dhcp.py @@ -22,30 +22,35 @@ async def test_get_lease( tmpdir: str, monkeypatch: pytest.MonkeyPatch, ): - """The client can get a lease and write it to a file.""" + '''The client can get a lease and write it to a file.''' work_dir = Path(tmpdir) # Patch JSONFileLease so leases get written to the temp dir # instead of whatever the working directory is - monkeypatch.setattr(JSONFileLease, "_get_lease_dir", lambda: work_dir) + monkeypatch.setattr(JSONFileLease, '_get_lease_dir', lambda: work_dir) # boot up the dhcp client and wait for a lease async with client.AsyncDHCPClient(veth_pair.client) as cli: await cli.bootstrap() - await asyncio.wait_for(cli.bound.wait(), timeout=5) + try: + await asyncio.wait_for(cli.bound.wait(), timeout=5) + except TimeoutError: + raise AssertionError( + f'Timed out. dnsmasq output: {dnsmasq.stderr}' + ) assert cli.state == fsm.State.BOUND lease = cli.lease - assert lease.ack["xid"] == cli.xid + assert lease.ack['xid'] == cli.xid # check the obtained lease assert lease.interface == veth_pair.client - assert lease.ack["op"] == bootp.MessageType.BOOTREPLY - assert lease.ack["options"]["message_type"] == dhcp.MessageType.ACK + assert lease.ack['op'] == bootp.MessageType.BOOTREPLY + assert lease.ack['options']['message_type'] == dhcp.MessageType.ACK assert ( dnsmasq.options.range_start <= IPv4Address(lease.ip) <= dnsmasq.options.range_end ) - assert lease.ack["chaddr"] + assert lease.ack['chaddr'] # TODO: check chaddr matches veth_pair.client's MAC # check the lease was written to disk and can be loaded @@ -58,7 +63,7 @@ async def test_get_lease( @pytest.mark.asyncio async def test_client_console(dnsmasq: DnsmasqFixture, veth_pair: VethPair): - """The commandline client can get a lease, print it to stdout and exit.""" + '''The commandline client can get a lease, print it to stdout and exit.''' process = await asyncio.create_subprocess_exec( 'pyroute2-dhcp-client', veth_pair.client, @@ -66,14 +71,18 @@ async def test_client_console(dnsmasq: DnsmasqFixture, veth_pair: VethPair): 'pyroute2.dhcp.leases.JSONStdoutLease', '--exit-on-lease', stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + # stderr=asyncio.subprocess.PIPE, ) - stdout, _ = await asyncio.wait_for(process.communicate(), timeout=5) + try: + stdout, _ = await asyncio.wait_for(process.communicate(), timeout=5) + except TimeoutError: + raise AssertionError(f'Timed out. dnsmasq output: {dnsmasq.stderr}') assert process.returncode == 0 + assert stdout json_lease = json.loads(stdout) - assert json_lease["interface"] == veth_pair.client + assert json_lease['interface'] == veth_pair.client assert ( dnsmasq.options.range_start - <= IPv4Address(json_lease["ack"]["yiaddr"]) + <= IPv4Address(json_lease['ack']['yiaddr']) <= dnsmasq.options.range_end ) From 2a45a26faf4347194a84ff41cf905c369b444362 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Thu, 16 Jan 2025 18:28:12 +0100 Subject: [PATCH 12/18] dhcp client: add basic NAK support --- pyroute2/dhcp/client.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 235b77ce5..3b4c84fc8 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -224,7 +224,19 @@ async def ack_received(self, pkt: dhcp4msg): # FIXME: call hooks in a non blocking way (maybe call_soon ?) for i in self.hooks: await i.bound(self.lease) - return True + + @fsm.state_guard( + fsm.State.REQUESTING, + fsm.State.REBOOTING, + fsm.State.RENEWING, + fsm.State.REBINDING, + ) + async def nak_received(self, pkt: dhcp4msg): + await self.transition(to=fsm.State.INIT) + # Reset lease & timers and start again + self._lease = None + self.timers.cancel() + await self.bootstrap() @fsm.state_guard(fsm.State.SELECTING) async def offer_received(self, pkt: dhcp4msg): @@ -236,7 +248,6 @@ async def offer_received(self, pkt: dhcp4msg): parameter_list=self.requested_parameters, ), ) - return True async def __aenter__(self): self._lease = self.lease_type.load(self.interface) From ba3d62a262539e438e9663c4b890efd7ad339d98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Thu, 16 Jan 2025 19:40:56 +0100 Subject: [PATCH 13/18] dhcp client: handle expired leases on loading --- pyroute2/dhcp/client.py | 10 +++++++--- pyroute2/dhcp/leases.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 3b4c84fc8..5512868be 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -250,9 +250,13 @@ async def offer_received(self, pkt: dhcp4msg): ) async def __aenter__(self): - self._lease = self.lease_type.load(self.interface) - if self.lease: + loaded_lease = self.lease_type.load(self.interface) + if loaded_lease.expired: + LOG.info("Discarding stale lease") + loaded_lease = None + if loaded_lease: # TODO check lease is not expired + self._lease = loaded_lease self.state = fsm.State.INIT_REBOOT else: LOG.debug('No current lease') @@ -272,7 +276,7 @@ async def __aenter__(self): async def __aexit__(self, *_): self.timers.cancel() - if self.lease: + if self.lease and not self.lease.expired: await self._sendq.put( messages.release( requested_ip=self.lease.ip, server_id=self.lease.server_id diff --git a/pyroute2/dhcp/leases.py b/pyroute2/dhcp/leases.py index d9cd322eb..9ce51d87e 100644 --- a/pyroute2/dhcp/leases.py +++ b/pyroute2/dhcp/leases.py @@ -41,6 +41,17 @@ def _seconds_til_timer(self, timer_name: str) -> Optional[float]: except KeyError: return None + @property + def expired(self) -> bool: + '''Whether this lease has expired (its expiration is in the past). + + When loading a persisted lease, this won't be correct if the clock + has been adjusted since the lease was written. + However the worst case scenario is that we send a REQUEST for it, + get a NAK and restart from scratch. + ''' + return self.expiration_in < 0 + @property def expiration_in(self) -> Optional[float]: return self._seconds_til_timer('lease') From e941c980a8b733cb98ead3e183c14f8ba695f623 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Thu, 16 Jan 2025 23:04:12 +0100 Subject: [PATCH 14/18] dhcp client: fix coroutine warnings in timer callbacks --- pyroute2/dhcp/timers.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pyroute2/dhcp/timers.py b/pyroute2/dhcp/timers.py index e7142ba1a..3925f18cc 100644 --- a/pyroute2/dhcp/timers.py +++ b/pyroute2/dhcp/timers.py @@ -52,12 +52,15 @@ def arm(self, lease: Lease, **callbacks: Callable[[], Awaitable[None]]): LOG.debug('Lease %s is in the past', timer_name) continue LOG.info('Scheduling lease %s in %.2fs', timer_name, lease_time) - '''FIXME: calling async_callback() causes a - "coroutine was never awaited" warning. - But deferring its call in a lambda causes the callback to be - the same for all timers, since we're in a loop. - ''' + # Since call_later doesn't support async callbacks, we wrap the + # callback in a lambda that will schedule it when it's time timer = loop.call_later( - lease_time, asyncio.create_task, async_callback() + lease_time, + # since lambdas are evaluated when they're run, we have to + # bind variables as argument defaults or they'll have the + # value from the last loop iteration + lambda cb=async_callback, name=lease_time: asyncio.create_task( + cb(), name=f"{name} timer callback" + ), ) setattr(self, timer_name, timer) From 091056d02bc1e5981555105f8b02828dc52a36c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Thu, 16 Jan 2025 23:05:12 +0100 Subject: [PATCH 15/18] dhcp client: fix logic error when loading lease --- pyroute2/dhcp/client.py | 2 +- pyroute2/dhcp/leases.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 5512868be..bef277fde 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -251,7 +251,7 @@ async def offer_received(self, pkt: dhcp4msg): async def __aenter__(self): loaded_lease = self.lease_type.load(self.interface) - if loaded_lease.expired: + if loaded_lease and loaded_lease.expired: LOG.info("Discarding stale lease") loaded_lease = None if loaded_lease: diff --git a/pyroute2/dhcp/leases.py b/pyroute2/dhcp/leases.py index 9ce51d87e..28c77bc06 100644 --- a/pyroute2/dhcp/leases.py +++ b/pyroute2/dhcp/leases.py @@ -105,7 +105,11 @@ def dump(self) -> None: @classmethod @abc.abstractmethod def load(cls, interface: str) -> 'Optional[Lease]': - '''Load an existing lease for an interface, if it exists.''' + '''Load an existing lease for an interface, if it exists. + + The lease is not checked for freshness, and will be None if no lease + could be loaded. + ''' pass From 1e195c4d6519052699d07bf4de2cba842d4120a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Fri, 17 Jan 2025 08:53:40 +0100 Subject: [PATCH 16/18] dhcp client: add comments, improve waiting for state --- pyroute2/dhcp/cli.py | 5 +- pyroute2/dhcp/client.py | 213 ++++++++++++++++--------- pyroute2/dhcp/leases.py | 2 +- pyroute2/ext/rawsocket.py | 4 +- tests/test_linux/test_raw/test_dhcp.py | 4 +- 5 files changed, 150 insertions(+), 78 deletions(-) diff --git a/pyroute2/dhcp/cli.py b/pyroute2/dhcp/cli.py index 86d3f73aa..7241dc9d3 100644 --- a/pyroute2/dhcp/cli.py +++ b/pyroute2/dhcp/cli.py @@ -5,6 +5,7 @@ from typing import Any from pyroute2.dhcp.client import AsyncDHCPClient +from pyroute2.dhcp.fsm import State from pyroute2.dhcp.hooks import ConfigureIP, Hook from pyroute2.dhcp.leases import Lease @@ -87,10 +88,10 @@ async def main(): await acli.bootstrap() if args.exit_on_lease: # Wait until we're bound once, then exit - await acli.bound.wait() + await acli.wait_for_state(State.BOUND) else: # Wait until the client is stopped otherwise - await acli._stopped.wait() + await acli.wait_for_state(None) def run(): diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index bef277fde..4e56f530f 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -1,6 +1,6 @@ import asyncio from logging import getLogger -from typing import ClassVar, Iterable, Optional +from typing import ClassVar, DefaultDict, Iterable, Optional from pyroute2.dhcp import fsm, messages from pyroute2.dhcp.constants import dhcp @@ -33,27 +33,116 @@ def __init__( self.interface = interface self.lease_type = lease_type self.hooks = hooks - self._sock: AsyncDHCP4Socket = AsyncDHCP4Socket(self.interface) - self._state: Optional[fsm.State] = None - self._lease: Optional[Lease] = None self.requested_parameters = list( requested_parameters if requested_parameters else self.DEFAULT_PARAMETERS ) - self._stopped = asyncio.Event() + # The raw socket used to send and receive packets + self._sock: AsyncDHCP4Socket = AsyncDHCP4Socket(self.interface) + # Current client state + self._state: Optional[fsm.State] = None + # Current lease, read from persistent storage or sent by a server + self._lease: Optional[Lease] = None + # dhcp messages put in this queue are sent by _send_forever self._sendq: asyncio.Queue[Optional[dhcp4msg]] = asyncio.Queue() + # Handle to run _send_forever for the context manager's lifetime self._sender_task: Optional[asyncio.Task] = None + # Handle to run _recv_forever for the context manager's lifetime self._receiver_task: Optional[asyncio.Task] = None - self.bound = asyncio.Event() + # Timers to run callbacks on lease timeouts expiration self.timers = Timers() + # Allows to easily track the state when running the client from python + self._states: DefaultDict[Optional[fsm.State], asyncio.Event] = ( + DefaultDict(asyncio.Event) + ) + + # "public api" + + async def wait_for_state(self, state: Optional[fsm.State]) -> None: + '''Waits until the client is in the target state. + + Since the state is set to None upon exit, + you can also pass None to wait for the client to stop. + ''' + await self._states[state].wait() + + @fsm.state_guard(fsm.State.INIT, fsm.State.INIT_REBOOT) + async def bootstrap(self): + '''Send a `DISCOVER` or a `REQUEST`, + + depending on whether we're initializing or rebooting. + + Use this to get a lease when running the client from Python code. + ''' + if self.state is fsm.State.INIT: + # send discover + await self.transition( + to=fsm.State.SELECTING, + send=messages.discover( + parameter_list=self.requested_parameters + ), + ) + elif self.state is fsm.State.INIT_REBOOT: + assert self.lease, 'cannot init_reboot without a lease' + # send request for lease + await self.transition( + to=fsm.State.REBOOTING, + send=messages.request( + requested_ip=self.lease.ip, + server_id=self.lease.server_id, + parameter_list=self.requested_parameters, + ), + ) + # the decorator prevents the needs for an else + + # properties + + @property + def lease(self) -> Optional[Lease]: + """The current lease, if we have one.""" + return self._lease + + @lease.setter + def lease(self, value: Lease): + '''Set a fresh lease; only call this when a server grants one.''' + + self._lease = value + self.timers.arm( + lease=self._lease, + renewal=self._renew, + rebinding=self._rebind, + expiration=self._expire_lease, + ) + self._lease.dump() + + @property + def state(self) -> Optional[fsm.State]: + """The current client state.""" + return self._state + + @state.setter + def state(self, value: Optional[fsm.State]): + """Check the client can transition to the state, and set it.""" + old_state = self.state + if value and old_state and value not in fsm.TRANSITIONS[old_state]: + raise ValueError( + f'Cannot transition from {self._state} to {value}' + ) + LOG.info('%s -> %s', old_state, value) + if old_state in self._states: + self._states[old_state].clear() + self._state = value + self._states[value].set() + + # Timer callbacks async def _renew(self): - '''Called when the renewal timer, as defined in the lease, expires.''' + '''Called when the renewal time defined in the lease expires.''' assert self.lease, 'cannot renew without an existing lease' LOG.info('Renewal timer expired') # TODO: send only to server that gave us the current lease - self.timers._reset_timer('renewal') + self.timers._reset_timer('renewal') # FIXME should be automatic await self.transition( to=fsm.State.RENEWING, send=messages.request( @@ -64,6 +153,7 @@ async def _renew(self): ) async def _rebind(self): + ''' 'Called when the rebinding time defined in the lease expires.''' assert self.lease, 'cannot rebind without an existing lease' LOG.info('Rebinding timer expired') self.timers._reset_timer('rebinding') @@ -77,6 +167,7 @@ async def _rebind(self): ) async def _expire_lease(self): + ''' 'Called when the expiration time defined in the lease expires.''' LOG.info('Lease expired') self.timers._reset_timer('expiration') self.state = fsm.State.INIT @@ -86,43 +177,13 @@ async def _expire_lease(self): self._lease = None await self.bootstrap() - @property - def lease(self) -> Optional[Lease]: - return self._lease - - @lease.setter - def lease(self, value: Lease): - '''Set a fresh lease; only call this when a server grants one.''' - - self._lease = value - self.timers.arm( - lease=self._lease, - renewal=self._renew, - rebinding=self._rebind, - expiration=self._expire_lease, - ) - self._lease.dump() - - @property - def state(self) -> Optional[fsm.State]: - return self._state - - @state.setter - def state(self, value: Optional[fsm.State]): - if value and self._state and value not in fsm.TRANSITIONS[self._state]: - raise ValueError( - f'Cannot transition from {self._state} to {value}' - ) - LOG.info('%s -> %s', self.state, value) - self._state = value - - def _make_wait_stopped_task(self) -> asyncio.Task: - return asyncio.Task(self._stopped.wait(), name='wait until stopped') + # DHCP packet sending & receving coroutines async def _send_forever(self): + """Send packets from _sendq until the client stops.""" packet_to_send = None - wait_til_stopped = self._make_wait_stopped_task() - interval = 5 # TODO make dynamic ? + wait_til_stopped = asyncio.Task(self.wait_for_state(None)) + interval = 5 # TODO make dynamic while not wait_til_stopped.done(): wait_for_packet_to_send = asyncio.Task( self._sendq.get(), name='wait for packet to send' @@ -146,7 +207,13 @@ async def _send_forever(self): await self._sock.put(packet_to_send) async def _recv_forever(self) -> None: - wait_til_stopped = self._make_wait_stopped_task() + """Receive & process DHCP packets until the client stops. + + The incoming packet's xid is checked against the client's. + Then, the relevant handler ({type}_received) is called. + """ + + wait_til_stopped = asyncio.Task(self.wait_for_state(None)) while not wait_til_stopped.done(): wait_for_received_packet = asyncio.Task( @@ -178,35 +245,14 @@ async def _recv_forever(self) -> None: wait_for_received_packet.cancel() async def transition(self, to: fsm.State, send: Optional[dhcp4msg] = None): + '''Change the client's state, and start sending a message repeatedly. + + If the message is None, any current message will stop being sent. + ''' self.state = to await self._sendq.put(send) - @fsm.state_guard(fsm.State.INIT, fsm.State.INIT_REBOOT) - async def bootstrap(self): - '''Send a `DISCOVER` or a `REQUEST`, - - depending on whether we're initializing or rebooting. - ''' - if self.state is fsm.State.INIT: - # send discover - await self.transition( - to=fsm.State.SELECTING, - send=messages.discover( - parameter_list=self.requested_parameters - ), - ) - elif self.state is fsm.State.INIT_REBOOT: - assert self.lease, 'cannot init_reboot without a lease' - # send request for lease - await self.transition( - to=fsm.State.REBOOTING, - send=messages.request( - requested_ip=self.lease.ip, - server_id=self.lease.server_id, - parameter_list=self.requested_parameters, - ), - ) - # the decorator prevents the needs for an else + # Callbacks for received DHCP messages @fsm.state_guard( fsm.State.REQUESTING, @@ -215,12 +261,15 @@ async def bootstrap(self): fsm.State.RENEWING, ) async def ack_received(self, pkt: dhcp4msg): + '''Called when an ACK is received. + + Stores the lease and puts the client in the BOUND state. + ''' self.lease = self.lease_type(ack=pkt, interface=self.interface) LOG.info( 'Got lease for %s from %s', self.lease.ip, self.lease.server_id ) await self.transition(to=fsm.State.BOUND) - self.bound.set() # FIXME: call hooks in a non blocking way (maybe call_soon ?) for i in self.hooks: await i.bound(self.lease) @@ -232,6 +281,10 @@ async def ack_received(self, pkt: dhcp4msg): fsm.State.REBINDING, ) async def nak_received(self, pkt: dhcp4msg): + '''Called when a NAK is received. + + Resets the client and starts looking for a new IP. + ''' await self.transition(to=fsm.State.INIT) # Reset lease & timers and start again self._lease = None @@ -240,6 +293,10 @@ async def nak_received(self, pkt: dhcp4msg): @fsm.state_guard(fsm.State.SELECTING) async def offer_received(self, pkt: dhcp4msg): + '''Called when an OFFER is received. + + Sends a REQUEST for the offered IP address. + ''' await self.transition( to=fsm.State.REQUESTING, send=messages.request( @@ -249,7 +306,15 @@ async def offer_received(self, pkt: dhcp4msg): ), ) + # Async context manager methods + async def __aenter__(self): + '''Set up the client so it's ready to obtain an IP. + + Tries to load a lease for the client's interface, + opens the socket, starts the sender & receiver tasks + and allocates a request ID. + ''' loaded_lease = self.lease_type.load(self.interface) if loaded_lease and loaded_lease.expired: LOG.info("Discarding stale lease") @@ -275,6 +340,10 @@ async def __aenter__(self): return self async def __aexit__(self, *_): + '''Shut down the client. + + If there's an active lease, send a RELEASE for it first. + ''' self.timers.cancel() if self.lease and not self.lease.expired: await self._sendq.put( @@ -282,10 +351,8 @@ async def __aexit__(self, *_): requested_ip=self.lease.ip, server_id=self.lease.server_id ) ) - self._stopped.set() + self.state = None await self._sender_task await self._receiver_task await self._sock.__aexit__() self.xid = None - self.state = None - self.bound.clear() diff --git a/pyroute2/dhcp/leases.py b/pyroute2/dhcp/leases.py index 28c77bc06..87b5c1346 100644 --- a/pyroute2/dhcp/leases.py +++ b/pyroute2/dhcp/leases.py @@ -50,7 +50,7 @@ def expired(self) -> bool: However the worst case scenario is that we send a REQUEST for it, get a NAK and restart from scratch. ''' - return self.expiration_in < 0 + return self.expiration_in and self.expiration_in < 0 @property def expiration_in(self) -> Optional[float]: diff --git a/pyroute2/ext/rawsocket.py b/pyroute2/ext/rawsocket.py index cad4489cf..2103004a3 100644 --- a/pyroute2/ext/rawsocket.py +++ b/pyroute2/ext/rawsocket.py @@ -114,7 +114,9 @@ def clear_buffer(self, remove_total_filter: bool = False): self, SOL_SOCKET, SO_DETACH_FILTER, total_fstring ) - def csum(self, data): + @staticmethod + def csum(data: bytes) -> int: + '''Compute the "Internet checksum" for the given bytes.''' if len(data) % 2: data += b'\x00' csum = sum( diff --git a/tests/test_linux/test_raw/test_dhcp.py b/tests/test_linux/test_raw/test_dhcp.py index 129877153..e53015733 100644 --- a/tests/test_linux/test_raw/test_dhcp.py +++ b/tests/test_linux/test_raw/test_dhcp.py @@ -32,7 +32,9 @@ async def test_get_lease( async with client.AsyncDHCPClient(veth_pair.client) as cli: await cli.bootstrap() try: - await asyncio.wait_for(cli.bound.wait(), timeout=5) + await asyncio.wait_for( + cli.wait_for_state(fsm.State.BOUND), timeout=5 + ) except TimeoutError: raise AssertionError( f'Timed out. dnsmasq output: {dnsmasq.stderr}' From 417a8d9a6668cebf970afbfde17c768c4919145d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Fri, 17 Jan 2025 20:02:53 +0100 Subject: [PATCH 17/18] dhcp client: add a udhcpd fixture and a test with it --- pyroute2/dhcp/client.py | 12 +- tests/test_linux/conftest.py | 3 +- .../fixtures/dhcp_servers/__init__.py | 131 +++++++++++++++ .../fixtures/dhcp_servers/dnsmasq.py | 68 ++++++++ .../fixtures/dhcp_servers/udhcpd.py | 101 ++++++++++++ tests/test_linux/fixtures/dnsmasq.py | 156 ------------------ tests/test_linux/fixtures/interfaces.py | 26 ++- tests/test_linux/test_raw/test_dhcp.py | 72 ++++++-- 8 files changed, 384 insertions(+), 185 deletions(-) create mode 100644 tests/test_linux/fixtures/dhcp_servers/__init__.py create mode 100644 tests/test_linux/fixtures/dhcp_servers/dnsmasq.py create mode 100644 tests/test_linux/fixtures/dhcp_servers/udhcpd.py delete mode 100644 tests/test_linux/fixtures/dnsmasq.py diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 4e56f530f..485aefe05 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -59,13 +59,21 @@ def __init__( # "public api" - async def wait_for_state(self, state: Optional[fsm.State]) -> None: + async def wait_for_state( + self, state: Optional[fsm.State], timeout: Optional[float] = None + ) -> None: '''Waits until the client is in the target state. Since the state is set to None upon exit, you can also pass None to wait for the client to stop. ''' - await self._states[state].wait() + try: + await asyncio.wait_for(self._states[state].wait(), timeout=timeout) + except TimeoutError as err: + raise TimeoutError( + f"Timed out waiting for the {state} state. " + f"Current state: {self.state}" + ) from err @fsm.state_guard(fsm.State.INIT, fsm.State.INIT_REBOOT) async def bootstrap(self): diff --git a/tests/test_linux/conftest.py b/tests/test_linux/conftest.py index ece1de2e4..0d40120df 100644 --- a/tests/test_linux/conftest.py +++ b/tests/test_linux/conftest.py @@ -2,7 +2,8 @@ from uuid import uuid4 import pytest -from fixtures.dnsmasq import dnsmasq, dnsmasq_options # noqa: F401 +from fixtures.dhcp_servers.dnsmasq import dnsmasq, dnsmasq_config # noqa: F401 +from fixtures.dhcp_servers.udhcpd import udhcpd, udhcpd_config # noqa: F401 from fixtures.interfaces import dhcp_range, veth_pair # noqa: F401 from pr2test.context_manager import NDBContextManager, SpecContextManager from utils import require_user diff --git a/tests/test_linux/fixtures/dhcp_servers/__init__.py b/tests/test_linux/fixtures/dhcp_servers/__init__.py new file mode 100644 index 000000000..57eae89aa --- /dev/null +++ b/tests/test_linux/fixtures/dhcp_servers/__init__.py @@ -0,0 +1,131 @@ +import abc +import asyncio +from argparse import ArgumentParser +from dataclasses import dataclass +from ipaddress import IPv4Address +from typing import ClassVar, Generic, Literal, Optional, TypeVar + +from ..interfaces import DHCPRangeConfig + + +@dataclass +class DHCPServerConfig: + range: DHCPRangeConfig + interface: str + lease_time: int = 120 # in seconds + max_leases: int = 50 + + +DHCPServerConfigT = TypeVar("DHCPServerConfigT", bound=DHCPServerConfig) + + +class DHCPServerFixture(abc.ABC, Generic[DHCPServerConfigT]): + + BINARY_PATH: ClassVar[Optional[str]] = None + + @classmethod + def get_config_class(cls) -> type[DHCPServerConfigT]: + return cls.__orig_bases__[0].__args__[0] + + def __init__(self, config: DHCPServerConfigT) -> None: + self.config = config + self.stdout: list[str] = [] + self.stderr: list[str] = [] + self.process: Optional[asyncio.subprocess.Process] = None + self.output_poller: Optional[asyncio.Task] = None + + async def _read_output(self, name: Literal['stdout', 'stderr']): + '''Read stdout or stderr until the process exits.''' + stream = getattr(self.process, name) + output = getattr(self, name) + while line := await stream.readline(): + output.append(line.decode().strip()) + + async def _read_outputs(self): + '''Read stdout & stderr until the process exits.''' + assert self.process + await asyncio.gather( + self._read_output('stderr'), self._read_output('stdout') + ) + + @abc.abstractmethod + def get_cmdline_options(self) -> tuple[str]: + '''All commandline options passed to the server.''' + + async def __aenter__(self): + '''Start the server process and start polling its output.''' + if not self.BINARY_PATH: + raise RuntimeError( + f"server binary is missing for {type(self).__name__}" + ) + self.process = await asyncio.create_subprocess_exec( + self.BINARY_PATH, + *self.get_cmdline_options(), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env={'LANG': 'C'}, # usually ensures the output is in english + ) + self.output_poller = asyncio.Task(self._read_outputs()) + return self + + async def __aexit__(self, *_): + if self.process: + if self.process.returncode is None: + self.process.terminate() + await self.process.wait() + await self.output_poller + + +def get_psr() -> ArgumentParser: + psr = ArgumentParser() + psr.add_argument('interface', help='Interface to listen on') + psr.add_argument( + '--router', type=IPv4Address, default=None, help='Router IPv4 address.' + ) + psr.add_argument( + '--range-start', + type=IPv4Address, + default=IPv4Address('192.168.186.10'), + help='Start of the DHCP client range.', + ) + psr.add_argument( + '--range-end', + type=IPv4Address, + default=IPv4Address('192.168.186.100'), + help='End of the DHCP client range.', + ) + psr.add_argument( + '--lease-time', + default=120, + type=int, + help='DHCP lease time in seconds (minimum 2 minutes)', + ) + psr.add_argument( + '--netmask', type=IPv4Address, default=IPv4Address("255.255.255.0") + ) + return psr + + +async def run_fixture_as_main(fixture_cls: type[DHCPServerFixture]): + config_cls = fixture_cls.get_config_class() + args = get_psr().parse_args() + range_config = DHCPRangeConfig( + start=args.range_start, + end=args.range_end, + router=args.router, + netmask=args.netmask, + ) + conf = config_cls( + range=range_config, + interface=args.interface, + lease_time=args.lease_time, + ) + read_lines: int = 0 + async with fixture_cls(conf) as dhcp_server: + # quick & dirty stderr polling + while True: + if len(dhcp_server.stderr) > read_lines: + read_lines += len(lines := dhcp_server.stderr[read_lines:]) + print(*lines, sep='\n') + else: + await asyncio.sleep(0.2) diff --git a/tests/test_linux/fixtures/dhcp_servers/dnsmasq.py b/tests/test_linux/fixtures/dhcp_servers/dnsmasq.py new file mode 100644 index 000000000..d98e48140 --- /dev/null +++ b/tests/test_linux/fixtures/dhcp_servers/dnsmasq.py @@ -0,0 +1,68 @@ +import asyncio +from dataclasses import dataclass +from shutil import which +from typing import AsyncGenerator, ClassVar, Optional + +import pytest +import pytest_asyncio +from fixtures.interfaces import DHCPRangeConfig + +from . import DHCPServerConfig, DHCPServerFixture, run_fixture_as_main + + +@dataclass +class DnsmasqConfig(DHCPServerConfig): + '''Options for the dnsmasq server.''' + + def __iter__(self): + opts = [ + f'--interface={self.interface}', + f'--dhcp-range={self.range.start},' + f'{self.range.end},{self.lease_time}', + f'--dhcp-lease-max={self.max_leases}', + ] + if router := self.range.router: + opts.append(f"--dhcp-option=option:router,{router}") + return iter(opts) + + +class DnsmasqFixture(DHCPServerFixture[DnsmasqConfig]): + '''Runs the dnsmasq server as an async context manager.''' + + BINARY_PATH: ClassVar[Optional[str]] = which('dnsmasq') + + def _get_base_cmdline_options(self) -> tuple[str]: + '''The base commandline options for dnsmasq.''' + return ( + '--keep-in-foreground', # self explanatory + '--no-resolv', # don't mess w/ resolv.conf + '--log-facility=-', # log to stdout + '--no-hosts', # don't read /etc/hosts + '--bind-interfaces', # don't bind on wildcard + '--no-ping', # don't ping to check if ips are attributed + ) + + def get_cmdline_options(self) -> tuple[str]: + '''All commandline options passed to dnsmasq.''' + return (*self._get_base_cmdline_options(), *self.config) + + +@pytest.fixture +def dnsmasq_config( + veth_pair: tuple[str, str], dhcp_range: DHCPRangeConfig +) -> DnsmasqConfig: + '''dnsmasq options useful for test purposes.''' + return DnsmasqConfig(range=dhcp_range, interface=veth_pair[0]) + + +@pytest_asyncio.fixture +async def dnsmasq( + dnsmasq_config: DnsmasqConfig, +) -> AsyncGenerator[DnsmasqFixture, None]: + '''A dnsmasq instance running for the duration of the test.''' + async with DnsmasqFixture(config=dnsmasq_config) as dnsf: + yield dnsf + + +if __name__ == '__main__': + asyncio.run(run_fixture_as_main(DnsmasqFixture)) diff --git a/tests/test_linux/fixtures/dhcp_servers/udhcpd.py b/tests/test_linux/fixtures/dhcp_servers/udhcpd.py new file mode 100644 index 000000000..8912a0ef5 --- /dev/null +++ b/tests/test_linux/fixtures/dhcp_servers/udhcpd.py @@ -0,0 +1,101 @@ +import asyncio +from dataclasses import dataclass +from pathlib import Path +from shutil import which +from tempfile import TemporaryDirectory +from typing import AsyncGenerator, ClassVar, Optional + +import pytest +import pytest_asyncio + +from ..interfaces import DHCPRangeConfig +from . import DHCPServerConfig, DHCPServerFixture, run_fixture_as_main + + +@dataclass +class UdhcpdConfig(DHCPServerConfig): + arp_ping_timeout_ms: int = 200 # default is 2000 + + +class UdhcpdFixture(DHCPServerFixture[UdhcpdConfig]): + '''Runs the udhcpd server as an async context manager.''' + + BINARY_PATH: ClassVar[Optional[str]] = which('busybox') + + def __init__(self, config): + super().__init__(config) + self._temp_dir: Optional[TemporaryDirectory[str]] = None + + @property + def workdir(self) -> Path: + '''A temporary directory for udhcpd's files.''' + assert self._temp_dir + return Path(self._temp_dir.name) + + @property + def config_file(self) -> Path: + '''The udhcpd config file path.''' + return self.workdir.joinpath("udhcpd.conf") + + async def __aenter__(self): + self._temp_dir = TemporaryDirectory(prefix=type(self).__name__) + self._temp_dir.__enter__() + self.config_file.write_text(self.generate_config()) + return await super().__aenter__() + + def generate_config(self) -> str: + '''Generate the contents of udhcpd's config file.''' + cfg = self.config + base_workfile = self.workdir.joinpath(self.config.interface) + lease_file = base_workfile.with_suffix(".leases") + pidfile = base_workfile.with_suffix(".pid") + lines = [ + ("start", cfg.range.start), + ("end", cfg.range.end), + ("max_leases", cfg.max_leases), + ("interface", cfg.interface), + ("lease_file", lease_file), + ("pidfile", pidfile), + ("opt lease", cfg.lease_time), + ("opt router", cfg.range.router), + ] + return "\n".join(f"{opt}\t{value}" for opt, value in lines) + + async def __aexit__(self, *_): + await super().__aexit__(*_) + self._temp_dir.__exit__(*_) + + def get_cmdline_options(self) -> tuple[str]: + '''All commandline options passed to udhcpd.''' + return ( + 'udhcpd', + '-f', # run in foreground + '-a', + str(self.config.arp_ping_timeout_ms), + str(self.config_file), + ) + + +@pytest.fixture +def udhcpd_config( + veth_pair: tuple[str, str], dhcp_range: DHCPRangeConfig +) -> UdhcpdConfig: + '''udhcpd options useful for test purposes.''' + return UdhcpdConfig( + range=dhcp_range, + interface=veth_pair[0], + lease_time=1, # very short leases for tests + ) + + +@pytest_asyncio.fixture +async def udhcpd( + udhcpd_config: UdhcpdConfig, +) -> AsyncGenerator[UdhcpdFixture, None]: + '''An udhcpd instance running for the duration of the test.''' + async with UdhcpdFixture(config=udhcpd_config) as dhcp_server: + yield dhcp_server + + +if __name__ == '__main__': + asyncio.run(run_fixture_as_main(UdhcpdFixture)) diff --git a/tests/test_linux/fixtures/dnsmasq.py b/tests/test_linux/fixtures/dnsmasq.py deleted file mode 100644 index e453d2ad8..000000000 --- a/tests/test_linux/fixtures/dnsmasq.py +++ /dev/null @@ -1,156 +0,0 @@ -import asyncio -from argparse import ArgumentParser -from dataclasses import dataclass -from ipaddress import IPv4Address -from shutil import which -from typing import AsyncGenerator, ClassVar, Literal, Optional - -import pytest -import pytest_asyncio -from fixtures.interfaces import DHCPRangeConfig - - -@dataclass -class DnsmasqOptions: - '''Options for the dnsmasq server.''' - - range_start: IPv4Address - range_end: IPv4Address - interface: str - lease_time: str = '12h' - - def __iter__(self): - opts = ( - f'--interface={self.interface}', - f'--dhcp-range={self.range_start},' - f'{self.range_end},{self.lease_time}', - ) - return iter(opts) - - -class DnsmasqFixture: - '''Runs the dnsmasq server as an async context manager.''' - - DNSMASQ_PATH: ClassVar[Optional[str]] = which('dnsmasq') - - def __init__(self, options: DnsmasqOptions) -> None: - self.options = options - self.stdout: list[str] = [] - self.stderr: list[str] = [] - self.process: Optional[asyncio.subprocess.Process] = None - self.output_poller: Optional[asyncio.Task] = None - - async def _read_output(self, name: Literal['stdout', 'stderr']): - '''Read stdout or stderr until the process exits.''' - stream = getattr(self.process, name) - output = getattr(self, name) - while line := await stream.readline(): - output.append(line.decode().strip()) - - async def _read_outputs(self): - '''Read stdout & stderr until the process exits.''' - assert self.process - await asyncio.gather( - self._read_output('stderr'), self._read_output('stdout') - ) - - def _get_base_cmdline_options(self) -> tuple[str]: - '''The base commandline options for dnsmasq.''' - return ( - '--keep-in-foreground', # self explanatory - '--no-resolv', # don't mess w/ resolv.conf - '--log-facility=-', # log to stdout - '--no-hosts', # don't read /etc/hosts - '--bind-interfaces', # don't bind on wildcard - '--no-ping', # don't ping to check if ips are attributed - ) - - def get_cmdline_options(self) -> tuple[str]: - '''All commandline options passed to dnsmasq.''' - return (*self._get_base_cmdline_options(), *self.options) - - async def __aenter__(self): - '''Start the dnsmasq process and start polling its output.''' - assert self.DNSMASQ_PATH - self.process = await asyncio.create_subprocess_exec( - self.DNSMASQ_PATH, - *self.get_cmdline_options(), - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env={'LANG': 'C'}, - ) - self.output_poller = asyncio.Task(self._read_outputs()) - return self - - async def __aexit__(self, *_): - if self.process: - if self.process.returncode is None: - self.process.terminate() - await self.process.wait() - await self.output_poller - - -@pytest.fixture -def dnsmasq_options( - veth_pair: tuple[str, str], dhcp_range: DHCPRangeConfig -) -> DnsmasqOptions: - '''dnsmasq options useful for test purposes.''' - return DnsmasqOptions( - range_start=dhcp_range.range_start, - range_end=dhcp_range.range_end, - interface=veth_pair[0], - ) - - -@pytest_asyncio.fixture -async def dnsmasq( - dnsmasq_options: DnsmasqOptions, -) -> AsyncGenerator[DnsmasqFixture, None]: - '''A dnsmasq instance running for the duration of the test.''' - async with DnsmasqFixture(options=dnsmasq_options) as dnsf: - yield dnsf - - -def get_psr() -> ArgumentParser: - psr = ArgumentParser() - psr.add_argument('interface', help='Interface to listen on') - psr.add_argument( - '--range-start', - type=IPv4Address, - default=IPv4Address('192.168.186.10'), - help='Start of the DHCP client range.', - ) - psr.add_argument( - '--range-end', - type=IPv4Address, - default=IPv4Address('192.168.186.100'), - help='End of the DHCP client range.', - ) - psr.add_argument( - '--lease-time', - default='2m', - help='DHCP lease time (minimum 2 minutes according to man)', - ) - return psr - - -async def main(): - '''Commandline entrypoint to start dnsmasq the same way the fixture does. - - Useful for debugging. - ''' - args = get_psr().parse_args() - opts = DnsmasqOptions(**args.__dict__) - read_lines: int = 0 - async with DnsmasqFixture(opts) as dnsm: - # quick & dirty stderr polling - while True: - if len(dnsm.stderr) > read_lines: - read_lines += len(lines := dnsm.stderr[read_lines:]) - print(*lines, sep='\n') - else: - await asyncio.sleep(0.2) - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/tests/test_linux/fixtures/interfaces.py b/tests/test_linux/fixtures/interfaces.py index c276b2cbc..caa2229ab 100644 --- a/tests/test_linux/fixtures/interfaces.py +++ b/tests/test_linux/fixtures/interfaces.py @@ -1,6 +1,6 @@ import asyncio import random -from ipaddress import IPv4Address, IPv4Interface +from ipaddress import IPv4Address from typing import AsyncGenerator, NamedTuple import pytest @@ -8,9 +8,10 @@ class DHCPRangeConfig(NamedTuple): - range_start: IPv4Address - range_end: IPv4Address - gw: IPv4Interface + start: IPv4Address + end: IPv4Address + router: IPv4Address + netmask: IPv4Address async def ip(*args: str): @@ -23,12 +24,13 @@ async def ip(*args: str): @pytest.fixture def dhcp_range() -> DHCPRangeConfig: - ''' 'An IPv4 DHCP range configuration.''' + '''An IPv4 DHCP range configuration.''' rangeidx = random.randint(1, 254) return DHCPRangeConfig( - range_start=IPv4Address(f'10.{rangeidx}.0.10'), - range_end=IPv4Address(f'10.{rangeidx}.0.20'), - gw=IPv4Interface(f'10.{rangeidx}.0.1/16'), + start=IPv4Address(f'10.{rangeidx}.0.10'), + end=IPv4Address(f'10.{rangeidx}.0.20'), + router=IPv4Address(f'10.{rangeidx}.0.1'), + netmask=IPv4Address('255.255.255.0'), ) @@ -60,7 +62,13 @@ async def veth_pair( 'name', client_ifname, ) - await ip('addr', 'add', str(dhcp_range.gw), 'dev', server_ifname) + await ip( + 'addr', + 'add', + f"{dhcp_range.router}/{dhcp_range.netmask}", + 'dev', + server_ifname, + ) await ip('link', 'set', server_ifname, 'up') await ip('link', 'set', client_ifname, 'up') yield VethPair(server_ifname, client_ifname) diff --git a/tests/test_linux/test_raw/test_dhcp.py b/tests/test_linux/test_raw/test_dhcp.py index e53015733..42c8059a5 100644 --- a/tests/test_linux/test_raw/test_dhcp.py +++ b/tests/test_linux/test_raw/test_dhcp.py @@ -4,13 +4,15 @@ from pathlib import Path import pytest -from fixtures.dnsmasq import DnsmasqFixture +from fixtures.dhcp_servers.dnsmasq import DnsmasqFixture +from fixtures.dhcp_servers.udhcpd import UdhcpdFixture from fixtures.interfaces import VethPair from pr2test.marks import require_root -from pyroute2.dhcp import client, fsm +from pyroute2.dhcp import fsm +from pyroute2.dhcp.client import AsyncDHCPClient from pyroute2.dhcp.constants import bootp, dhcp -from pyroute2.dhcp.leases import JSONFileLease +from pyroute2.dhcp.leases import JSONFileLease, JSONStdoutLease pytestmark = [require_root()] @@ -21,24 +23,19 @@ async def test_get_lease( veth_pair: VethPair, tmpdir: str, monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, ): '''The client can get a lease and write it to a file.''' work_dir = Path(tmpdir) + caplog.set_level("DEBUG") # Patch JSONFileLease so leases get written to the temp dir # instead of whatever the working directory is monkeypatch.setattr(JSONFileLease, '_get_lease_dir', lambda: work_dir) # boot up the dhcp client and wait for a lease - async with client.AsyncDHCPClient(veth_pair.client) as cli: + async with AsyncDHCPClient(veth_pair.client) as cli: await cli.bootstrap() - try: - await asyncio.wait_for( - cli.wait_for_state(fsm.State.BOUND), timeout=5 - ) - except TimeoutError: - raise AssertionError( - f'Timed out. dnsmasq output: {dnsmasq.stderr}' - ) + await cli.wait_for_state(fsm.State.BOUND, timeout=10) assert cli.state == fsm.State.BOUND lease = cli.lease assert lease.ack['xid'] == cli.xid @@ -48,9 +45,9 @@ async def test_get_lease( assert lease.ack['op'] == bootp.MessageType.BOOTREPLY assert lease.ack['options']['message_type'] == dhcp.MessageType.ACK assert ( - dnsmasq.options.range_start + dnsmasq.config.range.start <= IPv4Address(lease.ip) - <= dnsmasq.options.range_end + <= dnsmasq.config.range.end ) assert lease.ack['chaddr'] # TODO: check chaddr matches veth_pair.client's MAC @@ -72,8 +69,8 @@ async def test_client_console(dnsmasq: DnsmasqFixture, veth_pair: VethPair): '--lease-type', 'pyroute2.dhcp.leases.JSONStdoutLease', '--exit-on-lease', + '--log-level=DEBUG', stdout=asyncio.subprocess.PIPE, - # stderr=asyncio.subprocess.PIPE, ) try: stdout, _ = await asyncio.wait_for(process.communicate(), timeout=5) @@ -84,7 +81,48 @@ async def test_client_console(dnsmasq: DnsmasqFixture, veth_pair: VethPair): json_lease = json.loads(stdout) assert json_lease['interface'] == veth_pair.client assert ( - dnsmasq.options.range_start + dnsmasq.config.range.start <= IPv4Address(json_lease['ack']['yiaddr']) - <= dnsmasq.options.range_end + <= dnsmasq.config.range.end ) + + +@pytest.mark.asyncio +async def test_client_lifecycle(udhcpd: UdhcpdFixture, veth_pair: VethPair): + '''Test getting a lease, expiring & getting a lease again.''' + async with AsyncDHCPClient( + veth_pair.client, lease_type=JSONStdoutLease + ) as cli: + # No lease, we're in the INIT state + assert cli.state == fsm.State.INIT + # Start requesting an IP + await cli.bootstrap() + # Then, the client in the SELECTING state while sending DISCOVERs + await cli.wait_for_state(fsm.State.SELECTING, timeout=1) + # Once we get an OFFER the client switches to REQUESTING + await cli.wait_for_state(fsm.State.REQUESTING, timeout=1) + # After getting an ACK, we're BOUND ! + await cli.wait_for_state(fsm.State.BOUND, timeout=1) + + # Ideally, we would test the REBINDING & RENEWING states here, + # but they depend on timers that udhcpd does not implement. + + # The lease expires, and we're back to INIT + await cli.wait_for_state(fsm.State.INIT, timeout=5) + await cli.wait_for_state(fsm.State.SELECTING, timeout=1) + await cli.wait_for_state(fsm.State.REQUESTING, timeout=1) + await cli.wait_for_state(fsm.State.BOUND, timeout=1) + + # Stop here, that's enough + lease = cli.lease + assert lease.ack['xid'] == cli.xid + + # The obtained IP must be in the range + assert ( + udhcpd.config.range.start + <= IPv4Address(lease.ip) + <= udhcpd.config.range.end + ) + assert lease.routers == [str(udhcpd.config.range.router)] + assert lease.interface == veth_pair.client + assert lease.ack["options"]["lease_time"] == udhcpd.config.lease_time From c5679a048c0bfa65803ee9b92710276b0bfcd871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Noss?= Date: Mon, 20 Jan 2025 18:39:16 +0100 Subject: [PATCH 18/18] properly do unicast & broadcast for REQUESTs according to RFC --- pyroute2/dhcp/__init__.py | 20 +-- pyroute2/dhcp/client.py | 91 ++++++----- pyroute2/dhcp/constants/__init__.py | 0 pyroute2/dhcp/dhcp4msg.py | 9 +- pyroute2/dhcp/dhcp4socket.py | 61 ++++--- pyroute2/dhcp/enums/__init__.py | 1 + pyroute2/dhcp/{constants => enums}/bootp.py | 0 pyroute2/dhcp/{constants => enums}/dhcp.py | 0 pyroute2/dhcp/hooks.py | 14 ++ pyroute2/dhcp/leases.py | 5 + pyroute2/dhcp/messages.py | 172 ++++++++++++++++---- tests/test_linux/test_raw/test_dhcp.py | 2 +- 12 files changed, 257 insertions(+), 118 deletions(-) delete mode 100644 pyroute2/dhcp/constants/__init__.py create mode 100644 pyroute2/dhcp/enums/__init__.py rename pyroute2/dhcp/{constants => enums}/bootp.py (100%) rename pyroute2/dhcp/{constants => enums}/dhcp.py (100%) diff --git a/pyroute2/dhcp/__init__.py b/pyroute2/dhcp/__init__.py index a8c72e555..42dd0b1de 100644 --- a/pyroute2/dhcp/__init__.py +++ b/pyroute2/dhcp/__init__.py @@ -105,18 +105,7 @@ class array8(option): from pyroute2.common import basestring from pyroute2.protocols import msg -BOOTREQUEST = 1 -BOOTREPLY = 2 - -DHCPDISCOVER = 1 -DHCPOFFER = 2 -DHCPREQUEST = 3 -DHCPDECLINE = 4 -DHCPACK = 5 -DHCPNAK = 6 -DHCPRELEASE = 7 -DHCPINFORM = 8 - +from . import enums if not hasattr(array, 'tobytes'): # Python2 and Python3 versions of array differ, @@ -262,8 +251,8 @@ def encode(self): self._register_options() # put message type options = self.get('options') or { - 'message_type': DHCPDISCOVER, - 'parameter_list': [1, 3, 6, 12, 15, 28], + 'message_type': enums.dhcp.MessageType.DISCOVER, + 'parameter_list': [1, 3, 6, 12, 15, 28], # FIXME } self.buf += ( @@ -321,3 +310,6 @@ class array8(option): class client_id(option): fields = (('type', 'uint8'), ('key', 'l2addr')) + + class message_type(option): + policy = {'format': 'B', 'decode': lambda x: enums.dhcp.MessageType(x)} diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 485aefe05..8120fb4b7 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -3,9 +3,8 @@ from typing import ClassVar, DefaultDict, Iterable, Optional from pyroute2.dhcp import fsm, messages -from pyroute2.dhcp.constants import dhcp -from pyroute2.dhcp.dhcp4msg import dhcp4msg from pyroute2.dhcp.dhcp4socket import AsyncDHCP4Socket +from pyroute2.dhcp.enums import dhcp from pyroute2.dhcp.hooks import Hook from pyroute2.dhcp.leases import JSONFileLease, Lease from pyroute2.dhcp.timers import Timers @@ -45,7 +44,9 @@ def __init__( # Current lease, read from persistent storage or sent by a server self._lease: Optional[Lease] = None # dhcp messages put in this queue are sent by _send_forever - self._sendq: asyncio.Queue[Optional[dhcp4msg]] = asyncio.Queue() + self._sendq: asyncio.Queue[Optional[messages.SentDHCPMessage]] = ( + asyncio.Queue() + ) # Handle to run _send_forever for the context manager's lifetime self._sender_task: Optional[asyncio.Task] = None # Handle to run _recv_forever for the context manager's lifetime @@ -96,10 +97,10 @@ async def bootstrap(self): # send request for lease await self.transition( to=fsm.State.REBOOTING, - send=messages.request( - requested_ip=self.lease.ip, - server_id=self.lease.server_id, + send=messages.request_for_lease( parameter_list=self.requested_parameters, + lease=self.lease, + state=fsm.State.REBOOTING, ), ) # the decorator prevents the needs for an else @@ -153,10 +154,10 @@ async def _renew(self): self.timers._reset_timer('renewal') # FIXME should be automatic await self.transition( to=fsm.State.RENEWING, - send=messages.request( - requested_ip=self.lease.ip, - server_id=self.lease.server_id, + send=messages.request_for_lease( parameter_list=self.requested_parameters, + lease=self.lease, + state=fsm.State.RENEWING, ), ) @@ -167,10 +168,10 @@ async def _rebind(self): self.timers._reset_timer('rebinding') await self.transition( to=fsm.State.REBINDING, - send=messages.request( - requested_ip=self.lease.ip, - server_id=self.lease.server_id, + send=messages.request_for_lease( parameter_list=self.requested_parameters, + lease=self.lease, + state=fsm.State.REBINDING, ), ) @@ -189,30 +190,27 @@ async def _expire_lease(self): async def _send_forever(self): """Send packets from _sendq until the client stops.""" - packet_to_send = None + msg_to_send = None wait_til_stopped = asyncio.Task(self.wait_for_state(None)) interval = 5 # TODO make dynamic while not wait_til_stopped.done(): - wait_for_packet_to_send = asyncio.Task( + wait_for_msg_to_send = asyncio.Task( self._sendq.get(), name='wait for packet to send' ) done, pending = await asyncio.wait( - (wait_til_stopped, wait_for_packet_to_send), + (wait_til_stopped, wait_for_msg_to_send), return_when=asyncio.FIRST_COMPLETED, timeout=interval, ) - if wait_for_packet_to_send in done: - if packet_to_send := wait_for_packet_to_send.result(): - packet_to_send['xid'] = self.xid - elif wait_for_packet_to_send in pending: - wait_for_packet_to_send.cancel() - - if packet_to_send: - LOG.debug( - 'Sending %s', - packet_to_send['options']['message_type'].name, - ) - await self._sock.put(packet_to_send) + if wait_for_msg_to_send in done: + if msg_to_send := wait_for_msg_to_send.result(): + msg_to_send.dhcp['xid'] = self.xid + elif wait_for_msg_to_send in pending: + wait_for_msg_to_send.cancel() + + if msg_to_send: + LOG.debug('Sending %s', msg_to_send) + await self._sock.put(msg_to_send) async def _recv_forever(self) -> None: """Receive & process DHCP packets until the client stops. @@ -236,10 +234,10 @@ async def _recv_forever(self) -> None: if wait_for_received_packet in done: received_packet = wait_for_received_packet.result() msg_type = dhcp.MessageType( - received_packet['options']['message_type'] + received_packet.dhcp['options']['message_type'] ) - LOG.info('Received %s', msg_type.name) - if received_packet.get('xid') != self.xid: + LOG.info('Received %s', received_packet) + if received_packet.dhcp.get('xid') != self.xid: LOG.error('Missing or wrong xid, discarding') else: handler_name = f'{msg_type.name.lower()}_received' @@ -252,7 +250,9 @@ async def _recv_forever(self) -> None: elif wait_for_received_packet in pending: wait_for_received_packet.cancel() - async def transition(self, to: fsm.State, send: Optional[dhcp4msg] = None): + async def transition( + self, to: fsm.State, send: Optional[messages.SentDHCPMessage] = None + ): '''Change the client's state, and start sending a message repeatedly. If the message is None, any current message will stop being sent. @@ -268,12 +268,14 @@ async def transition(self, to: fsm.State, send: Optional[dhcp4msg] = None): fsm.State.REBINDING, fsm.State.RENEWING, ) - async def ack_received(self, pkt: dhcp4msg): + async def ack_received(self, msg: messages.ReceivedDHCPMessage): '''Called when an ACK is received. Stores the lease and puts the client in the BOUND state. ''' - self.lease = self.lease_type(ack=pkt, interface=self.interface) + self.lease = self.lease_type( + ack=msg.dhcp, interface=self.interface, server_mac=msg.eth_src + ) LOG.info( 'Got lease for %s from %s', self.lease.ip, self.lease.server_id ) @@ -288,11 +290,12 @@ async def ack_received(self, pkt: dhcp4msg): fsm.State.RENEWING, fsm.State.REBINDING, ) - async def nak_received(self, pkt: dhcp4msg): + async def nak_received(self, msg: messages.ReceivedDHCPMessage): '''Called when a NAK is received. Resets the client and starts looking for a new IP. ''' + # TODO: check the NAK matches something we asked for ? await self.transition(to=fsm.State.INIT) # Reset lease & timers and start again self._lease = None @@ -300,17 +303,15 @@ async def nak_received(self, pkt: dhcp4msg): await self.bootstrap() @fsm.state_guard(fsm.State.SELECTING) - async def offer_received(self, pkt: dhcp4msg): + async def offer_received(self, msg: messages.ReceivedDHCPMessage): '''Called when an OFFER is received. Sends a REQUEST for the offered IP address. ''' await self.transition( to=fsm.State.REQUESTING, - send=messages.request( - requested_ip=pkt['yiaddr'], - server_id=pkt['options']['server_id'], - parameter_list=self.requested_parameters, + send=messages.request_for_offer( + parameter_list=self.requested_parameters, offer=msg ), ) @@ -353,12 +354,12 @@ async def __aexit__(self, *_): If there's an active lease, send a RELEASE for it first. ''' self.timers.cancel() - if self.lease and not self.lease.expired: - await self._sendq.put( - messages.release( - requested_ip=self.lease.ip, server_id=self.lease.server_id - ) - ) + # FIXME: call hooks in a non blocking way (maybe call_soon ?) + if self.lease: + for i in self.hooks: + await i.unbound(self.lease) + if not self.lease.expired: + await self._sendq.put(messages.release(lease=self.lease)) self.state = None await self._sender_task await self._receiver_task diff --git a/pyroute2/dhcp/constants/__init__.py b/pyroute2/dhcp/constants/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pyroute2/dhcp/dhcp4msg.py b/pyroute2/dhcp/dhcp4msg.py index ec5f93c6a..3416f3a90 100644 --- a/pyroute2/dhcp/dhcp4msg.py +++ b/pyroute2/dhcp/dhcp4msg.py @@ -14,6 +14,11 @@ class dhcp4msg(dhcpmsg): ('hops', 'uint8'), ('xid', 'uint32'), ('secs', 'uint16'), + # TODO: set flags to broadcast (RFC 2131) + # A client that cannot receive unicast IP datagrams until its protocol + # software has been configured with an IP address SHOULD set the + # BROADCAST bit in the 'flags' field to 1 in any DHCPDISCOVER or + # DHCPREQUEST messages that client sends. ('flags', 'uint16'), ('ciaddr', 'ip4addr'), ('yiaddr', 'ip4addr'), @@ -40,10 +45,10 @@ class dhcp4msg(dhcpmsg): (9, 'lpr_server', 'ip4list'), (50, 'requested_ip', 'ip4addr'), (51, 'lease_time', 'be32'), - (53, 'message_type', 'uint8'), + (53, 'message_type', 'message_type'), (54, 'server_id', 'ip4addr'), (55, 'parameter_list', 'array8'), - (57, 'messagi_size', 'be16'), + (57, 'message_size', 'be16'), (58, 'renewal_time', 'be32'), (59, 'rebinding_time', 'be32'), (60, 'vendor_id', 'string'), diff --git a/pyroute2/dhcp/dhcp4socket.py b/pyroute2/dhcp/dhcp4socket.py index d0af5cc06..a19ef2712 100644 --- a/pyroute2/dhcp/dhcp4socket.py +++ b/pyroute2/dhcp/dhcp4socket.py @@ -11,11 +11,13 @@ from pyroute2.common import AddrPool from pyroute2.compat import ETHERTYPE_IP from pyroute2.dhcp.dhcp4msg import dhcp4msg +from pyroute2.dhcp.messages import ReceivedDHCPMessage, SentDHCPMessage from pyroute2.ext.rawsocket import AsyncRawSocket from pyroute2.protocols import ethmsg, ip4msg, udp4_pseudo_header, udpmsg LOG = logging.getLogger(__name__) + UDP_HEADER_SIZE = 8 IPV4_HEADER_SIZE = 20 @@ -67,20 +69,9 @@ def __init__(self, ifname, port: int = 68): ) # TODO : maybe it should be in the client and not here ? self.aio_loop = asyncio.get_running_loop() - async def put( - self, - msg: dhcp4msg, - eth_dst: str = 'ff:ff:ff:ff:ff:ff', - ip_dst: str = '255.255.255.255', - dport: int = 67, - ) -> dhcp4msg: + async def put(self, msg: SentDHCPMessage) -> SentDHCPMessage: ''' - Put DHCP message. Parameters: - - * msg -- dhcp4msg instance - * eth_dst -- dest MAC address - * ip_dst -- dest IP address - * dport -- DHCP server port + Put DHCP message. Examples:: @@ -91,11 +82,21 @@ async def put( 'requested_ip': '172.16.101.2', 'server_id': '172.16.101.1'}})) - The method returns the sent dhcp4msg, so one can get from + The method returns the SentDHCPMessage, so one can get from there the `xid` (transaction id) and other details. ''' + + if msg.sport != self.port: + raise ValueError( + f"Client source port is set to {self.port}, " + f"cannot send message from port {msg.sport}." + ) + + if not msg.eth_src: + msg.eth_src = self.l2addr + # DHCP layer - dhcp = msg + dhcp = msg.dhcp # dhcp transaction id if dhcp['xid'] is None: @@ -103,7 +104,7 @@ async def put( # auto add src addr if dhcp['chaddr'] is None: - dhcp['chaddr'] = self.l2addr + dhcp['chaddr'] = msg.eth_src data = dhcp.encode().buf dhcp_payload_size = len(data) @@ -112,13 +113,17 @@ async def put( udp = udpmsg( { 'sport': self.port, - 'dport': dport, + 'dport': msg.dport, 'len': UDP_HEADER_SIZE + dhcp_payload_size, } ) # Pseudo UDP header, only for checksum purposes udph = udp4_pseudo_header( - {'dst': ip_dst, 'len': UDP_HEADER_SIZE + dhcp_payload_size} + { + 'src': msg.ip_src, + 'dst': msg.ip_dst, + 'len': UDP_HEADER_SIZE + dhcp_payload_size, + } ) udp['csum'] = self.csum(udph.encode().buf + udp.encode().buf + data) udp.reset() @@ -128,7 +133,8 @@ async def put( { 'len': IPV4_HEADER_SIZE + UDP_HEADER_SIZE + dhcp_payload_size, 'proto': socket.IPPROTO_UDP, - 'dst': ip_dst, + 'dst': msg.ip_dst, + 'src': msg.ip_src, } ) ip4['csum'] = self.csum(ip4.encode().buf) @@ -136,15 +142,15 @@ async def put( # MAC layer eth = ethmsg( - {'dst': eth_dst, 'src': self.l2addr, 'type': ETHERTYPE_IP} + {'dst': msg.eth_dst, 'src': msg.eth_src, 'type': ETHERTYPE_IP} ) data = eth.encode().buf + ip4.encode().buf + udp.encode().buf + data await self.aio_loop.sock_sendall(self, data) dhcp.reset() - return dhcp + return msg - async def get(self) -> dhcp4msg: + async def get(self) -> ReceivedDHCPMessage: ''' Get the next incoming packet from the socket and try to decode it as IPv4 DHCP. No analysis is done here, @@ -155,4 +161,13 @@ async def get(self) -> dhcp4msg: eth = ethmsg(buf=data).decode() ip4 = ip4msg(buf=data, offset=eth.offset).decode() udp = udpmsg(buf=data, offset=ip4.offset).decode() - return dhcp4msg(buf=data, offset=udp.offset).decode() + dhcp = dhcp4msg(buf=data, offset=udp.offset).decode() + return ReceivedDHCPMessage( + dhcp=dhcp, + eth_src=eth['src'], + eth_dst=eth['dst'], + ip_src=ip4['src'], + ip_dst=ip4['dst'], + sport=udp['sport'], + dport=udp['dport'], + ) diff --git a/pyroute2/dhcp/enums/__init__.py b/pyroute2/dhcp/enums/__init__.py new file mode 100644 index 000000000..9cb0220d5 --- /dev/null +++ b/pyroute2/dhcp/enums/__init__.py @@ -0,0 +1 @@ +from . import bootp, dhcp # noqa: F401 diff --git a/pyroute2/dhcp/constants/bootp.py b/pyroute2/dhcp/enums/bootp.py similarity index 100% rename from pyroute2/dhcp/constants/bootp.py rename to pyroute2/dhcp/enums/bootp.py diff --git a/pyroute2/dhcp/constants/dhcp.py b/pyroute2/dhcp/enums/dhcp.py similarity index 100% rename from pyroute2/dhcp/constants/dhcp.py rename to pyroute2/dhcp/enums/dhcp.py diff --git a/pyroute2/dhcp/hooks.py b/pyroute2/dhcp/hooks.py index cd5d50f9d..431296226 100644 --- a/pyroute2/dhcp/hooks.py +++ b/pyroute2/dhcp/hooks.py @@ -25,6 +25,20 @@ async def unbound(self, lease: Lease): class ConfigureIP(Hook): async def bound(self, lease: Lease): LOG.info('STUB: add %s to %s', lease.ip, lease.interface) + # await ip( + # "addr", + # "replace", + # f"{lease.ip}/{lease.subnet_mask}", + # "dev", + # lease.interface, + # ) async def unbound(self, lease: Lease): LOG.info('STUB: remove %s from %s', lease.ip, lease.interface) + # await ip( + # "addr", + # "del", + # f"{lease.ip}/{lease.subnet_mask}", + # "dev", + # lease.interface, + # ) diff --git a/pyroute2/dhcp/leases.py b/pyroute2/dhcp/leases.py index 87b5c1346..355e84066 100644 --- a/pyroute2/dhcp/leases.py +++ b/pyroute2/dhcp/leases.py @@ -26,6 +26,8 @@ class Lease(abc.ABC): ack: dhcp4msg # Name of the interface for which this lease was requested interface: str + # MAC address of the server that allocated the lease + server_mac: str # Timestamp of when this lease was obtained obtained: float = field(default_factory=_now) @@ -167,3 +169,6 @@ def load(cls, interface: str) -> 'Optional[JSONFileLease]': except FileNotFoundError: LOG.info('No existing lease at %s for %s', lease_path, interface) return None + except TypeError as err: + LOG.warning("Error loading lease: %s", err) + return None diff --git a/pyroute2/dhcp/messages.py b/pyroute2/dhcp/messages.py index 5994bda25..a6735bbd5 100644 --- a/pyroute2/dhcp/messages.py +++ b/pyroute2/dhcp/messages.py @@ -1,45 +1,151 @@ """Helper functions to build dhcp client messages.""" -from pyroute2.dhcp.constants import bootp, dhcp +from dataclasses import dataclass +from typing import Literal, Optional + +from pyroute2.dhcp import enums from pyroute2.dhcp.dhcp4msg import dhcp4msg +from pyroute2.dhcp.fsm import State +from pyroute2.dhcp.leases import Lease + + +@dataclass +class _DHCPMessage: + '''A DHCP message with some extra info from other layers.''' + + dhcp: dhcp4msg + eth_src: Optional[str] = None + eth_dst: str = 'ff:ff:ff:ff:ff:ff' + ip_src: str = '0.0.0.0' + ip_dst: str = '255.255.255.255' + sport: int = 68 + dport: int = 67 + + @property + def message_type(self) -> enums.dhcp.MessageType: + '''The DHCP message type (DISCOVER, REQUEST, ACK...)''' + return self.dhcp['options']['message_type'] + + +class SentDHCPMessage(_DHCPMessage): + '''A DHCP message to be sent to a server or broadcast.''' + + def __str__(self) -> str: + type_name = self.message_type.name + return f"{type_name} to {self.eth_dst}/{self.ip_dst}:{self.dport}" + + +class ReceivedDHCPMessage(_DHCPMessage): + '''A DHCP message received by the client.''' + + def __str__(self) -> str: + type_name = self.dhcp['options']['message_type'].name + return f"{type_name} from {self.eth_src}/{self.ip_src}:{self.sport}" -def discover(parameter_list: list[dhcp.Parameter]) -> dhcp4msg: - return dhcp4msg( - { - 'op': bootp.MessageType.BOOTREQUEST, - 'options': { - 'message_type': dhcp.MessageType.DISCOVER, - 'parameter_list': parameter_list, - }, - } +def discover(parameter_list: list[enums.dhcp.Parameter]) -> SentDHCPMessage: + # Default for SentDHCPMessage is broadcast which is what we want here + return SentDHCPMessage( + dhcp=dhcp4msg( + { + 'op': enums.bootp.MessageType.BOOTREQUEST, + 'options': { + 'message_type': enums.dhcp.MessageType.DISCOVER, + 'parameter_list': parameter_list, + }, + } + ) ) -def request( - requested_ip: str, server_id: str, parameter_list: list[dhcp.Parameter] -) -> dhcp4msg: - return dhcp4msg( - { - 'op': bootp.MessageType.BOOTREQUEST, - 'options': { - 'message_type': dhcp.MessageType.REQUEST, - 'requested_ip': requested_ip, - 'server_id': server_id, - 'parameter_list': parameter_list, - }, - } +def request_for_offer( + parameter_list: list[enums.dhcp.Parameter], offer: ReceivedDHCPMessage +) -> SentDHCPMessage: + '''Make a REQUEST message for a given OFFER. + + Since we don't have an IP yet, the message is always broadcast. + Contrary to other cases where an REQUEST is sent, the server_id DHCP option + is always set. + + See RFC 2131 section 4.3.2. + ''' + return SentDHCPMessage( + dhcp=dhcp4msg( + { + 'op': enums.bootp.MessageType.BOOTREQUEST, + 'options': { + 'message_type': enums.dhcp.MessageType.REQUEST, + 'requested_ip': offer.dhcp['yiaddr'], + 'server_id': offer.dhcp['options']['server_id'], + 'parameter_list': parameter_list, + }, + } + ) ) -def release(requested_ip: str, server_id: str) -> dhcp4msg: - return dhcp4msg( - { - 'op': bootp.MessageType.BOOTREQUEST, - 'options': { - 'message_type': dhcp.MessageType.RELEASE, - 'requested_ip': requested_ip, - 'server_id': server_id, - }, - } +def request_for_lease( + parameter_list: list[enums.dhcp.Parameter], + lease: Lease, + state: Literal[State.RENEWING, State.REBINDING, State.REBOOTING], +) -> SentDHCPMessage: + '''Make a REQUEST for an existing lease. + + This differs from REQUESTs in response to an OFFER in that the server_id + option is never set. + + When rebooting, the message is broadcast, and the requested_ip option is + set to the IP in the stored lease. The bootp client IP is left blank. + + When renewing, (i.e. T1 expires) the message is for the server that granted + the lease. The leases's IP is expected to be assigned to the client's + interface at this point. + + When rebinding (T2), the message is broadcast on the network. + + In both cases, the bootp client IP (ciaddr) is set to the leases's IP. + + See RFC 2131 section 4.3.6. + ''' + kwargs = { + 'dhcp': dhcp4msg( + { + 'op': enums.bootp.MessageType.BOOTREQUEST, + # TODO: broadcast flag + 'options': { + 'message_type': enums.dhcp.MessageType.REQUEST, + 'parameter_list': parameter_list, + }, + } + ) + } + if state == State.INIT_REBOOT: + kwargs['dhcp']['options']['requested_ip'] = lease.ip + else: + kwargs['dhcp']['ciaddr'] = lease.ip + if state == State.RENEWING: + kwargs['eth_dst'] = lease.server_mac + kwargs['ip_dst'] = lease.server_id + kwargs['ip_src'] = lease.ip + + return SentDHCPMessage(**kwargs) + + +def release(lease: Lease) -> SentDHCPMessage: + '''Make a RELEASE for an existing & active lease.''' + return SentDHCPMessage( + dhcp=dhcp4msg( + { + 'op': enums.bootp.MessageType.BOOTREQUEST, + 'options': { + 'message_type': enums.dhcp.MessageType.RELEASE, + 'requested_ip': lease.ip, + 'server_id': lease.server_id, + }, + } + ), + # RELEASEs are unicast (see rfc section 4.4.4) + eth_dst=lease.server_mac, + ip_dst=lease.server_id, + ip_src=lease.ip, ) diff --git a/tests/test_linux/test_raw/test_dhcp.py b/tests/test_linux/test_raw/test_dhcp.py index 42c8059a5..be912429a 100644 --- a/tests/test_linux/test_raw/test_dhcp.py +++ b/tests/test_linux/test_raw/test_dhcp.py @@ -11,7 +11,7 @@ from pyroute2.dhcp import fsm from pyroute2.dhcp.client import AsyncDHCPClient -from pyroute2.dhcp.constants import bootp, dhcp +from pyroute2.dhcp.enums import bootp, dhcp from pyroute2.dhcp.leases import JSONFileLease, JSONStdoutLease pytestmark = [require_root()]