Skip to content

Commit

Permalink
Issue #7351 - support demand of multiple frames in PerMessageDeflateE…
Browse files Browse the repository at this point in the history
…xtension

Signed-off-by: Lachlan Roberts <[email protected]>
  • Loading branch information
lachlan-roberts committed Jan 11, 2022
1 parent 78ba66b commit 1c21dcd
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.LongConsumer;
import java.util.zip.DataFormatException;
Expand Down Expand Up @@ -243,15 +244,9 @@ protected void nextOutgoingFrame(Frame frame, Callback callback, boolean batch)
@Override
public void demand(long n, LongConsumer nextDemand)
{
if (!incomingFlusher.isFinished())
{
// TODO: what to do with n?
incomingFlusher.succeeded();
}
else
{
nextDemand.accept(n);
}
incomingFlusher._demand.addAndGet(n);
incomingFlusher.setNextDemand(nextDemand);
incomingFlusher.iterate();
}

private class OutgoingFlusher extends TransformingFlusher
Expand Down Expand Up @@ -309,7 +304,7 @@ private boolean deflate(Callback callback)

if (buffer.limit() == bufferSize)
{
// We need to fragment. TODO: what if there was only bufferSize of content?
// We need to fragment.
if (!getConfiguration().isAutoFragment())
throw new MessageTooLargeException("Deflated payload exceeded the compress buffer size");
break;
Expand Down Expand Up @@ -355,6 +350,7 @@ else if (_frame.isFin())

private class IncomingFlusher extends IteratingCallback
{
private final AtomicLong _demand = new AtomicLong();
private final AtomicReference<Throwable> _failure = new AtomicReference<>();
private boolean _complete = true;
private boolean _finished = true;
Expand All @@ -363,29 +359,24 @@ private class IncomingFlusher extends IteratingCallback
private ByteBuffer _framePayload;
private Callback _frameCallback;
private boolean _tailBytes;
private LongConsumer _nextDemand;

public boolean isFinished()
{
return _finished;
}

@Override
protected Action process() throws Throwable
public void setNextDemand(LongConsumer nextDemand)
{
Throwable failure = _failure.get();
if (failure != null)
throw failure;
_nextDemand = nextDemand;
}

try
{
Action action = inflate();
_first = false;
return action;
}
catch (DataFormatException e)
{
throw new BadPayloadException(e);
}
public void demand(long n)
{
if (n <= 0)
throw new IllegalArgumentException("Demand must be positive");
_demand.addAndGet(n);
iterate();
}

public void onFrame(Frame frame, Callback callback)
Expand Down Expand Up @@ -445,15 +436,40 @@ public void onFrame(Frame frame, Callback callback)
iterate();
}

private Action inflate() throws DataFormatException
@Override
protected Action process() throws Throwable
{
WebSocketCoreSession coreSession = (WebSocketCoreSession)getCoreSession();
while (_demand.get() > 0)
{
Throwable failure = _failure.get();
if (failure != null)
throw failure;

try
{
inflate();
_first = false;
if (_finished)
{
if (_demand.get() > 0)
_nextDemand.accept(1);
break;
}
}
catch (DataFormatException e)
{
throw new BadPayloadException(e);
}
}
return Action.IDLE;
}

private void inflate() throws DataFormatException
{
if (_complete)
{
clear();
coreSession.internalDemand(1);
return Action.IDLE;
return;
}

// Get a buffer for the inflated payload.
Expand Down Expand Up @@ -500,37 +516,30 @@ private Action inflate() throws DataFormatException

boolean succeedCallback = _complete;
Callback frameCallback = _frameCallback;
WebSocketCoreSession coreSession = (WebSocketCoreSession)getCoreSession();
Callback payloadCallback = Callback.from(() ->
{
getBufferPool().release(payload);
if (succeedCallback)
{
frameCallback.succeeded();
}
else
{
if (!coreSession.isDemanding())
coreSession.internalDemand(1);
}
}, t ->
{
// The error needs to be forwarded to the CoreSession if callback is failed.
getBufferPool().release(payload);
failFlusher(t);
coreSession.processHandlerError(t, NOOP);
});
_demand.decrementAndGet();
nextIncomingFrame(chunk, payloadCallback);

if (LOG.isDebugEnabled())
LOG.debug("Decompress finished: {} {}", _complete, chunk);

return Action.SCHEDULED;
}

private void clear()
{
_finished = true;
_complete = false;
_complete = true;
_first = false;
_frame = null;
_framePayload = null;
Expand All @@ -552,10 +561,7 @@ protected void onCompleteFailure(Throwable cause)
private void failFlusher(Throwable t)
{
if (_failure.compareAndSet(null, t))
{
// TODO: if the callback is pending then this will be noop.
iterate();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,6 @@ public void succeeded()
frame.close();
if (referenced != null)
referenced.release();

if (!coreSession.isDemanding())
coreSession.internalDemand(1);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ public void onOpen()
if (LOG.isDebugEnabled())
LOG.debug("ConnectionState: Transition to OPEN");
if (!demanding)
connection.demand(1);
autoDemand();
},
x ->
{
Expand Down Expand Up @@ -424,12 +424,12 @@ public void demand(long n)
{
if (!demanding)
throw new IllegalStateException("FrameHandler is not demanding: " + this);
internalDemand(n);
getExtensionStack().demand(n, connection::demand);
}

public void internalDemand(long n)
public void autoDemand()
{
getExtensionStack().demand(n, connection::demand);
getExtensionStack().demand(1, connection::demand);
}

@Override
Expand Down Expand Up @@ -647,7 +647,7 @@ public void setMaxOutgoingFrames(int maxOutgoingFrames)
private class IncomingAdaptor implements IncomingFrames
{
@Override
public void onFrame(Frame frame, final Callback callback)
public void onFrame(Frame frame, Callback callback)
{
Callback closeCallback = null;
try
Expand All @@ -660,7 +660,14 @@ public void onFrame(Frame frame, final Callback callback)
// Handle inbound frame
if (frame.getOpCode() != OpCode.CLOSE)
{
handle(() -> handler.onFrame(frame, callback));
Callback handlerCallback = isDemanding() ? callback : Callback.from(() ->
{
callback.succeeded();
if (!isDemanding())
autoDemand();
}, callback::failed);

handle(() -> handler.onFrame(frame, handlerCallback));
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ public void parseIncomingHex(String... rawhex)
int parts = rawhex.length;
byte[] net;

// Simulate initial demand from onOpen().
coreSession.autoDemand();

for (int i = 0; i < parts; i++)
{
String hex = rawhex[i].replaceAll("\\s*(0x)?", "");
Expand All @@ -101,7 +104,7 @@ public void succeeded()
{
super.succeeded();
if (!coreSession.isDemanding())
coreSession.internalDemand(1);
coreSession.autoDemand();
}
};
ext.onFrame(frame, callback);
Expand Down Expand Up @@ -175,9 +178,10 @@ private WebSocketCoreSession newWebSocketCoreSession(List<ExtensionConfig> confi
return new WebSocketCoreSession(new TestMessageHandler(), Behavior.SERVER, Negotiated.from(exStack), components)
{
@Override
public void internalDemand(long n)
public void autoDemand()
{
getExtensionStack().demand(n, l -> {});
// Never delegate to WebSocketConnection as it is null for this test.
getExtensionStack().demand(1, l -> {});
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -368,17 +368,19 @@ public void testIncomingUncompressedFrames()
@Test
public void testIncomingFrameNoPayload()
{
PerMessageDeflateExtension ext = new PerMessageDeflateExtension();
ExtensionConfig config = ExtensionConfig.parse("permessage-deflate");
ext.init(config, components);
ext.setCoreSession(newSession());
WebSocketCoreSession coreSession = newSession(config);
PerMessageDeflateExtension ext = (PerMessageDeflateExtension)coreSession.getExtensionStack().getExtensions().get(0);

// Setup capture of incoming frames
IncomingFramesCapture capture = new IncomingFramesCapture();

// Wire up stack
ext.setNextIncomingFrames(capture);

// Simulate initial demand from onOpen().
coreSession.autoDemand();

Frame ping = new Frame(OpCode.TEXT);
ping.setRsv1(true);
ext.onFrame(ping, Callback.NOOP);
Expand Down Expand Up @@ -592,15 +594,28 @@ public void testPyWebSocketServerToraToraTora()

private WebSocketCoreSession newSession()
{
return newSessionFromConfig(new ConfigurationCustomizer());
return newSession(null);
}

private WebSocketCoreSession newSessionFromConfig(ConfigurationCustomizer configuration)
private WebSocketCoreSession newSession(ExtensionConfig config)
{
return newSessionFromConfig(new ConfigurationCustomizer(), config == null ? Collections.emptyList() : Collections.singletonList(config));
}

private WebSocketCoreSession newSessionFromConfig(ConfigurationCustomizer configuration, List<ExtensionConfig> configs)
{
ExtensionStack exStack = new ExtensionStack(components, Behavior.SERVER);
exStack.negotiate(new LinkedList<>(), new LinkedList<>());
exStack.negotiate(configs, configs);

WebSocketCoreSession coreSession = new WebSocketCoreSession(new TestMessageHandler(), Behavior.SERVER, Negotiated.from(exStack), components);
WebSocketCoreSession coreSession = new WebSocketCoreSession(new TestMessageHandler(), Behavior.SERVER, Negotiated.from(exStack), components)
{
@Override
public void autoDemand()
{
// Never delegate to WebSocketConnection as it is null for this test.
getExtensionStack().demand(1, l -> {});
}
};
configuration.customize(configuration);
return coreSession;
}
Expand Down

0 comments on commit 1c21dcd

Please sign in to comment.