Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UDP Socket support #87

Merged
merged 7 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ jobs:
magic run test
magic run integration_tests_py
magic run integration_tests_external
magic run integration_tests_udp
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ install_id
# Rattler
output

# integration tests
udp_client.DSYM
udp_server.DSYM
__pycache__

# misc
.vscode

__pycache__
270 changes: 269 additions & 1 deletion lightbug_http/libc.mojo

Large diffs are not rendered by default.

275 changes: 254 additions & 21 deletions lightbug_http/net.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from lightbug_http.libc import (
AF_INET,
AF_INET6,
SOCK_STREAM,
SOCK_DGRAM,
SOL_SOCKET,
SO_REUSEADDR,
SO_REUSEPORT,
Expand Down Expand Up @@ -71,7 +72,7 @@ trait Connection(Movable):
fn teardown(mut self) raises:
...

fn local_addr(mut self) -> TCPAddr:
fn local_addr(self) -> TCPAddr:
...

fn remote_addr(self) -> TCPAddr:
Expand Down Expand Up @@ -135,12 +136,8 @@ struct ListenConfig:
fn __init__(out self, keep_alive: Duration = default_tcp_keep_alive):
self._keep_alive = keep_alive

fn listen[network: NetworkType, address_family: Int = AF_INET](mut self, address: String) raises -> NoTLSListener:
fn listen[address_family: Int = AF_INET](mut self, address: String) raises -> NoTLSListener:
constrained[address_family in [AF_INET, AF_INET6], "Address family must be either AF_INET or AF_INET6."]()
constrained[
network in NetworkType.SUPPORTED_TYPES,
"Unsupported network type for internet address resolution. Unix addresses are not supported yet.",
]()
var local = parse_address(address)
var addr = TCPAddr(local[0], local[1])
var socket: Socket[TCPAddr]
Expand Down Expand Up @@ -196,18 +193,18 @@ struct ListenConfig:
return listener^


struct TCPConnection(Connection):
struct TCPConnection:
var socket: Socket[TCPAddr]

fn __init__(inout self, owned socket: Socket[TCPAddr]):
fn __init__(out self, owned socket: Socket[TCPAddr]):
self.socket = socket^

fn __moveinit__(inout self, owned existing: Self):
fn __moveinit__(out self, owned existing: Self):
self.socket = existing.socket^

fn read(self, mut buf: Bytes) raises -> Int:
try:
return self.socket.receive_into(buf)
return self.socket.receive(buf)
except e:
if str(e) == "EOF":
raise e
Expand Down Expand Up @@ -237,13 +234,101 @@ struct TCPConnection(Connection):
fn is_closed(self) -> Bool:
return self.socket._closed

fn local_addr(mut self) -> TCPAddr:
# TODO: Switch to property or return ref when trait supports attributes.
fn local_addr(self) -> TCPAddr:
return self.socket.local_address()

fn remote_addr(self) -> TCPAddr:
return self.socket.remote_address()


struct UDPConnection:
var socket: Socket[UDPAddr]

fn __init__(out self, owned socket: Socket[UDPAddr]):
self.socket = socket^

fn __moveinit__(out self, owned existing: Self):
self.socket = existing.socket^

fn read_from(mut self, size: Int = default_buffer_size) raises -> (Bytes, String, UInt16):
"""Reads data from the underlying file descriptor.

Args:
size: The size of the buffer to read data into.

Returns:
The number of bytes read, or an error if one occurred.

Raises:
Error: If an error occurred while reading data.
"""
return self.socket.receive_from(size)

fn read_from(mut self, mut dest: Bytes) raises -> (UInt, String, UInt16):
"""Reads data from the underlying file descriptor.

Args:
dest: The buffer to read data into.

Returns:
The number of bytes read, or an error if one occurred.

Raises:
Error: If an error occurred while reading data.
"""
return self.socket.receive_from(dest)

fn write_to(mut self, src: Span[Byte], address: UDPAddr) raises -> Int:
"""Writes data to the underlying file descriptor.

Args:
src: The buffer to read data into.
address: The remote peer address.

Returns:
The number of bytes written, or an error if one occurred.

Raises:
Error: If an error occurred while writing data.
"""
return self.socket.send_to(src, address.ip, address.port)

fn write_to(mut self, src: Span[Byte], host: String, port: UInt16) raises -> Int:
"""Writes data to the underlying file descriptor.

Args:
src: The buffer to read data into.
host: The remote peer address in IPv4 format.
port: The remote peer port.

Returns:
The number of bytes written, or an error if one occurred.

Raises:
Error: If an error occurred while writing data.
"""
return self.socket.send_to(src, host, port)

fn close(mut self) raises:
self.socket.close()

fn shutdown(mut self) raises:
self.socket.shutdown()

fn teardown(mut self) raises:
self.socket.teardown()

fn is_closed(self) -> Bool:
return self.socket._closed

fn local_addr(self) -> ref [self.socket._local_address] UDPAddr:
return self.socket.local_address()

fn remote_addr(self) -> ref [self.socket._remote_address] UDPAddr:
return self.socket.remote_address()


@value
@register_passable("trivial")
struct addrinfo_macos(AnAddrInfo):
Expand All @@ -261,12 +346,19 @@ struct addrinfo_macos(AnAddrInfo):
var ai_addr: UnsafePointer[sockaddr]
var ai_next: OpaquePointer

fn __init__(out self, ai_flags: c_int = 0, ai_family: c_int = 0, ai_socktype: c_int = 0, ai_protocol: c_int = 0):
self.ai_flags = 0
self.ai_family = 0
self.ai_socktype = 0
self.ai_protocol = 0
self.ai_addrlen = 0
fn __init__(
out self,
ai_flags: c_int = 0,
ai_family: c_int = 0,
ai_socktype: c_int = 0,
ai_protocol: c_int = 0,
ai_addrlen: socklen_t = 0,
):
self.ai_flags = ai_flags
self.ai_family = ai_family
self.ai_socktype = ai_socktype
self.ai_protocol = ai_protocol
self.ai_addrlen = ai_addrlen
self.ai_canonname = UnsafePointer[c_char]()
self.ai_addr = UnsafePointer[sockaddr]()
self.ai_next = OpaquePointer()
Expand Down Expand Up @@ -314,12 +406,19 @@ struct addrinfo_unix(AnAddrInfo):
var ai_canonname: UnsafePointer[c_char]
var ai_next: OpaquePointer

fn __init__(out self, ai_flags: c_int = 0, ai_family: c_int = 0, ai_socktype: c_int = 0, ai_protocol: c_int = 0):
fn __init__(
out self,
ai_flags: c_int = 0,
ai_family: c_int = 0,
ai_socktype: c_int = 0,
ai_protocol: c_int = 0,
ai_addrlen: socklen_t = 0,
):
self.ai_flags = ai_flags
self.ai_family = ai_family
self.ai_socktype = ai_socktype
self.ai_protocol = ai_protocol
self.ai_addrlen = 0
self.ai_addrlen = ai_addrlen
self.ai_addr = UnsafePointer[sockaddr]()
self.ai_canonname = UnsafePointer[c_char]()
self.ai_next = OpaquePointer()
Expand Down Expand Up @@ -395,10 +494,10 @@ struct TCPAddr(Addr):
fn network(self) -> String:
return NetworkType.tcp.value

fn __eq__(self, other: TCPAddr) -> Bool:
fn __eq__(self, other: Self) -> Bool:
return self.ip == other.ip and self.port == other.port and self.zone == other.zone

fn __ne__(self, other: TCPAddr) -> Bool:
fn __ne__(self, other: Self) -> Bool:
return not self == other

fn __str__(self) -> String:
Expand All @@ -413,6 +512,140 @@ struct TCPAddr(Addr):
writer.write("TCPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")")


@value
struct UDPAddr(Addr):
alias _type = "UDPAddr"
var ip: String
var port: UInt16
var zone: String # IPv6 addressing zone

fn __init__(out self):
self.ip = "127.0.0.1"
self.port = 8000
self.zone = ""

fn __init__(out self, ip: String = "127.0.0.1", port: UInt16 = 8000):
self.ip = ip
self.port = port
self.zone = ""

fn network(self) -> String:
return NetworkType.udp.value

fn __eq__(self, other: Self) -> Bool:
return self.ip == other.ip and self.port == other.port and self.zone == other.zone

fn __ne__(self, other: Self) -> Bool:
return not self == other

fn __str__(self) -> String:
if self.zone != "":
return join_host_port(self.ip + "%" + self.zone, str(self.port))
return join_host_port(self.ip, str(self.port))

fn __repr__(self) -> String:
return String.write(self)

fn write_to[W: Writer, //](self, mut writer: W):
writer.write("UDPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")")


fn listen_udp(local_address: UDPAddr) raises -> UDPConnection:
"""Creates a new UDP listener.

Args:
local_address: The local address to listen on.

Returns:
A UDP connection.

Raises:
Error: If the address is invalid or failed to bind the socket.
"""
socket = Socket[UDPAddr](socket_type=SOCK_DGRAM)
socket.bind(local_address.ip, local_address.port)
return UDPConnection(socket^)


fn listen_udp(local_address: String) raises -> UDPConnection:
"""Creates a new UDP listener.

Args:
local_address: The address to listen on. The format is "host:port".

Returns:
A UDP connection.

Raises:
Error: If the address is invalid or failed to bind the socket.
"""
var address = parse_address(local_address)
return listen_udp(UDPAddr(address[0], address[1]))


fn listen_udp(host: String, port: UInt16) raises -> UDPConnection:
"""Creates a new UDP listener.

Args:
host: The address to listen on in ipv4 format.
port: The port number.

Returns:
A UDP connection.

Raises:
Error: If the address is invalid or failed to bind the socket.
"""
return listen_udp(UDPAddr(host, port))


fn dial_udp(local_address: UDPAddr) raises -> UDPConnection:
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".

Args:
local_address: The local address.

Returns:
The UDP connection.

Raises:
Error: If the network type is not supported or failed to connect to the address.
"""
return UDPConnection(Socket(local_address=local_address, socket_type=SOCK_DGRAM))


fn dial_udp(local_address: String) raises -> UDPConnection:
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".

Args:
local_address: The local address.

Returns:
The UDP connection.

Raises:
Error: If the network type is not supported or failed to connect to the address.
"""
var address = parse_address(local_address)
return dial_udp(UDPAddr(address[0], address[1]))


fn dial_udp(host: String, port: UInt16) raises -> UDPConnection:
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".

Args:
host: The host to connect to.
port: The port to connect on.

Returns:
The UDP connection.

Raises:
Error: If the network type is not supported or failed to connect to the address.
"""
return dial_udp(UDPAddr(host, port))


# TODO: Support IPv6 long form.
fn join_host_port(host: String, port: String) -> String:
if host.find(":") != -1: # must be IPv6 literal
Expand Down
4 changes: 2 additions & 2 deletions lightbug_http/server.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from lightbug_http.io.sync import Duration
from lightbug_http.io.bytes import Bytes, bytes
from lightbug_http.strings import NetworkType
from lightbug_http.utils import ByteReader, logger
from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig, TCPAddr
from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig
from lightbug_http.socket import Socket
from lightbug_http.http import HTTPRequest, encode
from lightbug_http.http.common_response import InternalError
Expand Down Expand Up @@ -92,7 +92,7 @@ struct Server(Movable):
handler: An object that handles incoming HTTP requests.
"""
var net = ListenConfig()
var listener = net.listen[NetworkType.tcp4](address)
var listener = net.listen(address)
self.set_address(address)
self.serve(listener^, handler)

Expand Down
Loading
Loading