Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

websocket: Support clients that do not specify subprotocol #2621

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion include/seastar/websocket/server.hh
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,13 @@ public:

bool is_handler_registered(std::string const& name);

void register_handler(std::string&& name, handler_t handler);
/*!
* \brief Register a handler for specific subprotocol
* \param name The name of the subprotocol. If it is empty string, then the handler is used
* when the protocol is not specified
* \param handler Handler for incoming WebSocket messages.
*/
void register_handler(const std::string& name, handler_t handler);

friend class connection;
protected:
Expand Down
11 changes: 5 additions & 6 deletions src/websocket/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,6 @@ future<> connection::read_http_upgrade_request() {
}

sstring subprotocol = req->get_header("Sec-WebSocket-Protocol");
if (subprotocol.empty()) {
throw websocket::exception("Subprotocol header missing.");
}

if (!_server.is_handler_registered(subprotocol)) {
throw websocket::exception("Subprotocol not supported.");
Expand All @@ -187,8 +184,10 @@ future<> connection::read_http_upgrade_request() {

co_await _write_buf.write(http_upgrade_reply_template);
co_await _write_buf.write(sha1_output);
co_await _write_buf.write("\r\nSec-WebSocket-Protocol: ", 26);
co_await _write_buf.write(_subprotocol);
if (!_subprotocol.empty()) {
co_await _write_buf.write("\r\nSec-WebSocket-Protocol: ", 26);
co_await _write_buf.write(_subprotocol);
}
co_await _write_buf.write("\r\n\r\n", 4);
co_await _write_buf.flush();
}
Expand Down Expand Up @@ -414,7 +413,7 @@ bool server::is_handler_registered(std::string const& name) {
return _handlers.find(name) != _handlers.end();
}

void server::register_handler(std::string&& name, handler_t handler) {
void server::register_handler(const std::string& name, handler_t handler) {
_handlers[name] = handler;
}

Expand Down
71 changes: 48 additions & 23 deletions tests/unit/websocket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,30 @@

using namespace seastar;
using namespace seastar::experimental;
using namespace std::literals::string_view_literals;

std::string build_request(std::string_view key_base64, std::string_view subprotocol) {
std::string subprotocol_line;
if (!subprotocol.empty()) {
subprotocol_line = fmt::format("Sec-WebSocket-Protocol: {}\r\n", subprotocol);
}

return fmt::format(
"GET / HTTP/1.1\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: {}\r\n"
"Sec-WebSocket-Version: 13\r\n"
"{}"
"\r\n",
key_base64,
subprotocol_line);
}

future<> test_websocket_handshake_common(std::string subprotocol) {
return seastar::async([=] {
const std::string request = build_request("dGhlIHNhbXBsZSBub25jZQ==", subprotocol);

SEASTAR_TEST_CASE(test_websocket_handshake) {
return seastar::async([] {
const std::string request =
"GET / HTTP/1.1\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Protocol: echo\r\n"
"\r\n";
loopback_connection_factory factory;
loopback_socket_impl lsi(factory);

Expand All @@ -32,7 +45,7 @@ SEASTAR_TEST_CASE(test_websocket_handshake) {
auto output = sock.output();

websocket::server dummy;
dummy.register_handler("echo", [] (input_stream<char>& in,
dummy.register_handler(subprotocol, [] (input_stream<char>& in,
output_stream<char>& out) {
return repeat([&in, &out]() {
return in.read().then([&out](temporary_buffer<char> f) {
Expand Down Expand Up @@ -81,10 +94,16 @@ SEASTAR_TEST_CASE(test_websocket_handshake) {
});
}

SEASTAR_TEST_CASE(test_websocket_handshake) {
return test_websocket_handshake_common("echo");
}

SEASTAR_TEST_CASE(test_websocket_handshake_no_subprotocol) {
return test_websocket_handshake_common("");
}

SEASTAR_TEST_CASE(test_websocket_handler_registration) {
return seastar::async([] {
future<> test_websocket_handler_registration_common(std::string subprotocol) {
return seastar::async([=] {
loopback_connection_factory factory;
loopback_socket_impl lsi(factory);

Expand All @@ -96,7 +115,7 @@ SEASTAR_TEST_CASE(test_websocket_handler_registration) {

// Setup server
websocket::server ws;
ws.register_handler("echo", [] (input_stream<char>& in,
ws.register_handler(subprotocol, [] (input_stream<char>& in,
output_stream<char>& out) {
return repeat([&in, &out]() {
return in.read().then([&out](temporary_buffer<char> f) {
Expand Down Expand Up @@ -124,17 +143,15 @@ SEASTAR_TEST_CASE(test_websocket_handler_registration) {
});

// handshake
const std::string request =
"GET / HTTP/1.1\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Protocol: echo\r\n"
"\r\n";
const std::string request = build_request("dGhlIHNhbXBsZSBub25jZQ==", subprotocol);
output.write(request).get();
output.flush().get();
input.read_exactly(186).get();

unsigned reply_size = 156;
if (!subprotocol.empty()) {
reply_size += ("\r\nSec-WebSocket-Protocol: "sv).size() + subprotocol.size();
}
input.read_exactly(reply_size).get();

unsigned ws_frame_len = 10;

Expand All @@ -156,6 +173,14 @@ SEASTAR_TEST_CASE(test_websocket_handler_registration) {
});
}

SEASTAR_TEST_CASE(test_websocket_handler_registration) {
return test_websocket_handler_registration_common("echo");
}

SEASTAR_TEST_CASE(test_websocket_handler_registration_no_subprotocol) {
return test_websocket_handler_registration_common("");
}

// Simple wrapper to help create a testable input_stream.
class test_source_impl : public data_source_impl {
std::vector<temporary_buffer<char>> _bufs{};
Expand Down
Loading