diff --git a/pyroute2/dhcp/cli.py b/pyroute2/dhcp/cli.py new file mode 100644 index 000000000..86d3f73aa --- /dev/null +++ b/pyroute2/dhcp/cli.py @@ -0,0 +1,105 @@ +import asyncio +import logging +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from importlib import import_module +from typing import Any + +from pyroute2.dhcp.client import AsyncDHCPClient +from pyroute2.dhcp.hooks import ConfigureIP, Hook +from pyroute2.dhcp.leases import Lease + + +def importable(name: str) -> Any: + '''Imports anything by name. Used by the argument parser.''' + module_name, obj_name = name.rsplit('.', 1) + module = import_module(module_name) + return getattr(module, obj_name) + + +def get_psr() -> ArgumentParser: + psr = ArgumentParser( + description='pyroute2 DHCP client', + formatter_class=ArgumentDefaultsHelpFormatter, + ) + psr.add_argument( + 'interface', help='The interface to request an address for.' + ) + psr.add_argument( + '--lease-type', + help='Class to use for leases. ' + 'Must be a subclass of `pyroute2.dhcp.leases.Lease`.', + type=importable, + default='pyroute2.dhcp.leases.JSONFileLease', + metavar='dotted.name', + ) + psr.add_argument( + '--hook', + help='Hooks to load. ' + 'These are used to run async python code when, ' + 'for example, renewing or expiring a lease.', + nargs='+', + type=importable, + default=[ConfigureIP], + metavar='dotted.name', + ) + psr.add_argument( + '-x', + '--exit-on-lease', + help='Exit as soon as getting a lease.', + default=False, + action='store_true', + ) + psr.add_argument( + '--log-level', + help='Logging level to use.', + choices=('DEBUG', 'INFO', 'WARNING', 'ERROR'), + default='INFO', + ) + return psr + + +async def main(): + psr = get_psr() + args = psr.parse_args() + logging.basicConfig( + level=args.log_level, + format='%(asctime)s %(levelname)s [%(name)s:%(funcName)s] %(message)s', + ) + + if not issubclass(args.lease_type, Lease): + psr.error(f'{args.lease_type!r} must be a Lease subclass') + + # Check hooks are subclasses of Hook + for i in args.hook: + if not issubclass(i, Hook): + psr.error(f'{i!r} must be a Hook subclass') + + acli = AsyncDHCPClient( + interface=args.interface, + lease_type=args.lease_type, + # Instantiate hooks + hooks=[i() for i in args.hook], + ) + + # Open the socket, read existing lease, etc + async with acli: + # Bootstrap the client by sending a DISCOVER or a REQUEST + await acli.bootstrap() + if args.exit_on_lease: + # Wait until we're bound once, then exit + await acli.bound.wait() + else: + # Wait until the client is stopped otherwise + await acli._stopped.wait() + + +def run(): + # for the setup.cfg entrypoint + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + run() diff --git a/pyroute2/dhcp/client.py b/pyroute2/dhcp/client.py index 7b62bd813..4bb37e30a 100644 --- a/pyroute2/dhcp/client.py +++ b/pyroute2/dhcp/client.py @@ -1,82 +1,261 @@ -import json -import select -import sys - -from pyroute2.dhcp import ( - BOOTREQUEST, - DHCPACK, - DHCPDISCOVER, - DHCPOFFER, - DHCPREQUEST, -) +import asyncio +from logging import getLogger +from typing import ClassVar, Iterable + +from pyroute2.dhcp import fsm, messages +from pyroute2.dhcp.constants import dhcp from pyroute2.dhcp.dhcp4msg import dhcp4msg -from pyroute2.dhcp.dhcp4socket import DHCP4Socket - - -def req(s, poll, msg, expect): - do_req = True - xid = None - - while True: - # get transaction id - if do_req: - xid = s.put(msg)['xid'] - # wait for response - events = poll.poll(2) - for fd, event in events: - response = s.get() - if response['xid'] != xid: - do_req = False - continue - if response['options']['message_type'] != expect: - raise Exception("DHCP protocol error") - return response - do_req = True - - -def action(ifname): - s = DHCP4Socket(ifname) - poll = select.poll() - poll.register(s, select.POLLIN | select.POLLPRI) - - # DISCOVER - discover = dhcp4msg( - { - 'op': BOOTREQUEST, - 'chaddr': s.l2addr, - 'options': { - 'message_type': DHCPDISCOVER, - 'parameter_list': [1, 3, 6, 12, 15, 28], - }, - } - ) - reply = req(s, poll, discover, expect=DHCPOFFER) - - # REQUEST - request = dhcp4msg( - { - 'op': BOOTREQUEST, - 'chaddr': s.l2addr, - 'options': { - 'message_type': DHCPREQUEST, - 'requested_ip': reply['yiaddr'], - 'server_id': reply['options']['server_id'], - 'parameter_list': [1, 3, 6, 12, 15, 28], - }, - } - ) - reply = req(s, poll, request, expect=DHCPACK) - s.close() - return reply +from pyroute2.dhcp.dhcp4socket import AsyncDHCP4Socket +from pyroute2.dhcp.hooks import Hook +from pyroute2.dhcp.leases import JSONFileLease, Lease +from pyroute2.dhcp.timers import Timers + +LOG = getLogger(__name__) + + +class AsyncDHCPClient: + '''A simple async DHCP client based on pyroute2.''' + + DEFAULT_PARAMETERS: ClassVar[tuple[dhcp.Parameter]] = [ + dhcp.Parameter.SUBNET_MASK, + dhcp.Parameter.ROUTER, + dhcp.Parameter.DOMAIN_NAME_SERVER, + dhcp.Parameter.DOMAIN_NAME, + ] + + def __init__( + self, + interface: str, + lease_type: type[Lease] = JSONFileLease, + hooks: Iterable[Hook] = (), + requested_parameters: Iterable[dhcp.Parameter] = (), + ): + self.interface = interface + self.lease_type = lease_type + self.hooks = hooks + self._sock: AsyncDHCP4Socket = AsyncDHCP4Socket(self.interface) + self._state: fsm.fsm.State | None = None + self._lease: Lease | None = None + self.requested_parameters = list( + requested_parameters + if requested_parameters + else self.DEFAULT_PARAMETERS + ) + self._stopped = asyncio.Event() + self._sendq: asyncio.Queue[dhcp4msg | None] = asyncio.Queue() + self._send_task: asyncio.Task | None = None + self.bound = asyncio.Event() + self.timers = Timers() + + async def _renew(self): + '''Called when the renewal timer, as defined in the lease, expires.''' + assert self.lease, 'cannot renew without an existing lease' + LOG.info('Renewal timer expired') + # TODO: send only to server that gave us the current lease + self.timers._reset_timer('renewal') + await self.transition( + to=fsm.State.RENEWING, + send=messages.request( + requested_ip=self.lease.ip, + server_id=self.lease.server_id, + parameter_list=self.requested_parameters, + ), + ) + + async def _rebind(self): + assert self.lease, 'cannot rebind without an existing lease' + LOG.info('Rebinding timer expired') + self.timers._reset_timer('rebinding') + await self.transition( + to=fsm.State.REBINDING, + send=messages.request( + requested_ip=self.lease.ip, + server_id=self.lease.server_id, + parameter_list=self.requested_parameters, + ), + ) + + async def _expire_lease(self): + LOG.info('Lease expired') + self.timers._reset_timer('expiration') + self.state = fsm.State.INIT + # FIXME: call hooks in a non blocking way (maybe call_soon ?) + for i in self.hooks: + await i.unbound(self.lease) + self._lease = None + await self.bootstrap() + + @property + def lease(self) -> Lease | None: + return self._lease + @lease.setter + def lease(self, value: Lease): + '''Set a fresh lease; only call this when a server grants one.''' + self._lease = value + self.timers.arm( + lease=self.lease, + renewal=self._renew, + rebinding=self._rebind, + expiration=self._expire_lease, + ) + self.lease.dump() + + @property + def state(self) -> fsm.State | None: + return self._state + + @state.setter + def state(self, value: fsm.State | None): + if value and self._state and value not in fsm.TRANSITIONS[self._state]: + raise ValueError( + f'Cannot transition from {self._state} to {value}' + ) + LOG.info('%s -> %s', self.state, value) + self._state = value + + async def _run(self): + packet_to_send: dhcp4msg | None = None + # TODO: refactor, this is even more awkward than select() + while True: + wait_for_stopped = asyncio.Task( + self._stopped.wait(), name='wait until stopped' + ) + wait_for_received_packet = asyncio.Task( + self._sock.get(), name='wait for received packet' + ) + wait_for_packet_to_send = asyncio.Task( + self._sendq.get(), name='wait for packet to send' + ) + tasks = ( + wait_for_stopped, + wait_for_received_packet, + wait_for_packet_to_send, + ) + done, pending = await asyncio.wait( + tasks, + timeout=5, # TODO interval + return_when=asyncio.FIRST_COMPLETED, + ) + for i in pending: + i.cancel() + + if wait_for_packet_to_send in done: + packet_to_send = wait_for_packet_to_send.result() + + if wait_for_received_packet in done: + received_packet = wait_for_received_packet.result() + msg_type = dhcp.MessageType( + received_packet['options']['message_type'] + ) + LOG.info('Received %s', msg_type.name) + if received_packet.get('xid') != self.xid: + LOG.error('Missing or wrong xid, discarding') + else: + handler_name = f'{msg_type.name.lower()}_received' + handler = getattr(self, handler_name, None) + if not handler: + LOG.debug('%r messages are not handled', msg_type.name) + else: + if await handler(received_packet): + packet_to_send = None + + if packet_to_send: + packet_to_send['xid'] = self.xid + LOG.debug( + 'Sending %s', + packet_to_send['options']['message_type'].name, + ) + await self._sock.put(packet_to_send) + + if wait_for_stopped in done and wait_for_stopped.result() is True: + return + + async def transition(self, to: fsm.State, send: dhcp4msg | None = None): + self.state = to + await self._sendq.put(send) + + @fsm.state_guard(fsm.State.INIT, fsm.State.INIT_REBOOT) + async def bootstrap(self): + '''Send a `DISCOVER` or a `REQUEST`, + + depending on whether we're initializing or rebooting. + ''' + match self.state: + case fsm.State.INIT: + # send discover + await self.transition( + to=fsm.State.SELECTING, + send=messages.discover( + parameter_list=self.requested_parameters + ), + ) + case fsm.State.INIT_REBOOT: + assert self.lease, 'cannot init_reboot without a lease' + # send request for lease + await self.transition( + to=fsm.State.REBOOTING, + send=messages.request( + requested_ip=self.lease.ip, + server_id=self.lease.server_id, + parameter_list=self.requested_parameters, + ), + ) + + @fsm.state_guard( + fsm.State.REQUESTING, + fsm.State.REBOOTING, + fsm.State.REBINDING, + fsm.State.RENEWING, + ) + async def ack_received(self, pkt: dhcp4msg): + self.lease = self.lease_type(ack=pkt, interface=self.interface) + LOG.info( + 'Got lease for %s from %s', self.lease.ip, self.lease.server_id + ) + await self.transition(to=fsm.State.BOUND) + self.bound.set() + # FIXME: call hooks in a non blocking way (maybe call_soon ?) + for i in self.hooks: + await i.bound(self.lease) + return True -def run(): - if len(sys.argv) > 1: - ifname = sys.argv[1] - else: - ifname = 'eth0' - print(json.dumps(action(ifname), indent=4)) + @fsm.state_guard(fsm.State.SELECTING) + async def offer_received(self, pkt: dhcp4msg): + await self.transition( + to=fsm.State.REQUESTING, + send=messages.request( + requested_ip=pkt['yiaddr'], + server_id=pkt['options']['server_id'], + parameter_list=self.requested_parameters, + ), + ) + return True + async def __aenter__(self): + self._lease = self.lease_type.load(self.interface) + if self.lease: + # TODO check lease is not expired + self.state = fsm.State.INIT_REBOOT + else: + LOG.debug('No current lease') + self.state = fsm.State.INIT + await self._sock.__aenter__() + self._send_task = asyncio.Task(self._run()) + self.xid = self._sock.xid_pool.alloc() + return self -if __name__ == '__main__': - run() + async def __aexit__(self, *_): + self.timers.cancel() + if self.lease: + await self._sendq.put( + messages.release( + requested_ip=self.lease.ip, server_id=self.lease.server_id + ) + ) + self._stopped.set() + await self._send_task + await self._sock.__aexit__() + self.xid = None + self.state = None + self.bound.clear() diff --git a/pyroute2/dhcp/constants/bootp.py b/pyroute2/dhcp/constants/bootp.py new file mode 100644 index 000000000..f8a0c14fd --- /dev/null +++ b/pyroute2/dhcp/constants/bootp.py @@ -0,0 +1,35 @@ +from enum import IntEnum + + +class MessageType(IntEnum): + BOOTREQUEST = 1 # Client to server + BOOTREPLY = 2 # Server to client + + +class HardwareType(IntEnum): + ETHERNET = 1 # Ethernet (10Mb) + EXPERIMENTAL_ETHERNET = 2 + AMATEUR_RADIO = 3 + TOKEN_RING = 4 + FDDI = 8 + ATM = 19 + WIRELESS_IEEE_802_11 = 20 + + +class Flag(IntEnum): # TODO: use enum.Flag + UNICAST = 0x0000 # Unicast response requested + BROADCAST = 0x8000 # Broadcast response requested + + +class Option(IntEnum): + PAD = 0 # Padding (no operation) + SUBNET_MASK = 1 # Subnet mask + ROUTER = 3 # Router address + DNS_SERVER = 6 # Domain name server + HOSTNAME = 12 # Hostname + BOOTFILE_SIZE = 13 # Boot file size + DOMAIN_NAME = 15 # Domain name + IP_ADDRESS_LEASE_TIME = 51 # DHCP lease time + MESSAGE_TYPE = 53 # DHCP message type (extended from BOOTP) + SERVER_IDENTIFIER = 54 # DHCP server identifier + END = 255 # End of options diff --git a/pyroute2/dhcp/constants/dhcp.py b/pyroute2/dhcp/constants/dhcp.py new file mode 100644 index 000000000..e5f8dadad --- /dev/null +++ b/pyroute2/dhcp/constants/dhcp.py @@ -0,0 +1,130 @@ +from enum import IntEnum + + +class MessageType(IntEnum): + DISCOVER = 1 + OFFER = 2 + REQUEST = 3 + DECLINE = 4 + ACK = 5 + NAK = 6 + RELEASE = 7 + INFORM = 8 + + +class Option(IntEnum): + SUBNET_MASK = 1 + TIME_OFFSET = 2 + ROUTER = 3 + DNS_SERVER = 6 + HOST_NAME = 12 + DOMAIN_NAME = 15 + BROADCAST_ADDRESS = 28 + REQUESTED_IP_ADDRESS = 50 + IP_ADDRESS_LEASE_TIME = 51 + DHCP_MESSAGE_TYPE = 53 + SERVER_IDENTIFIER = 54 + PARAMETER_REQUEST_LIST = 55 + MAX_MSG_SIZE = 57 + RENEWAL_TIME = 58 + REBINDING_TIME = 59 + VENDOR_CLASS_IDENTIFIER = 60 + CLIENT_IDENTIFIER = 61 + PADDING = 255 + + +class Parameter(IntEnum): + SUBNET_MASK = 1 # Subnet Mask + TIME_OFFSET = 2 # Time Offset + ROUTER = 3 # Router + TIME_SERVER = 4 # Time Server + NAME_SERVER = 5 # Name Server + DOMAIN_NAME_SERVER = 6 # Domain Name Server (DNS) + LOG_SERVER = 7 # Log Server + COOKIE_SERVER = 8 # Cookie Server + LPR_SERVER = 9 # Line Printer Server + IMPRESS_SERVER = 10 # Impress Server + RESOURCE_LOCATION_SERVER = 11 # Resource Location Server + HOST_NAME = 12 # Host Name + BOOT_FILE_SIZE = 13 # Boot File Size + MERIT_DUMP_FILE = 14 # Merit Dump File + DOMAIN_NAME = 15 # Domain Name + SWAP_SERVER = 16 # Swap Server + ROOT_PATH = 17 # Root Path + EXTENSIONS_PATH = 18 # Extensions Path + IP_FORWARDING = 19 # IP Forwarding Enable/Disable + NON_LOCAL_SOURCE_ROUTING = 20 # Non-Local Source Routing Enable/Disable + POLICY_FILTER = 21 # Policy Filter + MAX_DATAGRAM_REASSEMBLY = 22 # Maximum Datagram Reassembly Size + DEFAULT_TTL = 23 # Default IP Time-to-Live + PATH_MTU_AGING_TIMEOUT = 24 # Path MTU Aging Timeout + PATH_MTU_PLATEAU_TABLE = 25 # Path MTU Plateau Table + INTERFACE_MTU = 26 # Interface MTU + ALL_SUBNETS_LOCAL = 27 # All Subnets Are Local + BROADCAST_ADDRESS = 28 # Broadcast Address + PERFORM_MASK_DISCOVERY = 29 # Perform Mask Discovery + MASK_SUPPLIER = 30 # Mask Supplier + PERFORM_ROUTER_DISCOVERY = 31 # Perform Router Discovery + ROUTER_SOLICITATION_ADDRESS = 32 # Router Solicitation Address + STATIC_ROUTE = 33 # Static Route + TRAILER_ENCAPSULATION = 34 # Trailer Encapsulation + ARP_CACHE_TIMEOUT = 35 # ARP Cache Timeout + ETHERNET_ENCAPSULATION = 36 # Ethernet Encapsulation + TCP_DEFAULT_TTL = 37 # TCP Default TTL + TCP_KEEPALIVE_INTERVAL = 38 # TCP Keepalive Interval + TCP_KEEPALIVE_GARBAGE = 39 # TCP Keepalive Garbage + NIS_DOMAIN = 40 # Network Information Service Domain + NIS_SERVERS = 41 # NIS Servers + NTP_SERVERS = 42 # NTP Servers + VENDOR_SPECIFIC_INFORMATION = 43 # Vendor Specific Information + NETBIOS_NAME_SERVER = 44 # NetBIOS over TCP/IP Name Server + NETBIOS_DDG_SERVER = 45 # NetBIOS Datagram Distribution Server + NETBIOS_NODE_TYPE = 46 # NetBIOS Node Type + NETBIOS_SCOPE = 47 # NetBIOS Scope + X_WINDOW_FONT_SERVER = 48 # X Window System Font Server + X_WINDOW_DISPLAY_MANAGER = 49 # X Window System Display Manager + REQUESTED_IP_ADDRESS = 50 # Requested IP Address + IP_ADDRESS_LEASE_TIME = 51 # IP Address Lease Time + OPTION_OVERLOAD = 52 # Option Overload + DHCP_MESSAGE_TYPE = 53 # DHCP Message Type + SERVER_IDENTIFIER = 54 # Server Identifier + PARAMETER_REQUEST_LIST = 55 # Parameter Request List + MESSAGE = 56 # Message + MAX_DHCP_MESSAGE_SIZE = 57 # Maximum DHCP Message Size + RENEWAL_TIME_VALUE = 58 # Renewal (T1) Time Value + REBINDING_TIME_VALUE = 59 # Rebinding (T2) Time Value + CLASS_IDENTIFIER = 60 # Vendor Class Identifier + CLIENT_IDENTIFIER = 61 # Client Identifier + NETWARE_IP_DOMAIN = 62 # NetWare/IP Domain Name + NETWARE_IP_OPTION = 63 # NetWare/IP Option + NIS_PLUS_DOMAIN = 64 # NIS+ Domain + NIS_PLUS_SERVERS = 65 # NIS+ Servers + TFTP_SERVER_NAME = 66 # TFTP Server Name + BOOTFILE_NAME = 67 # Bootfile Name + MOBILE_IP_HOME_AGENT = 68 # Mobile IP Home Agent + SMTP_SERVER = 69 # Simple Mail Transport Protocol Server + POP3_SERVER = 70 # Post Office Protocol v3 Server + NNTP_SERVER = 71 # Network News Transport Protocol Server + DEFAULT_WWW_SERVER = 72 # Default World Wide Web Server + DEFAULT_FINGER_SERVER = 73 # Default Finger Server + DEFAULT_IRC_SERVER = 74 # Default Internet Relay Chat Server + STREETTALK_SERVER = 75 # StreetTalk Server + STDA_SERVER = 76 # StreetTalk Directory Assistance Server + USER_CLASS_INFORMATION = 77 # User Class Information + SLP_DIRECTORY_AGENT = 78 # SLP Directory Agent + SLP_SERVICE_SCOPE = 79 # SLP Service Scope + RAPID_COMMIT = 80 # Rapid Commit + CLIENT_FQDN = 81 # Fully Qualified Domain Name + RELAY_AGENT_INFORMATION = 82 # Relay Agent Information + INTERNET_STORAGE_NAME_SERVICE = 83 # ISNS + NDS_SERVERS = 85 # Novell Directory Services Servers + NDS_TREE_NAME = 86 # Novell Directory Services Tree Name + NDS_CONTEXT = 87 # Novell Directory Services Context + BCMCS_CONTROLLER_DOMAIN = 88 # BCMCS Controller Domain Name List + AUTHENTICATION = 90 + CLIENT_SYSTEM_ARCHITECTURE_TYPE = 93 + CLIENT_NETWORK_INTERFACE_IDENTIFIER = 94 + CLASSLESS_STATIC_ROUTE = 121 + DOMAIN_SEARCH = 119 + PRIVATE_CLASSIC_ROUTE_MS = 249 + PRIVATE_PROXY_AUTODISCOVERY = 252 diff --git a/pyroute2/dhcp/dhcp4socket.py b/pyroute2/dhcp/dhcp4socket.py index 9b2286b08..82e36e57d 100644 --- a/pyroute2/dhcp/dhcp4socket.py +++ b/pyroute2/dhcp/dhcp4socket.py @@ -4,13 +4,22 @@ ''' +import asyncio +import logging +import socket + from pyroute2.common import AddrPool from pyroute2.dhcp.dhcp4msg import dhcp4msg -from pyroute2.ext.rawsocket import RawSocket +from pyroute2.ext.rawsocket import AsyncRawSocket from pyroute2.protocols import ethmsg, ip4msg, udp4_pseudo_header, udpmsg +LOG = logging.getLogger(__name__) + +UDP_HEADER_SIZE = 8 +IPV4_HEADER_SIZE = 20 -def listen_udp_port(port=68): + +def listen_udp_port(port: int = 68) -> list[list[int]]: # pre-scripted BPF code that matches UDP port bpf_code = [ [40, 0, 0, 12], @@ -28,7 +37,7 @@ def listen_udp_port(port=68): return bpf_code -class DHCP4Socket(RawSocket): +class AsyncDHCP4Socket(AsyncRawSocket): ''' Parameters: @@ -45,31 +54,33 @@ class DHCP4Socket(RawSocket): not provided, DHCP4Socket generates it for outgoing messages. ''' - def __init__(self, ifname, port=68): - RawSocket.__init__(self, ifname, listen_udp_port(port)) + def __init__(self, ifname, port: int = 68): + AsyncRawSocket.__init__(self, ifname, listen_udp_port(port)) self.port = port # Create xid pool # # Every allocated xid will be released automatically after 1024 # alloc() calls, there is no need to call free(). Minimal xid == 16 - self.xid_pool = AddrPool(minaddr=16, release=1024) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def put(self, msg=None, dport=67): + self.xid_pool = AddrPool( + minaddr=16, release=1024 + ) # TODO : maybe it should be in the client and not here ? + self.aio_loop = asyncio.get_running_loop() + + async def put( + self, + msg: dhcp4msg, + eth_dst: str = 'ff:ff:ff:ff:ff:ff', + ip_dst: str = '255.255.255.255', + dport: int = 67, + ) -> dhcp4msg: ''' Put DHCP message. Parameters: * msg -- dhcp4msg instance + * eth_dst -- dest MAC address + * ip_dst -- dest IP address * dport -- DHCP server port - If `msg` is not provided, it is constructed as default - BOOTREQUEST + DHCPDISCOVER. - Examples:: sock.put(dhcp4msg({'op': BOOTREQUEST, @@ -79,53 +90,67 @@ def put(self, msg=None, dport=67): 'requested_ip': '172.16.101.2', 'server_id': '172.16.101.1'}})) - The method returns dhcp4msg that was sent, so one can get from - there `xid` (transaction id) and other details. + The method returns the sent dhcp4msg, so one can get from + there the `xid` (transaction id) and other details. ''' # DHCP layer - dhcp = msg or dhcp4msg({'chaddr': self.l2addr}) + dhcp = msg # dhcp transaction id if dhcp['xid'] is None: dhcp['xid'] = self.xid_pool.alloc() + # auto add src addr + if dhcp['chaddr'] is None: + dhcp['chaddr'] = self.l2addr + data = dhcp.encode().buf + dhcp_payload_size = len(data) # UDP layer udp = udpmsg( - {'sport': self.port, 'dport': dport, 'len': 8 + len(data)} + { + 'sport': self.port, + 'dport': dport, + 'len': UDP_HEADER_SIZE + dhcp_payload_size, + } ) + # Pseudo UDP header, only for checksum purposes udph = udp4_pseudo_header( - {'dst': '255.255.255.255', 'len': 8 + len(data)} + {'dst': ip_dst, 'len': UDP_HEADER_SIZE + dhcp_payload_size} ) udp['csum'] = self.csum(udph.encode().buf + udp.encode().buf + data) udp.reset() # IPv4 layer ip4 = ip4msg( - {'len': 20 + 8 + len(data), 'proto': 17, 'dst': '255.255.255.255'} + { + 'len': IPV4_HEADER_SIZE + UDP_HEADER_SIZE + dhcp_payload_size, + 'proto': socket.IPPROTO_UDP, + 'dst': ip_dst, + } ) ip4['csum'] = self.csum(ip4.encode().buf) ip4.reset() # MAC layer eth = ethmsg( - {'dst': 'ff:ff:ff:ff:ff:ff', 'src': self.l2addr, 'type': 0x800} + {'dst': eth_dst, 'src': self.l2addr, 'type': socket.ETHERTYPE_IP} ) data = eth.encode().buf + ip4.encode().buf + udp.encode().buf + data - self.send(data) + await self.aio_loop.sock_sendall(self, data) dhcp.reset() return dhcp - def get(self): + async def get(self) -> dhcp4msg: ''' Get the next incoming packet from the socket and try to decode it as IPv4 DHCP. No analysis is done here, only MAC/IPv4/UDP headers are stripped out, and the rest is interpreted as DHCP. ''' - (data, addr) = self.recvfrom(4096) + data, _ = await self.aio_loop.sock_recvfrom(self, 4096) eth = ethmsg(buf=data).decode() ip4 = ip4msg(buf=data, offset=eth.offset).decode() udp = udpmsg(buf=data, offset=ip4.offset).decode() diff --git a/pyroute2/dhcp/fsm.py b/pyroute2/dhcp/fsm.py new file mode 100644 index 000000000..80c45aaca --- /dev/null +++ b/pyroute2/dhcp/fsm.py @@ -0,0 +1,62 @@ +'''DHCP client state machine helpers.''' + +from enum import StrEnum, auto +from logging import getLogger +from typing import TYPE_CHECKING, Final + +if TYPE_CHECKING: + from .client import AsyncDHCPClient + + +LOG = getLogger(__name__) + + +class State(StrEnum): + '''DHCP client states. + + see + http://www.tcpipguide.com/free/t_DHCPGeneralOperationandClientFiniteStateMachine.htm + ''' + + INIT = auto() + INIT_REBOOT = auto() + REBOOTING = auto() + REQUESTING = auto() + SELECTING = auto() + BOUND = auto() + RENEWING = auto() + REBINDING = auto() + + +# allowed transitions between states +TRANSITIONS: Final[dict[State, set[State]]] = { + State.INIT_REBOOT: {State.REBOOTING}, + State.REBOOTING: {State.INIT, State.BOUND}, + State.INIT: {State.SELECTING}, + State.SELECTING: {State.REQUESTING, State.INIT}, + State.REQUESTING: {State.BOUND, State.INIT}, + State.BOUND: {State.INIT, State.RENEWING, State.REBINDING}, + State.RENEWING: {State.BOUND, State.INIT, State.REBINDING}, + State.REBINDING: {State.BOUND, State.INIT}, +} + + +def state_guard(*states: State): + '''Decorator that prevents a method from running + + if the associated instance is not in one of the given States.''' + + def decorator(meth): + async def wrapper(self: 'AsyncDHCPClient', *args, **kwargs): + if self.state not in states: + LOG.debug( + 'Ignoring call to %r in %s state', + meth.__name__, + self.state, + ) + return False + return await meth(self, *args, **kwargs) + + return wrapper + + return decorator diff --git a/pyroute2/dhcp/hooks.py b/pyroute2/dhcp/hooks.py new file mode 100644 index 000000000..cd5d50f9d --- /dev/null +++ b/pyroute2/dhcp/hooks.py @@ -0,0 +1,30 @@ +'''Hooks called by the DHCP client when bound, a leases expires, etc.''' + +from logging import getLogger + +from pyroute2.dhcp.leases import Lease + +LOG = getLogger(__name__) + + +class Hook: + '''Base class for pyroute2 dhcp client hooks.''' + + def __init__(self, **settings): + pass + + async def bound(self, lease: Lease): + '''Called when the client gets a lease.''' + pass + + async def unbound(self, lease: Lease): + '''Called when a leases expires.''' + pass + + +class ConfigureIP(Hook): + async def bound(self, lease: Lease): + LOG.info('STUB: add %s to %s', lease.ip, lease.interface) + + async def unbound(self, lease: Lease): + LOG.info('STUB: remove %s from %s', lease.ip, lease.interface) diff --git a/pyroute2/dhcp/leases.py b/pyroute2/dhcp/leases.py new file mode 100644 index 000000000..0ebc9bd31 --- /dev/null +++ b/pyroute2/dhcp/leases.py @@ -0,0 +1,154 @@ +'''Lease classes used by the dhcp client.''' + +import abc +import json +from dataclasses import asdict, dataclass, field +from datetime import datetime +from logging import getLogger +from pathlib import Path +from typing import Self + +from pyroute2.dhcp.dhcp4msg import dhcp4msg + +LOG = getLogger(__name__) + + +def _now() -> float: + '''The current timestamp.''' + return datetime.now().timestamp() + + +@dataclass +class Lease(abc.ABC): + '''Represents a lease obtained through DHCP.''' + + # The DHCP ack sent by the server which allocated this lease + ack: dhcp4msg + # Name of the interface for which this lease was requested + interface: str + # Timestamp of when this lease was obtained + obtained: float = field(default_factory=_now) + + def _seconds_til_timer(self, timer_name: str) -> float | None: + '''The number of seconds to wait until the given timer expires. + + The value is fetched from options as `{timer_name}_time`. + (lease -> lease_time, renewal -> renewal_time, ...) + ''' + try: + delta: int = self.ack['options'][f'{timer_name}_time'] + return self.obtained + delta - _now() + except KeyError: + return None + + @property + def expiration_in(self) -> float | None: + return self._seconds_til_timer('lease') + + @property + def renewal_in(self) -> float | None: + '''The amount of seconds before we have to renew the lease. + + Can be negative if it's past due, + or `None` if the server didn't give a renewal time. + ''' + return self._seconds_til_timer('renewal') + + @property + def rebinding_in(self) -> float | None: + '''The amount of seconds before we have to rebind the lease. + + Can be negative if it's past due, + or `None` if the server didn't give a rebinding time. + ''' + return self._seconds_til_timer('rebinding') + + @property + def ip(self) -> str: + '''The IP address assigned to the client.''' + return self.ack['yiaddr'] + + @property + def subnet_mask(self) -> str: + '''The subnet mask assigned to the client.''' + return self.ack['options']['subnet_mask'] + + @property + def routers(self) -> str: + return self.ack['options']['router'] + + @property + def name_servers(self) -> str: # XXX: list ? + return self.ack['options']['name_server'] + + @property + def server_id(self) -> str: + '''The IP address of the server which allocated this lease.''' + return self.ack['options']['server_id'] + + @abc.abstractmethod + def dump(self) -> None: + '''Write a lease, i.e. to disk or to stdout.''' + pass + + @classmethod + @abc.abstractmethod + def load(cls, interface: str) -> 'Self | None': + '''Load an existing lease for an interface, if it exists.''' + pass + + +class JSONStdoutLease(Lease): + '''Just prints the lease to stdout when the client gets a new one.''' + + def dump(self) -> None: + """Writes the lease as json to stdout.""" + print(json.dumps(asdict(self), indent=2)) + + @classmethod + def load(cls, interface: str) -> None: + '''Does not do anything.''' + return None + + +class JSONFileLease(Lease): + '''Write and load the lease from a JSON file in the working directory.''' + + @classmethod + def _get_lease_dir(cls) -> Path: + '''Where to store the lease file, i.e. the working directory.''' + return Path.cwd() + + @classmethod + def _get_path(cls, interface: str) -> Path: + '''The lease file, named after the interface.''' + return ( + cls._get_lease_dir().joinpath(interface).with_suffix('.lease.json') + ) + + def dump(self) -> None: + '''Dump the lease to a file. + + The lease file is named after the interface + and written in the working directory. + ''' + lease_path = self._get_path(self.interface) + LOG.info('Writing lease for %s to %s', self.interface, lease_path) + with lease_path.open('wt') as lf: + json.dump(asdict(self), lf) + + @classmethod + def load(cls, interface: str) -> 'JSONFileLease | None': + '''Load the lease from a file. + + The lease file is named after the interface + and read from the working directory. + ''' + lease_path = cls._get_path(interface) + try: + with lease_path.open('rt') as lf: + LOG.info('Loading lease for %s from %s', interface, lease_path) + return cls(**json.load(lf)) + except FileNotFoundError: + LOG.info('No existing lease at %s for %s', lease_path, interface) + return None diff --git a/pyroute2/dhcp/messages.py b/pyroute2/dhcp/messages.py new file mode 100644 index 000000000..5994bda25 --- /dev/null +++ b/pyroute2/dhcp/messages.py @@ -0,0 +1,45 @@ +"""Helper functions to build dhcp client messages.""" + +from pyroute2.dhcp.constants import bootp, dhcp +from pyroute2.dhcp.dhcp4msg import dhcp4msg + + +def discover(parameter_list: list[dhcp.Parameter]) -> dhcp4msg: + return dhcp4msg( + { + 'op': bootp.MessageType.BOOTREQUEST, + 'options': { + 'message_type': dhcp.MessageType.DISCOVER, + 'parameter_list': parameter_list, + }, + } + ) + + +def request( + requested_ip: str, server_id: str, parameter_list: list[dhcp.Parameter] +) -> dhcp4msg: + return dhcp4msg( + { + 'op': bootp.MessageType.BOOTREQUEST, + 'options': { + 'message_type': dhcp.MessageType.REQUEST, + 'requested_ip': requested_ip, + 'server_id': server_id, + 'parameter_list': parameter_list, + }, + } + ) + + +def release(requested_ip: str, server_id: str) -> dhcp4msg: + return dhcp4msg( + { + 'op': bootp.MessageType.BOOTREQUEST, + 'options': { + 'message_type': dhcp.MessageType.RELEASE, + 'requested_ip': requested_ip, + 'server_id': server_id, + }, + } + ) diff --git a/pyroute2/dhcp/timers.py b/pyroute2/dhcp/timers.py new file mode 100644 index 000000000..65d62ee85 --- /dev/null +++ b/pyroute2/dhcp/timers.py @@ -0,0 +1,64 @@ +'''Timers to manage lease rebinding, renewal & expiration.''' + +import asyncio +import dataclasses +from logging import getLogger +from typing import Awaitable, Callable + +from pyroute2.dhcp.leases import Lease + +LOG = getLogger(__name__) + + +@dataclasses.dataclass +class Timers: + '''Manage callbacks associated with DHCP leases.''' + + renewal: asyncio.TimerHandle | None = None + rebinding: asyncio.TimerHandle | None = None + expiration: asyncio.TimerHandle | None = None + + def cancel(self): + '''Cancel all current timers.''' + for timer_name in ('renewal', 'rebinding', 'expiration'): + self._reset_timer(timer_name) + + def _reset_timer(self, timer_name: str): + '''Cancel a timer and set it to None.''' + if timer := getattr(self, timer_name): + timer: asyncio.TimerHandle + if not timer.cancelled(): + # FIXME: how do we know a timer wasn't cancelled ? + # this causes spurious logs + LOG.debug('Canceling %s timer', timer_name) + timer.cancel() + setattr(self, timer_name, None) + + def arm(self, lease: Lease, **callbacks: Callable[[], Awaitable[None]]): + '''Reset & arm timers from a `Lease`. + + `callbacks` must be async callables with no arguments + that will be called when the associated timer expires. + ''' + self.cancel() + loop = asyncio.get_running_loop() + + for timer_name, async_callback in callbacks.items(): + self._reset_timer(timer_name) + lease_time = getattr(lease, f'{timer_name}_in') + if not lease_time: + LOG.debug('Lease does not set a %s time', timer_name) + continue + if lease_time < 0.0: + LOG.debug('Lease %s is in the past', timer_name) + continue + LOG.info('Scheduling lease %s in %.2fs', timer_name, lease_time) + '''FIXME: calling async_callback() causes a + "coroutine was never awaited" warning. + But deferring its call in a lambda causes the callback to be + the same for all timers, since we're in a loop. + ''' + timer = loop.call_later( + lease_time, asyncio.create_task, async_callback() + ) + setattr(self, timer_name, timer) diff --git a/pyroute2/ext/rawsocket.py b/pyroute2/ext/rawsocket.py index 902c24147..bf08b9305 100644 --- a/pyroute2/ext/rawsocket.py +++ b/pyroute2/ext/rawsocket.py @@ -11,7 +11,7 @@ ) from socket import AF_PACKET, SOCK_RAW, SOL_SOCKET, errno, error, htons, socket -from pyroute2.iproute.linux import IPRoute +from pyroute2.iproute.linux import AsyncIPRoute ETH_P_ALL = 3 SO_ATTACH_FILTER = 26 @@ -34,14 +34,14 @@ class sock_fprog(Structure): _fields_ = [('len', c_ushort), ('filter', c_void_p)] -def compile_bpf(code): +def compile_bpf(code: list[int]): ProgramType = sock_filter * len(code) program = ProgramType(*[sock_filter(*line) for line in code]) sfp = sock_fprog(len(code), addressof(program[0])) return string_at(addressof(sfp), sizeof(sfp)), program -class RawSocket(socket): +class AsyncRawSocket(socket): ''' This raw socket binds to an interface and optionally installs a BPF filter. @@ -55,28 +55,37 @@ class RawSocket(socket): fprog = None - def __init__(self, ifname, bpf=None): - self.ifname = ifname + async def __aexit__(self, *_): + self.close() + + async def __aenter__(self): # lookup the interface details - with IPRoute() as ip: - for link in ip.get_links(): - if link.get_attr('IFLA_IFNAME') == ifname: + async with AsyncIPRoute() as ip: + async for link in await ip.get_links(): + if link.get_attr('IFLA_IFNAME') == self.ifname: break else: raise IOError(2, 'Link not found') - self.l2addr = link.get_attr('IFLA_ADDRESS') - self.ifindex = link['index'] + self.l2addr: str = link.get_attr('IFLA_ADDRESS') + self.ifindex: int = link['index'] # bring up the socket socket.__init__(self, AF_PACKET, SOCK_RAW, htons(ETH_P_ALL)) + socket.setblocking(self, False) socket.bind(self, (self.ifname, ETH_P_ALL)) - if bpf: + if self.bpf: self.clear_buffer() - fstring, self.fprog = compile_bpf(bpf) + fstring, self.fprog = compile_bpf(self.bpf) socket.setsockopt(self, SOL_SOCKET, SO_ATTACH_FILTER, fstring) else: + # FIXME: should be async self.clear_buffer(remove_total_filter=True) + return self + + def __init__(self, ifname: str, bpf: list[list[int]] | None = None): + self.ifname = ifname + self.bpf = bpf - def clear_buffer(self, remove_total_filter=False): + def clear_buffer(self, remove_total_filter: bool = False): # there is a window of time after the socket has been created and # before bind/attaching a filter where packets can be queued onto the # socket buffer @@ -86,7 +95,6 @@ def clear_buffer(self, remove_total_filter=False): # before setting the desired filter total_fstring, prog = compile_bpf(total_filter) socket.setsockopt(self, SOL_SOCKET, SO_ATTACH_FILTER, total_fstring) - self.setblocking(0) while True: try: self.recvfrom(0) @@ -99,7 +107,6 @@ def clear_buffer(self, remove_total_filter=False): break else: raise - self.setblocking(1) if remove_total_filter: # total_fstring ignored socket.setsockopt( diff --git a/setup.cfg b/setup.cfg index a31c73a36..12fb5a2a8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,5 +40,5 @@ console_scripts = ss2 = pyroute2.netlink.diag.ss2:run [psutil] pyroute2-cli = pyroute2.ndb.cli:run pyroute2-decoder = pyroute2.decoder.main:run - pyroute2-dhcp-client = pyroute2.dhcp.client:run + pyroute2-dhcp-client = pyroute2.dhcp.cli:run pyroute2-test-platform = pyroute2.config.test_platform:run diff --git a/tests/test_linux/conftest.py b/tests/test_linux/conftest.py index df00a9c68..d96bc743e 100644 --- a/tests/test_linux/conftest.py +++ b/tests/test_linux/conftest.py @@ -8,9 +8,11 @@ from pyroute2 import config from pyroute2.ipset import IPSet, IPSetError from pyroute2.wiset import COUNT +from fixtures.dnsmasq import dnsmasq, dnsmasq_options # noqa: F401 +from fixtures.interfaces import dhcp_range, veth_pair # noqa: F401 config.nlm_generator = True -pytest_plugins = "pytester" +pytest_plugins = 'pytester' @pytest.fixture @@ -62,7 +64,7 @@ def wiset_sock(request): if request.param is None: yield None else: - before_count = COUNT["count"] + before_count = COUNT['count'] with IPSet() as sock: yield sock assert before_count == COUNT['count'] diff --git a/tests/test_linux/fixtures/dnsmasq.py b/tests/test_linux/fixtures/dnsmasq.py new file mode 100644 index 000000000..7125df24b --- /dev/null +++ b/tests/test_linux/fixtures/dnsmasq.py @@ -0,0 +1,156 @@ +import asyncio +from argparse import ArgumentParser +from dataclasses import dataclass +from ipaddress import IPv4Address +from shutil import which +from typing import AsyncGenerator, ClassVar, Literal + +import pytest +import pytest_asyncio + +from fixtures.interfaces import DHCPRangeConfig + + +@dataclass +class DnsmasqOptions: + '''Options for the dnsmasq server.''' + + range_start: IPv4Address + range_end: IPv4Address + interface: str + lease_time: str = '12h' + + def __iter__(self): + opts = ( + f'--interface={self.interface}', + f'--dhcp-range={self.range_start},' + f'{self.range_end},{self.lease_time}', + ) + return iter(opts) + + +class DnsmasqFixture: + '''Runs the dnsmasq server as an async context manager.''' + + DNSMASQ_PATH: ClassVar[str | None] = which('dnsmasq') + + def __init__(self, options: DnsmasqOptions) -> None: + self.options = options + self.stdout: list[bytes] = [] + self.stderr: list[bytes] = [] + self.process: asyncio.subprocess.Process | None = None + self.output_poller: asyncio.Task | None = None + + async def _read_output(self, name: Literal['stdout', 'stderr']): + '''Read stdout or stderr until the process exits.''' + stream = getattr(self.process, name) + output = getattr(self, name) + while line := await stream.readline(): + output.append(line) + + async def _read_outputs(self): + '''Read stdout & stderr until the process exits.''' + assert self.process + await asyncio.gather( + self._read_output('stderr'), self._read_output('stdout') + ) + + def _get_base_cmdline_options(self) -> tuple[str]: + '''The base commandline options for dnsmasq.''' + return ( + '--keep-in-foreground', # self explanatory + '--no-resolv', # don't mess w/ resolv.conf + '--log-facility=-', # log to stdout + '--no-hosts', # don't read /etc/hosts + '--bind-interfaces', # don't bind on wildcard + '--no-ping', # don't ping to check if ips are attributed + ) + + def get_cmdline_options(self) -> tuple[str]: + '''All commandline options passed to dnsmasq.''' + return (*self._get_base_cmdline_options(), *self.options) + + async def __aenter__(self): + '''Start the dnsmasq process and start polling its output.''' + assert self.DNSMASQ_PATH + self.process = await asyncio.create_subprocess_exec( + self.DNSMASQ_PATH, + *self.get_cmdline_options(), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + self.output_poller = asyncio.Task(self._read_outputs()) + return self + + async def __aexit__(self, *_): + if self.process: + if self.process.returncode is None: + self.process.terminate() + await self.process.wait() + await self.output_poller + + +@pytest.fixture +def dnsmasq_options( + veth_pair: tuple[str, str], dhcp_range: DHCPRangeConfig +) -> DnsmasqOptions: + '''dnsmasq options useful for test purposes.''' + return DnsmasqOptions( + range_start=dhcp_range.range_start, + range_end=dhcp_range.range_end, + interface=veth_pair[0], + ) + + +@pytest_asyncio.fixture +async def dnsmasq( + dnsmasq_options: DnsmasqOptions, +) -> AsyncGenerator[DnsmasqFixture, None]: + '''A dnsmasq instance running for the duration of the test.''' + async with DnsmasqFixture(options=dnsmasq_options) as dnsf: + yield dnsf + + +def get_psr() -> ArgumentParser: + psr = ArgumentParser() + psr.add_argument('interface', help='Interface to listen on') + psr.add_argument( + '--range-start', + type=IPv4Address, + default=IPv4Address('192.168.186.10'), + help='Start of the DHCP client range.', + ) + psr.add_argument( + '--range-end', + type=IPv4Address, + default=IPv4Address('192.168.186.100'), + help='End of the DHCP client range.', + ) + psr.add_argument( + '--lease-time', + default='2m', + help='DHCP lease time (minimum 2 minutes according to man)', + ) + return psr + + +async def main(): + '''Commandline entrypoint to start dnsmasq the same way the fixture does. + + Useful for debugging. + ''' + args = get_psr().parse_args() + opts = DnsmasqOptions(**args.__dict__) + read_lines: int = 0 + async with DnsmasqFixture(opts) as dnsm: + # quick & dirty stderr polling + while True: + if len(dnsm.stderr) > read_lines: + read_lines += len(lines := dnsm.stderr[read_lines:]) + print(*(i.decode().strip() for i in lines), sep='\n') + else: + await asyncio.sleep(0.2) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/tests/test_linux/fixtures/interfaces.py b/tests/test_linux/fixtures/interfaces.py new file mode 100644 index 000000000..c276b2cbc --- /dev/null +++ b/tests/test_linux/fixtures/interfaces.py @@ -0,0 +1,68 @@ +import asyncio +import random +from ipaddress import IPv4Address, IPv4Interface +from typing import AsyncGenerator, NamedTuple + +import pytest +import pytest_asyncio + + +class DHCPRangeConfig(NamedTuple): + range_start: IPv4Address + range_end: IPv4Address + gw: IPv4Interface + + +async def ip(*args: str): + '''Call `ip` in a subprocess.''' + proc = await asyncio.create_subprocess_exec('ip', *args) + stdout, stderr = await proc.communicate() + assert proc.returncode == 0, stderr + return stdout + + +@pytest.fixture +def dhcp_range() -> DHCPRangeConfig: + ''' 'An IPv4 DHCP range configuration.''' + rangeidx = random.randint(1, 254) + return DHCPRangeConfig( + range_start=IPv4Address(f'10.{rangeidx}.0.10'), + range_end=IPv4Address(f'10.{rangeidx}.0.20'), + gw=IPv4Interface(f'10.{rangeidx}.0.1/16'), + ) + + +class VethPair(NamedTuple): + '''A pair of veth interfaces.''' + + server: str + client: str + + +@pytest_asyncio.fixture +async def veth_pair( + dhcp_range: DHCPRangeConfig, +) -> AsyncGenerator[VethPair, None]: + '''Fixture that creates a temporary veth pair.''' + # FIXME: use pyroute2 + # TODO: /proc/sys/net/ipv4/conf/{interface}/accept_local ? + idx = random.randint(0, 999) + server_ifname = f'dhcptest{idx}-srv' + client_ifname = f'dhcptest{idx}-cli' + try: + await ip( + 'link', + 'add', + server_ifname, + 'type', + 'veth', + 'peer', + 'name', + client_ifname, + ) + await ip('addr', 'add', str(dhcp_range.gw), 'dev', server_ifname) + await ip('link', 'set', server_ifname, 'up') + await ip('link', 'set', client_ifname, 'up') + yield VethPair(server_ifname, client_ifname) + finally: + await ip('link', 'del', server_ifname) diff --git a/tests/test_linux/test_raw/test_dhcp.py b/tests/test_linux/test_raw/test_dhcp.py index 9bc68f89c..8264fd7f1 100644 --- a/tests/test_linux/test_raw/test_dhcp.py +++ b/tests/test_linux/test_raw/test_dhcp.py @@ -1,65 +1,79 @@ -import collections +import asyncio import json -import subprocess +from ipaddress import IPv4Address +from pathlib import Path import pytest from pr2test.marks import require_root -from pyroute2 import NDB -from pyroute2.common import dqn2int, hexdump, hexload -from pyroute2.dhcp import client +from pyroute2.dhcp import client, fsm +from pyroute2.dhcp.constants import bootp, dhcp +from pyroute2.dhcp.leases import JSONFileLease +from fixtures.dnsmasq import DnsmasqFixture +from fixtures.interfaces import VethPair pytestmark = [require_root()] -@pytest.fixture -def ctx(): - ndb = NDB() - index = 0 - ifname = '' - # get a DHCP default route, if exists - with ndb.routes.dump() as dump: - dump.select_records(proto=16, dst='') - for route in dump: - index = route.oif - ifname = ndb.interfaces[index]['ifname'] - break - yield collections.namedtuple('Context', ['ndb', 'index', 'ifname'])( - ndb, index, ifname - ) - ndb.close() +@pytest.mark.asyncio +async def test_get_lease( + dnsmasq: DnsmasqFixture, + veth_pair: VethPair, + tmpdir: str, + monkeypatch: pytest.MonkeyPatch, +): + """The client can get a lease and write it to a file.""" + work_dir = Path(tmpdir) + # Patch JSONFileLease so leases get written to the temp dir + # instead of whatever the working directory is + monkeypatch.setattr(JSONFileLease, "_get_lease_dir", lambda: work_dir) + # boot up the dhcp client and wait for a lease + async with client.AsyncDHCPClient(veth_pair.client) as cli: + await cli.bootstrap() + await asyncio.wait_for(cli.bound.wait(), timeout=5) + assert cli.state == fsm.State.BOUND + lease = cli.lease + assert lease.ack["xid"] == cli.xid -def _do_test_client_module(ctx): - if ctx.index == 0: - pytest.skip('no DHCP interfaces detected') - - response = client.action(ctx.ifname) - options = response['options'] - router = response['options']['router'][0] - prefixlen = dqn2int(response['options']['subnet_mask']) - address = response['yiaddr'] - l2addr = response['chaddr'] - - # convert addresses like 96:0:1:45:fa:6c into 96:00:01:45:fa:6c + # check the obtained lease + assert lease.interface == veth_pair.client + assert lease.ack["op"] == bootp.MessageType.BOOTREPLY + assert lease.ack["options"]["message_type"] == dhcp.MessageType.ACK assert ( - hexdump(hexload(l2addr)) == ctx.ndb.interfaces[ctx.ifname]['address'] + dnsmasq.options.range_start + <= IPv4Address(lease.ip) + <= dnsmasq.options.range_end ) - assert router == ctx.ndb.routes['default']['gateway'] - assert options['lease_time'] > 0 - assert prefixlen > 0 - assert address is not None - return response - + assert lease.ack["chaddr"] + # TODO: check chaddr matches veth_pair.client's MAC -def test_client_module(ctx): - _do_test_client_module(ctx) + # check the lease was written to disk and can be loaded + expected_lease_file = JSONFileLease._get_path(lease.interface) + assert expected_lease_file.is_file() + json_lease = json.loads(expected_lease_file.read_bytes()) + assert isinstance(json_lease, dict) + assert JSONFileLease(**json_lease) == lease -def test_client_console(ctx): - response_from_module = json.loads(json.dumps(_do_test_client_module(ctx))) - client = subprocess.run( - ['pyroute2-dhcp-client', ctx.ifname], stdout=subprocess.PIPE +@pytest.mark.asyncio +async def test_client_console(dnsmasq: DnsmasqFixture, veth_pair: VethPair): + """The commandline client can get a lease, print it to stdout and exit.""" + process = await asyncio.create_subprocess_exec( + 'pyroute2-dhcp-client', + veth_pair.client, + '--lease-type', + 'pyroute2.dhcp.leases.JSONStdoutLease', + '--exit-on-lease', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, _ = await asyncio.wait_for(process.communicate(), timeout=5) + assert process.returncode == 0 + json_lease = json.loads(stdout) + assert json_lease["interface"] == veth_pair.client + assert ( + dnsmasq.options.range_start + <= IPv4Address(json_lease["ack"]["yiaddr"]) + <= dnsmasq.options.range_end ) - response_from_console = json.loads(client.stdout) - assert response_from_module == response_from_console