Skip to content

Commit

Permalink
Merge pull request #87 from thatstoasty/udp
Browse files Browse the repository at this point in the history
UDP Socket support
  • Loading branch information
saviorand authored Jan 15, 2025
2 parents b41e88e + bd14fc5 commit 064d7f6
Show file tree
Hide file tree
Showing 10 changed files with 685 additions and 32 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ jobs:
magic run test
magic run integration_tests_py
magic run integration_tests_external
magic run integration_tests_udp
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ install_id
# Rattler
output

# integration tests
udp_client.DSYM
udp_server.DSYM
__pycache__

# misc
.vscode

__pycache__
270 changes: 269 additions & 1 deletion lightbug_http/libc.mojo

Large diffs are not rendered by default.

275 changes: 254 additions & 21 deletions lightbug_http/net.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from lightbug_http.libc import (
AF_INET,
AF_INET6,
SOCK_STREAM,
SOCK_DGRAM,
SOL_SOCKET,
SO_REUSEADDR,
SO_REUSEPORT,
Expand Down Expand Up @@ -71,7 +72,7 @@ trait Connection(Movable):
fn teardown(mut self) raises:
...

fn local_addr(mut self) -> TCPAddr:
fn local_addr(self) -> TCPAddr:
...

fn remote_addr(self) -> TCPAddr:
Expand Down Expand Up @@ -135,12 +136,8 @@ struct ListenConfig:
fn __init__(out self, keep_alive: Duration = default_tcp_keep_alive):
self._keep_alive = keep_alive

fn listen[network: NetworkType, address_family: Int = AF_INET](mut self, address: String) raises -> NoTLSListener:
fn listen[address_family: Int = AF_INET](mut self, address: String) raises -> NoTLSListener:
constrained[address_family in [AF_INET, AF_INET6], "Address family must be either AF_INET or AF_INET6."]()
constrained[
network in NetworkType.SUPPORTED_TYPES,
"Unsupported network type for internet address resolution. Unix addresses are not supported yet.",
]()
var local = parse_address(address)
var addr = TCPAddr(local[0], local[1])
var socket: Socket[TCPAddr]
Expand Down Expand Up @@ -196,18 +193,18 @@ struct ListenConfig:
return listener^


struct TCPConnection(Connection):
struct TCPConnection:
var socket: Socket[TCPAddr]

fn __init__(inout self, owned socket: Socket[TCPAddr]):
fn __init__(out self, owned socket: Socket[TCPAddr]):
self.socket = socket^

fn __moveinit__(inout self, owned existing: Self):
fn __moveinit__(out self, owned existing: Self):
self.socket = existing.socket^

fn read(self, mut buf: Bytes) raises -> Int:
try:
return self.socket.receive_into(buf)
return self.socket.receive(buf)
except e:
if str(e) == "EOF":
raise e
Expand Down Expand Up @@ -237,13 +234,101 @@ struct TCPConnection(Connection):
fn is_closed(self) -> Bool:
return self.socket._closed

fn local_addr(mut self) -> TCPAddr:
# TODO: Switch to property or return ref when trait supports attributes.
fn local_addr(self) -> TCPAddr:
return self.socket.local_address()

fn remote_addr(self) -> TCPAddr:
return self.socket.remote_address()


struct UDPConnection:
var socket: Socket[UDPAddr]

fn __init__(out self, owned socket: Socket[UDPAddr]):
self.socket = socket^

fn __moveinit__(out self, owned existing: Self):
self.socket = existing.socket^

fn read_from(mut self, size: Int = default_buffer_size) raises -> (Bytes, String, UInt16):
"""Reads data from the underlying file descriptor.
Args:
size: The size of the buffer to read data into.
Returns:
The number of bytes read, or an error if one occurred.
Raises:
Error: If an error occurred while reading data.
"""
return self.socket.receive_from(size)

fn read_from(mut self, mut dest: Bytes) raises -> (UInt, String, UInt16):
"""Reads data from the underlying file descriptor.
Args:
dest: The buffer to read data into.
Returns:
The number of bytes read, or an error if one occurred.
Raises:
Error: If an error occurred while reading data.
"""
return self.socket.receive_from(dest)

fn write_to(mut self, src: Span[Byte], address: UDPAddr) raises -> Int:
"""Writes data to the underlying file descriptor.
Args:
src: The buffer to read data into.
address: The remote peer address.
Returns:
The number of bytes written, or an error if one occurred.
Raises:
Error: If an error occurred while writing data.
"""
return self.socket.send_to(src, address.ip, address.port)

fn write_to(mut self, src: Span[Byte], host: String, port: UInt16) raises -> Int:
"""Writes data to the underlying file descriptor.
Args:
src: The buffer to read data into.
host: The remote peer address in IPv4 format.
port: The remote peer port.
Returns:
The number of bytes written, or an error if one occurred.
Raises:
Error: If an error occurred while writing data.
"""
return self.socket.send_to(src, host, port)

fn close(mut self) raises:
self.socket.close()

fn shutdown(mut self) raises:
self.socket.shutdown()

fn teardown(mut self) raises:
self.socket.teardown()

fn is_closed(self) -> Bool:
return self.socket._closed

fn local_addr(self) -> ref [self.socket._local_address] UDPAddr:
return self.socket.local_address()

fn remote_addr(self) -> ref [self.socket._remote_address] UDPAddr:
return self.socket.remote_address()


@value
@register_passable("trivial")
struct addrinfo_macos(AnAddrInfo):
Expand All @@ -261,12 +346,19 @@ struct addrinfo_macos(AnAddrInfo):
var ai_addr: UnsafePointer[sockaddr]
var ai_next: OpaquePointer

fn __init__(out self, ai_flags: c_int = 0, ai_family: c_int = 0, ai_socktype: c_int = 0, ai_protocol: c_int = 0):
self.ai_flags = 0
self.ai_family = 0
self.ai_socktype = 0
self.ai_protocol = 0
self.ai_addrlen = 0
fn __init__(
out self,
ai_flags: c_int = 0,
ai_family: c_int = 0,
ai_socktype: c_int = 0,
ai_protocol: c_int = 0,
ai_addrlen: socklen_t = 0,
):
self.ai_flags = ai_flags
self.ai_family = ai_family
self.ai_socktype = ai_socktype
self.ai_protocol = ai_protocol
self.ai_addrlen = ai_addrlen
self.ai_canonname = UnsafePointer[c_char]()
self.ai_addr = UnsafePointer[sockaddr]()
self.ai_next = OpaquePointer()
Expand Down Expand Up @@ -314,12 +406,19 @@ struct addrinfo_unix(AnAddrInfo):
var ai_canonname: UnsafePointer[c_char]
var ai_next: OpaquePointer

fn __init__(out self, ai_flags: c_int = 0, ai_family: c_int = 0, ai_socktype: c_int = 0, ai_protocol: c_int = 0):
fn __init__(
out self,
ai_flags: c_int = 0,
ai_family: c_int = 0,
ai_socktype: c_int = 0,
ai_protocol: c_int = 0,
ai_addrlen: socklen_t = 0,
):
self.ai_flags = ai_flags
self.ai_family = ai_family
self.ai_socktype = ai_socktype
self.ai_protocol = ai_protocol
self.ai_addrlen = 0
self.ai_addrlen = ai_addrlen
self.ai_addr = UnsafePointer[sockaddr]()
self.ai_canonname = UnsafePointer[c_char]()
self.ai_next = OpaquePointer()
Expand Down Expand Up @@ -395,10 +494,10 @@ struct TCPAddr(Addr):
fn network(self) -> String:
return NetworkType.tcp.value

fn __eq__(self, other: TCPAddr) -> Bool:
fn __eq__(self, other: Self) -> Bool:
return self.ip == other.ip and self.port == other.port and self.zone == other.zone

fn __ne__(self, other: TCPAddr) -> Bool:
fn __ne__(self, other: Self) -> Bool:
return not self == other

fn __str__(self) -> String:
Expand All @@ -413,6 +512,140 @@ struct TCPAddr(Addr):
writer.write("TCPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")")


@value
struct UDPAddr(Addr):
alias _type = "UDPAddr"
var ip: String
var port: UInt16
var zone: String # IPv6 addressing zone

fn __init__(out self):
self.ip = "127.0.0.1"
self.port = 8000
self.zone = ""

fn __init__(out self, ip: String = "127.0.0.1", port: UInt16 = 8000):
self.ip = ip
self.port = port
self.zone = ""

fn network(self) -> String:
return NetworkType.udp.value

fn __eq__(self, other: Self) -> Bool:
return self.ip == other.ip and self.port == other.port and self.zone == other.zone

fn __ne__(self, other: Self) -> Bool:
return not self == other

fn __str__(self) -> String:
if self.zone != "":
return join_host_port(self.ip + "%" + self.zone, str(self.port))
return join_host_port(self.ip, str(self.port))

fn __repr__(self) -> String:
return String.write(self)

fn write_to[W: Writer, //](self, mut writer: W):
writer.write("UDPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")")


fn listen_udp(local_address: UDPAddr) raises -> UDPConnection:
"""Creates a new UDP listener.
Args:
local_address: The local address to listen on.
Returns:
A UDP connection.
Raises:
Error: If the address is invalid or failed to bind the socket.
"""
socket = Socket[UDPAddr](socket_type=SOCK_DGRAM)
socket.bind(local_address.ip, local_address.port)
return UDPConnection(socket^)


fn listen_udp(local_address: String) raises -> UDPConnection:
"""Creates a new UDP listener.
Args:
local_address: The address to listen on. The format is "host:port".
Returns:
A UDP connection.
Raises:
Error: If the address is invalid or failed to bind the socket.
"""
var address = parse_address(local_address)
return listen_udp(UDPAddr(address[0], address[1]))


fn listen_udp(host: String, port: UInt16) raises -> UDPConnection:
"""Creates a new UDP listener.
Args:
host: The address to listen on in ipv4 format.
port: The port number.
Returns:
A UDP connection.
Raises:
Error: If the address is invalid or failed to bind the socket.
"""
return listen_udp(UDPAddr(host, port))


fn dial_udp(local_address: UDPAddr) raises -> UDPConnection:
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".
Args:
local_address: The local address.
Returns:
The UDP connection.
Raises:
Error: If the network type is not supported or failed to connect to the address.
"""
return UDPConnection(Socket(local_address=local_address, socket_type=SOCK_DGRAM))


fn dial_udp(local_address: String) raises -> UDPConnection:
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".
Args:
local_address: The local address.
Returns:
The UDP connection.
Raises:
Error: If the network type is not supported or failed to connect to the address.
"""
var address = parse_address(local_address)
return dial_udp(UDPAddr(address[0], address[1]))


fn dial_udp(host: String, port: UInt16) raises -> UDPConnection:
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".
Args:
host: The host to connect to.
port: The port to connect on.
Returns:
The UDP connection.
Raises:
Error: If the network type is not supported or failed to connect to the address.
"""
return dial_udp(UDPAddr(host, port))


# TODO: Support IPv6 long form.
fn join_host_port(host: String, port: String) -> String:
if host.find(":") != -1: # must be IPv6 literal
Expand Down
4 changes: 2 additions & 2 deletions lightbug_http/server.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from lightbug_http.io.sync import Duration
from lightbug_http.io.bytes import Bytes, bytes
from lightbug_http.strings import NetworkType
from lightbug_http.utils import ByteReader, logger
from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig, TCPAddr
from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig
from lightbug_http.socket import Socket
from lightbug_http.http import HTTPRequest, encode
from lightbug_http.http.common_response import InternalError
Expand Down Expand Up @@ -92,7 +92,7 @@ struct Server(Movable):
handler: An object that handles incoming HTTP requests.
"""
var net = ListenConfig()
var listener = net.listen[NetworkType.tcp4](address)
var listener = net.listen(address)
self.set_address(address)
self.serve(listener^, handler)

Expand Down
Loading

0 comments on commit 064d7f6

Please sign in to comment.