Skip to content

Commit

Permalink
rm hostport
Browse files Browse the repository at this point in the history
  • Loading branch information
thatstoasty committed Jan 12, 2025
1 parent 3469871 commit 9633dae
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 99 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ jobs:
curl -ssL https://magic.modular.com | bash
source $HOME/.bash_profile
magic run test
magic run integration_test
magic run integration_tests
magic run integration_tests_py
magic run integration_tests_external
113 changes: 46 additions & 67 deletions lightbug_http/net.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ trait Addr(Stringable, Representable, Writable, EqualityComparableCollectionElem
fn __init__(out self, ip: String, port: UInt16):
...

@implicit
fn __init__(out self, host_port: HostPort):
...

fn network(self) -> String:
...

Expand All @@ -117,13 +113,6 @@ struct NoTLSListener:
fn __moveinit__(out self, owned existing: Self):
self.socket = existing.socket^

# fn __del__(owned self):
# logger.info("Listener cleaning up", self.socket)
# try:
# self.teardown()
# except e:
# logger.error("NoTLSListener.__del__: Failed to close connection: " + str(e))

fn accept(self) raises -> TCPConnection:
return TCPConnection(self.socket.accept())

Expand Down Expand Up @@ -152,7 +141,8 @@ struct ListenConfig:
network in NetworkType.SUPPORTED_TYPES,
"Unsupported network type for internet address resolution. Unix addresses are not supported yet.",
]()
var addr = TCPAddr(HostPort.from_string(address))
var local = parse_address(address)
var addr = TCPAddr(local[0], local[1])
var socket: Socket[TCPAddr]
try:
socket = Socket[TCPAddr]()
Expand Down Expand Up @@ -215,13 +205,6 @@ struct TCPConnection(Connection):
fn __moveinit__(inout self, owned existing: Self):
self.socket = existing.socket^

# fn __del__(owned self):
# logger.info("TCPConnection cleaning up", self.socket)
# try:
# self.teardown()
# except e:
# logger.error("TCPConnection.__del__: Failed to close connection: " + str(e))

fn read(self, mut buf: Bytes) raises -> Int:
try:
return self.socket.receive_into(buf)
Expand Down Expand Up @@ -409,12 +392,6 @@ struct TCPAddr(Addr):
self.port = port
self.zone = ""

@implicit
fn __init__(out self, host_port: HostPort):
self.ip = host_port.host
self.port = host_port.port
self.zone = ""

fn network(self) -> String:
return NetworkType.tcp.value

Expand Down Expand Up @@ -447,54 +424,56 @@ alias MissingPortError = Error("missing port in address")
alias TooManyColonsError = Error("too many colons in address")


@value
struct HostPort:
var host: String
var port: UInt16
fn parse_address(address: String) raises -> (String, UInt16):
"""Parse an address string into a host and port.
@staticmethod
fn from_string(address: String) raises -> HostPort:
var colon_index = address.rfind(":")
if colon_index == -1:
raise MissingPortError

var host: String = ""
var port: String = ""
var j: Int = 0
var k: Int = 0
Args:
address: The address string.
if address[0] == "[":
var end_bracket_index = address.find("]")
if end_bracket_index == -1:
raise Error("missing ']' in address")
Returns:
A tuple containing the host and port.
"""
var colon_index = address.rfind(":")
if colon_index == -1:
raise MissingPortError

if end_bracket_index + 1 == len(address):
raise MissingPortError
elif end_bracket_index + 1 == colon_index:
host = address[1:end_bracket_index]
j = 1
k = end_bracket_index + 1
else:
if address[end_bracket_index + 1] == ":":
raise TooManyColonsError
else:
raise MissingPortError
else:
host = address[:colon_index]
if host.find(":") != -1:
raise TooManyColonsError
var host: String = ""
var port: String = ""
var j: Int = 0
var k: Int = 0

if address[j:].find("[") != -1:
raise Error("unexpected '[' in address")
if address[k:].find("]") != -1:
raise Error("unexpected ']' in address")
if address[0] == "[":
var end_bracket_index = address.find("]")
if end_bracket_index == -1:
raise Error("missing ']' in address")

port = address[colon_index + 1 :]
if port == "":
if end_bracket_index + 1 == len(address):
raise MissingPortError
if host == "":
raise Error("missing host")
return HostPort(host, int(port))
elif end_bracket_index + 1 == colon_index:
host = address[1:end_bracket_index]
j = 1
k = end_bracket_index + 1
else:
if address[end_bracket_index + 1] == ":":
raise TooManyColonsError
else:
raise MissingPortError
else:
host = address[:colon_index]
if host.find(":") != -1:
raise TooManyColonsError

if address[j:].find("[") != -1:
raise Error("unexpected '[' in address")
if address[k:].find("]") != -1:
raise Error("unexpected ']' in address")

port = address[colon_index + 1 :]
if port == "":
raise MissingPortError
if host == "":
raise Error("missing host")
return host, UInt16(int(port))


fn binary_port_to_int(port: UInt16) -> Int:
Expand Down
28 changes: 10 additions & 18 deletions lightbug_http/socket.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ from lightbug_http.strings import NetworkType
from lightbug_http.net import (
Addr,
TCPAddr,
HostPort,
default_buffer_size,
binary_port_to_int,
binary_ip_to_string,
Expand All @@ -75,8 +74,6 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri

var fd: Int32
"""The file descriptor of the socket."""
# var address_family: Int
# """The address family of the socket."""
var socket_type: Int32
"""The socket type."""
var protocol: Byte
Expand All @@ -94,7 +91,6 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri
out self,
local_address: AddrType = AddrType(),
remote_address: AddrType = AddrType(),
# address_family: Int = AF_INET,
socket_type: Int32 = SOCK_STREAM,
protocol: Byte = 0,
) raises:
Expand All @@ -109,7 +105,6 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri
Raises:
Error: If the socket creation fails.
"""
# self.address_family = address_family
self.socket_type = socket_type
self.protocol = protocol

Expand Down Expand Up @@ -152,7 +147,6 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri
existing: The existing socket object to move the data from.
"""
self.fd = existing.fd
# self.address_family = existing.address_family
self.socket_type = existing.socket_type
self.protocol = existing.protocol
self._local_address = existing._local_address^
Expand Down Expand Up @@ -276,12 +270,12 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri

