Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UDP Socket support #87

Merged
merged 7 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ install_id
# Rattler
output

# integration tests
udp_client.DSYM
udp_server.DSYM
__pycache__

# misc
.vscode

__pycache__
270 changes: 269 additions & 1 deletion lightbug_http/libc.mojo

Large diffs are not rendered by default.

250 changes: 247 additions & 3 deletions lightbug_http/net.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from lightbug_http.libc import (
AF_INET,
AF_INET6,
SOCK_STREAM,
SOCK_DGRAM,
SOL_SOCKET,
SO_REUSEADDR,
SO_REUSEPORT,
Expand Down Expand Up @@ -196,7 +197,7 @@ struct ListenConfig:
return listener^


struct TCPConnection(Connection):
struct TCPConnection:
var socket: Socket[TCPAddr]

fn __init__(inout self, owned socket: Socket[TCPAddr]):
Expand Down Expand Up @@ -244,6 +245,93 @@ struct TCPConnection(Connection):
return self.socket.remote_address()


struct UDPConnection:
var socket: Socket[UDPAddr]

fn __init__(inout self, owned socket: Socket[UDPAddr]):
self.socket = socket^

fn __moveinit__(inout self, owned existing: Self):
self.socket = existing.socket^

fn read_from(mut self, size: Int = default_buffer_size) raises -> (Bytes, String, UInt16):
"""Reads data from the underlying file descriptor.

Args:
size: The size of the buffer to read data into.

Returns:
The number of bytes read, or an error if one occurred.

Raises:
Error: If an error occurred while reading data.
"""
return self.socket.receive_from(size)

fn read_from(mut self, mut dest: Bytes) raises -> (UInt, String, UInt16):
"""Reads data from the underlying file descriptor.

Args:
dest: The buffer to read data into.

Returns:
The number of bytes read, or an error if one occurred.

Raises:
Error: If an error occurred while reading data.
"""
return self.socket.receive_from_into(dest)
Copy link
Collaborator

@saviorand saviorand Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the socket this method is called receive_from_into but here it's one of the overloads on read_from. For consistency we could call this one read_from_into? or am I missing something

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's just verbiage we want to use. Do you think connections should read/write (read_from/write_to) or send/receive (receive_from/send_to)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think read/write is fine on connections, I was more talking about this pattern where we supply the dest to read into, looks like this same pattern has the from_into suffix on the socket but just from on the connection, although they have a similar signature.
I think on the socket this similar to how it's done in the Python socket, right? where recvfrom_into writes into a given buffer, and recvfrom allocates a new one. Maybe on the connection we can have a similar pattern?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or is it better to have an overload here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, I see what you mean. Yeah, the socket follows Python. I think perhaps overloading the socket like the connection would be better. Python doesn't have function overloading which is probably why there's more than one function name. Didn't think of that!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, overloading sounds good!


fn write_to(mut self, src: Span[Byte], address: UDPAddr) raises -> Int:
"""Writes data to the underlying file descriptor.

Args:
src: The buffer to read data into.
address: The remote peer address.

Returns:
The number of bytes written, or an error if one occurred.

Raises:
Error: If an error occurred while writing data.
"""
return self.socket.send_to(src, address.ip, address.port)

fn write_to(mut self, src: Span[Byte], host: String, port: UInt16) raises -> Int:
"""Writes data to the underlying file descriptor.

Args:
src: The buffer to read data into.
host: The remote peer address in IPv4 format.
port: The remote peer port.

Returns:
The number of bytes written, or an error if one occurred.

Raises:
Error: If an error occurred while writing data.
"""
return self.socket.send_to(src, host, port)

fn close(mut self) raises:
self.socket.close()

fn shutdown(mut self) raises:
self.socket.shutdown()

fn teardown(mut self) raises:
self.socket.teardown()

fn is_closed(self) -> Bool:
return self.socket._closed

fn local_addr(mut self) -> UDPAddr:
return self.socket.local_address()

fn remote_addr(self) -> UDPAddr:
return self.socket.remote_address()


@value
@register_passable("trivial")
struct addrinfo_macos(AnAddrInfo):
Expand Down Expand Up @@ -395,10 +483,10 @@ struct TCPAddr(Addr):
fn network(self) -> String:
return NetworkType.tcp.value

fn __eq__(self, other: TCPAddr) -> Bool:
fn __eq__(self, other: Self) -> Bool:
return self.ip == other.ip and self.port == other.port and self.zone == other.zone

fn __ne__(self, other: TCPAddr) -> Bool:
fn __ne__(self, other: Self) -> Bool:
return not self == other

fn __str__(self) -> String:
Expand All @@ -413,6 +501,162 @@ struct TCPAddr(Addr):
writer.write("TCPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")")


@value
struct UDPAddr(Addr):
alias _type = "UDPAddr"
var ip: String
var port: UInt16
var zone: String # IPv6 addressing zone

fn __init__(out self):
self.ip = "127.0.0.1"
self.port = 8000
self.zone = ""

fn __init__(out self, ip: String = "127.0.0.1", port: UInt16 = 8000):
self.ip = ip
self.port = port
self.zone = ""

fn network(self) -> String:
return NetworkType.udp.value

fn __eq__(self, other: Self) -> Bool:
return self.ip == other.ip and self.port == other.port and self.zone == other.zone

fn __ne__(self, other: Self) -> Bool:
return not self == other

fn __str__(self) -> String:
if self.zone != "":
return join_host_port(self.ip + "%" + self.zone, str(self.port))
return join_host_port(self.ip, str(self.port))

fn __repr__(self) -> String:
return String.write(self)

fn write_to[W: Writer, //](self, mut writer: W):
writer.write("UDPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")")


fn listen_udp[network: NetworkType = NetworkType.udp](local_address: UDPAddr) raises -> UDPConnection:
"""Creates a new UDP listener.

Parameters:
network: The network type.

Args:
local_address: The local address to listen on.

Returns:
A UDP connection.

Raises:
Error: If the address is invalid or failed to bind the socket.
"""
socket = Socket[UDPAddr](socket_type=SOCK_DGRAM)
socket.bind(local_address.ip, local_address.port)
return UDPConnection(socket^)


fn listen_udp[network: NetworkType = NetworkType.udp](local_address: String) raises -> UDPConnection:
"""Creates a new UDP listener.

Parameters:
network: The network type.

Args:
local_address: The address to listen on. The format is "host:port".

Returns:
A UDP connection.

Raises:
Error: If the address is invalid or failed to bind the socket.
"""
var address = parse_address(local_address)
return listen_udp[network](UDPAddr(address[0], address[1]))


fn listen_udp[network: NetworkType = NetworkType.udp](host: String, port: UInt16) raises -> UDPConnection:
"""Creates a new UDP listener.

Parameters:
network: The network type.

Args:
host: The address to listen on in ipv4 format.
port: The port number.

Returns:
A UDP connection.

Raises:
Error: If the address is invalid or failed to bind the socket.
"""
return listen_udp[network](UDPAddr(host, port))


fn dial_udp[network: NetworkType = NetworkType.udp](local_address: UDPAddr) raises -> UDPConnection:
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".

Parameters:
network: The network type.

Args:
local_address: The local address.

Returns:
The UDP connection.

Raises:
Error: If the network type is not supported or failed to connect to the address.
"""
constrained[
network in NetworkType.UDP_TYPES,
"Unsupported network type for UDP.",
]()
return UDPConnection(Socket(local_address=local_address, socket_type=SOCK_DGRAM))


fn dial_udp[network: NetworkType = NetworkType.udp](local_address: String) raises -> UDPConnection:
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".

Parameters:
network: The network type.

Args:
local_address: The local address.

Returns:
The UDP connection.

Raises:
Error: If the network type is not supported or failed to connect to the address.
"""
var address = parse_address(local_address)
return dial_udp[network](UDPAddr(address[0], address[1]))


fn dial_udp[network: NetworkType = NetworkType.udp](host: String, port: UInt16) raises -> UDPConnection:
"""Connects to the address on the named network. The network must be "udp", "udp4", or "udp6".

Parameters:
network: The network type.

Args:
host: The host to connect to.
port: The port to connect on.

Returns:
The UDP connection.

Raises:
Error: If the network type is not supported or failed to connect to the address.
"""
return dial_udp[network](UDPAddr(host, port))


# TODO: Support IPv6 long form.
fn join_host_port(host: String, port: String) -> String:
if host.find(":") != -1: # must be IPv6 literal
Expand Down
2 changes: 1 addition & 1 deletion lightbug_http/server.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from lightbug_http.io.sync import Duration
from lightbug_http.io.bytes import Bytes, bytes
from lightbug_http.strings import NetworkType
from lightbug_http.utils import ByteReader, logger
from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig, TCPAddr
from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig
from lightbug_http.socket import Socket
from lightbug_http.http import HTTPRequest, encode
from lightbug_http.http.common_response import InternalError
Expand Down
Loading
Loading