Skip to content

Commit

Permalink
Issue #12429 - case-insensitive headers for websocket
Browse files Browse the repository at this point in the history
Signed-off-by: Lachlan Roberts <[email protected]>
  • Loading branch information
lachlan-roberts committed Oct 30, 2024
1 parent 7bc94d6 commit 8ff6828
Show file tree
Hide file tree
Showing 18 changed files with 314 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -927,6 +928,13 @@ default int size()
return size;
}

static Map<String, List<String>> asMap(HttpFields fields)
{
Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
fields.getFieldNamesCollection().forEach(name -> headers.putIfAbsent(name, fields.getValuesList(name)));
return headers;
}

/**
* @return a sequential stream of the {@link HttpField}s in this instance
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.stream.Collectors;

import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpScheme;
import org.eclipse.jetty.io.EndPoint;
Expand Down Expand Up @@ -78,7 +79,7 @@ public List<String> getHeaders(String name)
@Override
public Map<String, List<String>> getHeaders()
{
return null;
return Collections.unmodifiableMap(HttpFields.asMap(delegate.getHeaders()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

package org.eclipse.jetty.websocket.client.internal;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.eclipse.jetty.client.Response;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.websocket.api.ExtensionConfig;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
Expand Down Expand Up @@ -65,9 +65,7 @@ public List<String> getHeaders(String name)
@Override
public Map<String, List<String>> getHeaders()
{
Map<String, List<String>> headers = getHeaderNames().stream()
.collect(Collectors.toMap((name) -> name, (name) -> new ArrayList<>(getHeaders(name))));
return Collections.unmodifiableMap(headers);
return Collections.unmodifiableMap(HttpFields.asMap(delegate.getHeaders()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.Map;
import java.util.stream.Collectors;

import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpScheme;
Expand Down Expand Up @@ -73,14 +72,7 @@ public int getHeaderInt(String name)
@Override
public Map<String, List<String>> getHeaders()
{
Map<String, List<String>> result = new LinkedHashMap<>();
HttpFields headers = request.getHeaders();
for (HttpField header : headers)
{
String name = header.getName();
result.put(name, headers.getValuesList(name));
}
return result;
return HttpFields.asMap(request.getHeaders());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@

package org.eclipse.jetty.websocket.server.internal;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.websocket.api.ExtensionConfig;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
Expand Down Expand Up @@ -64,14 +62,7 @@ public Set<String> getHeaderNames()
@Override
public Map<String, List<String>> getHeaders()
{
Map<String, List<String>> result = new LinkedHashMap<>();
HttpFields.Mutable headers = response.getHeaders();
for (HttpField header : headers)
{
String name = header.getName();
result.put(name, headers.getValuesList(name));
}
return result;
return HttpFields.asMap(response.getHeaders());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@

package org.eclipse.jetty.ee10.websocket.jakarta.client.internal;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand All @@ -41,16 +38,9 @@ public void onHandshakeRequest(Request request)
if (configurator == null)
return;

HttpFields fields = request.getHeaders();
Map<String, List<String>> originalHeaders = new HashMap<>();
fields.forEach(field ->
{
originalHeaders.putIfAbsent(field.getName(), new ArrayList<>());
List<String> values = originalHeaders.get(field.getName());
Collections.addAll(values, field.getValues());
});

// Give headers to configurator
HttpFields fields = request.getHeaders();
Map<String, List<String>> originalHeaders = HttpFields.asMap(fields);
configurator.beforeRequest(originalHeaders);

// Reset headers on HttpRequest per configurator
Expand All @@ -67,18 +57,7 @@ public void onHandshakeResponse(Request request, Response response)
if (configurator == null)
return;

HandshakeResponse handshakeResponse = () ->
{
Map<String, List<String>> ret = new HashMap<>();
response.getHeaders().forEach(field ->
{
ret.putIfAbsent(field.getName(), new ArrayList<>());
List<String> values = ret.get(field.getName());
Collections.addAll(values, field.getValues());
});
return ret;
};

HandshakeResponse handshakeResponse = () -> HttpFields.asMap(response.getHeaders());
configurator.afterResponse(handshakeResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@

import java.net.URI;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.websocket.server.HandshakeRequest;
import org.eclipse.jetty.ee10.websocket.jakarta.server.JakartaWebSocketServerContainer;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.pathmap.PathSpec;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.Fields;
Expand All @@ -47,9 +46,7 @@ public JsrHandshakeRequest(ServerUpgradeRequest req)
@Override
public Map<String, List<String>> getHeaders()
{
Map<String, List<String>> headers = delegate.getHeaders().getFieldNamesCollection().stream()
.collect(Collectors.toMap((name) -> name, (name) -> new ArrayList<>(delegate.getHeaders().getValuesList(name))));
return Collections.unmodifiableMap(headers);
return Collections.unmodifiableMap(HttpFields.asMap(delegate.getHeaders()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@

package org.eclipse.jetty.ee10.websocket.jakarta.server.internal;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import jakarta.websocket.HandshakeResponse;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.websocket.core.server.ServerUpgradeResponse;

public class JsrHandshakeResponse implements HandshakeResponse
Expand All @@ -29,8 +28,7 @@ public class JsrHandshakeResponse implements HandshakeResponse
public JsrHandshakeResponse(ServerUpgradeResponse resp)
{
this.delegate = resp;
this.headers = delegate.getHeaders().getFieldNamesCollection().stream()
.collect(Collectors.toMap((name) -> name, (name) -> new ArrayList<>(delegate.getHeaders().getValuesList(name))));
this.headers = HttpFields.asMap(delegate.getHeaders());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.ee10.websocket.jakarta.tests;

import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import jakarta.websocket.ClientEndpointConfig;
import jakarta.websocket.Endpoint;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.HandshakeResponse;
import jakarta.websocket.Session;
import jakarta.websocket.server.HandshakeRequest;
import jakarta.websocket.server.ServerEndpointConfig;
import org.eclipse.jetty.ee10.servlet.ServletContextHandler;
import org.eclipse.jetty.ee10.websocket.jakarta.client.JakartaWebSocketClientContainer;
import org.eclipse.jetty.ee10.websocket.jakarta.server.config.JakartaWebSocketServletContainerInitializer;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertTrue;

public class UpgradeHeadersTest
{
private Server _server;
private JakartaWebSocketClientContainer _client;
private ServerConnector _connector;

public static class MyEndpoint extends Endpoint
{
@Override
public void onOpen(Session session, EndpointConfig config)
{
}
}

public void start(ServerEndpointConfig.Configurator configurator) throws Exception
{
_server = new Server();
_connector = new ServerConnector(_server);
_server.addConnector(_connector);

ServletContextHandler contextHandler = new ServletContextHandler();
_server.setHandler(contextHandler);
JakartaWebSocketServletContainerInitializer.configure(contextHandler, (context, container) ->
{
container.addEndpoint(ServerEndpointConfig.Builder
.create(MyEndpoint.class, "/")
.configurator(configurator)
.build());
});

_server.start();
_client = new JakartaWebSocketClientContainer();
_client.start();
}

@AfterEach
public void after() throws Exception
{
_client.stop();
_server.stop();
}

@Test
public void testCaseInsensitiveUpgradeHeaders() throws Exception
{
ClientEndpointConfig.Configurator configurator = new ClientEndpointConfig.Configurator()
{
@Override
public void beforeRequest(Map<String, List<String>> headers)
{
// Verify that existing headers can be accessed in a case-insensitive way.
if (headers.get("cOnnEcTiOn") == null)
throw new IllegalStateException("No Connection Header on client Request");
headers.put("sentHeader", List.of("value123"));
}

@Override
public void afterResponse(HandshakeResponse hr)
{
if (hr.getHeaders().get("MyHeAdEr") == null)
throw new IllegalStateException("No custom Header on HandshakeResponse");
if (hr.getHeaders().get("cOnnEcTiOn") == null)
throw new IllegalStateException("No Connection Header on HandshakeRequest");
}
};

start(new ServerEndpointConfig.Configurator()
{
@Override
public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response)
{
// Verify that existing headers can be accessed in a case-insensitive way.
if (request.getHeaders().get("cOnnEcTiOn") == null)
throw new IllegalStateException("No Connection Header on HandshakeRequest");
if (response.getHeaders().get("sErVeR") == null)
throw new IllegalStateException("No Server Header on HandshakeResponse");

// Verify custom header sent from client.
if (request.getHeaders().get("SeNtHeadEr") == null)
throw new IllegalStateException("No sent Header on HandshakeResponse");

// Add custom response header.
response.getHeaders().put("myHeader", List.of("foobar"));
if (response.getHeaders().get("MyHeAdEr") == null)
throw new IllegalStateException("No custom Header on HandshakeResponse");

super.modifyHandshake(sec, request, response);
}
});

WSEndpointTracker clientEndpoint = new WSEndpointTracker(){};
ClientEndpointConfig clientConfig = ClientEndpointConfig.Builder.create().configurator(configurator).build();
URI uri = URI.create("ws://localhost:" + _connector.getLocalPort());

// If any of the above throw it would fail to upgrade to websocket.
Session session = _client.connectToServer(clientEndpoint, clientConfig, uri);
assertTrue(clientEndpoint.openLatch.await(5, TimeUnit.SECONDS));
session.close();
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import java.net.URI;
import java.security.Principal;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
Expand All @@ -33,6 +32,7 @@
import jakarta.servlet.http.HttpSession;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.http.BadMessageException;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.util.URIUtil;
import org.eclipse.jetty.websocket.api.ExtensionConfig;
Expand Down Expand Up @@ -121,9 +121,7 @@ public int getHeaderInt(String name)
@Override
public Map<String, List<String>> getHeaders()
{
Map<String, List<String>> headers = upgradeRequest.getHeaders().getFieldNamesCollection().stream()
.collect(Collectors.toMap((name) -> name, (name) -> new ArrayList<>(getHeaders(name))));
return Collections.unmodifiableMap(headers);
return Collections.unmodifiableMap(HttpFields.asMap(upgradeRequest.getHeaders()));
}

@Override
Expand Down
Loading

0 comments on commit 8ff6828

Please sign in to comment.