var new_socket = Socket(
fd=new_socket_fd,
# address_family=self.address_family,
socket_type=self.socket_type,
protocol=self.protocol,
local_address=self.local_address(),
)
new_socket.set_remote_address(new_socket.get_peer_name())
var peer = new_socket.get_peer_name()
new_socket.set_remote_address(AddrType(peer[0], peer[1]))
return new_socket^

fn listen(self, backlog: UInt = 0) raises:
Expand Down Expand Up @@ -336,9 +330,9 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri
raise Error("Socket.bind: Binding socket failed.")

var local = self.get_sock_name()
self._local_address = AddrType(local.host, int(local.port))
self._local_address = AddrType(local[0], local[1])

fn get_sock_name(self) raises -> HostPort:
fn get_sock_name(self) raises -> (String, UInt16):
"""Return the address of the socket.
Returns:
Expand All @@ -363,12 +357,11 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri
raise Error("get_sock_name: Failed to get address of local socket.")

var addr_in = local_address.bitcast[sockaddr_in]().take_pointee()
return HostPort(
host=binary_ip_to_string[AF_INET](addr_in.sin_addr.s_addr),
port=binary_port_to_int(addr_in.sin_port),
return binary_ip_to_string[address_family](addr_in.sin_addr.s_addr), UInt16(
binary_port_to_int(addr_in.sin_port)
)

fn get_peer_name(self) raises -> HostPort:
fn get_peer_name(self) raises -> (String, UInt16):
"""Return the address of the peer connected to the socket.
Returns:
Expand All @@ -388,9 +381,8 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri
logger.error(e)
raise Error("get_peer_name: Failed to get address of remote socket.")

return HostPort(
host=binary_ip_to_string[AF_INET](addr_in.sin_addr.s_addr),
port=binary_port_to_int(addr_in.sin_port),
return binary_ip_to_string[address_family](addr_in.sin_addr.s_addr), UInt16(
binary_port_to_int(addr_in.sin_port)
)

fn get_socket_option(self, option_name: Int) raises -> Int:
Expand Down Expand Up @@ -454,7 +446,7 @@ struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stri
raise e

var remote = self.get_peer_name()
self._remote_address = AddrType(remote.host, remote.port)
self._remote_address = AddrType(remote[0], remote[1])

fn send(self, buffer: Span[Byte]) raises -> Int:
if buffer[-1] == 0:
Expand Down
4 changes: 2 additions & 2 deletions mojoproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ version = "0.1.8"
build = { cmd = "rattler-build build --recipe recipes -c https://conda.modular.com/max -c conda-forge --skip-existing=all", env = {MODULAR_MOJO_IMPORT_PATH = "$CONDA_PREFIX/lib/mojo"} }
publish = { cmd = "bash scripts/publish.sh", env = { PREFIX_API_KEY = "$PREFIX_API_KEY" } }
test = { cmd = "magic run mojo test -I . tests/lightbug_http" }
integration_test = { cmd = "bash scripts/integration_test.sh" }
integration_tests = { cmd = "magic run mojo test -I . tests/integration" }
integration_tests_py = { cmd = "bash scripts/integration_test.sh" }
integration_tests_external = { cmd = "magic run mojo test -I . tests/integration" }
bench = { cmd = "magic run mojo -I . benchmark/bench.mojo" }
bench_server = { cmd = "bash scripts/bench_server.sh" }
format = { cmd = "magic run mojo format -l 120 lightbug_http" }
Expand Down
20 changes: 10 additions & 10 deletions tests/lightbug_http/test_host_port.mojo
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import testing
from lightbug_http.net import join_host_port, HostPort, TCPAddr
from lightbug_http.net import join_host_port, parse_address, TCPAddr
from lightbug_http.strings import NetworkType


def test_split_host_port():
# IPv4
var hp = HostPort.from_string("127.0.0.1:8080")
testing.assert_equal(hp.host, "127.0.0.1")
testing.assert_equal(hp.port, 8080)
var hp = parse_address("127.0.0.1:8080")
testing.assert_equal(hp[0], "127.0.0.1")
testing.assert_equal(hp[1], 8080)

# IPv6
hp = HostPort.from_string("[::1]:8080")
testing.assert_equal(hp.host, "::1")
testing.assert_equal(hp.port, 8080)
hp = parse_address("[::1]:8080")
testing.assert_equal(hp[0], "::1")
testing.assert_equal(hp[1], 8080)

# # TODO: IPv6 long form - Not supported yet.
# hp = HostPort.from_string("0:0:0:0:0:0:0:1:8080")
# testing.assert_equal(hp.host, "0:0:0:0:0:0:0:1")
# testing.assert_equal(hp.port, 8080)
# hp = parse_address("0:0:0:0:0:0:0:1:8080")
# testing.assert_equal(hp[0], "0:0:0:0:0:0:0:1")
# testing.assert_equal(hp[1], 8080)


def test_join_host_port():
Expand Down

0 comments on commit 9633dae

Please sign in to comment.