diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 00000000..29a56f6c --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,18 @@ +name: Run the benchmarking suite + +on: + workflow_call: + +jobs: + test: + name: Run benchmarks + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Run the test suite + run: | + curl -ssL https://magic.modular.com | bash + source $HOME/.bash_profile + magic run bench + # magic run bench_server # Commented out until we get `wrk` installed diff --git a/.github/workflows/branch.yml b/.github/workflows/branch.yml index 9a5b2ec6..fda11dd1 100644 --- a/.github/workflows/branch.yml +++ b/.github/workflows/branch.yml @@ -2,10 +2,10 @@ name: Branch workflow on: push: - branches: + branches: - '*' pull_request: - branches: + branches: - '*' permissions: @@ -14,6 +14,9 @@ permissions: jobs: test: uses: ./.github/workflows/test.yml - + + bench: + uses: ./.github/workflows/bench.yml + package: uses: ./.github/workflows/package.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2db3154e..efbf7722 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,4 +15,5 @@ jobs: curl -ssL https://magic.modular.com | bash source $HOME/.bash_profile magic run test - + magic run integration_tests_py + magic run integration_tests_external diff --git a/.gitignore b/.gitignore index b89f8c78..7740ac4d 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,6 @@ install_id output # misc -.vscode \ No newline at end of file +.vscode + +__pycache__ diff --git a/bench.mojo b/benchmark/bench.mojo similarity index 90% rename from bench.mojo rename to benchmark/bench.mojo index 1574d67d..accd1ad5 100644 --- a/bench.mojo +++ b/benchmark/bench.mojo @@ -74,9 +74,8 @@ fn lightbug_benchmark_response_parse(mut b: Bencher): @always_inline @parameter fn response_parse(): - var res = Response try: - _ = HTTPResponse.from_bytes(res.as_bytes()) + _ = HTTPResponse.from_bytes(Response.as_bytes()) except: pass @@ -88,9 +87,8 @@ fn lightbug_benchmark_request_parse(mut b: Bencher): @always_inline @parameter fn request_parse(): - var r = Request try: - _ = HTTPRequest.from_bytes("127.0.0.1/path", 4096, r.as_bytes()) + _ = HTTPRequest.from_bytes("127.0.0.1/path", 4096, Request.as_bytes()) except: pass @@ -103,7 +101,7 @@ fn lightbug_benchmark_request_encode(mut b: Bencher): @parameter fn request_encode(): var req = HTTPRequest( - URI.parse("http://127.0.0.1:8080/some-path")[URI], + URI.parse("http://127.0.0.1:8080/some-path"), headers=headers_struct, body=body_bytes, ) @@ -118,8 +116,7 @@ fn lightbug_benchmark_header_encode(mut b: Bencher): @parameter fn header_encode(): var b = ByteWriter() - var h = headers_struct - b.write(h) + b.write(headers_struct) b.iter[header_encode]() @@ -130,9 +127,8 @@ fn lightbug_benchmark_header_parse(mut b: Bencher): @parameter fn header_parse(): try: - var b = headers var header = Headers() - var reader = ByteReader(b.as_bytes()) + var reader = ByteReader(headers.as_bytes()) _ = header.parse_raw(reader) except: print("failed") diff --git a/bench_server.mojo b/benchmark/bench_server.mojo similarity index 100% rename from bench_server.mojo rename to benchmark/bench_server.mojo diff --git a/client.mojo b/client.mojo index 8af25d57..d52be369 100644 --- a/client.mojo +++ b/client.mojo @@ -3,7 +3,7 @@ from lightbug_http.client import Client fn test_request(mut client: Client) raises -> None: - var uri = URI.parse_raises("google.com") + var uri = URI.parse("google.com") var headers = Headers(Header("Host", "google.com")) var request = HTTPRequest(uri, headers) var response = client.do(request^) diff --git a/lightbug_http/__init__.mojo b/lightbug_http/__init__.mojo index 7d89a5f9..b4b6b77b 100644 --- a/lightbug_http/__init__.mojo +++ b/lightbug_http/__init__.mojo @@ -5,8 +5,3 @@ from lightbug_http.cookie import Cookie, RequestCookieJar, ResponseCookieJar from lightbug_http.service import HTTPService, Welcome, Counter from lightbug_http.server import Server from lightbug_http.strings import to_string - - -trait DefaultConstructible: - fn __init__(out self) raises: - ... diff --git a/lightbug_http/client.mojo b/lightbug_http/client.mojo index a9df7a1c..2f655199 100644 --- a/lightbug_http/client.mojo +++ b/lightbug_http/client.mojo @@ -1,4 +1,6 @@ -from .libc import ( +from collections import Dict +from memory import UnsafePointer +from lightbug_http.libc import ( c_int, AF_INET, SOCK_STREAM, @@ -12,32 +14,32 @@ from lightbug_http.strings import to_string from lightbug_http.net import default_buffer_size from lightbug_http.http import HTTPRequest, HTTPResponse, encode from lightbug_http.header import Headers, HeaderKey -from lightbug_http.net import create_connection, SysConnection +from lightbug_http.net import create_connection, TCPConnection from lightbug_http.io.bytes import Bytes from lightbug_http.utils import ByteReader, logger -from collections import Dict +from lightbug_http.pool_manager import PoolManager struct Client: var host: String var port: Int var name: String + var allow_redirects: Bool - var _connections: Dict[String, SysConnection] + var _connections: PoolManager[TCPConnection] - fn __init__(out self, host: String = "127.0.0.1", port: Int = 8888): + fn __init__( + out self, + host: String = "127.0.0.1", + port: Int = 8888, + cached_connections: Int = 10, + allow_redirects: Bool = False, + ): self.host = host self.port = port self.name = "lightbug_http_client" - self._connections = Dict[String, SysConnection]() - - fn __del__(owned self): - for conn in self._connections.values(): - try: - conn[].close() - except: - # TODO: Add an optional debug log entry here - pass + self.allow_redirects = allow_redirects + self._connections = PoolManager[TCPConnection](cached_connections) fn do(mut self, owned req: HTTPRequest) raises -> HTTPResponse: """The `do` method is responsible for sending an HTTP request to a server and receiving the corresponding response. @@ -84,17 +86,15 @@ struct Client: else: port = 80 - var conn: SysConnection var cached_connection = False + var conn: TCPConnection try: - conn = self._connections[host_str] + conn = self._connections.take(host_str) cached_connection = True - except: - # If connection is not cached, create a new one. - try: - conn = create_connection(socket(AF_INET, SOCK_STREAM, 0), host_str, port) - self._connections[host_str] = conn - except e: + except e: + if str(e) == "PoolManager.take: Key not found.": + conn = create_connection(host_str, port) + else: logger.error(e) raise Error("Client.do: Failed to create a connection to host.") @@ -105,35 +105,49 @@ struct Client: # Maybe peer reset ungracefully, so try a fresh connection if str(e) == "SendError: Connection reset by peer.": logger.debug("Client.do: Connection reset by peer. Trying a fresh connection.") - self._close_conn(host_str) + conn.teardown() if cached_connection: return self.do(req^) logger.error("Client.do: Failed to send message.") raise e - # TODO: What if the response is too large for the buffer? We should read until the end of the response. + # TODO: What if the response is too large for the buffer? We should read until the end of the response. (@thatstoasty) var new_buf = Bytes(capacity=default_buffer_size) - var bytes_recv = conn.read(new_buf) - if bytes_recv == 0: - self._close_conn(host_str) - if cached_connection: - return self.do(req^) - raise Error("Client.do: No response received from the server.") + try: + _ = conn.read(new_buf) + except e: + if str(e) == "EOF": + conn.teardown() + if cached_connection: + return self.do(req^) + raise Error("Client.do: No response received from the server.") + else: + logger.error(e) + raise Error("Client.do: Failed to read response from peer.") + var res: HTTPResponse try: - var res = HTTPResponse.from_bytes(new_buf, conn) - if res.is_redirect(): - self._close_conn(host_str) - return self._handle_redirect(req^, res^) - if res.connection_close(): - self._close_conn(host_str) - return res + res = HTTPResponse.from_bytes(new_buf, conn) except e: - self._close_conn(host_str) + logger.error("Failed to parse a response...") + try: + conn.teardown() + except: + logger.error("Failed to teardown connection...") raise e - return HTTPResponse(Bytes()) + # Redirects should not keep the connection alive, as redirects can send the client to a different server. + if self.allow_redirects and res.is_redirect(): + conn.teardown() + return self._handle_redirect(req^, res^) + # Server told the client to close the connection, we can assume the server closed their side after sending the response. + elif res.connection_close(): + conn.teardown() + # Otherwise, persist the connection by giving it back to the pool manager. + else: + self._connections.give(host_str, conn^) + return res fn _handle_redirect( mut self, owned original_req: HTTPRequest, owned original_response: HTTPResponse @@ -144,20 +158,12 @@ struct Client: new_location = original_response.headers[HeaderKey.LOCATION] except e: raise Error("Client._handle_redirect: `Location` header was not received in the response.") - + if new_location and new_location.startswith("http"): - try: - new_uri = URI.parse_raises(new_location) - except e: - raise Error("Client._handle_redirect: Failed to parse the new URI - " + str(e)) + new_uri = URI.parse(new_location) original_req.headers[HeaderKey.HOST] = new_uri.host else: new_uri = original_req.uri new_uri.path = new_location original_req.uri = new_uri return self.do(original_req^) - - fn _close_conn(mut self, host: String) raises: - if host in self._connections: - self._connections[host].close() - _ = self._connections.pop(host) diff --git a/lightbug_http/cookie/cookie.mojo b/lightbug_http/cookie/cookie.mojo index e06afc16..7969f756 100644 --- a/lightbug_http/cookie/cookie.mojo +++ b/lightbug_http/cookie/cookie.mojo @@ -86,7 +86,7 @@ struct Cookie(CollectionElement): self.partitioned = partitioned fn __str__(self) -> String: - return "Name: " + self.name + " Value: " + self.value + return String.write("Name: ", self.name, " Value: ", self.value) fn __copyinit__(out self: Cookie, existing: Cookie): self.name = existing.name @@ -101,15 +101,15 @@ struct Cookie(CollectionElement): self.partitioned = existing.partitioned fn __moveinit__(out self: Cookie, owned existing: Cookie): - self.name = existing.name - self.value = existing.value - self.max_age = existing.max_age - self.expires = existing.expires - self.domain = existing.domain - self.path = existing.path + self.name = existing.name^ + self.value = existing.value^ + self.max_age = existing.max_age^ + self.expires = existing.expires^ + self.domain = existing.domain^ + self.path = existing.path^ self.secure = existing.secure self.http_only = existing.http_only - self.same_site = existing.same_site + self.same_site = existing.same_site^ self.partitioned = existing.partitioned fn clear_cookie(mut self): @@ -120,23 +120,23 @@ struct Cookie(CollectionElement): return Header(HeaderKey.SET_COOKIE, self.build_header_value()) fn build_header_value(self) -> String: - var header_value = self.name + Cookie.EQUAL + self.value + var header_value = String.write(self.name, Cookie.EQUAL, self.value) if self.expires.is_datetime(): var v = self.expires.http_date_timestamp() if v: - header_value += Cookie.SEPERATOR + Cookie.EXPIRES + Cookie.EQUAL + v.value() + header_value.write(Cookie.SEPERATOR, Cookie.EXPIRES, Cookie.EQUAL, v.value()) if self.max_age: - header_value += Cookie.SEPERATOR + Cookie.MAX_AGE + Cookie.EQUAL + str(self.max_age.value().total_seconds) + header_value.write(Cookie.SEPERATOR, Cookie.MAX_AGE, Cookie.EQUAL, str(self.max_age.value().total_seconds)) if self.domain: - header_value += Cookie.SEPERATOR + Cookie.DOMAIN + Cookie.EQUAL + self.domain.value() + header_value.write(Cookie.SEPERATOR, Cookie.DOMAIN, Cookie.EQUAL, self.domain.value()) if self.path: - header_value += Cookie.SEPERATOR + Cookie.PATH + Cookie.EQUAL + self.path.value() + header_value.write(Cookie.SEPERATOR, Cookie.PATH, Cookie.EQUAL, self.path.value()) if self.secure: - header_value += Cookie.SEPERATOR + Cookie.SECURE + header_value.write(Cookie.SEPERATOR, Cookie.SECURE) if self.http_only: - header_value += Cookie.SEPERATOR + Cookie.HTTP_ONLY + header_value.write(Cookie.SEPERATOR, Cookie.HTTP_ONLY) if self.same_site: - header_value += Cookie.SEPERATOR + Cookie.SAME_SITE + Cookie.EQUAL + str(self.same_site.value()) + header_value.write(Cookie.SEPERATOR, Cookie.SAME_SITE, Cookie.EQUAL, str(self.same_site.value())) if self.partitioned: - header_value += Cookie.SEPERATOR + Cookie.PARTITIONED + header_value.write(Cookie.SEPERATOR, Cookie.PARTITIONED) return header_value diff --git a/lightbug_http/cookie/expiration.mojo b/lightbug_http/cookie/expiration.mojo index bf7094aa..fa865a95 100644 --- a/lightbug_http/cookie/expiration.mojo +++ b/lightbug_http/cookie/expiration.mojo @@ -3,6 +3,7 @@ from small_time import SmallTime alias HTTP_DATE_FORMAT = "ddd, DD MMM YYYY HH:mm:ss ZZZ" alias TZ_GMT = TimeZone(0, "GMT") + @value struct Expiration(CollectionElement): var variant: UInt8 diff --git a/lightbug_http/error.mojo b/lightbug_http/error.mojo index 546c17c8..a513e8f6 100644 --- a/lightbug_http/error.mojo +++ b/lightbug_http/error.mojo @@ -1,7 +1,6 @@ from lightbug_http.http import HTTPResponse -from lightbug_http.io.bytes import bytes -alias TODO_MESSAGE = String("TODO").as_bytes() +alias TODO_MESSAGE = "TODO".as_bytes() # TODO: Custom error handlers provided by the user diff --git a/lightbug_http/header.mojo b/lightbug_http/header.mojo index 89ecd2cd..8c30a212 100644 --- a/lightbug_http/header.mojo +++ b/lightbug_http/header.mojo @@ -22,10 +22,16 @@ struct HeaderKey: @value -struct Header: +struct Header(Writable, Stringable): var key: String var value: String + fn __str__(self) -> String: + return String.write(self) + + fn write_to[T: Writer, //](self, mut writer: T): + writer.write(self.key + ": ", self.value, lineBreak) + @always_inline fn write_header[T: Writer](mut writer: T, key: String, value: String): @@ -63,7 +69,7 @@ struct Headers(Writable, Stringable): return self._inner[key.lower()] except: raise Error("KeyError: Key not found in headers: " + key) - + @always_inline fn get(self, key: String) -> Optional[String]: return self._inner.get(key.lower()) diff --git a/lightbug_http/http/__init__.mojo b/lightbug_http/http/__init__.mojo index 118a394c..0f7c784c 100644 --- a/lightbug_http/http/__init__.mojo +++ b/lightbug_http/http/__init__.mojo @@ -4,11 +4,11 @@ from .request import * from .http_version import HttpVersion -@always_inline -fn encode(owned req: HTTPRequest) -> Bytes: - return req._encoded() +trait Encodable: + fn encode(owned self) -> Bytes: + ... @always_inline -fn encode(owned res: HTTPResponse) -> Bytes: - return res._encoded() +fn encode[T: Encodable](owned data: T) -> Bytes: + return data^.encode() diff --git a/lightbug_http/http/common_response.mojo b/lightbug_http/http/common_response.mojo index 75d5018a..7c8a8942 100644 --- a/lightbug_http/http/common_response.mojo +++ b/lightbug_http/http/common_response.mojo @@ -1,25 +1,11 @@ -fn OK(body: String) -> HTTPResponse: - return HTTPResponse( - headers=Headers(Header(HeaderKey.CONTENT_TYPE, "text/plain")), - body_bytes=bytes(body), - ) - - -fn OK(body: String, content_type: String) -> HTTPResponse: +fn OK(body: String, content_type: String = "text/plain") -> HTTPResponse: return HTTPResponse( headers=Headers(Header(HeaderKey.CONTENT_TYPE, content_type)), body_bytes=bytes(body), ) -fn OK(body: Bytes) -> HTTPResponse: - return HTTPResponse( - headers=Headers(Header(HeaderKey.CONTENT_TYPE, "text/plain")), - body_bytes=body, - ) - - -fn OK(body: Bytes, content_type: String) -> HTTPResponse: +fn OK(body: Bytes, content_type: String = "text/plain") -> HTTPResponse: return HTTPResponse( headers=Headers(Header(HeaderKey.CONTENT_TYPE, content_type)), body_bytes=body, diff --git a/lightbug_http/http/request.mojo b/lightbug_http/http/request.mojo index 0943cbb0..83572e94 100644 --- a/lightbug_http/http/request.mojo +++ b/lightbug_http/http/request.mojo @@ -34,27 +34,28 @@ struct HTTPRequest(Writable, Stringable): fn from_bytes(addr: String, max_body_size: Int, b: Span[Byte]) raises -> HTTPRequest: var reader = ByteReader(b) var headers = Headers() - var cookies = RequestCookieJar() var method: String var protocol: String - var uri_str: String + var uri: String try: var rest = headers.parse_raw(reader) - method, uri_str, protocol = rest[0], rest[1], rest[2] + method, uri, protocol = rest[0], rest[1], rest[2] except e: raise Error("HTTPRequest.from_bytes: Failed to parse request headers: " + str(e)) - + + var cookies = RequestCookieJar() try: cookies.parse_cookies(headers) except e: raise Error("HTTPRequest.from_bytes: Failed to parse cookies: " + str(e)) - var uri = URI.parse_raises(addr + uri_str) var content_length = headers.content_length() if content_length > 0 and max_body_size > 0 and content_length > max_body_size: raise Error("HTTPRequest.from_bytes: Request body too large.") - var request = HTTPRequest(uri, headers=headers, method=method, protocol=protocol, cookies=cookies) + var request = HTTPRequest( + URI.parse(addr + uri), headers=headers, method=method, protocol=protocol, cookies=cookies + ) try: request.read_body(reader, content_length, max_body_size) except e: @@ -87,6 +88,9 @@ struct HTTPRequest(Writable, Stringable): if HeaderKey.HOST not in self.headers: self.headers[HeaderKey.HOST] = uri.host + fn get_body(self) -> StringSlice[__origin_of(self.body_raw)]: + return StringSlice(unsafe_from_utf8=Span(self.body_raw)) + fn set_connection_close(mut self): self.headers[HeaderKey.CONNECTION] = "close" @@ -104,55 +108,51 @@ struct HTTPRequest(Writable, Stringable): if content_length > max_body_size: raise Error("Request body too large") - self.body_raw = r.bytes(content_length) + self.body_raw = r.read_bytes(content_length) self.set_content_length(content_length) - fn write_to[T: Writer](self, mut writer: T): - writer.write(self.method, whitespace) + fn write_to[T: Writer, //](self, mut writer: T): path = self.uri.path if len(self.uri.path) > 1 else strSlash if len(self.uri.query_string) > 0: - path += "?" + self.uri.query_string - - writer.write(path) + path.write("?", self.uri.query_string) writer.write( + self.method, + whitespace, + path, whitespace, self.protocol, lineBreak, + self.headers, + self.cookies, + lineBreak, + to_string(self.body_raw), ) - self.headers.write_to(writer) - self.cookies.write_to(writer) - writer.write(lineBreak) - writer.write(to_string(self.body_raw)) - - # TODO: If we want to consume the args for speed, then this should be owned and not mut. self is being consumed and is invalid after this call. - fn _encoded(owned self) -> Bytes: + fn encode(owned self) -> Bytes: """Encodes request as bytes. This method consumes the data in this request and it should no longer be considered valid. """ - var writer = ByteWriter() - writer.write(self.method) - writer.write(whitespace) var path = self.uri.path if len(self.uri.path) > 1 else strSlash if len(self.uri.query_string) > 0: - path += "?" + self.uri.query_string - writer.write(path) - writer.write(whitespace) - writer.write(self.protocol) - writer.write(lineBreak) - - writer.write(self.headers) - writer.write(self.cookies) - # self.headers.encode_to(writer) - # self.cookies.encode_to(writer) - writer.write(lineBreak) - - writer.consuming_write(self.body_raw) + path.write("?", self.uri.query_string) + var writer = ByteWriter() + writer.write( + self.method, + whitespace, + path, + whitespace, + self.protocol, + lineBreak, + self.headers, + self.cookies, + lineBreak, + ) + writer.consuming_write(self^.body_raw) return writer.consume() fn __str__(self) -> String: - return to_string(self) + return String.write(self) diff --git a/lightbug_http/http/response.mojo b/lightbug_http/http/response.mojo index 26000d4f..333ef494 100644 --- a/lightbug_http/http/response.mojo +++ b/lightbug_http/http/response.mojo @@ -14,7 +14,7 @@ from lightbug_http.strings import ( ) from collections import Optional from utils import StringSlice -from lightbug_http.net import SysConnection, default_buffer_size +from lightbug_http.net import TCPConnection, default_buffer_size struct StatusCode: @@ -38,7 +38,7 @@ struct HTTPResponse(Writable, Stringable): var protocol: String @staticmethod - fn from_bytes(b: Span[Byte], conn: Optional[SysConnection] = None) raises -> HTTPResponse: + fn from_bytes(b: Span[Byte]) raises -> HTTPResponse: var reader = ByteReader(b) var headers = Headers() var cookies = ResponseCookieJar() @@ -50,10 +50,40 @@ struct HTTPResponse(Writable, Stringable): var properties = headers.parse_raw(reader) protocol, status_code, status_text = properties[0], properties[1], properties[2] cookies.from_headers(properties[3]) - reader.skip_newlines() + reader.skip_carriage_return() except e: - raise Error("Failed to parse response headers: " + e.__str__()) - + raise Error("Failed to parse response headers: " + str(e)) + + try: + return HTTPResponse( + reader=reader, + headers=headers, + cookies=cookies, + protocol=protocol, + status_code=int(status_code), + status_text=status_text, + ) + except e: + logger.error(e) + raise Error("Failed to read request body") + + @staticmethod + fn from_bytes(b: Span[Byte], conn: TCPConnection) raises -> HTTPResponse: + var reader = ByteReader(b) + var headers = Headers() + var cookies = ResponseCookieJar() + var protocol: String + var status_code: String + var status_text: String + + try: + var properties = headers.parse_raw(reader) + protocol, status_code, status_text = properties[0], properties[1], properties[2] + cookies.from_headers(properties[3]) + reader.skip_carriage_return() + except e: + raise Error("Failed to parse response headers: " + str(e)) + var response = HTTPResponse( Bytes(), headers=headers, @@ -65,21 +95,23 @@ struct HTTPResponse(Writable, Stringable): var transfer_encoding = response.headers.get(HeaderKey.TRANSFER_ENCODING) if transfer_encoding and transfer_encoding.value() == "chunked": - var b = reader.bytes() - + var b = Bytes(reader.read_bytes()) var buff = Bytes(capacity=default_buffer_size) try: - while conn.value().read(buff) > 0: + while conn.read(buff) > 0: b += buff - if buff[-5] == byte('0') and buff[-4] == byte('\r') - and buff[-3] == byte('\n') - and buff[-2] == byte('\r') - and buff[-1] == byte('\n'): + if ( + buff[-5] == byte("0") + and buff[-4] == byte("\r") + and buff[-3] == byte("\n") + and buff[-2] == byte("\r") + and buff[-1] == byte("\n") + ): break buff.resize(0) - response.read_chunks(b^) + response.read_chunks(b) return response except e: logger.error(e) @@ -115,28 +147,57 @@ struct HTTPResponse(Writable, Stringable): self.set_content_length(len(body_bytes)) if HeaderKey.DATE not in self.headers: try: - var current_time = now(utc=True).__str__() + var current_time = str(now(utc=True)) + self.headers[HeaderKey.DATE] = current_time + except: + logger.debug("DATE header not set, unable to get current time and it was instead omitted.") + + fn __init__( + mut self, + mut reader: ByteReader, + headers: Headers = Headers(), + cookies: ResponseCookieJar = ResponseCookieJar(), + status_code: Int = 200, + status_text: String = "OK", + protocol: String = strHttp11, + ) raises: + self.headers = headers + self.cookies = cookies + if HeaderKey.CONTENT_TYPE not in self.headers: + self.headers[HeaderKey.CONTENT_TYPE] = "application/octet-stream" + self.status_code = status_code + self.status_text = status_text + self.protocol = protocol + self.body_raw = reader.read_bytes() + self.set_content_length(len(self.body_raw)) + if HeaderKey.CONNECTION not in self.headers: + self.set_connection_keep_alive() + if HeaderKey.CONTENT_LENGTH not in self.headers: + self.set_content_length(len(self.body_raw)) + if HeaderKey.DATE not in self.headers: + try: + var current_time = str(now(utc=True)) self.headers[HeaderKey.DATE] = current_time except: pass - fn get_body_bytes(self) -> Bytes: - return self.body_raw + fn get_body(self) -> StringSlice[__origin_of(self.body_raw)]: + return StringSlice(unsafe_from_utf8=Span(self.body_raw)) @always_inline fn set_connection_close(mut self): self.headers[HeaderKey.CONNECTION] = "close" - @always_inline - fn set_connection_keep_alive(mut self): - self.headers[HeaderKey.CONNECTION] = "keep-alive" - fn connection_close(self) -> Bool: var result = self.headers.get(HeaderKey.CONNECTION) if not result: return False return result.value() == "close" + @always_inline + fn set_connection_keep_alive(mut self): + self.headers[HeaderKey.CONNECTION] = "keep-alive" + @always_inline fn set_content_length(mut self, l: Int): self.headers[HeaderKey.CONTENT_LENGTH] = str(l) @@ -159,17 +220,17 @@ struct HTTPResponse(Writable, Stringable): @always_inline fn read_body(mut self, mut r: ByteReader) raises -> None: - self.body_raw = r.bytes(self.content_length()) + self.body_raw = r.read_bytes(self.content_length()) self.set_content_length(len(self.body_raw)) - fn read_chunks(mut self, chunks: Bytes) raises: - var reader = ByteReader(Span(chunks)) + fn read_chunks(mut self, chunks: Span[Byte]) raises: + var reader = ByteReader(chunks) while True: var size = atol(StringSlice(unsafe_from_utf8=reader.read_line()), 16) if size == 0: break - var data = reader.bytes(size) - reader.skip_newlines() + var data = reader.read_bytes(size) + reader.skip_carriage_return() self.set_content_length(self.content_length() + len(data)) self.body_raw += data @@ -179,40 +240,33 @@ struct HTTPResponse(Writable, Stringable): if HeaderKey.SERVER not in self.headers: writer.write("server: lightbug_http", lineBreak) - writer.write(self.headers) - writer.write(self.cookies) - writer.write(lineBreak) - writer.write(to_string(self.body_raw)) + writer.write(self.headers, self.cookies, lineBreak, to_string(self.body_raw)) - fn _encoded(mut self) -> Bytes: + fn encode(owned self) -> Bytes: """Encodes response as bytes. This method consumes the data in this request and it should no longer be considered valid. """ var writer = ByteWriter() - writer.write(self.protocol) - writer.write(whitespace) - writer.consuming_write(bytes(str(self.status_code))) - writer.write(whitespace) - writer.write(self.status_text) - writer.write(lineBreak) - writer.write("server: lightbug_http") - writer.write(lineBreak) - + writer.write( + self.protocol, + whitespace, + str(self.status_code), + whitespace, + self.status_text, + lineBreak, + "server: lightbug_http", + lineBreak, + ) if HeaderKey.DATE not in self.headers: try: - var current_time = now(utc=True).__str__() - write_header(writer, HeaderKey.DATE, current_time) + write_header(writer, HeaderKey.DATE, str(now(utc=True))) except: pass - - writer.write(self.headers) - writer.write(self.cookies) - writer.write(lineBreak) - writer.consuming_write(self.body_raw) - + writer.write(self.headers, self.cookies, lineBreak) + writer.consuming_write(self^.body_raw) return writer.consume() fn __str__(self) -> String: - return to_string(self) + return String.write(self) diff --git a/lightbug_http/io/__init__.mojo b/lightbug_http/io/__init__.mojo index e69de29b..ab9b9eb3 100644 --- a/lightbug_http/io/__init__.mojo +++ b/lightbug_http/io/__init__.mojo @@ -0,0 +1,2 @@ +from lightbug_http.io.bytes import Bytes +from lightbug_http.io.sync import Duration diff --git a/lightbug_http/io/bytes.mojo b/lightbug_http/io/bytes.mojo index 325f68d0..915bd911 100644 --- a/lightbug_http/io/bytes.mojo +++ b/lightbug_http/io/bytes.mojo @@ -1,4 +1,4 @@ -alias Bytes = List[Byte, True] # TODO: We need to push upstream changes to Mojo so bytes correctly uses List[Byte, True] for the trivial type optimization. +alias Bytes = List[Byte, True] @always_inline @@ -9,12 +9,3 @@ fn byte(s: String) -> Byte: @always_inline fn bytes(s: String) -> Bytes: return s.as_bytes() - - -fn compare_case_insensitive(a: Bytes, b: Bytes) -> Bool: - if len(a) != len(b): - return False - for i in range(len(a) - 1): - if (a[i] | 0x20) != (b[i] | 0x20): - return False - return True diff --git a/lightbug_http/libc.mojo b/lightbug_http/libc.mojo index 1faa08da..70e7ad13 100644 --- a/lightbug_http/libc.mojo +++ b/lightbug_http/libc.mojo @@ -76,28 +76,29 @@ alias EPIPE = 32 alias EDOM = 33 alias ERANGE = 34 alias EWOULDBLOCK = EAGAIN -alias EINPROGRESS = 36 -alias EALREADY = 37 -alias ENOTSOCK = 38 -alias EDESTADDRREQ = 39 -alias EMSGSIZE = 40 -alias ENOPROTOOPT = 42 -alias EAFNOSUPPORT = 47 -alias EADDRINUSE = 48 -alias EADDRNOTAVAIL = 49 -alias ENETUNREACH = 51 -alias ECONNABORTED = 53 -alias ECONNRESET = 54 -alias ENOBUFS = 55 -alias EISCONN = 56 -alias ENOTCONN = 57 -alias ETIMEDOUT = 60 -alias ECONNREFUSED = 61 -alias ELOOP = 62 -alias ENAMETOOLONG = 63 -alias EDQUOT = 69 -alias EPROTO = 100 -alias EOPNOTSUPP = 102 +alias EINPROGRESS = 36 if os_is_macos() else 115 +alias EALREADY = 37 if os_is_macos() else 114 +alias ENOTSOCK = 38 if os_is_macos() else 88 +alias EDESTADDRREQ = 39 if os_is_macos() else 89 +alias EMSGSIZE = 40 if os_is_macos() else 90 +alias ENOPROTOOPT = 42 if os_is_macos() else 92 +alias EAFNOSUPPORT = 47 if os_is_macos() else 97 +alias EADDRINUSE = 48 if os_is_macos() else 98 +alias EADDRNOTAVAIL = 49 if os_is_macos() else 99 +alias ENETUNREACH = 51 if os_is_macos() else 101 +alias ECONNABORTED = 53 if os_is_macos() else 103 +alias ECONNRESET = 54 if os_is_macos() else 104 +alias ENOBUFS = 55 if os_is_macos() else 105 +alias EISCONN = 56 if os_is_macos() else 106 +alias ENOTCONN = 57 if os_is_macos() else 107 +alias ETIMEDOUT = 60 if os_is_macos() else 110 +alias ECONNREFUSED = 61 if os_is_macos() else 111 +alias ELOOP = 62 if os_is_macos() else 40 +alias ENAMETOOLONG = 63 if os_is_macos() else 36 +alias EDQUOT = 69 if os_is_macos() else 122 +alias ENOMSG = 91 if os_is_macos() else 42 +alias EPROTO = 100 if os_is_macos() else 71 +alias EOPNOTSUPP = 102 if os_is_macos() else 95 # --- ( Network Related Constants )--------------------------------------------- alias sa_family_t = c_ushort @@ -144,7 +145,7 @@ alias AF_SIP = 29 # Simple Internet Protocol alias AF_KEY = 30 alias pseudo_AF_HDRCMPLT = 31 # Used by BPF to not rewrite headers in interface output routine alias AF_BLUETOOTH = 32 # Bluetooth -alias AF_MPLS = 33 # MPLS +alias AF_MPLS = 33 # MPLS alias pseudo_AF_PFLOW = 34 # pflow alias pseudo_AF_PIPEX = 35 # PIPEX alias AF_FRAME = 36 # frame (Ethernet) sockets @@ -172,12 +173,12 @@ alias PF_HYLINK = AF_HYLINK alias PF_APPLETALK = AF_APPLETALK alias PF_ROUTE = AF_ROUTE alias PF_LINK = AF_LINK -alias PF_XTP = pseudo_AF_XTP # really just proto family, no AF +alias PF_XTP = pseudo_AF_XTP # really just proto family, no AF alias PF_COIP = AF_COIP alias PF_CNT = AF_CNT alias PF_IPX = AF_IPX # same format as = AF_NS alias PF_INET6 = AF_INET6 -alias PF_RTIP = pseudo_AF_RTIP # same format as AF_INET +alias PF_RTIP = pseudo_AF_RTIP # same format as AF_INET alias PF_PIP = pseudo_AF_PIP alias PF_ISDN = AF_ISDN alias PF_NATM = AF_NATM @@ -219,7 +220,7 @@ alias SHUT_RD = 0 alias SHUT_WR = 1 alias SHUT_RDWR = 2 -alias SOL_SOCKET = 0xffff +alias SOL_SOCKET = 0xFFFF # Socket option flags # TODO: These are probably platform specific, on MacOS I have these values, but we should check on Linux. @@ -285,6 +286,19 @@ struct sockaddr_in: var sin_addr: in_addr var sin_zero: StaticTuple[c_char, 8] + fn __init__(out self, address_family: Int, port: UInt16, binary_ip: UInt32): + """Construct a sockaddr_in struct. + + Args: + address_family: The address family. + port: A 16-bit integer port in host byte order, gets converted to network byte order via `htons`. + binary_ip: The binary representation of the IP address. + """ + self.sin_family = address_family + self.sin_port = htons(port) + self.sin_addr = in_addr(binary_ip) + self.sin_zero = StaticTuple[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0) + @value @register_passable("trivial") @@ -347,10 +361,10 @@ fn htons(hostshort: c_ushort) -> c_ushort: Args: hostshort: A 16-bit integer in host byte order. - + Returns: The value provided in network byte order. - + #### C Function ```c uint16_t htons(uint16_t hostshort) @@ -370,12 +384,12 @@ fn ntohl(netlong: c_uint) -> c_uint: Returns: The value provided in host byte order. - + #### C Function ```c uint32_t ntohl(uint32_t netlong) ``` - + #### Notes: * Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html """ @@ -390,12 +404,12 @@ fn ntohs(netshort: c_ushort) -> c_ushort: Returns: The value provided in host byte order. - + #### C Function ```c uint16_t ntohs(uint16_t netshort) ``` - + #### Notes: * Reference: https://man7.org/linux/man-pages/man3/htonl.3p.html """ @@ -418,7 +432,7 @@ fn _inet_ntop( Returns: A UnsafePointer to the buffer containing the result. - + #### C Function ```c const char *inet_ntop(int af, const void *restrict src, char *restrict dst, socklen_t size) @@ -437,28 +451,26 @@ fn _inet_ntop( ](af, src, dst, size) -fn inet_ntop( - af: c_int, - src: UnsafePointer[c_void], - dst: UnsafePointer[c_char], - size: socklen_t, -) raises -> String: +fn inet_ntop[ + address_family: Int32, address_length: Int +](ip_address: UInt32,) raises -> String: """Libc POSIX `inet_ntop` function. + Parameters: + address_family: Address Family see AF_ aliases. + address_length: Address length. + Args: - af: Address Family see AF_ aliases. - src: A UnsafePointer to a binary address. - dst: A UnsafePointer to a buffer to store the result. - size: The size of the buffer. + ip_address: Binary IP address. Returns: - A UnsafePointer to the buffer containing the result. - + The IP Address in the human readable format. + Raises: Error: If an error occurs while converting the address. EAFNOSUPPORT: `*src` was not an `AF_INET` or `AF_INET6` family address. ENOSPC: The buffer size, `size`, was not large enough to store the presentation form of the address. - + #### C Function ```c const char *inet_ntop(int af, const void *restrict src, char *restrict dst, socklen_t size) @@ -467,7 +479,24 @@ fn inet_ntop( #### Notes: * Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html. """ - var result = _inet_ntop(af, src, dst, size) + constrained[ + int(address_family) in [AF_INET, AF_INET6], "Address family must be either INET_ADDRSTRLEN or INET6_ADDRSTRLEN." + ]() + constrained[ + address_length in [INET_ADDRSTRLEN, INET6_ADDRSTRLEN], + "Address family must be either INET_ADDRSTRLEN or INET6_ADDRSTRLEN.", + ]() + var dst = String(capacity=address_length) + var result = _inet_ntop( + address_family, UnsafePointer.address_of(ip_address).bitcast[c_void](), dst.unsafe_ptr(), address_length + ) + + var i = 0 + while i <= address_length: + if result[i] == 0: + break + i += 1 + dst._buffer.size = i + 1 # Need to modify internal buffer's size for the string to be valid. # `inet_ntop` returns NULL on error. if not result: @@ -475,12 +504,15 @@ fn inet_ntop( if errno == EAFNOSUPPORT: raise Error("inet_ntop Error: `*src` was not an `AF_INET` or `AF_INET6` family address.") elif errno == ENOSPC: - raise Error("inet_ntop Error: The buffer size, `size`, was not large enough to store the presentation form of the address.") + raise Error( + "inet_ntop Error: The buffer size, `size`, was not large enough to store the presentation form of the" + " address." + ) else: raise Error("inet_ntop Error: An error occurred while converting the address. Error code: " + str(errno)) - + # We want the string representation of the address, so it's ok to take ownership of the pointer here. - return String(ptr=result, length=int(size)) + return dst fn _inet_pton(af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void]) -> c_int: @@ -490,13 +522,13 @@ fn _inet_pton(af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void]) or -1 if some system error occurred (in which case errno will have been set). Args: - af: Address Family see AF_ aliases. + af: Address Family: `AF_INET` or `AF_INET6`. src: A UnsafePointer to a string containing the address. dst: A UnsafePointer to a buffer to store the result. - + Returns: 1 on success, 0 if the input is not a valid address, -1 on error. - + #### C Function ```c int inet_pton(int af, const char *restrict src, void *restrict dst) @@ -514,18 +546,22 @@ fn _inet_pton(af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void]) ](af, src, dst) -fn inet_pton(af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void]) raises: +fn inet_pton[address_family: Int32](src: UnsafePointer[c_char]) raises -> c_uint: """Libc POSIX `inet_pton` function. Converts a presentation format address (that is, printable form as held in a character string) to network format (usually a struct in_addr or some other internal binary representation, in network byte order). + Parameters: + address_family: Address Family: `AF_INET` or `AF_INET6`. + Args: - af: Address Family see AF_ aliases. src: A UnsafePointer to a string containing the address. - dst: A UnsafePointer to a buffer to store the result. - + + Returns: + The binary representation of the ip address. + Raises: Error: If an error occurs while converting the address or the input is not a valid address. - + #### C Function ```c int inet_pton(int af, const char *restrict src, void *restrict dst) @@ -535,13 +571,26 @@ fn inet_pton(af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void]) * Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html * This function is valid for `AF_INET` and `AF_INET6`. """ - var result = _inet_pton(af, src, dst) + constrained[ + int(address_family) in [AF_INET, AF_INET6], "Address family must be either INET_ADDRSTRLEN or INET6_ADDRSTRLEN." + ]() + var ip_buffer: UnsafePointer[c_void] + + @parameter + if address_family == AF_INET6: + ip_buffer = stack_allocation[16, c_void]() + else: + ip_buffer = stack_allocation[4, c_void]() + + var result = _inet_pton(address_family, src, ip_buffer) if result == 0: raise Error("inet_pton Error: The input is not a valid address.") elif result == -1: var errno = get_errno() raise Error("inet_pton Error: An error occurred while converting the address. Error code: " + str(errno)) + return ip_buffer.bitcast[c_uint]().take_pointee() + fn _socket(domain: c_int, type: c_int, protocol: c_int) -> c_int: """Libc POSIX `socket` function. @@ -573,9 +622,9 @@ fn socket(domain: c_int, type: c_int, protocol: c_int) raises -> c_int: type: Socket Type see SOCK_ aliases. protocol: The protocol to use. - Returns: + Returns: A File Descriptor or -1 in case of failure. - + Raises: SocketError: If an error occurs while creating the socket. EACCES: Permission to create a socket of the specified type and/or protocol is denied. @@ -598,26 +647,42 @@ fn socket(domain: c_int, type: c_int, protocol: c_int) raises -> c_int: if fd == -1: var errno = get_errno() if errno == EACCES: - raise Error("SocketError (EACCES): Permission to create a socket of the specified type and/or protocol is denied.") + raise Error( + "SocketError (EACCES): Permission to create a socket of the specified type and/or protocol is denied." + ) elif errno == EAFNOSUPPORT: raise Error("SocketError (EAFNOSUPPORT): The implementation does not support the specified address family.") elif errno == EINVAL: - raise Error("SocketError (EINVAL): Invalid flags in type, Unknown protocol, or protocol family not available.") + raise Error( + "SocketError (EINVAL): Invalid flags in type, Unknown protocol, or protocol family not available." + ) elif errno == EMFILE: - raise Error("SocketError (EMFILE): The per-process limit on the number of open file descriptors has been reached.") + raise Error( + "SocketError (EMFILE): The per-process limit on the number of open file descriptors has been reached." + ) elif errno == ENFILE: - raise Error("SocketError (ENFILE): The system-wide limit on the total number of open files has been reached.") + raise Error( + "SocketError (ENFILE): The system-wide limit on the total number of open files has been reached." + ) elif int(errno) in [ENOBUFS, ENOMEM]: - raise Error("SocketError (ENOBUFS or ENOMEM): Insufficient memory is available. The socket cannot be created until sufficient resources are freed.") + raise Error( + "SocketError (ENOBUFS or ENOMEM): Insufficient memory is available. The socket cannot be created until" + " sufficient resources are freed." + ) elif errno == EPROTONOSUPPORT: - raise Error("SocketError (EPROTONOSUPPORT): The protocol type or the specified protocol is not supported within this domain.") + raise Error( + "SocketError (EPROTONOSUPPORT): The protocol type or the specified protocol is not supported within" + " this domain." + ) else: raise Error("SocketError: An error occurred while creating the socket. Error code: " + str(errno)) return fd -fn _setsockopt[origin: Origin]( +fn _setsockopt[ + origin: Origin +]( socket: c_int, level: c_int, option_name: c_int, @@ -655,12 +720,11 @@ fn _setsockopt[origin: Origin]( ](socket, level, option_name, option_value, option_len) -fn setsockopt[origin: Origin]( +fn setsockopt( socket: c_int, level: c_int, option_name: c_int, - option_value: Pointer[c_void, origin], - option_len: socklen_t, + option_value: c_void, ) raises: """Libc POSIX `setsockopt` function. Manipulate options for the socket referred to by the file descriptor, `socket`. @@ -669,7 +733,6 @@ fn setsockopt[origin: Origin]( level: The protocol level. option_name: The option to set. option_value: A UnsafePointer to the value to set. - option_len: The size of the value. Raises: Error: If an error occurs while setting the socket option. @@ -687,7 +750,7 @@ fn setsockopt[origin: Origin]( #### Notes: * Reference: https://man7.org/linux/man-pages/man3/setsockopt.3p.html """ - var result = _setsockopt(socket, level, option_name, option_value, option_len) + var result = _setsockopt(socket, level, option_name, Pointer.address_of(option_value), sizeof[Int]()) if result == -1: var errno = get_errno() if errno == EBADF: @@ -695,27 +758,121 @@ fn setsockopt[origin: Origin]( elif errno == EFAULT: raise Error("setsockopt: The argument `option_value` points outside the process's allocated address space.") elif errno == EINVAL: - raise Error("setsockopt: The argument `option_len` is invalid. Can sometimes occur when `option_value` is invalid.") + raise Error( + "setsockopt: The argument `option_len` is invalid. Can sometimes occur when `option_value` is invalid." + ) elif errno == ENOPROTOOPT: - raise Error("setsockopt: The option is unknown at the level indicated.") + raise Error("setsockopt [InvalidProtocol]: The option is unknown at the level indicated.") elif errno == ENOTSOCK: raise Error("setsockopt: The argument `socket` is not a socket.") else: raise Error("setsockopt: An error occurred while setting the socket option. Error code: " + str(errno)) -fn _getsockname[origin: Origin]( +fn _getsockopt[ + len_origin: Origin +]( socket: c_int, - address: UnsafePointer[sockaddr], - address_len: Pointer[socklen_t, origin], + level: c_int, + option_name: c_int, + option_value: UnsafePointer[c_void], + option_len: Pointer[socklen_t, len_origin], ) -> c_int: + """Libc POSIX `setsockopt` function. + + Args: + socket: A File Descriptor. + level: The protocol level. + option_name: The option to set. + option_value: A Pointer to the value to set. + option_len: The size of the value. + + Returns: + 0 on success, -1 on error. + + #### C Function + ```c + int getsockopt(int socket, int level, int option_name, const void *option_value, socklen_t option_len) + ``` + + #### Notes: + * Reference: https://man7.org/linux/man-pages/man3/setsockopt.3p.html + """ + return external_call[ + "getsockopt", + c_int, # FnName, RetType + c_int, + c_int, + c_int, + UnsafePointer[c_void], + Pointer[socklen_t, len_origin], # Args + ](socket, level, option_name, option_value, option_len) + + +fn getsockopt( + socket: c_int, + level: c_int, + option_name: c_int, +) raises -> Int: + """Libc POSIX `getsockopt` function. Manipulate options for the socket referred to by the file descriptor, `socket`. + + Args: + socket: A File Descriptor. + level: The protocol level. + option_name: The option to set. + + Returns: + The value of the option. + + Raises: + Error: If an error occurs while setting the socket option. + EBADF: The argument `socket` is not a valid descriptor. + EFAULT: The argument `option_value` points outside the process's allocated address space. + EINVAL: The argument `option_len` is invalid. Can sometimes occur when `option_value` is invalid. + ENOPROTOOPT: The option is unknown at the level indicated. + ENOTSOCK: The argument `socket` is not a socket. + + #### C Function + ```c + int getsockopt(int sockfd, int level, int optname, void optval[restrict *.optlen], socklen_t *restrict optlen); + ``` + + #### Notes: + * Reference: https://man7.org/linux/man-pages/man3/getsockopt.3p.html + """ + var option_value = stack_allocation[1, c_void]() + var option_len: socklen_t = sizeof[Int]() + var result = _getsockopt(socket, level, option_name, option_value, Pointer.address_of(option_len)) + if result == -1: + var errno = get_errno() + if errno == EBADF: + raise Error("getsockopt: The argument `socket` is not a valid descriptor.") + elif errno == EFAULT: + raise Error("getsockopt: The argument `option_value` points outside the process's allocated address space.") + elif errno == EINVAL: + raise Error( + "getsockopt: The argument `option_len` is invalid. Can sometimes occur when `option_value` is invalid." + ) + elif errno == ENOPROTOOPT: + raise Error("getsockopt: The option is unknown at the level indicated.") + elif errno == ENOTSOCK: + raise Error("getsockopt: The argument `socket` is not a socket.") + else: + raise Error("getsockopt: An error occurred while setting the socket option. Error code: " + str(errno)) + + return option_value.bitcast[Int]().take_pointee() + + +fn _getsockname[ + origin: Origin +](socket: c_int, address: UnsafePointer[sockaddr], address_len: Pointer[socklen_t, origin],) -> c_int: """Libc POSIX `getsockname` function. Args: socket: A File Descriptor. address: A UnsafePointer to a buffer to store the address of the peer. address_len: A UnsafePointer to the size of the buffer. - + Returns: 0 on success, -1 on error. @@ -736,18 +893,16 @@ fn _getsockname[origin: Origin]( ](socket, address, address_len) -fn getsockname[origin: Origin]( - socket: c_int, - address: UnsafePointer[sockaddr], - address_len: Pointer[socklen_t, origin], -) raises: +fn getsockname[ + origin: Origin +](socket: c_int, address: UnsafePointer[sockaddr], address_len: Pointer[socklen_t, origin],) raises: """Libc POSIX `getsockname` function. Args: socket: A File Descriptor. address: A UnsafePointer to a buffer to store the address of the peer. address_len: A UnsafePointer to the size of the buffer. - + Raises: Error: If an error occurs while getting the socket name. EBADF: The argument `socket` is not a valid descriptor. @@ -770,7 +925,9 @@ fn getsockname[origin: Origin]( if errno == EBADF: raise Error("getsockname: The argument `socket` is not a valid descriptor.") elif errno == EFAULT: - raise Error("getsockname: The `address` argument points to memory not in a valid part of the process address space.") + raise Error( + "getsockname: The `address` argument points to memory not in a valid part of the process address space." + ) elif errno == EINVAL: raise Error("getsockname: `address_len` is invalid (e.g., is negative).") elif errno == ENOBUFS: @@ -781,11 +938,9 @@ fn getsockname[origin: Origin]( raise Error("getsockname: An error occurred while getting the socket name. Error code: " + str(errno)) -fn _getpeername[origin: Origin]( - sockfd: c_int, - addr: UnsafePointer[sockaddr], - address_len: Pointer[socklen_t, origin], -) -> c_int: +fn _getpeername[ + origin: Origin +](sockfd: c_int, addr: UnsafePointer[sockaddr], address_len: Pointer[socklen_t, origin],) -> c_int: """Libc POSIX `getpeername` function. Args: @@ -813,17 +968,11 @@ fn _getpeername[origin: Origin]( ](sockfd, addr, address_len) -fn getpeername[origin: Origin]( - sockfd: c_int, - addr: UnsafePointer[sockaddr], - address_len: Pointer[socklen_t, origin], -) raises: +fn getpeername(file_descriptor: c_int) raises -> sockaddr_in: """Libc POSIX `getpeername` function. Args: - sockfd: A File Descriptor. - addr: A UnsafePointer to a buffer to store the address of the peer. - address_len: A UnsafePointer to the size of the buffer. + file_descriptor: A File Descriptor. Raises: Error: If an error occurs while getting the socket name. @@ -842,13 +991,16 @@ fn getpeername[origin: Origin]( #### Notes: * Reference: https://man7.org/linux/man-pages/man2/getpeername.2.html """ - var result = _getpeername(sockfd, addr, address_len) + var remote_address = stack_allocation[1, sockaddr]() + var result = _getpeername(file_descriptor, remote_address, Pointer.address_of(socklen_t(sizeof[sockaddr]()))) if result == -1: var errno = get_errno() if errno == EBADF: raise Error("getpeername: The argument `socket` is not a valid descriptor.") elif errno == EFAULT: - raise Error("getpeername: The `addr` argument points to memory not in a valid part of the process address space.") + raise Error( + "getpeername: The `addr` argument points to memory not in a valid part of the process address space." + ) elif errno == EINVAL: raise Error("getpeername: `address_len` is invalid (e.g., is negative).") elif errno == ENOBUFS: @@ -860,6 +1012,9 @@ fn getpeername[origin: Origin]( else: raise Error("getpeername: An error occurred while getting the socket name. Error code: " + str(errno)) + # Cast sockaddr struct to sockaddr_in + return remote_address.bitcast[sockaddr_in]().take_pointee() + fn _bind[origin: MutableOrigin](socket: c_int, address: Pointer[sockaddr_in, origin], address_len: socklen_t) -> c_int: """Libc POSIX `bind` function. @@ -868,7 +1023,7 @@ fn _bind[origin: MutableOrigin](socket: c_int, address: Pointer[sockaddr_in, ori socket: A File Descriptor. address: A UnsafePointer to the address to bind to. address_len: The size of the address. - + Returns: 0 on success, -1 on error. @@ -883,14 +1038,13 @@ fn _bind[origin: MutableOrigin](socket: c_int, address: Pointer[sockaddr_in, ori return external_call["bind", c_int, c_int, Pointer[sockaddr_in, origin], socklen_t](socket, address, address_len) -fn bind[origin: MutableOrigin](socket: c_int, address: Pointer[sockaddr_in, origin], address_len: socklen_t) raises: +fn bind(socket: c_int, mut address: sockaddr_in) raises: """Libc POSIX `bind` function. Args: socket: A File Descriptor. address: A UnsafePointer to the address to bind to. - address_len: The size of the address. - + Raises: Error: If an error occurs while binding the socket. EACCES: The address, `address`, is protected, and the user is not the superuser. @@ -919,7 +1073,7 @@ fn bind[origin: MutableOrigin](socket: c_int, address: Pointer[sockaddr_in, orig #### Notes: * Reference: https://man7.org/linux/man-pages/man3/bind.3p.html """ - var result = _bind(socket, address, address_len) + var result = _bind(socket, Pointer.address_of(address), sizeof[sockaddr_in]()) if result == -1: var errno = get_errno() if errno == EACCES: @@ -955,7 +1109,7 @@ fn bind[origin: MutableOrigin](socket: c_int, address: Pointer[sockaddr_in, orig # raise Error("bind: A component of the path prefix is not a directory.") # elif errno == EROFS: # raise Error("bind: The socket inode would reside on a read-only file system.") - + raise Error("bind: An error occurred while binding the socket. Error code: " + str(errno)) @@ -1017,11 +1171,9 @@ fn listen(socket: c_int, backlog: c_int) raises: raise Error("listen: An error occurred while listening on the socket. Error code: " + str(errno)) -fn _accept[address_origin: MutableOrigin, len_origin: Origin]( - socket: c_int, - address: Pointer[sockaddr, address_origin], - address_len: Pointer[socklen_t, len_origin], -) -> c_int: +fn _accept[ + address_origin: MutableOrigin, len_origin: Origin +](socket: c_int, address: Pointer[sockaddr, address_origin], address_len: Pointer[socklen_t, len_origin],) -> c_int: """Libc POSIX `accept` function. Args: @@ -1041,25 +1193,15 @@ fn _accept[address_origin: MutableOrigin, len_origin: Origin]( * Reference: https://man7.org/linux/man-pages/man3/accept.3p.html """ return external_call[ - "accept", - c_int, # FnName, RetType - c_int, - Pointer[sockaddr, address_origin], - Pointer[socklen_t, len_origin] + "accept", c_int, c_int, Pointer[sockaddr, address_origin], Pointer[socklen_t, len_origin] # FnName, RetType ](socket, address, address_len) -fn accept[address_origin: MutableOrigin, len_origin: Origin]( - socket: c_int, - address: Pointer[sockaddr, address_origin], - address_len: Pointer[socklen_t, len_origin], -) raises -> c_int: +fn accept(socket: c_int) raises -> c_int: """Libc POSIX `accept` function. Args: socket: A File Descriptor. - address: A UnsafePointer to a buffer to store the address of the peer. - address_len: A UnsafePointer to the size of the buffer. Raises: Error: If an error occurs while listening on the socket. @@ -1087,11 +1229,16 @@ fn accept[address_origin: MutableOrigin, len_origin: Origin]( #### Notes: * Reference: https://man7.org/linux/man-pages/man3/accept.3p.html """ - var result = _accept(socket, address, address_len) + var remote_address = sockaddr() + var result = _accept(socket, Pointer.address_of(remote_address), Pointer.address_of(socklen_t(sizeof[socklen_t]()))) if result == -1: var errno = get_errno() if int(errno) in [EAGAIN, EWOULDBLOCK]: - raise Error("accept: The socket is marked nonblocking and no connections are present to be accepted. POSIX.1-2001 allows either error to be returned for this case, and does not require these constants to have the same value, so a portable application should check for both possibilities..") + raise Error( + "accept: The socket is marked nonblocking and no connections are present to be accepted. POSIX.1-2001" + " allows either error to be returned for this case, and does not require these constants to have the" + " same value, so a portable application should check for both possibilities.." + ) elif errno == EBADF: raise Error("accept: `socket` is not a valid descriptor.") elif errno == ECONNABORTED: @@ -1099,22 +1246,30 @@ fn accept[address_origin: MutableOrigin, len_origin: Origin]( elif errno == EFAULT: raise Error("accept: The `address` argument is not in a writable part of the user address space.") elif errno == EINTR: - raise Error("accept: The system call was interrupted by a signal that was caught before a valid connection arrived; see `signal(7)`.") + raise Error( + "accept: The system call was interrupted by a signal that was caught before a valid connection arrived;" + " see `signal(7)`." + ) elif errno == EINVAL: - raise Error("accept: Socket is not listening for connections, or `addr_length` is invalid (e.g., is negative).") + raise Error( + "accept: Socket is not listening for connections, or `addr_length` is invalid (e.g., is negative)." + ) elif errno == EMFILE: raise Error("accept: The per-process limit of open file descriptors has been reached.") elif errno == ENFILE: raise Error("accept: The system limit on the total number of open files has been reached.") elif int(errno) in [ENOBUFS, ENOMEM]: - raise Error("accept: Not enough free memory. This often means that the memory allocation is limited by the socket buffer limits, not by the system memory.") + raise Error( + "accept: Not enough free memory. This often means that the memory allocation is limited by the socket" + " buffer limits, not by the system memory." + ) elif errno == ENOTSOCK: raise Error("accept: `socket` is a descriptor for a file, not a socket.") elif errno == EOPNOTSUPP: raise Error("accept: The referenced socket is not of type `SOCK_STREAM`.") elif errno == EPROTO: raise Error("accept: Protocol error.") - + @parameter if os_is_linux(): if errno == EPERM: @@ -1123,7 +1278,8 @@ fn accept[address_origin: MutableOrigin, len_origin: Origin]( return result -fn _connect[origin: MutableOrigin](socket: c_int, address: Pointer[sockaddr_in, origin], address_len: socklen_t) -> c_int: + +fn _connect[origin: Origin](socket: c_int, address: Pointer[sockaddr_in, origin], address_len: socklen_t) -> c_int: """Libc POSIX `connect` function. Args: socket: A File Descriptor. @@ -1142,14 +1298,13 @@ fn _connect[origin: MutableOrigin](socket: c_int, address: Pointer[sockaddr_in, return external_call["connect", c_int](socket, address, address_len) -fn connect(socket: c_int, mut address: sockaddr_in, address_len: socklen_t) raises: +fn connect(socket: c_int, address: sockaddr_in) raises: """Libc POSIX `connect` function. Args: socket: A File Descriptor. - address: A UnsafePointer to the address to connect to. - address_len: The size of the address. - + address: The address to connect to. + Raises: Error: If an error occurs while connecting to the socket. EACCES: For UNIX domain sockets, which are identified by pathname: Write permission is denied on the socket file, or search permission is denied for one of the directories in the path prefix. (See also path_resolution(7)). @@ -1175,17 +1330,23 @@ fn connect(socket: c_int, mut address: sockaddr_in, address_len: socklen_t) rais #### Notes: * Reference: https://man7.org/linux/man-pages/man3/connect.3p.html """ - var result = _connect(socket, Pointer.address_of(address), address_len) + var result = _connect(socket, Pointer.address_of(address), sizeof[sockaddr_in]()) if result == -1: var errno = get_errno() if errno == EACCES: - raise Error("connect: For UNIX domain sockets, which are identified by pathname: Write permission is denied on the socket file, or search permission is denied for one of the directories in the path prefix. (See also path_resolution(7)).") + raise Error( + "connect: For UNIX domain sockets, which are identified by pathname: Write permission is denied on the" + " socket file, or search permission is denied for one of the directories in the path prefix. (See also" + " path_resolution(7))." + ) elif errno == EADDRINUSE: raise Error("connect: Local address is already in use.") elif errno == EAGAIN: raise Error("connect: No more free local ports or insufficient entries in the routing cache.") elif errno == EALREADY: - raise Error("connect: The socket is nonblocking and a previous connection attempt has not yet been completed.") + raise Error( + "connect: The socket is nonblocking and a previous connection attempt has not yet been completed." + ) elif errno == EBADF: raise Error("connect: The file descriptor is not a valid index in the descriptor table.") elif errno == ECONNREFUSED: @@ -1193,7 +1354,13 @@ fn connect(socket: c_int, mut address: sockaddr_in, address_len: socklen_t) rais elif errno == EFAULT: raise Error("connect: The socket structure address is outside the user's address space.") elif errno == EINPROGRESS: - raise Error("connect: The socket is nonblocking and the connection cannot be completed immediately. It is possible to select(2) or poll(2) for completion by selecting the socket for writing. After select(2) indicates writability, use getsockopt(2) to read the SO_ERROR option at level SOL_SOCKET to determine whether connect() completed successfully (SO_ERROR is zero) or unsuccessfully (SO_ERROR is one of the usual error codes listed here, explaining the reason for the failure).") + raise Error( + "connect: The socket is nonblocking and the connection cannot be completed immediately. It is possible" + " to select(2) or poll(2) for completion by selecting the socket for writing. After select(2) indicates" + " writability, use getsockopt(2) to read the SO_ERROR option at level SOL_SOCKET to determine whether" + " connect() completed successfully (SO_ERROR is zero) or unsuccessfully (SO_ERROR is one of the usual" + " error codes listed here, explaining the reason for the failure)." + ) elif errno == EINTR: raise Error("connect: The system call was interrupted by a signal that was caught.") elif errno == EISCONN: @@ -1205,7 +1372,9 @@ fn connect(socket: c_int, mut address: sockaddr_in, address_len: socklen_t) rais elif errno == EAFNOSUPPORT: raise Error("connect: The passed address didn't have the correct address family in its `sa_family` field.") elif errno == ETIMEDOUT: - raise Error("connect: Timeout while attempting connection. The server may be too busy to accept new connections.") + raise Error( + "connect: Timeout while attempting connection. The server may be too busy to accept new connections." + ) else: raise Error("connect: An error occurred while connecting to the socket. Error code: " + str(errno)) @@ -1223,7 +1392,7 @@ fn _recv( buffer: A UnsafePointer to the buffer to store the received data. length: The size of the buffer. flags: Flags to control the behaviour of the function. - + Returns: The number of bytes received or -1 in case of failure. @@ -1258,7 +1427,7 @@ fn recv( buffer: A UnsafePointer to the buffer to store the received data. length: The size of the buffer. flags: Flags to control the behaviour of the function. - + Returns: The number of bytes received. @@ -1274,22 +1443,33 @@ fn recv( if result == -1: var errno = get_errno() if int(errno) in [EAGAIN, EWOULDBLOCK]: - raise Error("ReceiveError: The socket is marked nonblocking and the receive operation would block, or a receive timeout had been set and the timeout expired before data was received.") + raise Error( + "ReceiveError: The socket is marked nonblocking and the receive operation would block, or a receive" + " timeout had been set and the timeout expired before data was received." + ) elif errno == EBADF: raise Error("ReceiveError: The argument `socket` is an invalid descriptor.") elif errno == ECONNREFUSED: - raise Error("ReceiveError: The remote host refused to allow the network connection (typically because it is not running the requested service).") + raise Error( + "ReceiveError: The remote host refused to allow the network connection (typically because it is not" + " running the requested service)." + ) elif errno == EFAULT: raise Error("ReceiveError: `buffer` points outside the process's address space.") elif errno == EINTR: - raise Error("ReceiveError: The receive was interrupted by delivery of a signal before any data were available.") + raise Error( + "ReceiveError: The receive was interrupted by delivery of a signal before any data were available." + ) elif errno == ENOTCONN: raise Error("ReceiveError: The socket is not connected.") elif errno == ENOTSOCK: raise Error("ReceiveError: The file descriptor is not associated with a socket.") else: - raise Error("ReceiveError: An error occurred while attempting to receive data from the socket. Error code: " + str(errno)) - + raise Error( + "ReceiveError: An error occurred while attempting to receive data from the socket. Error code: " + + str(errno) + ) + return result @@ -1327,7 +1507,7 @@ fn send(socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c Returns: The number of bytes sent. - + Raises: Error: If an error occurs while attempting to receive data from the socket. EAGAIN or EWOULDBLOCK: The socket is marked nonblocking and the receive operation would block, or a receive timeout had been set and the timeout expired before data was received. @@ -1359,7 +1539,10 @@ fn send(socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c if result == -1: var errno = get_errno() if int(errno) in [EAGAIN, EWOULDBLOCK]: - raise Error("SendError: The socket is marked nonblocking and the receive operation would block, or a receive timeout had been set and the timeout expired before data was received.") + raise Error( + "SendError: The socket is marked nonblocking and the receive operation would block, or a receive" + " timeout had been set and the timeout expired before data was received." + ) elif errno == EBADF: raise Error("SendError: The argument `socket` is an invalid descriptor.") elif errno == EAGAIN: @@ -1369,19 +1552,30 @@ fn send(socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c elif errno == EDESTADDRREQ: raise Error("SendError: The socket is not connection-mode, and no peer address is set.") elif errno == ECONNREFUSED: - raise Error("SendError: The remote host refused to allow the network connection (typically because it is not running the requested service).") + raise Error( + "SendError: The remote host refused to allow the network connection (typically because it is not" + " running the requested service)." + ) elif errno == EFAULT: raise Error("SendError: `buffer` points outside the process's address space.") elif errno == EINTR: - raise Error("SendError: The receive was interrupted by delivery of a signal before any data were available.") + raise Error( + "SendError: The receive was interrupted by delivery of a signal before any data were available." + ) elif errno == EINVAL: raise Error("SendError: Invalid argument passed.") elif errno == EISCONN: raise Error("SendError: The connection-mode socket was connected already but a recipient was specified.") elif errno == EMSGSIZE: - raise Error("SendError: The socket type requires that message be sent atomically, and the size of the message to be sent made this impossible..") + raise Error( + "SendError: The socket type requires that message be sent atomically, and the size of the message to be" + " sent made this impossible.." + ) elif errno == ENOBUFS: - raise Error("SendError: The output queue for a network interface was full. This generally indicates that the interface has stopped sending, but may be caused by transient congestion.") + raise Error( + "SendError: The output queue for a network interface was full. This generally indicates that the" + " interface has stopped sending, but may be caused by transient congestion." + ) elif errno == ENOMEM: raise Error("SendError: No memory available.") elif errno == ENOTCONN: @@ -1391,10 +1585,16 @@ fn send(socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c elif errno == EOPNOTSUPP: raise Error("SendError: Some bit in the flags argument is inappropriate for the socket type.") elif errno == EPIPE: - raise Error("SendError: The local end has been shut down on a connection oriented socket. In this case the process will also receive a SIGPIPE unless MSG_NOSIGNAL is set.") + raise Error( + "SendError: The local end has been shut down on a connection oriented socket. In this case the process" + " will also receive a SIGPIPE unless MSG_NOSIGNAL is set." + ) else: - raise Error("SendError: An error occurred while attempting to receive data from the socket. Error code: " + str(errno)) - + raise Error( + "SendError: An error occurred while attempting to receive data from the socket. Error code: " + + str(errno) + ) + return result @@ -1404,7 +1604,7 @@ fn _shutdown(socket: c_int, how: c_int) -> c_int: Args: socket: A File Descriptor. how: How to shutdown the socket. - + Returns: 0 on success, -1 on error. @@ -1419,13 +1619,19 @@ fn _shutdown(socket: c_int, how: c_int) -> c_int: return external_call["shutdown", c_int, c_int, c_int](socket, how) +alias ShutdownInvalidDescriptorError = "ShutdownError (EBADF): The argument `socket` is an invalid descriptor." +alias ShutdownInvalidArgumentError = "ShutdownError (EINVAL): Invalid argument passed." +alias ShutdownNotConnectedError = "ShutdownError (ENOTCONN): The socket is not connected." +alias ShutdownNotSocketError = "ShutdownError (ENOTSOCK): The file descriptor is not associated with a socket." + + fn shutdown(socket: c_int, how: c_int) raises: """Libc POSIX `shutdown` function. Args: socket: A File Descriptor. how: How to shutdown the socket. - + Raises: Error: If an error occurs while attempting to receive data from the socket. EBADF: The argument `socket` is an invalid descriptor. @@ -1445,15 +1651,18 @@ fn shutdown(socket: c_int, how: c_int) raises: if result == -1: var errno = get_errno() if errno == EBADF: - raise Error("ShutdownError: The argument `socket` is an invalid descriptor.") + raise ShutdownInvalidDescriptorError elif errno == EINVAL: - raise Error("ShutdownError: Invalid argument passed.") + raise ShutdownInvalidArgumentError elif errno == ENOTCONN: - raise Error("ShutdownError: The socket is not connected.") + raise ShutdownNotConnectedError elif errno == ENOTSOCK: - raise Error("ShutdownError: The file descriptor is not associated with a socket.") + raise ShutdownNotSocketError else: - raise Error("ShutdownError: An error occurred while attempting to receive data from the socket. Error code: " + str(errno)) + raise Error( + "ShutdownError: An error occurred while attempting to receive data from the socket. Error code: " + + str(errno) + ) fn gai_strerror(ecode: c_int) -> UnsafePointer[c_char]: @@ -1476,27 +1685,6 @@ fn gai_strerror(ecode: c_int) -> UnsafePointer[c_char]: return external_call["gai_strerror", UnsafePointer[c_char], c_int](ecode) -fn inet_pton(address_family: Int, address: String) raises -> Int: - """Converts an IP address from text to binary form. - - Args: - address_family: The address family (AF_INET or AF_INET6). - address: The IP address in text form. - - Returns: - The IP address in binary form. - """ - var ip_buf_size = 4 - if address_family == AF_INET6: - ip_buf_size = 16 - - var ip_buf = UnsafePointer[c_void].alloc(ip_buf_size) - inet_pton(rebind[c_int](address_family), address.unsafe_ptr(), ip_buf) - var result = int(ip_buf.bitcast[c_uint]()) - ip_buf.free() - return result - - # --- ( File Related Syscalls & Structs )--------------------------------------- alias O_NONBLOCK = 16384 alias O_ACCMODE = 3 @@ -1512,7 +1700,7 @@ fn _close(fildes: c_int) -> c_int: Returns: Upon successful completion, 0 shall be returned; otherwise, -1 shall be returned and errno set to indicate the error. - + #### C Function ```c int close(int fildes). @@ -1524,12 +1712,18 @@ fn _close(fildes: c_int) -> c_int: return external_call["close", c_int, c_int](fildes) -fn close(fildes: c_int) raises: +alias CloseInvalidDescriptorError = "CloseError (EBADF): The file_descriptor argument is not a valid open file descriptor." +alias CloseInterruptedError = "CloseError (EINTR): The close() function was interrupted by a signal." +alias CloseRWError = "CloseError (EIO): An I/O error occurred while reading from or writing to the file system." +alias CloseOutOfSpaceError = "CloseError (ENOSPC or EDQUOT): On NFS, these errors are not normally reported against the first write which exceeds the available storage space, but instead against a subsequent write(2), fsync(2), or close()." + + +fn close(file_descriptor: c_int) raises: """Libc POSIX `close` function. Args: - fildes: A File Descriptor to close. - + file_descriptor: A File Descriptor to close. + Raises: SocketError: If an error occurs while creating the socket. EACCES: Permission to create a socket of the specified type and/or protocol is denied. @@ -1539,7 +1733,7 @@ fn close(fildes: c_int) raises: ENFILE: The system-wide limit on the total number of open files has been reached. ENOBUFS or ENOMEM: Insufficient memory is available. The socket cannot be created until sufficient resources are freed. EPROTONOSUPPORT: The protocol type or the specified protocol is not supported within this domain. - + #### C Function ```c int close(int fildes) @@ -1548,16 +1742,16 @@ fn close(fildes: c_int) raises: #### Notes: * Reference: https://man7.org/linux/man-pages/man3/close.3p.html """ - if _close(fildes) == -1: + if _close(file_descriptor) == -1: var errno = get_errno() if errno == EBADF: - raise Error("CloseError (EBADF): The fildes argument is not a valid open file descriptor.") + raise CloseInvalidDescriptorError elif errno == EINTR: - raise Error("CloseError (EINTR): The close() function was interrupted by a signal.") + raise CloseInterruptedError elif errno == EIO: - raise Error("CloseError (EIO): An I/O error occurred while reading from or writing to the file system.") + raise CloseRWError elif int(errno) in [ENOSPC, EDQUOT]: - raise Error("CloseError (ENOSPC or EDQUOT): On NFS, these errors are not normally reported against the first write which exceeds the available storage space, but instead against a subsequent write(2), fsync(2), or close().") + raise CloseOutOfSpaceError else: raise Error("SocketError: An error occurred while creating the socket. Error code: " + str(errno)) diff --git a/lightbug_http/net.mojo b/lightbug_http/net.mojo index 1bad2f23..414c7661 100644 --- a/lightbug_http/net.mojo +++ b/lightbug_http/net.mojo @@ -3,11 +3,10 @@ from time import sleep, perf_counter_ns from memory import UnsafePointer, stack_allocation, Span from sys.info import sizeof, os_is_macos from sys.ffi import external_call, OpaquePointer -from sys._libc import free from lightbug_http.strings import NetworkType, to_string from lightbug_http.io.bytes import Bytes, bytes from lightbug_http.io.sync import Duration -from .libc import ( +from lightbug_http.libc import ( c_void, c_int, c_uint, @@ -23,6 +22,7 @@ from .libc import ( SOCK_STREAM, SOL_SOCKET, SO_REUSEADDR, + SO_REUSEPORT, SHUT_RDWR, htons, ntohs, @@ -42,9 +42,11 @@ from .libc import ( getsockname, getpeername, gai_strerror, - INET_ADDRSTRLEN + INET_ADDRSTRLEN, + INET6_ADDRSTRLEN, ) -from .utils import logger +from lightbug_http.utils import logger +from lightbug_http.socket import Socket alias default_buffer_size = 4096 @@ -63,6 +65,12 @@ trait Connection(Movable): fn close(mut self) raises: ... + fn shutdown(mut self) raises -> None: + ... + + fn teardown(mut self) raises: + ... + fn local_addr(mut self) -> TCPAddr: ... @@ -70,11 +78,13 @@ trait Connection(Movable): ... -trait Addr(CollectionElement, Stringable): +trait Addr(Stringable, Representable, Writable, EqualityComparableCollectionElement): + alias _type: StringLiteral + fn __init__(out self): ... - fn __init__(out self, ip: String, port: Int): + fn __init__(out self, ip: String, port: UInt16): ... fn network(self) -> String: @@ -89,109 +99,74 @@ trait AnAddrInfo: ... -@value struct NoTLSListener: - """A TCP listener that listens for incoming connections and can accept them. - """ + """A TCP listener that listens for incoming connections and can accept them.""" - var fd: c_int - """The file descriptor of the listener.""" - var __addr: TCPAddr - """The address of the listener.""" + var socket: Socket[TCPAddr] - fn __init__(out self, addr: TCPAddr = TCPAddr("localhost", 8080)) raises: - self.__addr = addr - self.fd = socket(AF_INET, SOCK_STREAM, 0) + fn __init__(out self, owned socket: Socket[TCPAddr]): + self.socket = socket^ - fn __init__(out self, addr: TCPAddr, fd: c_int): - self.__addr = addr - self.fd = fd + fn __init__(out self) raises: + self.socket = Socket[TCPAddr]() - fn accept(self) raises -> SysConnection: - var their_addr = sockaddr() - var new_sockfd: c_int - try: - new_sockfd = accept(self.fd, Pointer.address_of(their_addr), Pointer.address_of(socklen_t(sizeof[socklen_t]()))) - except e: - logger.error(e) - raise Error("NoTLSListener.accept: Failed to accept connection, system `accept()` returned an error.") + fn __moveinit__(out self, owned existing: Self): + self.socket = existing.socket^ - var peer = get_peer_name(new_sockfd) - return SysConnection(self.__addr, TCPAddr(peer.host, atol(peer.port)), new_sockfd) + fn accept(self) raises -> TCPConnection: + return TCPConnection(self.socket.accept()) - fn close(self) raises: - try: - shutdown(self.fd, SHUT_RDWR) - except e: - logger.error("NoTLSListener.close: Failed to shutdown listener: " + str(e)) - logger.error(e) + fn close(mut self) raises -> None: + return self.socket.close() - try: - close(self.fd) - except e: - logger.error(e) - raise Error("NoTLSListener.close: Failed to close listener.") + fn shutdown(mut self) raises -> None: + return self.socket.shutdown() + + fn teardown(mut self) raises: + self.socket.teardown() fn addr(self) -> TCPAddr: - return self.__addr + return self.socket.local_address() struct ListenConfig: - var __keep_alive: Duration + var _keep_alive: Duration fn __init__(out self, keep_alive: Duration = default_tcp_keep_alive): - self.__keep_alive = keep_alive - - fn listen(mut self, network: String, address: String) raises -> NoTLSListener: - var addr: TCPAddr - try: - addr = resolve_internet_addr(network, address) - except e: - raise Error("ListenConfig.listen: Failed to resolve host address - " + str(e)) - var address_family = AF_INET - - var sockfd: c_int + self._keep_alive = keep_alive + + fn listen[network: NetworkType, address_family: Int = AF_INET](mut self, address: String) raises -> NoTLSListener: + constrained[address_family in [AF_INET, AF_INET6], "Address family must be either AF_INET or AF_INET6."]() + constrained[ + network in NetworkType.SUPPORTED_TYPES, + "Unsupported network type for internet address resolution. Unix addresses are not supported yet.", + ]() + var local = parse_address(address) + var addr = TCPAddr(local[0], local[1]) + var socket: Socket[TCPAddr] try: - sockfd = socket(address_family, SOCK_STREAM, 0) + socket = Socket[TCPAddr]() except e: logger.error(e) raise Error("ListenConfig.listen: Failed to create listener due to socket creation failure.") try: - setsockopt( - sockfd, - SOL_SOCKET, - SO_REUSEADDR, - Pointer[c_void].address_of(1), - sizeof[Int](), - ) + + @parameter + # REUSEADDR doesn't work on ubuntu. + if os_is_macos(): + socket.set_socket_option(SO_REUSEADDR, 1) + else: + socket.set_socket_option(SO_REUSEPORT, 1) except e: logger.warn("ListenConfig.listen: Failed to set socket as reusable", e) # TODO: Maybe raise here if we want to make this a hard failure. var bind_success = False var bind_fail_logged = False - - var ip_buf_size = 4 - if address_family == AF_INET6: - ip_buf_size = 16 - var ip_buf = UnsafePointer[c_void].alloc(ip_buf_size) - - try: - inet_pton(address_family, addr.ip.unsafe_ptr(), ip_buf) - except e: - logger.error(e) - raise Error("ListenConfig.listen: Failed to convert IP address to binary form.") - - var ai = sockaddr_in( - sin_family=address_family, - sin_port=htons(addr.port), - sin_addr=in_addr(ip_buf.bitcast[c_uint]().take_pointee()), - sin_zero=StaticTuple[c_char, 8]() - ) while not bind_success: try: - bind(sockfd, Pointer.address_of(ai), sizeof[sockaddr_in]()) + socket.bind(addr.ip, addr.port) bind_success = True except e: if not bind_fail_logged: @@ -201,124 +176,73 @@ struct ListenConfig: print(".", end="", flush=True) try: - shutdown(sockfd, SHUT_RDWR) + socket.shutdown() except e: logger.error("ListenConfig.listen: Failed to shutdown socket:", e) # TODO: Should shutdown failure be a hard failure? We can still ungracefully close the socket. sleep(UInt(1)) + try: - listen(sockfd, 128) + socket.listen(128) except e: logger.error(e) - raise Error("ListenConfig.listen: Listen failed on sockfd: " + str(sockfd)) + raise Error("ListenConfig.listen: Listen failed on sockfd: " + str(socket.fd)) - var listener = NoTLSListener(addr, sockfd) + var listener = NoTLSListener(socket^) var msg = String.write("\n🔥🐝 Lightbug is listening on ", "http://", addr.ip, ":", str(addr.port)) print(msg) print("Ready to accept connections...") - return listener - + return listener^ -@value -struct SysConnection(Connection): - var fd: c_int - var raddr: TCPAddr - var laddr: TCPAddr - var _closed: Bool - fn __init__(out self, laddr: String, raddr: String) raises: - try: - self.raddr = resolve_internet_addr(NetworkType.tcp4.value, raddr) - except e: - raise Error("Failed to resolve remote address: " + str(e)) - - try: - self.laddr = resolve_internet_addr(NetworkType.tcp4.value, laddr) - except e: - raise Error("Failed to resolve local address: " + str(e)) - - try: - self.fd = socket(AF_INET, SOCK_STREAM, 0) - except e: - logger.error(e) - raise Error("Failed to create connection to remote host.") - - self._closed = False +struct TCPConnection(Connection): + var socket: Socket[TCPAddr] - fn __init__(out self, laddr: TCPAddr, raddr: TCPAddr) raises: - self.raddr = raddr - self.laddr = laddr - try: - self.fd = socket(AF_INET, SOCK_STREAM, 0) - except e: - logger.error(e) - raise Error("Failed to create connection to remote host.") - self._closed = False + fn __init__(inout self, owned socket: Socket[TCPAddr]): + self.socket = socket^ - fn __init__(out self, laddr: TCPAddr, raddr: TCPAddr, fd: c_int): - self.raddr = raddr - self.laddr = laddr - self.fd = fd - self._closed = False + fn __moveinit__(inout self, owned existing: Self): + self.socket = existing.socket^ fn read(self, mut buf: Bytes) raises -> Int: try: - var bytes_recv = recv( - self.fd, - buf.unsafe_ptr().offset(buf.size), - buf.capacity - buf.size, - 0, - ) - buf.size += bytes_recv - return bytes_recv + return self.socket.receive_into(buf) except e: - logger.error(e) - raise Error("SysConnection.read: Failed to read data from connection.") + if str(e) == "EOF": + raise e + else: + logger.error(e) + raise Error("TCPConnection.read: Failed to read data from connection.") fn write(self, buf: Span[Byte]) raises -> Int: if buf[-1] == 0: - raise Error("SysConnection.write: Buffer must not be null-terminated.") - + raise Error("TCPConnection.write: Buffer must not be null-terminated.") + try: - return send(self.fd, buf.unsafe_ptr(), len(buf), 0) + return self.socket.send(buf) except e: - logger.error("SysConnection.write: Failed to write data to connection.") + logger.error("TCPConnection.write: Failed to write data to connection.") raise e fn close(mut self) raises: - if self._closed: - return + self.socket.close() - try: - shutdown(self.fd, SHUT_RDWR) - except e: - # TODO: In the case where the connection was already closed, should we just info or debug log? - logger.debug(e) - logger.debug("SysConnection.close: Failed to shutdown connection.") - - try: - close(self.fd) - except e: - logger.error(e) - raise Error("SysConnection.close: Failed to close connection.") - self._closed = True + fn shutdown(mut self) raises: + self.socket.shutdown() - fn local_addr(mut self) -> TCPAddr: - return self.laddr - - fn remote_addr(self) -> TCPAddr: - return self.raddr + fn teardown(mut self) raises: + self.socket.teardown() + fn is_closed(self) -> Bool: + return self.socket._closed -struct SysNet: - var __lc: ListenConfig + fn local_addr(mut self) -> TCPAddr: + return self.socket.local_address() - fn __init__(out self, keep_alive: Duration = default_tcp_keep_alive): - self.__lc = ListenConfig(default_tcp_keep_alive) + fn remote_addr(self) -> TCPAddr: + return self.socket.remote_address() - fn listen(mut self, network: String, addr: String) raises -> NoTLSListener: - return self.__lc.listen(network, addr) @value @register_passable("trivial") @@ -327,6 +251,7 @@ struct addrinfo_macos(AnAddrInfo): For MacOS, I had to swap the order of ai_canonname and ai_addr. https://stackoverflow.com/questions/53575101/calling-getaddrinfo-directly-from-python-ai-addr-is-null-pointer. """ + var ai_flags: c_int var ai_family: c_int var ai_socktype: c_int @@ -346,10 +271,10 @@ struct addrinfo_macos(AnAddrInfo): self.ai_addr = UnsafePointer[sockaddr]() self.ai_next = OpaquePointer() - fn get_ip_address(self, host: String) raises -> in_addr: + fn get_ip_address(self, host: String) raises -> in_addr: """Returns an IP address based on the host. This is a MacOS-specific implementation. - + Args: host: String - The host to get the IP from. @@ -357,12 +282,7 @@ struct addrinfo_macos(AnAddrInfo): The IP address. """ var result = UnsafePointer[Self]() - var hints = Self( - ai_flags=0, - ai_family=AF_INET, - ai_socktype=SOCK_STREAM, - ai_protocol=0 - ) + var hints = Self(ai_flags=0, ai_family=AF_INET, ai_socktype=SOCK_STREAM, ai_protocol=0) try: getaddrinfo(host, String(), hints, result) except e: @@ -376,7 +296,8 @@ struct addrinfo_macos(AnAddrInfo): var ip = result[].ai_addr.bitcast[sockaddr_in]()[].sin_addr freeaddrinfo(result) return ip - + + @value @register_passable("trivial") struct addrinfo_unix(AnAddrInfo): @@ -414,12 +335,7 @@ struct addrinfo_unix(AnAddrInfo): The IP address. """ var result = UnsafePointer[Self]() - var hints = Self( - ai_flags=0, - ai_family=AF_INET, - ai_socktype=SOCK_STREAM, - ai_protocol=0 - ) + var hints = Self(ai_flags=0, ai_family=AF_INET, ai_socktype=SOCK_STREAM, ai_protocol=0) try: getaddrinfo(host, String(), hints, result) except e: @@ -435,52 +351,43 @@ struct addrinfo_unix(AnAddrInfo): return ip -fn create_connection(sock: c_int, host: String, port: UInt16) raises -> SysConnection: +fn create_connection(host: String, port: UInt16) raises -> TCPConnection: """Connect to a server using a socket. Args: - sock: The socket file descriptor. host: The host to connect to. port: The port to connect on. - + Returns: - Int32 - The socket file descriptor. + The socket file descriptor. """ - @parameter - if os_is_macos(): - ip = addrinfo_macos().get_ip_address(host) - else: - ip = addrinfo_unix().get_ip_address(host) - - var addr = sockaddr_in(AF_INET, htons(port), in_addr(ip.s_addr), StaticTuple[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0)) + var socket = Socket[TCPAddr]() try: - connect(sock, addr, sizeof[sockaddr_in]()) + socket.connect(host, port) except e: logger.error(e) try: - shutdown(sock, SHUT_RDWR) + socket.shutdown() except e: logger.error("Failed to shutdown socket: " + str(e)) raise Error("Failed to establish a connection to the server.") - return SysConnection(sock, TCPAddr(), TCPAddr(host, int(port)), False) - - -alias TCPAddrList = List[TCPAddr] + return TCPConnection(socket^) @value struct TCPAddr(Addr): + alias _type = "TCPAddr" var ip: String - var port: Int + var port: UInt16 var zone: String # IPv6 addressing zone fn __init__(out self): - self.ip = String("127.0.0.1") + self.ip = "127.0.0.1" self.port = 8000 self.zone = "" - fn __init__(out self, ip: String, port: Int): + fn __init__(out self, ip: String = "127.0.0.1", port: UInt16 = 8000): self.ip = ip self.port = port self.zone = "" @@ -488,39 +395,25 @@ struct TCPAddr(Addr): fn network(self) -> String: return NetworkType.tcp.value + fn __eq__(self, other: TCPAddr) -> Bool: + return self.ip == other.ip and self.port == other.port and self.zone == other.zone + + fn __ne__(self, other: TCPAddr) -> Bool: + return not self == other + fn __str__(self) -> String: if self.zone != "": - return join_host_port(self.ip + "%" + self.zone, self.port.__str__()) - return join_host_port(self.ip, self.port.__str__()) + return join_host_port(self.ip + "%" + self.zone, str(self.port)) + return join_host_port(self.ip, str(self.port)) + fn __repr__(self) -> String: + return String.write(self) -fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr: - var host: String = "" - var port: String = "" - var port_number = 0 - if ( - network == NetworkType.tcp.value - or network == NetworkType.tcp4.value - or network == NetworkType.tcp6.value - or network == NetworkType.udp.value - or network == NetworkType.udp4.value - or network == NetworkType.udp6.value - ): - if address != "": - var host_port = split_host_port(address) - host = host_port.host - port = host_port.port - port_number = atol(str(port)) - elif network == NetworkType.ip.value or network == NetworkType.ip4.value or network == NetworkType.ip6.value: - if address != "": - host = address - elif network == NetworkType.unix.value: - raise Error("Couldn't resolve internet address as Unix addresses not supported yet") - else: - raise Error("Received an unsupported network type for internet address resolution: " + network) - return TCPAddr(host, port_number) + fn write_to[W: Writer, //](self, mut writer: W): + writer.write("TCPAddr(", "ip=", repr(self.ip), ", port=", str(self.port), ", zone=", repr(self.zone), ")") +# TODO: Support IPv6 long form. fn join_host_port(host: String, port: String) -> String: if host.find(":") != -1: # must be IPv6 literal return "[" + host + "]:" + port @@ -531,123 +424,102 @@ alias MissingPortError = Error("missing port in address") alias TooManyColonsError = Error("too many colons in address") -struct HostPort: - var host: String - var port: String +fn parse_address(address: String) raises -> (String, UInt16): + """Parse an address string into a host and port. - fn __init__(out self, host: String, port: String): - self.host = host - self.port = port + Args: + address: The address string. + Returns: + A tuple containing the host and port. + """ + var colon_index = address.rfind(":") + if colon_index == -1: + raise MissingPortError -fn split_host_port(hostport: String) raises -> HostPort: var host: String = "" var port: String = "" - var colon_index = hostport.rfind(":") var j: Int = 0 var k: Int = 0 - if colon_index == -1: - raise MissingPortError - if hostport[0] == "[": - var end_bracket_index = hostport.find("]") + if address[0] == "[": + var end_bracket_index = address.find("]") if end_bracket_index == -1: raise Error("missing ']' in address") - if end_bracket_index + 1 == len(hostport): + + if end_bracket_index + 1 == len(address): raise MissingPortError elif end_bracket_index + 1 == colon_index: - host = hostport[1:end_bracket_index] + host = address[1:end_bracket_index] j = 1 k = end_bracket_index + 1 else: - if hostport[end_bracket_index + 1] == ":": + if address[end_bracket_index + 1] == ":": raise TooManyColonsError else: raise MissingPortError else: - host = hostport[:colon_index] + host = address[:colon_index] if host.find(":") != -1: raise TooManyColonsError - if hostport[j:].find("[") != -1: + + if address[j:].find("[") != -1: raise Error("unexpected '[' in address") - if hostport[k:].find("]") != -1: + if address[k:].find("]") != -1: raise Error("unexpected ']' in address") - port = hostport[colon_index + 1 :] + port = address[colon_index + 1 :] if port == "": raise MissingPortError if host == "": raise Error("missing host") - return HostPort(host, port) + return host, UInt16(int(port)) + +fn binary_port_to_int(port: UInt16) -> Int: + """Convert a binary port to an integer. -fn convert_binary_port_to_int(port: UInt16) -> Int: + Args: + port: The binary port. + + Returns: + The port as an integer. + """ return int(ntohs(port)) -fn convert_binary_ip_to_string(owned ip_address: UInt32, address_family: Int32, address_length: UInt32) raises -> String: +fn binary_ip_to_string[address_family: Int32](owned ip_address: UInt32) raises -> String: """Convert a binary IP address to a string by calling `inet_ntop`. + Parameters: + address_family: The address family of the IP address. + Args: ip_address: The binary IP address. - address_family: The address family of the IP address. - address_length: The length of the address. Returns: The IP address as a string. """ - var ip_buffer = UnsafePointer[c_void].alloc(INET_ADDRSTRLEN) - var ip = inet_ntop(address_family, UnsafePointer.address_of(ip_address).bitcast[c_void](), ip_buffer, INET_ADDRSTRLEN) - return ip - - -fn get_sock_name(fd: Int32) raises -> HostPort: - """Return the address of the socket.""" - var local_address = stack_allocation[1, sockaddr]() - try: - getsockname( - fd, - local_address, - Pointer.address_of(socklen_t(sizeof[sockaddr]())), - ) - except e: - logger.error(e) - raise Error("get_sock_name: Failed to get address of local socket.") - - var addr_in = local_address.bitcast[sockaddr_in]().take_pointee() - return HostPort( - host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, INET_ADDRSTRLEN), - port=str(convert_binary_port_to_int(addr_in.sin_port)), - ) - + constrained[int(address_family) in [AF_INET, AF_INET6], "Address family must be either AF_INET or AF_INET6."]() + var ip: String -fn get_peer_name(fd: Int32) raises -> HostPort: - """Return the address of the peer connected to the socket.""" - var remote_address = stack_allocation[1, sockaddr]() - try: - getpeername( - fd, - remote_address, - Pointer.address_of(socklen_t(sizeof[sockaddr]())), - ) - except e: - logger.error(e) - raise Error("get_peer_name: Failed to get address of remote socket.") + @parameter + if address_family == AF_INET: + ip = inet_ntop[address_family, INET_ADDRSTRLEN](ip_address) + else: + ip = inet_ntop[address_family, INET6_ADDRSTRLEN](ip_address) - # Cast sockaddr struct to sockaddr_in to convert binary IP to string. - var addr_in = remote_address.bitcast[sockaddr_in]().take_pointee() - return HostPort( - host=convert_binary_ip_to_string(addr_in.sin_addr.s_addr, AF_INET, INET_ADDRSTRLEN), - port=str(convert_binary_port_to_int(addr_in.sin_port)), - ) + return ip -fn _getaddrinfo[T: AnAddrInfo, hints_origin: MutableOrigin, result_origin: MutableOrigin, //]( +fn _getaddrinfo[ + T: AnAddrInfo, hints_origin: MutableOrigin, result_origin: MutableOrigin, // +]( nodename: UnsafePointer[c_char], servname: UnsafePointer[c_char], hints: Pointer[T, hints_origin], res: Pointer[UnsafePointer[T], result_origin], -)-> c_int: +) -> c_int: """Libc POSIX `getaddrinfo` function. Args: @@ -655,7 +527,7 @@ fn _getaddrinfo[T: AnAddrInfo, hints_origin: MutableOrigin, result_origin: Mutab servname: The service name. hints: A Pointer to the hints. res: A UnsafePointer to the result. - + Returns: 0 on success, an error code on failure. @@ -677,12 +549,9 @@ fn _getaddrinfo[T: AnAddrInfo, hints_origin: MutableOrigin, result_origin: Mutab ](nodename, servname, hints, res) -fn getaddrinfo[T: AnAddrInfo, //]( - node: String, - service: String, - mut hints: T, - mut res: UnsafePointer[T], -) raises: +fn getaddrinfo[ + T: AnAddrInfo, // +](node: String, service: String, mut hints: T, mut res: UnsafePointer[T],) raises: """Libc POSIX `getaddrinfo` function. Args: @@ -690,7 +559,7 @@ fn getaddrinfo[T: AnAddrInfo, //]( service: The service name. hints: A Pointer to the hints. res: A UnsafePointer to the result. - + Raises: Error: If an error occurs while attempting to receive data from the socket. EAI_AGAIN: The name could not be resolved at this time. Future attempts may succeed. @@ -711,7 +580,9 @@ fn getaddrinfo[T: AnAddrInfo, //]( #### Notes: * Reference: https://man7.org/linux/man-pages/man3/getaddrinfo.3p.html """ - var result = _getaddrinfo(node.unsafe_ptr(), service.unsafe_ptr(), Pointer.address_of(hints), Pointer.address_of(res)) + var result = _getaddrinfo( + node.unsafe_ptr(), service.unsafe_ptr(), Pointer.address_of(hints), Pointer.address_of(res) + ) if result != 0: # gai_strerror returns a char buffer that we don't know the length of. # TODO: Perhaps switch to writing bytes once the Writer trait allows writing individual bytes. diff --git a/lightbug_http/owning_list.mojo b/lightbug_http/owning_list.mojo new file mode 100644 index 00000000..5c6885ff --- /dev/null +++ b/lightbug_http/owning_list.mojo @@ -0,0 +1,510 @@ +from os import abort +from sys import sizeof +from sys.intrinsics import _type_is_eq + +from memory import Pointer, UnsafePointer, memcpy, Span + +from collections import Optional + + +trait EqualityComparableMovable(EqualityComparable, Movable): + """A trait for types that are both `EqualityComparable` and `Movable`.""" + + ... + + +# ===-----------------------------------------------------------------------===# +# List +# ===-----------------------------------------------------------------------===# + + +@value +struct _OwningListIter[ + list_mutability: Bool, //, + T: Movable, + list_origin: Origin[list_mutability], + forward: Bool = True, +]: + """Iterator for List. + + Parameters: + list_mutability: Whether the reference to the list is mutable. + T: The type of the elements in the list. + list_origin: The origin of the List + forward: The iteration direction. `False` is backwards. + """ + + alias list_type = OwningList[T] + + var index: Int + var src: Pointer[Self.list_type, list_origin] + + fn __iter__(self) -> Self: + return self + + fn __next__( + mut self, + ) -> Pointer[T, list_origin]: + @parameter + if forward: + self.index += 1 + return Pointer.address_of(self.src[][self.index - 1]) + else: + self.index -= 1 + return Pointer.address_of(self.src[][self.index]) + + @always_inline + fn __has_next__(self) -> Bool: + return self.__len__() > 0 + + fn __len__(self) -> Int: + @parameter + if forward: + return len(self.src[]) - self.index + else: + return self.index + + +struct OwningList[T: Movable](Movable, Sized, Boolable): + """The `List` type is a dynamically-allocated list. + + It supports pushing and popping from the back resizing the underlying + storage as needed. When it is deallocated, it frees its memory. + + Parameters: + T: The type of the elements. + """ + + # Fields + var data: UnsafePointer[T] + """The underlying storage for the list.""" + var size: Int + """The number of elements in the list.""" + var capacity: Int + """The amount of elements that can fit in the list without resizing it.""" + + # ===-------------------------------------------------------------------===# + # Life cycle methods + # ===-------------------------------------------------------------------===# + + fn __init__(out self): + """Constructs an empty list.""" + self.data = UnsafePointer[T]() + self.size = 0 + self.capacity = 0 + + fn __init__(out self, *, capacity: Int): + """Constructs a list with the given capacity. + + Args: + capacity: The requested capacity of the list. + """ + self.data = UnsafePointer[T].alloc(capacity) + self.size = 0 + self.capacity = capacity + + fn __moveinit__(out self, owned existing: Self): + """Move data of an existing list into a new one. + + Args: + existing: The existing list. + """ + self.data = existing.data + self.size = existing.size + self.capacity = existing.capacity + + fn __del__(owned self): + """Destroy all elements in the list and free its memory.""" + for i in range(self.size): + (self.data + i).destroy_pointee() + self.data.free() + + # ===-------------------------------------------------------------------===# + # Operator dunders + # ===-------------------------------------------------------------------===# + + fn __contains__[U: EqualityComparableMovable, //](self: OwningList[U, *_], value: U) -> Bool: + """Verify if a given value is present in the list. + + Parameters: + U: The type of the elements in the list. Must implement the + traits `EqualityComparable` and `CollectionElement`. + + Args: + value: The value to find. + + Returns: + True if the value is contained in the list, False otherwise. + """ + for i in self: + if i[] == value: + return True + return False + + fn __iter__(ref self) -> _OwningListIter[T, __origin_of(self)]: + """Iterate over elements of the list, returning immutable references. + + Returns: + An iterator of immutable references to the list elements. + """ + return _OwningListIter(0, Pointer.address_of(self)) + + # ===-------------------------------------------------------------------===# + # Trait implementations + # ===-------------------------------------------------------------------===# + + fn __len__(self) -> Int: + """Gets the number of elements in the list. + + Returns: + The number of elements in the list. + """ + return self.size + + fn __bool__(self) -> Bool: + """Checks whether the list has any elements or not. + + Returns: + `False` if the list is empty, `True` if there is at least one element. + """ + return len(self) > 0 + + @no_inline + fn __str__[U: RepresentableCollectionElement, //](self: OwningList[U, *_]) -> String: + """Returns a string representation of a `List`. + + When the compiler supports conditional methods, then a simple `str(my_list)` will + be enough. + + The elements' type must implement the `__repr__()` method for this to work. + + Parameters: + U: The type of the elements in the list. Must implement the + traits `Representable` and `CollectionElement`. + + Returns: + A string representation of the list. + """ + var output = String() + self.write_to(output) + return output^ + + @no_inline + fn write_to[W: Writer, U: RepresentableCollectionElement, //](self: OwningList[U, *_], mut writer: W): + """Write `my_list.__str__()` to a `Writer`. + + Parameters: + W: A type conforming to the Writable trait. + U: The type of the List elements. Must have the trait `RepresentableCollectionElement`. + + Args: + writer: The object to write to. + """ + writer.write("[") + for i in range(len(self)): + writer.write(repr(self[i])) + if i < len(self) - 1: + writer.write(", ") + writer.write("]") + + @no_inline + fn __repr__[U: RepresentableCollectionElement, //](self: OwningList[U, *_]) -> String: + """Returns a string representation of a `List`. + + Note that since we can't condition methods on a trait yet, + the way to call this method is a bit special. Here is an example below: + + ```mojo + var my_list = List[Int](1, 2, 3) + print(my_list.__repr__()) + ``` + + When the compiler supports conditional methods, then a simple `repr(my_list)` will + be enough. + + The elements' type must implement the `__repr__()` for this to work. + + Parameters: + U: The type of the elements in the list. Must implement the + traits `Representable` and `CollectionElement`. + + Returns: + A string representation of the list. + """ + return self.__str__() + + # ===-------------------------------------------------------------------===# + # Methods + # ===-------------------------------------------------------------------===# + + fn bytecount(self) -> Int: + """Gets the bytecount of the List. + + Returns: + The bytecount of the List. + """ + return len(self) * sizeof[T]() + + fn _realloc(mut self, new_capacity: Int): + var new_data = UnsafePointer[T].alloc(new_capacity) + + _move_pointee_into_many_elements( + dest=new_data, + src=self.data, + size=self.size, + ) + + if self.data: + self.data.free() + self.data = new_data + self.capacity = new_capacity + + fn append(mut self, owned value: T): + """Appends a value to this list. + + Args: + value: The value to append. + """ + if self.size >= self.capacity: + self._realloc(max(1, self.capacity * 2)) + (self.data + self.size).init_pointee_move(value^) + self.size += 1 + + fn insert(mut self, i: Int, owned value: T): + """Inserts a value to the list at the given index. + `a.insert(len(a), value)` is equivalent to `a.append(value)`. + + Args: + i: The index for the value. + value: The value to insert. + """ + debug_assert(i <= self.size, "insert index out of range") + + var normalized_idx = i + if i < 0: + normalized_idx = max(0, len(self) + i) + + var earlier_idx = len(self) + var later_idx = len(self) - 1 + self.append(value^) + + for _ in range(normalized_idx, len(self) - 1): + var earlier_ptr = self.data + earlier_idx + var later_ptr = self.data + later_idx + + var tmp = earlier_ptr.take_pointee() + later_ptr.move_pointee_into(earlier_ptr) + later_ptr.init_pointee_move(tmp^) + + earlier_idx -= 1 + later_idx -= 1 + + fn extend(mut self, owned other: OwningList[T, *_]): + """Extends this list by consuming the elements of `other`. + + Args: + other: List whose elements will be added in order at the end of this list. + """ + + var final_size = len(self) + len(other) + var other_original_size = len(other) + + self.reserve(final_size) + + # Defensively mark `other` as logically being empty, as we will be doing + # consuming moves out of `other`, and so we want to avoid leaving `other` + # in a partially valid state where some elements have been consumed + # but are still part of the valid `size` of the list. + # + # That invalid intermediate state of `other` could potentially be + # visible outside this function if a `__moveinit__()` constructor were + # to throw (not currently possible AFAIK though) part way through the + # logic below. + other.size = 0 + + var dest_ptr = self.data + len(self) + + for i in range(other_original_size): + var src_ptr = other.data + i + + # This (TODO: optimistically) moves an element directly from the + # `other` list into this list using a single `T.__moveinit()__` + # call, without moving into an intermediate temporary value + # (avoiding an extra redundant move constructor call). + src_ptr.move_pointee_into(dest_ptr) + + dest_ptr = dest_ptr + 1 + + # Update the size now that all new elements have been moved into this + # list. + self.size = final_size + + fn pop(mut self, i: Int = -1) -> T: + """Pops a value from the list at the given index. + + Args: + i: The index of the value to pop. + + Returns: + The popped value. + """ + debug_assert(-len(self) <= i < len(self), "pop index out of range") + + var normalized_idx = i + if i < 0: + normalized_idx += len(self) + + var ret_val = (self.data + normalized_idx).take_pointee() + for j in range(normalized_idx + 1, self.size): + (self.data + j).move_pointee_into(self.data + j - 1) + self.size -= 1 + if self.size * 4 < self.capacity: + if self.capacity > 1: + self._realloc(self.capacity // 2) + return ret_val^ + + fn reserve(mut self, new_capacity: Int): + """Reserves the requested capacity. + + If the current capacity is greater or equal, this is a no-op. + Otherwise, the storage is reallocated and the date is moved. + + Args: + new_capacity: The new capacity. + """ + if self.capacity >= new_capacity: + return + self._realloc(new_capacity) + + fn resize(mut self, new_size: Int): + """Resizes the list to the given new size. + + With no new value provided, the new size must be smaller than or equal + to the current one. Elements at the end are discarded. + + Args: + new_size: The new size. + """ + if self.size < new_size: + abort( + "You are calling List.resize with a new_size bigger than the" + " current size. If you want to make the List bigger, provide a" + " value to fill the new slots with. If not, make sure the new" + " size is smaller than the current size." + ) + for i in range(new_size, self.size): + (self.data + i).destroy_pointee() + self.size = new_size + self.reserve(new_size) + + # TODO: Remove explicit self type when issue 1876 is resolved. + fn index[ + C: EqualityComparableMovable, // + ](ref self: OwningList[C, *_], value: C, start: Int = 0, stop: Optional[Int] = None,) raises -> Int: + """ + Returns the index of the first occurrence of a value in a list + restricted by the range given the start and stop bounds. + + ```mojo + var my_list = List[Int](1, 2, 3) + print(my_list.index(2)) # prints `1` + ``` + + Args: + value: The value to search for. + start: The starting index of the search, treated as a slice index + (defaults to 0). + stop: The ending index of the search, treated as a slice index + (defaults to None, which means the end of the list). + + Parameters: + C: The type of the elements in the list. Must implement the + `EqualityComparableMovable` trait. + + Returns: + The index of the first occurrence of the value in the list. + + Raises: + ValueError: If the value is not found in the list. + """ + var start_normalized = start + + var stop_normalized: Int + if stop is None: + # Default end + stop_normalized = len(self) + else: + stop_normalized = stop.value() + + if start_normalized < 0: + start_normalized += len(self) + if stop_normalized < 0: + stop_normalized += len(self) + + start_normalized = _clip(start_normalized, 0, len(self)) + stop_normalized = _clip(stop_normalized, 0, len(self)) + + for i in range(start_normalized, stop_normalized): + if self[i] == value: + return i + raise "ValueError: Given element is not in list" + + fn clear(mut self): + """Clears the elements in the list.""" + for i in range(self.size): + (self.data + i).destroy_pointee() + self.size = 0 + + fn steal_data(mut self) -> UnsafePointer[T]: + """Take ownership of the underlying pointer from the list. + + Returns: + The underlying data. + """ + var ptr = self.data + self.data = UnsafePointer[T]() + self.size = 0 + self.capacity = 0 + return ptr + + fn __getitem__(ref self, idx: Int) -> ref [self] T: + """Gets the list element at the given index. + + Args: + idx: The index of the element. + + Returns: + A reference to the element at the given index. + """ + + var normalized_idx = idx + + debug_assert( + -self.size <= normalized_idx < self.size, + "index: ", + normalized_idx, + " is out of bounds for `List` of size: ", + self.size, + ) + if normalized_idx < 0: + normalized_idx += len(self) + + return (self.data + normalized_idx)[] + + @always_inline + fn unsafe_ptr(self) -> UnsafePointer[T]: + """Retrieves a pointer to the underlying memory. + + Returns: + The UnsafePointer to the underlying memory. + """ + return self.data + + +fn _clip(value: Int, start: Int, end: Int) -> Int: + return max(start, min(value, end)) + + +fn _move_pointee_into_many_elements[T: Movable](dest: UnsafePointer[T], src: UnsafePointer[T], size: Int): + for i in range(size): + (src + i).move_pointee_into(dest + i) diff --git a/lightbug_http/pool_manager.mojo b/lightbug_http/pool_manager.mojo new file mode 100644 index 00000000..a83072ef --- /dev/null +++ b/lightbug_http/pool_manager.mojo @@ -0,0 +1,78 @@ +from sys.ffi import OpaquePointer +from bit import is_power_of_two +from builtin.value import StringableCollectionElement +from memory import UnsafePointer, bitcast, memcpy +from collections import Dict, Optional +from collections.dict import RepresentableKeyElement +from lightbug_http.net import create_connection, TCPConnection, Connection +from lightbug_http.utils import logger +from lightbug_http.owning_list import OwningList + + +struct PoolManager[ConnectionType: Connection](): + var _connections: OwningList[ConnectionType] + var _capacity: Int + var mapping: Dict[String, Int] + + fn __init__(out self, capacity: Int = 10): + self._connections = OwningList[ConnectionType](capacity=capacity) + self._capacity = capacity + self.mapping = Dict[String, Int]() + + fn __del__(owned self): + logger.debug( + "PoolManager shutting down and closing remaining connections before destruction:", self._connections.size + ) + self.clear() + + fn give(mut self, host: String, owned value: ConnectionType) raises: + if host in self.mapping: + self._connections[self.mapping[host]] = value^ + return + + if self._connections.size == self._capacity: + raise Error("PoolManager.give: Cache is full.") + + self._connections[self._connections.size] = value^ + self.mapping[host] = self._connections.size + self._connections.size += 1 + logger.debug("Checked in connection for peer:", host + ", at index:", self._connections.size) + + fn take(mut self, host: String) raises -> ConnectionType: + var index: Int + try: + index = self.mapping[host] + _ = self.mapping.pop(host) + except: + raise Error("PoolManager.take: Key not found.") + + var connection = self._connections.pop(index) + # Shift everything over by one + for kv in self.mapping.items(): + if kv[].value > index: + self.mapping[kv[].key] -= 1 + + logger.debug("Checked out connection for peer:", host + ", from index:", self._connections.size + 1) + return connection^ + + fn clear(mut self): + while self._connections: + var connection = self._connections.pop(0) + try: + connection.teardown() + except e: + # TODO: This is used in __del__, would be nice if we didn't have to absorb the error. + logger.error("Failed to tear down connection. Error:", e) + self.mapping.clear() + + fn __contains__(self, host: String) -> Bool: + return host in self.mapping + + fn __setitem__(mut self, host: String, owned value: ConnectionType) raises -> None: + if host in self.mapping: + self._connections[self.mapping[host]] = value^ + else: + self.give(host, value^) + + fn __getitem__(self, host: String) raises -> ref [self._connections] ConnectionType: + return self._connections[self.mapping[host]] diff --git a/lightbug_http/server.mojo b/lightbug_http/server.mojo index 3545bcc8..59e1ffd1 100644 --- a/lightbug_http/server.mojo +++ b/lightbug_http/server.mojo @@ -3,7 +3,8 @@ from lightbug_http.io.sync import Duration from lightbug_http.io.bytes import Bytes, bytes from lightbug_http.strings import NetworkType from lightbug_http.utils import ByteReader, logger -from lightbug_http.net import NoTLSListener, default_buffer_size, NoTLSListener, SysConnection, SysNet +from lightbug_http.net import NoTLSListener, default_buffer_size, TCPConnection, ListenConfig, TCPAddr +from lightbug_http.socket import Socket from lightbug_http.http import HTTPRequest, encode from lightbug_http.http.common_response import InternalError from lightbug_http.uri import URI @@ -16,66 +17,69 @@ alias DefaultConcurrency: Int = 256 * 1024 alias default_max_request_body_size = 4 * 1024 * 1024 # 4MB -@value -struct Server: - """ - A Mojo-based server that accept incoming requests and delivers HTTP services. - """ +struct Server(Movable): + """A Mojo-based server that accept incoming requests and delivers HTTP services.""" var error_handler: ErrorHandler var name: String - var __address: String - var max_concurrent_connections: Int - var max_requests_per_connection: Int + var _address: String + var max_concurrent_connections: UInt + var max_requests_per_connection: UInt - var __max_request_body_size: Int + var _max_request_body_size: UInt var tcp_keep_alive: Bool - var ln: NoTLSListener - fn __init__( out self, error_handler: ErrorHandler = ErrorHandler(), name: String = "lightbug_http", address: String = "127.0.0.1", - max_concurrent_connections: Int = 1000, - max_requests_per_connection: Int = 0, - max_request_body_size: Int = default_max_request_body_size, + max_concurrent_connections: UInt = 1000, + max_requests_per_connection: UInt = 0, + max_request_body_size: UInt = default_max_request_body_size, tcp_keep_alive: Bool = False, ) raises: self.error_handler = error_handler self.name = name - self.__address = address - self.max_concurrent_connections = max_concurrent_connections + self._address = address self.max_requests_per_connection = max_requests_per_connection - self.__max_request_body_size = default_max_request_body_size + self._max_request_body_size = default_max_request_body_size self.tcp_keep_alive = tcp_keep_alive - self.ln = NoTLSListener() - - fn address(self) -> String: - return self.__address + if max_concurrent_connections == 0: + self.max_concurrent_connections = DefaultConcurrency + else: + self.max_concurrent_connections = max_concurrent_connections + + fn __moveinit__(mut self, owned other: Server) -> None: + self.error_handler = other.error_handler^ + self.name = other.name^ + self._address = other._address^ + self.max_concurrent_connections = other.max_concurrent_connections + self.max_requests_per_connection = other.max_requests_per_connection + self._max_request_body_size = other._max_request_body_size + self.tcp_keep_alive = other.tcp_keep_alive + + fn address(self) -> ref [self._address] String: + return self._address fn set_address(mut self, own_address: String) -> None: - self.__address = own_address + self._address = own_address - fn max_request_body_size(self) -> Int: - return self.__max_request_body_size + fn max_request_body_size(self) -> UInt: + return self._max_request_body_size - fn set_max_request_body_size(mut self, size: Int) -> None: - self.__max_request_body_size = size + fn set_max_request_body_size(mut self, size: UInt) -> None: + self._max_request_body_size = size - fn get_concurrency(self) -> Int: + fn get_concurrency(self) -> UInt: """Retrieve the concurrency level which is either the configured `max_concurrent_connections` or the `DefaultConcurrency`. Returns: Concurrency level for the server. """ - var concurrency = self.max_concurrent_connections - if concurrency <= 0: - concurrency = DefaultConcurrency - return concurrency + return self.max_concurrent_connections fn listen_and_serve[T: HTTPService](mut self, address: String, mut handler: T) raises: """Listen for incoming connections and serve HTTP requests. @@ -87,12 +91,12 @@ struct Server: address: The address (host:port) to listen on. handler: An object that handles incoming HTTP requests. """ - var net = SysNet() - var listener = net.listen(NetworkType.tcp4.value, address) + var net = ListenConfig() + var listener = net.listen[NetworkType.tcp4](address) self.set_address(address) - self.serve(listener, handler) + self.serve(listener^, handler) - fn serve[T: HTTPService](mut self, ln: NoTLSListener, mut handler: T) raises: + fn serve[T: HTTPService](mut self, owned ln: NoTLSListener, mut handler: T) raises: """Serve HTTP requests. Parameters: @@ -105,12 +109,11 @@ struct Server: Raises: If there is an error while serving requests. """ - self.ln = ln while True: - var conn = self.ln.accept() + var conn = ln.accept() self.serve_connection(conn, handler) - fn serve_connection[T: HTTPService](mut self, mut conn: SysConnection, mut handler: T) raises -> None: + fn serve_connection[T: HTTPService](mut self, mut conn: TCPConnection, mut handler: T) raises -> None: """Serve a single connection. Parameters: @@ -123,6 +126,9 @@ struct Server: Raises: If there is an error while serving the connection. """ + logger.debug( + "Connection accepted! IP:", conn.socket._remote_address.ip, "Port:", conn.socket._remote_address.port + ) var max_request_body_size = self.max_request_body_size() if max_request_body_size <= 0: max_request_body_size = default_max_request_body_size @@ -131,14 +137,19 @@ struct Server: while True: req_number += 1 - # TODO: We should read until 0 bytes are received. + # TODO: We should read until 0 bytes are received. (@thatstoasty) # If we completely fill the buffer haven't read the full request, we end up processing a partial request. var b = Bytes(capacity=default_buffer_size) - var bytes_recv = conn.read(b) - # TODO: Should the connection be closed here? The client should close it for 1.1 http. - if bytes_recv == 0: - conn.close() - break + try: + _ = conn.read(b) + except e: + conn.teardown() + # 0 bytes were read from the peer, which indicates their side of the connection was closed. + if str(e) == "EOF": + break + else: + logger.error(e) + raise Error("Server.serve_connection: Failed to read request") var request: HTTPRequest try: @@ -146,31 +157,41 @@ struct Server: except e: logger.error(e) raise Error("Server.serve_connection: Failed to parse request") - - var res: HTTPResponse + + var response: HTTPResponse try: - res = handler.func(request) + response = handler.func(request) except: - if not conn._closed: + if not conn.is_closed(): + # Try to send back an internal server error, but always attempt to teardown the connection. try: + # TODO: Move InternalError response to an alias when Mojo can support Dict operations at compile time. (@thatstoasty) _ = conn.write(encode(InternalError())) - conn.close() except e: logger.error(e) raise Error("Failed to send InternalError response") + finally: + conn.teardown() return + # If the server is set to not support keep-alive connections, or the client requests a connection close, we mark the connection to be closed. var close_connection = (not self.tcp_keep_alive) or request.connection_close() if close_connection: - res.set_connection_close() - + response.set_connection_close() + + logger.debug( + conn.socket._remote_address.ip, + str(conn.socket._remote_address.port), + request.method, + request.uri.path, + response.status_code, + ) try: - _ = conn.write(encode(res^)) + _ = conn.write(encode(response^)) except e: - conn.close() + conn.teardown() break - + if close_connection: - conn.close() + conn.teardown() break - diff --git a/lightbug_http/socket.mojo b/lightbug_http/socket.mojo new file mode 100644 index 00000000..0b99cec0 --- /dev/null +++ b/lightbug_http/socket.mojo @@ -0,0 +1,594 @@ +from memory import Span, stack_allocation +from utils import StaticTuple +from sys import sizeof, external_call +from sys.info import os_is_macos +from memory import Pointer, UnsafePointer +from lightbug_http.libc import ( + socket, + connect, + recv, + # recvfrom, + send, + # sendto, + shutdown, + inet_pton, + inet_ntop, + htons, + ntohs, + gai_strerror, + bind, + listen, + accept, + setsockopt, + getsockopt, + getsockname, + getpeername, + close, + sockaddr, + sockaddr_in, + addrinfo, + socklen_t, + c_void, + c_uint, + c_char, + c_int, + in_addr, + SHUT_RDWR, + SOL_SOCKET, + AF_INET, + AF_INET6, + SOCK_STREAM, + INET_ADDRSTRLEN, + SO_REUSEADDR, + SO_RCVTIMEO, + CloseInvalidDescriptorError, + ShutdownInvalidArgumentError, +) +from lightbug_http.io.bytes import Bytes +from lightbug_http.strings import NetworkType +from lightbug_http.net import ( + Addr, + TCPAddr, + default_buffer_size, + binary_port_to_int, + binary_ip_to_string, + addrinfo_macos, + addrinfo_unix, +) +from lightbug_http.utils import logger + + +alias SocketClosedError = "Socket: Socket is already closed" + + +struct Socket[AddrType: Addr, address_family: Int = AF_INET](Representable, Stringable, Writable): + """Represents a network file descriptor. Wraps around a file descriptor and provides network functions. + + Args: + local_address: The local address of the socket (local address if bound). + remote_address: The remote address of the socket (peer's address if connected). + address_family: The address family of the socket. + socket_type: The socket type. + protocol: The protocol. + """ + + var fd: Int32 + """The file descriptor of the socket.""" + var socket_type: Int32 + """The socket type.""" + var protocol: Byte + """The protocol.""" + var _local_address: AddrType + """The local address of the socket (local address if bound).""" + var _remote_address: AddrType + """The remote address of the socket (peer's address if connected).""" + var _closed: Bool + """Whether the socket is closed.""" + var _connected: Bool + """Whether the socket is connected.""" + + fn __init__( + out self, + local_address: AddrType = AddrType(), + remote_address: AddrType = AddrType(), + socket_type: Int32 = SOCK_STREAM, + protocol: Byte = 0, + ) raises: + """Create a new socket object. + + Args: + local_address: The local address of the socket (local address if bound). + remote_address: The remote address of the socket (peer's address if connected). + socket_type: The socket type. + protocol: The protocol. + + Raises: + Error: If the socket creation fails. + """ + self.socket_type = socket_type + self.protocol = protocol + + self.fd = socket(address_family, socket_type, 0) + self._local_address = local_address + self._remote_address = remote_address + self._closed = False + self._connected = False + + fn __init__( + out self, + fd: Int32, + socket_type: Int32, + protocol: Byte, + local_address: AddrType, + remote_address: AddrType = AddrType(), + ): + """ + Create a new socket object when you already have a socket file descriptor. Typically through socket.accept(). + + Args: + fd: The file descriptor of the socket. + socket_type: The socket type. + protocol: The protocol. + local_address: The local address of the socket (local address if bound). + remote_address: The remote address of the socket (peer's address if connected). + """ + self.fd = fd + self.socket_type = socket_type + self.protocol = protocol + self._local_address = local_address + self._remote_address = remote_address + self._closed = False + self._connected = True + + fn __moveinit__(out self, owned existing: Self): + """Initialize a new socket object by moving the data from an existing socket object. + + Args: + existing: The existing socket object to move the data from. + """ + self.fd = existing.fd + self.socket_type = existing.socket_type + self.protocol = existing.protocol + self._local_address = existing._local_address^ + self._remote_address = existing._remote_address^ + self._closed = existing._closed + self._connected = existing._connected + + fn teardown(mut self) raises: + """Close the socket and free the file descriptor.""" + if self._connected: + try: + self.shutdown() + except e: + logger.debug("Socket.teardown: Failed to shutdown socket: " + str(e)) + + if not self._closed: + try: + self.close() + except e: + logger.error("Socket.teardown: Failed to close socket.") + raise e + + fn __enter__(owned self) -> Self: + return self^ + + fn __exit__(mut self) raises: + self.teardown() + + # TODO: Seems to be bugged if this is included. Mojo tries to delete a mystical 0 fd socket that was never initialized? + # fn __del__(owned self): + # """Close the socket when the object is deleted.""" + # logger.info("In socket del", self) + # try: + # self.teardown() + # except e: + # logger.debug("Socket.__del__: Failed to close socket during deletion:", str(e)) + + fn __str__(self) -> String: + return String.write(self) + + fn __repr__(self) -> String: + return String.write(self) + + fn write_to[W: Writer, //](self, mut writer: W): + @parameter + fn af() -> String: + if address_family == AF_INET: + return "AF_INET" + else: + return "AF_INET6" + + writer.write( + "Socket[", + AddrType._type, + ", ", + af(), + "]", + "(", + "fd=", + str(self.fd), + ", _local_address=", + repr(self._local_address), + ", _remote_address=", + repr(self._remote_address), + ", _closed=", + str(self._closed), + ", _connected=", + str(self._connected), + ")", + ) + + fn local_address(ref self) -> ref [self._local_address] AddrType: + """Return the local address of the socket as a UDP address. + + Returns: + The local address of the socket as a UDP address. + """ + return self._local_address + + fn set_local_address(mut self, address: AddrType) -> None: + """Set the local address of the socket. + + Args: + address: The local address to set. + """ + self._local_address = address + + fn remote_address(ref self) -> ref [self._remote_address] AddrType: + """Return the remote address of the socket as a UDP address. + + Returns: + The remote address of the socket as a UDP address. + """ + return self._remote_address + + fn set_remote_address(mut self, address: AddrType) -> None: + """Set the remote address of the socket. + + Args: + address: The remote address to set. + """ + self._remote_address = address + + fn accept(self) raises -> Socket[AddrType]: + """Accept a connection. The socket must be bound to an address and listening for connections. + The return value is a connection where conn is a new socket object usable to send and receive data on the connection, + and address is the address bound to the socket on the other end of the connection. + + Returns: + A new socket object and the address of the remote socket. + + Raises: + Error: If the connection fails. + """ + var new_socket_fd: c_int + try: + new_socket_fd = accept(self.fd) + except e: + logger.error(e) + raise Error("Socket.accept: Failed to accept connection, system `accept()` returned an error.") + + var new_socket = Socket( + fd=new_socket_fd, + socket_type=self.socket_type, + protocol=self.protocol, + local_address=self.local_address(), + ) + var peer = new_socket.get_peer_name() + new_socket.set_remote_address(AddrType(peer[0], peer[1])) + return new_socket^ + + fn listen(self, backlog: UInt = 0) raises: + """Enable a server to accept connections. + + Args: + backlog: The maximum number of queued connections. Should be at least 0, and the maximum is system-dependent (usually 5). + + Raises: + Error: If listening for a connection fails. + """ + try: + listen(self.fd, backlog) + except e: + logger.error(e) + raise Error("Socket.listen: Failed to listen for connections.") + + fn bind[network: String = NetworkType.tcp4.value](mut self, address: String, port: UInt16) raises: + """Bind the socket to address. The socket must not already be bound. (The format of address depends on the address family). + + When a socket is created with Socket(), it exists in a name + space (address family) but has no address assigned to it. bind() + assigns the address specified by addr to the socket referred to + by the file descriptor fd. addrlen specifies the size, in + bytes, of the address structure pointed to by addr. + Traditionally, this operation is called 'assigning a name to a + socket'. + + Args: + address: The IP address to bind the socket to. + port: The port number to bind the socket to. + + Raises: + Error: If binding the socket fails. + """ + var binary_ip: c_uint + try: + binary_ip = inet_pton[address_family](address.unsafe_ptr()) + except e: + logger.error(e) + raise Error("ListenConfig.listen: Failed to convert IP address to binary form.") + + var local_address = sockaddr_in( + address_family=address_family, + port=port, + binary_ip=binary_ip, + ) + try: + bind(self.fd, local_address) + except e: + logger.error(e) + raise Error("Socket.bind: Binding socket failed.") + + var local = self.get_sock_name() + self._local_address = AddrType(local[0], local[1]) + + fn get_sock_name(self) raises -> (String, UInt16): + """Return the address of the socket. + + Returns: + The address of the socket. + + Raises: + Error: If getting the address of the socket fails. + """ + if self._closed: + raise SocketClosedError + + # TODO: Add check to see if the socket is bound and error if not. + var local_address = stack_allocation[1, sockaddr]() + try: + getsockname( + self.fd, + local_address, + Pointer.address_of(socklen_t(sizeof[sockaddr]())), + ) + except e: + logger.error(e) + raise Error("get_sock_name: Failed to get address of local socket.") + + var addr_in = local_address.bitcast[sockaddr_in]().take_pointee() + return binary_ip_to_string[address_family](addr_in.sin_addr.s_addr), UInt16( + binary_port_to_int(addr_in.sin_port) + ) + + fn get_peer_name(self) raises -> (String, UInt16): + """Return the address of the peer connected to the socket. + + Returns: + The address of the peer connected to the socket. + + Raises: + Error: If getting the address of the peer connected to the socket fails. + """ + if self._closed: + raise SocketClosedError + + # TODO: Add check to see if the socket is bound and error if not. + var addr_in: sockaddr_in + try: + addr_in = getpeername(self.fd) + except e: + logger.error(e) + raise Error("get_peer_name: Failed to get address of remote socket.") + + return binary_ip_to_string[address_family](addr_in.sin_addr.s_addr), UInt16( + binary_port_to_int(addr_in.sin_port) + ) + + fn get_socket_option(self, option_name: Int) raises -> Int: + """Return the value of the given socket option. + + Args: + option_name: The socket option to get. + + Returns: + The value of the given socket option. + + Raises: + Error: If getting the socket option fails. + """ + try: + return getsockopt(self.fd, SOL_SOCKET, option_name) + except e: + # TODO: Should this be a warning or an error? + logger.warn("Socket.get_socket_option: Failed to get socket option.") + raise e + + fn set_socket_option(self, option_name: Int, owned option_value: Byte = 1) raises: + """Return the value of the given socket option. + + Args: + option_name: The socket option to set. + option_value: The value to set the socket option to. Defaults to 1 (True). + + Raises: + Error: If setting the socket option fails. + """ + try: + setsockopt(self.fd, SOL_SOCKET, option_name, option_value) + except e: + # TODO: Should this be a warning or an error? + logger.warn("Socket.set_socket_option: Failed to set socket option.") + raise e + + fn connect(mut self, address: String, port: UInt16) raises -> None: + """Connect to a remote socket at address. + + Args: + address: The IP address to connect to. + port: The port number to connect to. + + Raises: + Error: If connecting to the remote socket fails. + """ + + @parameter + if os_is_macos(): + ip = addrinfo_macos().get_ip_address(address) + else: + ip = addrinfo_unix().get_ip_address(address) + + var addr = sockaddr_in(address_family=address_family, port=port, binary_ip=ip.s_addr) + try: + connect(self.fd, addr) + except e: + logger.error("Socket.connect: Failed to establish a connection to the server.") + raise e + + var remote = self.get_peer_name() + self._remote_address = AddrType(remote[0], remote[1]) + + fn send(self, buffer: Span[Byte]) raises -> Int: + if buffer[-1] == 0: + raise Error("Socket.send: Buffer must not be null-terminated.") + + try: + return send(self.fd, buffer.unsafe_ptr(), len(buffer), 0) + except e: + logger.error("Socket.send: Failed to write data to connection.") + raise e + + fn send_all(self, src: Span[Byte], max_attempts: Int = 3) raises -> None: + """Send data to the socket. The socket must be connected to a remote socket. + + Args: + src: The data to send. + max_attempts: The maximum number of attempts to send the data. + + Raises: + Error: If sending the data fails, or if the data is not sent after the maximum number of attempts. + """ + var total_bytes_sent = 0 + var attempts = 0 + + # Try to send all the data in the buffer. If it did not send all the data, keep trying but start from the offset of the last successful send. + while total_bytes_sent < len(src): + if attempts > max_attempts: + raise Error("Failed to send message after " + str(max_attempts) + " attempts.") + + var sent: Int + try: + sent = self.send(src[total_bytes_sent:]) + except e: + logger.error(e) + raise Error( + "Socket.send_all: Failed to send message, wrote" + str(total_bytes_sent) + "bytes before failing." + ) + + total_bytes_sent += sent + attempts += 1 + + fn _receive(self, mut buffer: Bytes) raises -> Int: + """Receive data from the socket into the buffer. + + Args: + buffer: The buffer to read data into. + + Returns: + The buffer with the received data, and an error if one occurred. + + Raises: + Error: If reading data from the socket fails. + EOF: If 0 bytes are received, return EOF. + """ + var bytes_received: Int + try: + bytes_received = recv( + self.fd, + buffer.unsafe_ptr().offset(buffer.size), + buffer.capacity - buffer.size, + 0, + ) + buffer.size += bytes_received + except e: + logger.error(e) + raise Error("Socket.receive: Failed to read data from connection.") + + if bytes_received == 0: + raise Error("EOF") + + return bytes_received + + fn receive(self, size: Int = default_buffer_size) raises -> List[Byte, True]: + """Receive data from the socket into the buffer with capacity of `size` bytes. + + Args: + size: The size of the buffer to receive data into. + + Returns: + The buffer with the received data, and an error if one occurred. + """ + var buffer = Bytes(capacity=size) + _ = self._receive(buffer) + return buffer + + fn receive_into(self, mut buffer: Bytes) raises -> Int: + """Receive data from the socket into the buffer. + + Args: + buffer: The buffer to read data into. + + Returns: + The buffer with the received data, and an error if one occurred. + + Raises: + Error: If reading data from the socket fails. + EOF: If 0 bytes are received, return EOF. + """ + return self._receive(buffer) + + fn shutdown(mut self) raises -> None: + """Shut down the socket. The remote end will receive no more data (after queued data is flushed).""" + try: + shutdown(self.fd, SHUT_RDWR) + except e: + # For the other errors, either the socket is already closed or the descriptor is invalid. + # At that point we can feasibly say that the socket is already shut down. + if str(e) == ShutdownInvalidArgumentError: + logger.error("Socket.shutdown: Failed to shutdown socket.") + raise e + logger.debug(e) + + self._connected = False + + fn close(mut self) raises -> None: + """Mark the socket closed. + Once that happens, all future operations on the socket object will fail. + The remote end will receive no more data (after queued data is flushed). + + Raises: + Error: If closing the socket fails. + """ + try: + close(self.fd) + except e: + # If the file descriptor is invalid, then it was most likely already closed. + # Other errors indicate a failure while attempting to close the socket. + if str(e) != CloseInvalidDescriptorError: + logger.error("Socket.close: Failed to close socket.") + raise e + logger.debug(e) + + self._closed = True + + fn get_timeout(self) raises -> Int: + """Return the timeout value for the socket.""" + return self.get_socket_option(SO_RCVTIMEO) + + fn set_timeout(self, owned duration: Int) raises: + """Set the timeout value for the socket. + + Args: + duration: Seconds - The timeout duration in seconds. + """ + self.set_socket_option(SO_RCVTIMEO, duration) diff --git a/lightbug_http/strings.mojo b/lightbug_http/strings.mojo index a06c58b2..700f4ef3 100644 --- a/lightbug_http/strings.mojo +++ b/lightbug_http/strings.mojo @@ -1,6 +1,5 @@ from utils import StringSlice from memory import Span -from lightbug_http.io.bytes import Bytes from lightbug_http.io.bytes import Bytes, bytes, byte alias strSlash = "/" @@ -33,7 +32,7 @@ struct BytesConstant: @value -struct NetworkType: +struct NetworkType(EqualityComparableCollectionElement): var value: String alias empty = NetworkType("") @@ -48,6 +47,39 @@ struct NetworkType: alias ip6 = NetworkType("ip6") alias unix = NetworkType("unix") + alias SUPPORTED_TYPES = [ + Self.tcp, + Self.tcp4, + Self.tcp6, + Self.udp, + Self.udp4, + Self.udp6, + Self.ip, + Self.ip4, + Self.ip6, + ] + alias TCP_TYPES = [ + Self.tcp, + Self.tcp4, + Self.tcp6, + ] + alias UDP_TYPES = [ + Self.udp, + Self.udp4, + Self.udp6, + ] + alias IP_TYPES = [ + Self.ip, + Self.ip4, + Self.ip6, + ] + + fn __eq__(self, other: NetworkType) -> Bool: + return self.value == other.value + + fn __ne__(self, other: NetworkType) -> Bool: + return self.value != other.value + @value struct ConnType: diff --git a/lightbug_http/uri.mojo b/lightbug_http/uri.mojo index c8b314e7..d56295ea 100644 --- a/lightbug_http/uri.mojo +++ b/lightbug_http/uri.mojo @@ -12,12 +12,12 @@ from lightbug_http.strings import ( @value -struct URI(Writable, Stringable): - var __path_original: String +struct URI(Writable, Stringable, Representable): + var _original_path: String var scheme: String var path: String var query_string: String - var __hash: String + var _hash: String var host: String var full_uri: String @@ -27,96 +27,93 @@ struct URI(Writable, Stringable): var password: String @staticmethod - fn parse(uri: String) -> Variant[URI, String]: - var u = URI(uri) - try: - u._parse() - except e: - return "Failed to parse URI: " + str(e) - - return u - - @staticmethod - fn parse_raises(uri: String) raises -> URI: - var u = URI(uri) - u._parse() - return u - - fn __init__( - mut self, - uri: String = "", - ) -> None: - self.__path_original = "/" - self.scheme = "" - self.path = "/" - self.query_string = "" - self.__hash = "" - self.host = "" - self.full_uri = uri - self.request_uri = "" - self.username = "" - self.password = "" - - fn __str__(self) -> String: - var s = self.scheme + "://" + self.host + self.path - if len(self.query_string) > 0: - s += "?" + self.query_string - return s - - fn write_to[T: Writer](self, mut writer: T): - writer.write(str(self)) - - fn is_https(self) -> Bool: - return self.scheme == https - - fn is_http(self) -> Bool: - return self.scheme == http or len(self.scheme) == 0 - - fn _parse(mut self) raises -> None: + fn parse(uri: String) -> URI: var proto_str = String(strHttp11) var is_https = False - var proto_end = self.full_uri.find("://") + var proto_end = uri.find("://") var remainder_uri: String if proto_end >= 0: - proto_str = self.full_uri[:proto_end] + proto_str = uri[:proto_end] if proto_str == https: is_https = True - remainder_uri = self.full_uri[proto_end + 3 :] + remainder_uri = uri[proto_end + 3 :] else: - remainder_uri = self.full_uri - - self.scheme = proto_str^ + remainder_uri = uri var path_start = remainder_uri.find("/") var host_and_port: String var request_uri: String + var host: String if path_start >= 0: host_and_port = remainder_uri[:path_start] request_uri = remainder_uri[path_start:] - self.host = host_and_port[:path_start] + host = host_and_port[:path_start] else: host_and_port = remainder_uri request_uri = strSlash - self.host = host_and_port + host = host_and_port + var scheme: String if is_https: - self.scheme = https + scheme = https else: - self.scheme = http + scheme = http var n = request_uri.find("?") + var original_path: String + var query_string: String if n >= 0: - self.__path_original = request_uri[:n] - self.query_string = request_uri[n + 1 :] + original_path = request_uri[:n] + query_string = request_uri[n + 1 :] else: - self.__path_original = request_uri - self.query_string = Bytes() + original_path = request_uri + query_string = "" + + return URI( + _original_path=original_path, + scheme=scheme, + path=original_path, + query_string=query_string, + _hash="", + host=host, + full_uri=uri, + request_uri=request_uri, + username="", + password="", + ) - self.path = self.__path_original - self.request_uri = request_uri + fn __str__(self) -> String: + var result = String.write(self.scheme, "://", self.host, self.path) + if len(self.query_string) > 0: + result.write("?", self.query_string) + return result^ + fn __repr__(self) -> String: + return String.write(self) -fn normalise_path(path: String, path_original: String) -> String: - # TODO: implement - return path + fn write_to[T: Writer](self, mut writer: T): + writer.write( + "URI(", + "scheme=", + repr(self.scheme), + ", host=", + repr(self.host), + ", path=", + repr(self.path), + ", _original_path=", + repr(self._original_path), + ", query_string=", + repr(self.query_string), + ", full_uri=", + repr(self.full_uri), + ", request_uri=", + repr(self.request_uri), + ")", + ) + + fn is_https(self) -> Bool: + return self.scheme == https + + fn is_http(self) -> Bool: + return self.scheme == http or len(self.scheme) == 0 diff --git a/lightbug_http/utils.mojo b/lightbug_http/utils.mojo index a8d493c2..4097568b 100644 --- a/lightbug_http/utils.mojo +++ b/lightbug_http/utils.mojo @@ -1,7 +1,8 @@ +from memory import Span +from sys.param_env import env_get_string from lightbug_http.io.bytes import Bytes, Byte from lightbug_http.strings import BytesConstant from lightbug_http.net import default_buffer_size -from memory import memcpy, Span @always_inline @@ -19,18 +20,16 @@ struct ByteWriter(Writer): fn __init__(out self, capacity: Int = default_buffer_size): self._inner = Bytes(capacity=capacity) - + @always_inline fn write_bytes(mut self, bytes: Span[Byte]) -> None: - """Writes the contents of `src` into the internal buffer. - If `total_bytes_written` < `len(src)`, it also returns an error explaining - why the write is short. + """Writes the contents of `bytes` into the internal buffer. Args: bytes: The bytes to write. """ self._inner.extend(bytes) - + fn write[*Ts: Writable](mut self, *args: *Ts) -> None: """Write data to the `Writer`. @@ -40,6 +39,7 @@ struct ByteWriter(Writer): Args: args: The data to write. """ + @parameter fn write_arg[T: Writable](arg: T): arg.write_to(self) @@ -67,6 +67,10 @@ struct ByteWriter(Writer): return ret^ +alias EndOfReaderError = "No more bytes to read." +alias OutOfBoundsError = "Tried to read past the end of the ByteReader." + + struct ByteReader[origin: Origin]: var _inner: Span[Byte, origin] var read_pos: Int @@ -75,15 +79,37 @@ struct ByteReader[origin: Origin]: self._inner = b self.read_pos = 0 - fn peek(self) -> Byte: - if self.read_pos >= len(self._inner): - return 0 + @always_inline + fn available(self) -> Bool: + return self.read_pos < len(self._inner) + + fn __len__(self) -> Int: + return len(self._inner) - self.read_pos + + fn peek(self) raises -> Byte: + if not self.available(): + raise EndOfReaderError return self._inner[self.read_pos] + fn read_bytes(mut self, n: Int = -1) raises -> Span[Byte, origin]: + var count = n + var start = self.read_pos + if n == -1: + count = len(self) + + if start + count > len(self._inner): + raise OutOfBoundsError + + self.read_pos += count + return self._inner[start : start + count] + fn read_until(mut self, char: Byte) -> Span[Byte, origin]: var start = self.read_pos - while self.peek() != char: + for i in range(start, len(self._inner)): + if self._inner[i] == char: + break self.increment() + return self._inner[start : self.read_pos] @always_inline @@ -92,10 +118,17 @@ struct ByteReader[origin: Origin]: fn read_line(mut self) -> Span[Byte, origin]: var start = self.read_pos - while not is_newline(self.peek()): + for i in range(start, len(self._inner)): + if is_newline(self._inner[i]): + break self.increment() + + # If we are at the end of the buffer, there is no newline to check for. var ret = self._inner[start : self.read_pos] - if self.peek() == BytesConstant.rChar: + if not self.available(): + return ret + + if self._inner[self.read_pos] == BytesConstant.rChar: self.increment(2) else: self.increment() @@ -103,33 +136,30 @@ struct ByteReader[origin: Origin]: @always_inline fn skip_whitespace(mut self): - while is_space(self.peek()): - self.increment() + for i in range(self.read_pos, len(self._inner)): + if is_space(self._inner[i]): + self.increment() + else: + break @always_inline - fn skip_newlines(mut self): - while self.peek() == BytesConstant.rChar: - self.increment(2) + fn skip_carriage_return(mut self): + for i in range(self.read_pos, len(self._inner)): + if self._inner[i] == BytesConstant.rChar: + self.increment(2) + else: + break @always_inline fn increment(mut self, v: Int = 1): self.read_pos += v @always_inline - fn bytes(mut self, bytes_len: Int = -1) -> Bytes: - var pos = self.read_pos - var read_len: Int - if bytes_len == -1: - self.read_pos = -1 - read_len = len(self._inner) - pos - else: - self.read_pos += bytes_len - read_len = bytes_len + fn consume(owned self, bytes_len: Int = -1) -> Bytes: + return self^._inner[self.read_pos : self.read_pos + len(self) + 1] - return self._inner[pos : pos + read_len + 1] - -struct LogLevel(): +struct LogLevel: alias FATAL = 0 alias ERROR = 1 alias WARN = 2 @@ -137,59 +167,105 @@ struct LogLevel(): alias DEBUG = 4 +fn get_log_level() -> Int: + """Returns the log level based on the parameter environment variable `LOG_LEVEL`. + + Returns: + The log level. + """ + alias level = env_get_string["LB_LOG_LEVEL", "INFO"]() + if level == "INFO": + return LogLevel.INFO + elif level == "WARN": + return LogLevel.WARN + elif level == "ERROR": + return LogLevel.ERROR + elif level == "DEBUG": + return LogLevel.DEBUG + elif level == "FATAL": + return LogLevel.FATAL + else: + return LogLevel.INFO + + +alias LOG_LEVEL = get_log_level() +"""Logger level determined by the `LB_LOG_LEVEL` param environment variable. + +When building or running the application, you can set `LB_LOG_LEVEL` by providing the the following option: + +```bash +mojo build ... -D LB_LOG_LEVEL=DEBUG +# or +mojo ... -D LB_LOG_LEVEL=DEBUG +``` +""" + + @value -struct Logger(): - var level: Int +struct Logger[level: Int]: + alias STDOUT = 1 + alias STDERR = 2 - fn __init__(out self, level: Int = LogLevel.INFO): - self.level = level + fn _log_message[event_level: Int](self, message: String): + @parameter + if level >= event_level: - fn _log_message(self, message: String, level: Int): - if self.level >= level: - if level < LogLevel.WARN: - print(message, file=2) + @parameter + if event_level < LogLevel.WARN: + # Write to stderr if FATAL or ERROR + print(message, file=Self.STDERR) else: print(message) fn info[*Ts: Writable](self, *messages: *Ts): var msg = String.write("\033[36mINFO\033[0m - ") + @parameter fn write_message[T: Writable](message: T): msg.write(message, " ") + messages.each[write_message]() - self._log_message(msg, LogLevel.INFO) + self._log_message[LogLevel.INFO](msg) fn warn[*Ts: Writable](self, *messages: *Ts): var msg = String.write("\033[33mWARN\033[0m - ") + @parameter fn write_message[T: Writable](message: T): msg.write(message, " ") + messages.each[write_message]() - self._log_message(msg, LogLevel.WARN) + self._log_message[LogLevel.WARN](msg) fn error[*Ts: Writable](self, *messages: *Ts): var msg = String.write("\033[31mERROR\033[0m - ") + @parameter fn write_message[T: Writable](message: T): msg.write(message, " ") + messages.each[write_message]() - self._log_message(msg, LogLevel.ERROR) + self._log_message[LogLevel.ERROR](msg) fn debug[*Ts: Writable](self, *messages: *Ts): var msg = String.write("\033[34mDEBUG\033[0m - ") + @parameter fn write_message[T: Writable](message: T): msg.write(message, " ") + messages.each[write_message]() - self._log_message(msg, LogLevel.DEBUG) + self._log_message[LogLevel.DEBUG](msg) fn fatal[*Ts: Writable](self, *messages: *Ts): var msg = String.write("\033[35mFATAL\033[0m - ") + @parameter fn write_message[T: Writable](message: T): msg.write(message, " ") + messages.each[write_message]() - self._log_message(msg, LogLevel.FATAL) + self._log_message[LogLevel.FATAL](msg) -alias logger = Logger() +alias logger = Logger[LOG_LEVEL]() diff --git a/mojoproject.toml b/mojoproject.toml index abc5f228..8ca95172 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -9,12 +9,13 @@ version = "0.1.8" [tasks] build = { cmd = "rattler-build build --recipe recipes -c https://conda.modular.com/max -c conda-forge --skip-existing=all", env = {MODULAR_MOJO_IMPORT_PATH = "$CONDA_PREFIX/lib/mojo"} } publish = { cmd = "bash scripts/publish.sh", env = { PREFIX_API_KEY = "$PREFIX_API_KEY" } } -test = { cmd = "magic run mojo test -I . tests" } -integration_test = { cmd = "bash scripts/integration_test.sh" } -bench = { cmd = "magic run mojo bench.mojo" } +test = { cmd = "magic run mojo test -I . tests/lightbug_http" } +integration_tests_py = { cmd = "bash scripts/integration_test.sh" } +integration_tests_external = { cmd = "magic run mojo test -I . tests/integration" } +bench = { cmd = "magic run mojo -I . benchmark/bench.mojo" } bench_server = { cmd = "bash scripts/bench_server.sh" } format = { cmd = "magic run mojo format -l 120 lightbug_http" } [dependencies] max = ">=24.6.0,<25" -small_time = "0.1.6" \ No newline at end of file +small_time = "==0.1.6" diff --git a/python_integration_client.py b/python_integration_client.py deleted file mode 100644 index 70cb9dce..00000000 --- a/python_integration_client.py +++ /dev/null @@ -1,4 +0,0 @@ -import requests - - -requests.get('http://127.0.0.1:8080/redirect', headers={'connection': 'keep-alive'}) \ No newline at end of file diff --git a/scripts/bench_server.sh b/scripts/bench_server.sh index 33b27f05..50d663d6 100644 --- a/scripts/bench_server.sh +++ b/scripts/bench_server.sh @@ -1,6 +1,6 @@ -magic run mojo build bench_server.mojo || exit 1 +magic run mojo build -I . benchmark/bench_server.mojo || exit 1 echo "running server..." ./bench_server& diff --git a/scripts/integration_test.sh b/scripts/integration_test.sh index 4381ce6c..c1abf525 100644 --- a/scripts/integration_test.sh +++ b/scripts/integration_test.sh @@ -1,18 +1,39 @@ #!/bin/bash +echo "[INFO] Building mojo binaries.." -(magic run mojo build --debug-level full integration_test_server.mojo) || exit 1 -(magic run mojo build --debug-level full integration_test_client.mojo) || exit 1 +kill_server() { + pid=$(ps aux | grep "$1" | grep -v grep | awk '{print $2}' | head -n 1) + kill $pid + wait $pid 2>/dev/null +} -echo "starting server..." -./integration_test_server & +test_server() { + (magic run mojo build -D LB_LOG_LEVEL=DEBUG -I . --debug-level full tests/integration/integration_test_server.mojo) || exit 1 -sleep 5 + echo "[INFO] Starting Mojo server..." + ./integration_test_server & -echo "starting test suite" -./integration_test_client + sleep 5 -kill $! -wait $! 2>/dev/null -echo "cleaning up binaries" -rm ./integration_test_server -rm ./integration_test_client \ No newline at end of file + echo "[INFO] Testing server with Python client" + magic run python3 tests/integration/integration_client.py + + rm ./integration_test_server + kill_server "integration_test_server" || echo "Failed to kill Mojo server" +} + +test_client() { + echo "[INFO] Testing Mojo client with Python server" + (magic run mojo build -D LB_LOG_LEVEL=DEBUG -I . --debug-level full tests/integration/integration_test_client.mojo) || exit 1 + + echo "[INFO] Starting Python server..." + magic run fastapi run tests/integration/integration_server.py & + sleep 5 + + ./integration_test_client + rm ./integration_test_client + kill_server "fastapi run" || echo "Failed to kill fastapi server" +} + +test_server +test_client diff --git a/tests/integration/integration_client.py b/tests/integration/integration_client.py new file mode 100644 index 00000000..c98c8ebb --- /dev/null +++ b/tests/integration/integration_client.py @@ -0,0 +1,18 @@ +import requests + + +# TODO: Pair with the Mojo integration server to test the client and server independently. +print("\n~~~ Testing redirect ~~~") +session = requests.Session() +response = session.get('http://127.0.0.1:8080/redirect', allow_redirects=True) +assert response.status_code == 200 +assert response.text == "yay you made it" + +print("\n~~~ Testing close connection ~~~") +response = session.get('http://127.0.0.1:8080/close-connection', headers={'connection': 'close'}) +assert response.status_code == 200 +assert response.text == "connection closed" + +print("\n~~~ Testing internal server error ~~~") +response = session.get('http://127.0.0.1:8080/error', headers={'connection': 'keep-alive'}) +assert response.status_code == 500 diff --git a/tests/integration/integration_server.py b/tests/integration/integration_server.py new file mode 100644 index 00000000..9da862ca --- /dev/null +++ b/tests/integration/integration_server.py @@ -0,0 +1,31 @@ +from typing import Union + +from fastapi import FastAPI, Response +from fastapi.responses import RedirectResponse, PlainTextResponse + +app = FastAPI() + + +@app.get("/redirect") +async def redirect(response: Response): + return RedirectResponse( + url="/rd-destination", status_code=308, headers={"Location": "/rd-destination"} + ) + + +@app.get("/rd-destination") +async def rd_destination(response: Response): + response.headers["Content-Type"] = "text/plain" + return PlainTextResponse("yay you made it") + + +@app.get("/close-connection") +async def close_connection(response: Response): + response.headers["Content-Type"] = "text/plain" + response.headers["Connection"] = "close" + return PlainTextResponse("connection closed") + + +@app.get("/error", status_code=500) +async def error(response: Response): + return PlainTextResponse("Internal Server Error", status_code=500) diff --git a/integration_test_client.mojo b/tests/integration/integration_test_client.mojo similarity index 80% rename from integration_test_client.mojo rename to tests/integration/integration_test_client.mojo index c3d0e647..210d2839 100644 --- a/integration_test_client.mojo +++ b/tests/integration/integration_test_client.mojo @@ -4,27 +4,29 @@ from lightbug_http.client import Client from lightbug_http.utils import logger from testing import * + fn u(s: String) raises -> URI: - return URI.parse_raises("http://127.0.0.1:8080/" + s) + return URI.parse("http://127.0.0.1:8000/" + s) + struct IntegrationTest: var client: Client var results: Dict[String, String] fn __init__(out self): - self.client = Client() + self.client = Client(allow_redirects=True) self.results = Dict[String, String]() - + fn mark_successful(mut self, name: String): self.results[name] = "✅" - + fn mark_failed(mut self, name: String): self.results[name] = "❌" fn test_redirect(mut self): alias name = "test_redirect" - logger.info("Testing redirect...") - var h = Headers(Header(HeaderKey.CONNECTION, 'keep-alive')) + print("\n~~~ Testing redirect ~~~") + var h = Headers(Header(HeaderKey.CONNECTION, "keep-alive")) try: var res = self.client.do(HTTPRequest(u("redirect"), headers=h)) assert_equal(res.status_code, StatusCode.OK) @@ -41,8 +43,8 @@ struct IntegrationTest: fn test_close_connection(mut self): alias name = "test_close_connection" - logger.info("Testing close connection...") - var h = Headers(Header(HeaderKey.CONNECTION, 'close')) + print("\n~~~ Testing close connection ~~~") + var h = Headers(Header(HeaderKey.CONNECTION, "close")) try: var res = self.client.do(HTTPRequest(u("close-connection"), headers=h)) assert_equal(res.status_code, StatusCode.OK) @@ -57,7 +59,7 @@ struct IntegrationTest: fn test_server_error(mut self): alias name = "test_server_error" - logger.info("Testing internal server error...") + print("\n~~~ Testing internal server error ~~~") try: var res = self.client.do(HTTPRequest(u("error"))) assert_equal(res.status_code, StatusCode.INTERNAL_ERROR) @@ -69,15 +71,17 @@ struct IntegrationTest: self.mark_failed(name) return - fn run_tests(mut self): + fn run_tests(mut self) -> Dict[String, String]: logger.info("Running Client Integration Tests...") self.test_redirect() self.test_close_connection() self.test_server_error() - for test in self.results.items(): - print(test[].key + ":", test[].value) + return self.results + fn main(): var test = IntegrationTest() - test.run_tests() + var results = test.run_tests() + for test in results.items(): + print(test[].key + ":", test[].value) diff --git a/integration_test_server.mojo b/tests/integration/integration_test_server.mojo similarity index 79% rename from integration_test_server.mojo rename to tests/integration/integration_test_server.mojo index 484b0d9c..ee538fb7 100644 --- a/integration_test_server.mojo +++ b/tests/integration/integration_test_server.mojo @@ -8,10 +8,8 @@ struct IntegrationTestService(HTTPService): if p == "/redirect": return HTTPResponse( "get off my lawn".as_bytes(), - headers=Headers( - Header(HeaderKey.LOCATION, "/rd-destination") - ), - status_code=StatusCode.PERMANENT_REDIRECT + headers=Headers(Header(HeaderKey.LOCATION, "/rd-destination")), + status_code=StatusCode.PERMANENT_REDIRECT, ) elif p == "/rd-destination": return OK("yay you made it") @@ -27,4 +25,3 @@ fn main() raises: var server = Server(tcp_keep_alive=True) var service = IntegrationTestService() server.listen_and_serve("127.0.0.1:8080", service) - diff --git a/tests/lightbug_http/test_client.mojo b/tests/integration/test_client.mojo similarity index 86% rename from tests/lightbug_http/test_client.mojo rename to tests/integration/test_client.mojo index c90384e4..a18aef2e 100644 --- a/tests/lightbug_http/test_client.mojo +++ b/tests/integration/test_client.mojo @@ -9,7 +9,7 @@ from lightbug_http.io.bytes import bytes fn test_mojo_client_redirect_external_req_google() raises: var client = Client() var req = HTTPRequest( - uri=URI.parse_raises("http://google.com"), + uri=URI.parse("http://google.com"), headers=Headers( Header("Connection", "close")), method="GET", @@ -23,7 +23,7 @@ fn test_mojo_client_redirect_external_req_google() raises: fn test_mojo_client_redirect_external_req_302() raises: var client = Client() var req = HTTPRequest( - uri=URI.parse_raises("http://httpbin.org/status/302"), + uri=URI.parse("http://httpbin.org/status/302"), headers=Headers( Header("Connection", "close")), method="GET", @@ -37,7 +37,7 @@ fn test_mojo_client_redirect_external_req_302() raises: fn test_mojo_client_redirect_external_req_308() raises: var client = Client() var req = HTTPRequest( - uri=URI.parse_raises("http://httpbin.org/status/308"), + uri=URI.parse("http://httpbin.org/status/308"), headers=Headers( Header("Connection", "close")), method="GET", @@ -51,7 +51,7 @@ fn test_mojo_client_redirect_external_req_308() raises: fn test_mojo_client_redirect_external_req_307() raises: var client = Client() var req = HTTPRequest( - uri=URI.parse_raises("http://httpbin.org/status/307"), + uri=URI.parse("http://httpbin.org/status/307"), headers=Headers( Header("Connection", "close")), method="GET", @@ -65,7 +65,7 @@ fn test_mojo_client_redirect_external_req_307() raises: fn test_mojo_client_redirect_external_req_301() raises: var client = Client() var req = HTTPRequest( - uri=URI.parse_raises("http://httpbin.org/status/301"), + uri=URI.parse("http://httpbin.org/status/301"), headers=Headers( Header("Connection", "close")), method="GET", @@ -81,7 +81,7 @@ fn test_mojo_client_lightbug_external_req_200() raises: try: var client = Client() var req = HTTPRequest( - uri=URI.parse_raises("http://httpbin.org/status/200"), + uri=URI.parse("http://httpbin.org/status/200"), headers=Headers( Header("Connection", "close")), method="GET", diff --git a/tests/integration/test_net.mojo b/tests/integration/test_net.mojo new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_pool_manager.mojo b/tests/integration/test_pool_manager.mojo new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_server.mojo b/tests/integration/test_server.mojo new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_socket.mojo b/tests/integration/test_socket.mojo new file mode 100644 index 00000000..e69de29b diff --git a/tests/lightbug_http/test_cookie.mojo b/tests/lightbug_http/cookie/test_cookie.mojo similarity index 64% rename from tests/lightbug_http/test_cookie.mojo rename to tests/lightbug_http/cookie/test_cookie.mojo index 6d3a21aa..bc0023ce 100644 --- a/tests/lightbug_http/test_cookie.mojo +++ b/tests/lightbug_http/cookie/test_cookie.mojo @@ -3,18 +3,19 @@ from small_time.small_time import SmallTime, now from testing import assert_true, assert_equal from collections import Optional + fn test_set_cookie() raises: cookie = Cookie( - name="mycookie", - value="myvalue", - max_age=Duration(minutes=20), - expires=Expiration.from_datetime(SmallTime(2037, 1, 22, 12, 0, 10, 0)), - path=str("/"), - domain=str("localhost"), - secure=True, - http_only=True, - same_site=SameSite.none, - partitioned=False + name="mycookie", + value="myvalue", + max_age=Duration(minutes=20), + expires=Expiration.from_datetime(SmallTime(2037, 1, 22, 12, 0, 10, 0)), + path=str("/"), + domain=str("localhost"), + secure=True, + http_only=True, + same_site=SameSite.none, + partitioned=False, ) var header = cookie.to_header() var header_value = header.value @@ -24,20 +25,16 @@ fn test_set_cookie() raises: fn test_set_cookie_partial_arguments() raises: - cookie = Cookie( - name="mycookie", - value="myvalue", - same_site=SameSite.lax - ) + cookie = Cookie(name="mycookie", value="myvalue", same_site=SameSite.lax) var header = cookie.to_header() var header_value = header.value var expected = "mycookie=myvalue; SameSite=lax" assert_equal("set-cookie", header.key) - assert_equal( header_value, expected) + assert_equal(header_value, expected) fn test_expires_http_timestamp_format() raises: var expected = "Thu, 22 Jan 2037 12:00:10 GMT" var http_date = Expiration.from_datetime(SmallTime(2037, 1, 22, 12, 0, 10, 0)).http_date_timestamp() assert_true(http_date is not None, msg="Http date is None") - assert_equal(expected , http_date.value()) + assert_equal(expected, http_date.value()) diff --git a/tests/lightbug_http/cookie/test_cookie_jar.mojo b/tests/lightbug_http/cookie/test_cookie_jar.mojo new file mode 100644 index 00000000..e69de29b diff --git a/tests/lightbug_http/cookie/test_duration.mojo b/tests/lightbug_http/cookie/test_duration.mojo new file mode 100644 index 00000000..d5e22ec0 --- /dev/null +++ b/tests/lightbug_http/cookie/test_duration.mojo @@ -0,0 +1,11 @@ +import testing +from lightbug_http.cookie.duration import Duration + + +def test_from_string(): + testing.assert_equal(Duration.from_string("10").value().total_seconds, 10) + testing.assert_false(Duration.from_string("10s").__bool__()) + + +def test_ctor(): + testing.assert_equal(Duration(seconds=1, minutes=1, hours=1, days=1).total_seconds, 90061) diff --git a/tests/lightbug_http/cookie/test_expiration.mojo b/tests/lightbug_http/cookie/test_expiration.mojo new file mode 100644 index 00000000..3d0cf0bf --- /dev/null +++ b/tests/lightbug_http/cookie/test_expiration.mojo @@ -0,0 +1,12 @@ +import testing +from lightbug_http.cookie.expiration import Expiration +from small_time import SmallTime + + +def test_ctors(): + # TODO: The string parsing is not correct, possibly a smalltime bug. I will look into it later. (@thatstoasty) + # print(Expiration.from_string("Thu, 22 Jan 2037 12:00:10 GMT").value().datetime.value(), Expiration.from_datetime(SmallTime(2037, 1, 22, 12, 0, 10, 0)).datetime.value()) + # testing.assert_true(Expiration.from_string("Thu, 22 Jan 2037 12:00:10 GMT").value() == Expiration.from_datetime(SmallTime(2037, 1, 22, 12, 0, 10, 0))) + # Failure returns None + # testing.assert_false(Expiration.from_string("abc").__bool__()) + pass diff --git a/tests/lightbug_http/http/test_request.mojo b/tests/lightbug_http/http/test_request.mojo new file mode 100644 index 00000000..d9e6fdfb --- /dev/null +++ b/tests/lightbug_http/http/test_request.mojo @@ -0,0 +1,25 @@ +import testing +from lightbug_http.http import HTTPRequest, StatusCode +from lightbug_http.strings import to_string + + +def test_request_from_bytes(): + alias data = "GET /redirect HTTP/1.1\r\nHost: 127.0.0.1:8080\r\nUser-Agent: python-requests/2.32.3\r\nAccept-Encoding: gzip, deflate, br, zstd\r\nAccept: */*\r\nconnection: keep-alive\r\n\r\n" + var request = HTTPRequest.from_bytes("127.0.0.1", 4096, data.as_bytes()) + testing.assert_equal(request.protocol, "HTTP/1.1") + testing.assert_equal(request.method, "GET") + testing.assert_equal(request.uri.request_uri, "/redirect") + testing.assert_equal(request.headers["Host"], "127.0.0.1:8080") + testing.assert_equal(request.headers["User-Agent"], "python-requests/2.32.3") + + testing.assert_false(request.connection_close()) + request.set_connection_close() + testing.assert_true(request.connection_close()) + + +def test_read_body(): + ... + + +def test_encode(): + ... diff --git a/tests/lightbug_http/http/test_response.mojo b/tests/lightbug_http/http/test_response.mojo new file mode 100644 index 00000000..c20db1ca --- /dev/null +++ b/tests/lightbug_http/http/test_response.mojo @@ -0,0 +1,55 @@ +import testing +from lightbug_http.http import HTTPResponse, StatusCode +from lightbug_http.strings import to_string + + +def test_response_from_bytes(): + alias data = "HTTP/1.1 200 OK\r\nServer: example.com\r\nUser-Agent: Mozilla/5.0\r\nContent-Type: text/html\r\nContent-Encoding: gzip\r\nContent-Length: 17\r\n\r\nThis is the body!" + var response = HTTPResponse.from_bytes(data.as_bytes()) + testing.assert_equal(response.protocol, "HTTP/1.1") + testing.assert_equal(response.status_code, 200) + testing.assert_equal(response.status_text, "OK") + testing.assert_equal(response.headers["Server"], "example.com") + testing.assert_equal(response.headers["Content-Type"], "text/html") + testing.assert_equal(response.headers["Content-Encoding"], "gzip") + + testing.assert_equal(response.content_length(), 17) + response.set_content_length(10) + testing.assert_equal(response.content_length(), 10) + + testing.assert_false(response.connection_close()) + response.set_connection_close() + testing.assert_true(response.connection_close()) + response.set_connection_keep_alive() + testing.assert_false(response.connection_close()) + testing.assert_equal(response.get_body(), "This is the body!") + + +def test_is_redirect(): + alias data = "HTTP/1.1 200 OK\r\nServer: example.com\r\nUser-Agent: Mozilla/5.0\r\nContent-Type: text/html\r\nContent-Encoding: gzip\r\nContent-Length: 17\r\n\r\nThis is the body!" + var response = HTTPResponse.from_bytes(data.as_bytes()) + testing.assert_false(response.is_redirect()) + + response.status_code = StatusCode.MOVED_PERMANENTLY + testing.assert_true(response.is_redirect()) + + response.status_code = StatusCode.FOUND + testing.assert_true(response.is_redirect()) + + response.status_code = StatusCode.TEMPORARY_REDIRECT + testing.assert_true(response.is_redirect()) + + response.status_code = StatusCode.PERMANENT_REDIRECT + testing.assert_true(response.is_redirect()) + + +def test_read_body(): + ... + + +def test_read_chunks(): + ... + + +def test_encode(): + ... diff --git a/tests/lightbug_http/test_byte_reader.mojo b/tests/lightbug_http/test_byte_reader.mojo new file mode 100644 index 00000000..9a0ceb0b --- /dev/null +++ b/tests/lightbug_http/test_byte_reader.mojo @@ -0,0 +1,76 @@ +import testing +from lightbug_http.utils import ByteReader, EndOfReaderError +from lightbug_http.io.bytes import Bytes + +alias example = "Hello, World!" + + +def test_peek(): + var r = ByteReader("H".as_bytes()) + testing.assert_equal(r.peek(), 72) + + # Peeking does not move the reader. + testing.assert_equal(r.peek(), 72) + + # Trying to peek past the end of the reader should raise an Error + r.read_pos = 1 + with testing.assert_raises(contains="No more bytes to read."): + _ = r.peek() + + +def test_read_until(): + var r = ByteReader(example.as_bytes()) + testing.assert_equal(r.read_pos, 0) + testing.assert_equal(Bytes(r.read_until(ord(","))), Bytes(72, 101, 108, 108, 111)) + testing.assert_equal(r.read_pos, 5) + + +def test_read_bytes(): + var r = ByteReader(example.as_bytes()) + testing.assert_equal(Bytes(r.read_bytes()), Bytes(72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33)) + + r = ByteReader(example.as_bytes()) + testing.assert_equal(Bytes(r.read_bytes(7)), Bytes(72, 101, 108, 108, 111, 44, 32)) + testing.assert_equal(Bytes(r.read_bytes()), Bytes(87, 111, 114, 108, 100, 33)) + + +def test_read_word(): + var r = ByteReader(example.as_bytes()) + testing.assert_equal(r.read_pos, 0) + testing.assert_equal(Bytes(r.read_word()), Bytes(72, 101, 108, 108, 111, 44)) + testing.assert_equal(r.read_pos, 6) + + +def test_read_line(): + # No newline, go to end of line + var r = ByteReader(example.as_bytes()) + testing.assert_equal(r.read_pos, 0) + testing.assert_equal(Bytes(r.read_line()), Bytes(72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33)) + testing.assert_equal(r.read_pos, 13) + + # Newline, go to end of line. Should cover carriage return and newline + var r2 = ByteReader("Hello\r\nWorld\n!".as_bytes()) + testing.assert_equal(r2.read_pos, 0) + testing.assert_equal(Bytes(r2.read_line()), Bytes(72, 101, 108, 108, 111)) + testing.assert_equal(r2.read_pos, 7) + testing.assert_equal(Bytes(r2.read_line()), Bytes(87, 111, 114, 108, 100)) + testing.assert_equal(r2.read_pos, 13) + + +def test_skip_whitespace(): + var r = ByteReader(" Hola".as_bytes()) + r.skip_whitespace() + testing.assert_equal(r.read_pos, 1) + testing.assert_equal(Bytes(r.read_word()), Bytes(72, 111, 108, 97)) + + +def test_skip_carriage_return(): + var r = ByteReader("\r\nHola".as_bytes()) + r.skip_carriage_return() + testing.assert_equal(r.read_pos, 2) + testing.assert_equal(Bytes(r.read_bytes(4)), Bytes(72, 111, 108, 97)) + + +def test_consume(): + var r = ByteReader(example.as_bytes()) + testing.assert_equal(Bytes(r^.consume()), Bytes(72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33)) diff --git a/tests/lightbug_http/test_byte_writer.mojo b/tests/lightbug_http/test_byte_writer.mojo new file mode 100644 index 00000000..86d28e11 --- /dev/null +++ b/tests/lightbug_http/test_byte_writer.mojo @@ -0,0 +1,31 @@ +import testing +from lightbug_http.utils import ByteWriter +from lightbug_http.io.bytes import Bytes + + +def test_write_byte(): + var w = ByteWriter() + w.write_byte(0x01) + testing.assert_equal(w.consume(), Bytes(0x01)) + w.write_byte(2) + testing.assert_equal(w.consume(), Bytes(2)) + + +def test_consuming_write(): + var w = ByteWriter() + var my_string: String = "World" + w.consuming_write("Hello ") + w.consuming_write(my_string^) + testing.assert_equal(w.consume(), Bytes(72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100)) + + var my_bytes = Bytes(72, 101, 108, 108, 111, 32) + w.consuming_write(my_bytes^) + w.consuming_write(Bytes(87, 111, 114, 108, 10)) + testing.assert_equal(w.consume(), Bytes(72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100)) + + +def test_write(): + var w = ByteWriter() + w.write("Hello", ", ") + w.write_bytes("World!".as_bytes()) + testing.assert_equal(w.consume(), Bytes(72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33)) diff --git a/tests/lightbug_http/test_header.mojo b/tests/lightbug_http/test_header.mojo index 5462aa32..cac3fd60 100644 --- a/tests/lightbug_http/test_header.mojo +++ b/tests/lightbug_http/test_header.mojo @@ -15,12 +15,9 @@ def test_header_case_insensitive(): def test_parse_request_header(): - var headers_str = bytes( - """GET /index.html HTTP/1.1\r\nHost:example.com\r\nUser-Agent: Mozilla/5.0\r\nContent-Type: text/html\r\nContent-Length: 1234\r\nConnection: close\r\nTrailer: end-of-message\r\n\r\n""" - ) + var headers_str = "GET /index.html HTTP/1.1\r\nHost:example.com\r\nUser-Agent: Mozilla/5.0\r\nContent-Type: text/html\r\nContent-Length: 1234\r\nConnection: close\r\nTrailer: end-of-message\r\n\r\n" var header = Headers() - var b = Bytes(headers_str) - var reader = ByteReader(Span(b)) + var reader = ByteReader(headers_str.as_bytes()) var method: String var protocol: String var uri: String diff --git a/tests/lightbug_http/test_host_port.mojo b/tests/lightbug_http/test_host_port.mojo new file mode 100644 index 00000000..2ad444b3 --- /dev/null +++ b/tests/lightbug_http/test_host_port.mojo @@ -0,0 +1,30 @@ +import testing +from lightbug_http.net import join_host_port, parse_address, TCPAddr +from lightbug_http.strings import NetworkType + + +def test_split_host_port(): + # IPv4 + var hp = parse_address("127.0.0.1:8080") + testing.assert_equal(hp[0], "127.0.0.1") + testing.assert_equal(hp[1], 8080) + + # IPv6 + hp = parse_address("[::1]:8080") + testing.assert_equal(hp[0], "::1") + testing.assert_equal(hp[1], 8080) + + # # TODO: IPv6 long form - Not supported yet. + # hp = parse_address("0:0:0:0:0:0:0:1:8080") + # testing.assert_equal(hp[0], "0:0:0:0:0:0:0:1") + # testing.assert_equal(hp[1], 8080) + + +def test_join_host_port(): + # IPv4 + testing.assert_equal(join_host_port("127.0.0.1", "8080"), "127.0.0.1:8080") + + # IPv6 + testing.assert_equal(join_host_port("::1", "8080"), "[::1]:8080") + + # TODO: IPv6 long form - Not supported yet. diff --git a/tests/lightbug_http/test_http.mojo b/tests/lightbug_http/test_http.mojo index 55289594..35907256 100644 --- a/tests/lightbug_http/test_http.mojo +++ b/tests/lightbug_http/test_http.mojo @@ -11,7 +11,7 @@ from lightbug_http.strings import to_string alias default_server_conn_string = "http://localhost:8080" def test_encode_http_request(): - var uri = URI.parse_raises(default_server_conn_string + "/foobar?baz") + var uri = URI.parse(default_server_conn_string + "/foobar?baz") var req = HTTPRequest( uri, body=String("Hello world!").as_bytes(), diff --git a/tests/lightbug_http/test_net.mojo b/tests/lightbug_http/test_net.mojo deleted file mode 100644 index 2a4d241b..00000000 --- a/tests/lightbug_http/test_net.mojo +++ /dev/null @@ -1,7 +0,0 @@ -def test_net(): - test_split_host_port() - - -def test_split_host_port(): - # TODO: Implement this test - ... diff --git a/tests/lightbug_http/test_owning_list.mojo b/tests/lightbug_http/test_owning_list.mojo new file mode 100644 index 00000000..0a486b60 --- /dev/null +++ b/tests/lightbug_http/test_owning_list.mojo @@ -0,0 +1,494 @@ +from lightbug_http.owning_list import OwningList +from sys.info import sizeof + +from memory import UnsafePointer, Span +from testing import assert_equal, assert_false, assert_raises, assert_true + + +def test_mojo_issue_698(): + var list = OwningList[Float64]() + for i in range(5): + list.append(i) + + assert_equal(0.0, list[0]) + assert_equal(1.0, list[1]) + assert_equal(2.0, list[2]) + assert_equal(3.0, list[3]) + assert_equal(4.0, list[4]) + + +def test_list(): + var list = OwningList[Int]() + + for i in range(5): + list.append(i) + + assert_equal(5, len(list)) + assert_equal(5 * sizeof[Int](), list.bytecount()) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + assert_equal(3, list[3]) + assert_equal(4, list[4]) + + assert_equal(0, list[-5]) + assert_equal(3, list[-2]) + assert_equal(4, list[-1]) + + list[2] = -2 + assert_equal(-2, list[2]) + + list[-5] = 5 + assert_equal(5, list[-5]) + list[-2] = 3 + assert_equal(3, list[-2]) + list[-1] = 7 + assert_equal(7, list[-1]) + + +def test_list_clear(): + var list = OwningList[Int](capacity=3) + list.append(1) + list.append(2) + list.append(3) + assert_equal(len(list), 3) + assert_equal(list.capacity, 3) + list.clear() + + assert_equal(len(list), 0) + assert_equal(list.capacity, 3) + + +def test_list_pop(): + var list = OwningList[Int]() + # Test pop with index + for i in range(6): + list.append(i) + + # try popping from index 3 for 3 times + for i in range(3, 6): + assert_equal(i, list.pop(3)) + + # list should have 3 elements now + assert_equal(3, len(list)) + assert_equal(0, list[0]) + assert_equal(1, list[1]) + assert_equal(2, list[2]) + + # Test pop with negative index + for i in range(0, 2): + assert_equal(i, list.pop(-len(list))) + + # test default index as well + assert_equal(2, list.pop()) + list.append(2) + assert_equal(2, list.pop()) + + # list should be empty now + assert_equal(0, len(list)) + # capacity should be 1 according to shrink_to_fit behavior + assert_equal(1, list.capacity) + + +def test_list_resize(): + var l = OwningList[Int]() + l.append(1) + l.resize(0) + assert_equal(len(l), 0) + + +def test_list_insert(): + # + # Test the list [1, 2, 3] created with insert + # + + v1 = OwningList[Int]() + v1.insert(len(v1), 1) + v1.insert(len(v1), 3) + v1.insert(1, 2) + + assert_equal(len(v1), 3) + assert_equal(v1[0], 1) + assert_equal(v1[1], 2) + assert_equal(v1[2], 3) + + # + # Test the list [1, 2, 3, 4, 5] created with negative and positive index + # + + v2 = OwningList[Int]() + v2.insert(-1729, 2) + v2.insert(len(v2), 3) + v2.insert(len(v2), 5) + v2.insert(-1, 4) + v2.insert(-len(v2), 1) + + assert_equal(len(v2), 5) + assert_equal(v2[0], 1) + assert_equal(v2[1], 2) + assert_equal(v2[2], 3) + assert_equal(v2[3], 4) + assert_equal(v2[4], 5) + + # + # Test the list [1, 2, 3, 4] created with negative index + # + + v3 = OwningList[Int]() + v3.insert(-11, 4) + v3.insert(-13, 3) + v3.insert(-17, 2) + v3.insert(-19, 1) + + assert_equal(len(v3), 4) + assert_equal(v3[0], 1) + assert_equal(v3[1], 2) + assert_equal(v3[2], 3) + assert_equal(v3[3], 4) + + # + # Test the list [1, 2, 3, 4, 5, 6, 7, 8] created with insert + # + + v4 = OwningList[Int]() + for i in range(4): + v4.insert(0, 4 - i) + v4.insert(len(v4), 4 + i + 1) + + for i in range(len(v4)): + assert_equal(v4[i], i + 1) + + +def test_list_index(): + var test_list_a = OwningList[Int]() + test_list_a.append(10) + test_list_a.append(20) + test_list_a.append(30) + test_list_a.append(40) + test_list_a.append(50) + + # Basic Functionality Tests + assert_equal(test_list_a.index(10), 0) + assert_equal(test_list_a.index(30), 2) + assert_equal(test_list_a.index(50), 4) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(60) + + # Tests With Start Parameter + assert_equal(test_list_a.index(30, start=1), 2) + assert_equal(test_list_a.index(30, start=-4), 2) + assert_equal(test_list_a.index(30, start=-1000), 2) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(30, start=3) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(30, start=5) + + # Tests With Start and End Parameters + assert_equal(test_list_a.index(30, start=1, stop=3), 2) + assert_equal(test_list_a.index(30, start=-4, stop=-2), 2) + assert_equal(test_list_a.index(30, start=-1000, stop=1000), 2) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(30, start=1, stop=2) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(30, start=3, stop=1) + + # Tests With End Parameter Only + assert_equal(test_list_a.index(30, stop=3), 2) + assert_equal(test_list_a.index(30, stop=-2), 2) + assert_equal(test_list_a.index(30, stop=1000), 2) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(30, stop=1) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(30, stop=2) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(60, stop=50) + + # Edge Cases and Special Conditions + assert_equal(test_list_a.index(10, start=-5, stop=-1), 0) + assert_equal(test_list_a.index(10, start=0, stop=50), 0) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(50, start=-5, stop=-1) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(50, start=0, stop=-1) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(10, start=-4, stop=-1) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(10, start=5, stop=50) + with assert_raises(contains="ValueError: Given element is not in list"): + _ = OwningList[Int]().index(10) + + # Test empty slice + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(10, start=1, stop=1) + # Test empty slice with 0 start and end + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_a.index(10, start=0, stop=0) + + var test_list_b = OwningList[Int]() + test_list_b.append(10) + test_list_b.append(20) + test_list_b.append(30) + test_list_b.append(20) + test_list_b.append(10) + + # Test finding the first occurrence of an item + assert_equal(test_list_b.index(10), 0) + assert_equal(test_list_b.index(20), 1) + + # Test skipping the first occurrence with a start parameter + assert_equal(test_list_b.index(20, start=2), 3) + + # Test constraining search with start and end, excluding last occurrence + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_b.index(10, start=1, stop=4) + + # Test search within a range that includes multiple occurrences + assert_equal(test_list_b.index(20, start=1, stop=4), 1) + + # Verify error when constrained range excludes occurrences + with assert_raises(contains="ValueError: Given element is not in list"): + _ = test_list_b.index(20, start=4, stop=5) + + +def test_list_extend(): + # + # Test extending the list [1, 2, 3] with itself + # + + vec = OwningList[Int]() + vec.append(1) + vec.append(2) + vec.append(3) + + assert_equal(len(vec), 3) + assert_equal(vec[0], 1) + assert_equal(vec[1], 2) + assert_equal(vec[2], 3) + + var copy = OwningList[Int]() + copy.append(1) + copy.append(2) + copy.append(3) + vec.extend(copy^) + + # vec == [1, 2, 3, 1, 2, 3] + assert_equal(len(vec), 6) + assert_equal(vec[0], 1) + assert_equal(vec[1], 2) + assert_equal(vec[2], 3) + assert_equal(vec[3], 1) + assert_equal(vec[4], 2) + assert_equal(vec[5], 3) + + +def test_list_extend_non_trivial(): + # Tests three things: + # - extend() for non-plain-old-data types + # - extend() with mixed-length self and other lists + # - extend() using optimal number of __moveinit__() calls + + # Preallocate with enough capacity to avoid reallocation making the + # move count checks below flaky. + var v1 = OwningList[String](capacity=5) + v1.append(String("Hello")) + v1.append(String("World")) + + var v2 = OwningList[String](capacity=3) + v2.append(String("Foo")) + v2.append(String("Bar")) + v2.append(String("Baz")) + + v1.extend(v2^) + + assert_equal(len(v1), 5) + assert_equal(v1[0], "Hello") + assert_equal(v1[1], "World") + assert_equal(v1[2], "Foo") + assert_equal(v1[3], "Bar") + assert_equal(v1[4], "Baz") + + +def test_2d_dynamic_list(): + var list = OwningList[OwningList[Int]]() + + for i in range(2): + var v = OwningList[Int]() + for j in range(3): + v.append(i + j) + list.append(v^) + + assert_equal(0, list[0][0]) + assert_equal(1, list[0][1]) + assert_equal(2, list[0][2]) + assert_equal(1, list[1][0]) + assert_equal(2, list[1][1]) + assert_equal(3, list[1][2]) + + assert_equal(2, len(list)) + assert_equal(2, list.capacity) + + assert_equal(3, len(list[0])) + + list[0].clear() + assert_equal(0, len(list[0])) + assert_equal(4, list[0].capacity) + + list.clear() + assert_equal(0, len(list)) + assert_equal(2, list.capacity) + + +def test_list_iter(): + var vs = OwningList[Int]() + vs.append(1) + vs.append(2) + vs.append(3) + + # Borrow immutably + fn sum(vs: OwningList[Int]) -> Int: + var sum = 0 + for v in vs: + sum += v[] + return sum + + assert_equal(6, sum(vs)) + + +def test_list_iter_mutable(): + var vs = OwningList[Int]() + vs.append(1) + vs.append(2) + vs.append(3) + + for v in vs: + v[] += 1 + + var sum = 0 + for v in vs: + sum += v[] + + assert_equal(9, sum) + + +def test_list_realloc_trivial_types(): + a = OwningList[Int]() + for i in range(100): + a.append(i) + + assert_equal(len(a), 100) + for i in range(100): + assert_equal(a[i], i) + + b = OwningList[Int8]() + for i in range(100): + b.append(Int8(i)) + + assert_equal(len(b), 100) + for i in range(100): + assert_equal(b[i], Int8(i)) + + +def test_list_boolable(): + var l = OwningList[Int]() + l.append(1) + assert_true(l) + assert_false(OwningList[Int]()) + + +def test_converting_list_to_string(): + # This is also testing the method `to_format` because + # essentially, `OwningList.__str__()` just creates a String and applies `to_format` to it. + # If we were to write unit tests for `to_format`, we would essentially copy-paste the code + # of `OwningList.__str__()` + var my_list = OwningList[Int]() + my_list.append(1) + my_list.append(2) + my_list.append(3) + assert_equal(my_list.__str__(), "[1, 2, 3]") + + var my_list4 = OwningList[String]() + my_list4.append("a") + my_list4.append("b") + my_list4.append("c") + my_list4.append("foo") + assert_equal(my_list4.__str__(), "['a', 'b', 'c', 'foo']") + + +def test_list_contains(): + var x = OwningList[Int]() + x.append(1) + x.append(2) + x.append(3) + assert_false(0 in x) + assert_true(1 in x) + assert_false(4 in x) + + +def test_indexing(): + var l = OwningList[Int]() + l.append(1) + l.append(2) + l.append(3) + assert_equal(l[int(1)], 2) + assert_equal(l[False], 1) + assert_equal(l[True], 2) + assert_equal(l[2], 3) + + +# ===-------------------------------------------------------------------===# +# OwningList dtor tests +# ===-------------------------------------------------------------------===# +var g_dtor_count: Int = 0 + + +struct DtorCounter(CollectionElement): + # NOTE: payload is required because OwningList does not support zero sized structs. + var payload: Int + + fn __init__(out self): + self.payload = 0 + + fn __init__(out self, *, other: Self): + self.payload = other.payload + + fn __copyinit__(out self, existing: Self, /): + self.payload = existing.payload + + fn __moveinit__(out self, owned existing: Self, /): + self.payload = existing.payload + existing.payload = 0 + + fn __del__(owned self): + g_dtor_count += 1 + + +def inner_test_list_dtor(): + # explicitly reset global counter + g_dtor_count = 0 + + var l = OwningList[DtorCounter]() + assert_equal(g_dtor_count, 0) + + l.append(DtorCounter()) + assert_equal(g_dtor_count, 0) + + l^.__del__() + assert_equal(g_dtor_count, 1) + + +def test_list_dtor(): + # call another function to force the destruction of the list + inner_test_list_dtor() + + # verify we still only ran the destructor once + assert_equal(g_dtor_count, 1) + + +def test_list_repr(): + var l = OwningList[Int]() + l.append(1) + l.append(2) + l.append(3) + assert_equal(l.__repr__(), "[1, 2, 3]") + var empty = OwningList[Int]() + assert_equal(empty.__repr__(), "[]") diff --git a/tests/lightbug_http/test_server.mojo b/tests/lightbug_http/test_server.mojo new file mode 100644 index 00000000..0402d06d --- /dev/null +++ b/tests/lightbug_http/test_server.mojo @@ -0,0 +1,14 @@ +import testing +from lightbug_http.server import Server + + +def test_server(): + var server = Server() + server.set_address("0.0.0.0") + testing.assert_equal(server.address(), "0.0.0.0") + server.set_max_request_body_size(1024) + testing.assert_equal(server.max_request_body_size(), 1024) + testing.assert_equal(server.get_concurrency(), 1000) + + server = Server(max_concurrent_connections=10) + testing.assert_equal(server.get_concurrency(), 10) diff --git a/tests/lightbug_http/test_service.mojo b/tests/lightbug_http/test_service.mojo new file mode 100644 index 00000000..17753b8c --- /dev/null +++ b/tests/lightbug_http/test_service.mojo @@ -0,0 +1,22 @@ +import testing +from lightbug_http.service import Printer, Welcome, ExampleRouter, TechEmpowerRouter, Counter + + +def test_printer(): + pass + + +def test_welcome(): + pass + + +def test_example_router(): + pass + + +def test_tech_empower_router(): + pass + + +def test_counter(): + pass diff --git a/tests/lightbug_http/test_uri.mojo b/tests/lightbug_http/test_uri.mojo index 885234b8..7f332841 100644 --- a/tests/lightbug_http/test_uri.mojo +++ b/tests/lightbug_http/test_uri.mojo @@ -5,21 +5,19 @@ from lightbug_http.strings import empty_string, to_string from lightbug_http.io.bytes import Bytes - def test_uri_no_parse_defaults(): - var uri = URI.parse("http://example.com")[URI] + var uri = URI.parse("http://example.com") testing.assert_equal(uri.full_uri, "http://example.com") - testing.assert_equal(uri.scheme, "http") testing.assert_equal(uri.path, "/") def test_uri_parse_http_with_port(): - var uri = URI.parse("http://example.com:8080/index.html")[URI] + var uri = URI.parse("http://example.com:8080/index.html") testing.assert_equal(uri.scheme, "http") testing.assert_equal(uri.host, "example.com:8080") testing.assert_equal(uri.path, "/index.html") - testing.assert_equal(uri.__path_original, "/index.html") + testing.assert_equal(uri._original_path, "/index.html") testing.assert_equal(uri.request_uri, "/index.html") testing.assert_equal(uri.is_https(), False) testing.assert_equal(uri.is_http(), True) @@ -27,11 +25,11 @@ def test_uri_parse_http_with_port(): def test_uri_parse_https_with_port(): - var uri = URI.parse("https://example.com:8080/index.html")[URI] + var uri = URI.parse("https://example.com:8080/index.html") testing.assert_equal(uri.scheme, "https") testing.assert_equal(uri.host, "example.com:8080") testing.assert_equal(uri.path, "/index.html") - testing.assert_equal(uri.__path_original, "/index.html") + testing.assert_equal(uri._original_path, "/index.html") testing.assert_equal(uri.request_uri, "/index.html") testing.assert_equal(uri.is_https(), True) testing.assert_equal(uri.is_http(), False) @@ -39,11 +37,11 @@ def test_uri_parse_https_with_port(): def test_uri_parse_http_with_path(): - var uri = URI.parse("http://example.com/index.html")[URI] + var uri = URI.parse("http://example.com/index.html") testing.assert_equal(uri.scheme, "http") testing.assert_equal(uri.host, "example.com") testing.assert_equal(uri.path, "/index.html") - testing.assert_equal(uri.__path_original, "/index.html") + testing.assert_equal(uri._original_path, "/index.html") testing.assert_equal(uri.request_uri, "/index.html") testing.assert_equal(uri.is_https(), False) testing.assert_equal(uri.is_http(), True) @@ -51,11 +49,11 @@ def test_uri_parse_http_with_path(): def test_uri_parse_https_with_path(): - var uri = URI.parse("https://example.com/index.html")[URI] + var uri = URI.parse("https://example.com/index.html") testing.assert_equal(uri.scheme, "https") testing.assert_equal(uri.host, "example.com") testing.assert_equal(uri.path, "/index.html") - testing.assert_equal(uri.__path_original, "/index.html") + testing.assert_equal(uri._original_path, "/index.html") testing.assert_equal(uri.request_uri, "/index.html") testing.assert_equal(uri.is_https(), True) testing.assert_equal(uri.is_http(), False) @@ -63,27 +61,33 @@ def test_uri_parse_https_with_path(): def test_uri_parse_http_basic(): - var uri = URI.parse("http://example.com")[URI] + var uri = URI.parse("http://example.com") testing.assert_equal(uri.scheme, "http") testing.assert_equal(uri.host, "example.com") testing.assert_equal(uri.path, "/") - testing.assert_equal(uri.__path_original, "/") + testing.assert_equal(uri._original_path, "/") testing.assert_equal(uri.request_uri, "/") testing.assert_equal(uri.query_string, empty_string) def test_uri_parse_http_basic_www(): - var uri = URI.parse("http://www.example.com")[URI] + var uri = URI.parse("http://www.example.com") testing.assert_equal(uri.scheme, "http") testing.assert_equal(uri.host, "www.example.com") testing.assert_equal(uri.path, "/") - testing.assert_equal(uri.__path_original, "/") + testing.assert_equal(uri._original_path, "/") testing.assert_equal(uri.request_uri, "/") testing.assert_equal(uri.query_string, empty_string) def test_uri_parse_http_with_query_string(): - ... + var uri = URI.parse("http://www.example.com/job?title=engineer") + testing.assert_equal(uri.scheme, "http") + testing.assert_equal(uri.host, "www.example.com") + testing.assert_equal(uri.path, "/job") + testing.assert_equal(uri._original_path, "/job") + testing.assert_equal(uri.request_uri, "/job?title=engineer") + testing.assert_equal(uri.query_string, "title=engineer") def test_uri_parse_http_with_hash():