Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Android] Improve SslStream PAL buffer resizing #104726

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.IO;
using System.Security.Cryptography.X509Certificates;
using System.Security.Authentication;
using System.Threading.Tasks;
using Xunit;
using System.Threading;
using System.Linq;

namespace System.Net.Security.Tests
{
using Configuration = System.Net.Test.Common.Configuration;

public sealed class SslStreamAppDataTests
{
[Fact]
public async Task UtilizeFullSizeOfTlsFrames()
{
(Stream client, Stream server) = TestHelper.GetConnectedTcpStreams();
using (client)
using (server)
{
using var clientInterceptingStream = new TlsFrameInterceptingStream(client);
using var serverInterceptingStream = new TlsFrameInterceptingStream(server);

using var clientStream = new SslStream(clientInterceptingStream, leaveInnerStreamOpen: true, (sender, cert, chain, errors) => true);
using var serverStream = new SslStream(serverInterceptingStream, leaveInnerStreamOpen: true);

using var serverCertificate = Configuration.Certificates.GetServerCertificate();
var hostName = serverCertificate.GetNameInfo(X509NameType.SimpleName, forIssuer: false);

Task t1 = clientStream.AuthenticateAsClientAsync(hostName, [], SslProtocols.None, checkCertificateRevocation: false);
Task t2 = serverStream.AuthenticateAsServerAsync(serverCertificate, clientCertificateRequired: false, checkCertificateRevocation: false);
await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);

// Clear the intercepted frames of the handshake
clientInterceptingStream.InterceptedTlsFrameHeaders.Clear();
serverInterceptingStream.InterceptedTlsFrameHeaders.Clear();

var cts = new CancellationTokenSource(TestConfiguration.PassingTestTimeoutMilliseconds);
var bytesToSend = 123_456;
var sentData = new byte[bytesToSend];
Random.Shared.NextBytes(sentData);
var receivedData = new byte[bytesToSend];

await clientStream.WriteAsync(sentData, cts.Token);

int receivedBytes = 0;
while (receivedBytes < bytesToSend)
{
receivedBytes += await serverStream.ReadAsync(receivedData.AsMemory(receivedBytes), cts.Token);
}

Assert.Equal(sentData, receivedData);

Assert.Equal(8, clientInterceptingStream.InterceptedTlsFrameHeaders.Count);
Assert.All(clientInterceptingStream.InterceptedTlsFrameHeaders, static frameHeader => Assert.Equal(TlsContentType.AppData, frameHeader.Type));

for (int i = 0; i < 7; i++)
{
// The first 7 frames should contain 16384 bytes of data + TLS frame overhead
Assert.True(clientInterceptingStream.InterceptedTlsFrameHeaders[i].Length > 16384
&& clientInterceptingStream.InterceptedTlsFrameHeaders[i].Length <= 16709);
}

// The last frame should contain less data than the previous ones
Assert.True(clientInterceptingStream.InterceptedTlsFrameHeaders[7].Length < 16384);
}
}

private sealed class TlsFrameInterceptingStream(Stream innerStream) : Stream
{
public List<TlsFrameHeader> InterceptedTlsFrameHeaders { get; } = new();
public List<byte[]> InterceptedTlsFramePayloads { get; } = new();

private readonly Stream _innerStream = innerStream;
private readonly List<byte> _tlsFrameBuffer = new();

public override bool CanRead => true;
public override bool CanWrite => true;
public override bool CanSeek => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }

public override void Flush() => _innerStream.Flush();
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) { }

public override int Read(byte[] buffer, int offset, int count)
{
return _innerStream.Read(buffer, offset, count);
}

public override void Write(byte[] buffer, int offset, int count)
{
_tlsFrameBuffer.AddRange(buffer[offset..(offset+count)]);
_innerStream.Write(buffer, offset, count);

TlsFrameHeader header = default;
while (TlsFrameHelper.TryGetFrameHeader(_tlsFrameBuffer.ToArray(), ref header))
{
if (header.Length <= _tlsFrameBuffer.Count)
{
InterceptedTlsFrameHeaders.Add(header);
InterceptedTlsFramePayloads.Add(_tlsFrameBuffer.Take(header.Length).ToArray());
_tlsFrameBuffer.RemoveRange(0, header.Length);
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
<Compile Include="SslStreamStreamToStreamTest.cs" />
<Compile Include="SslStreamNetworkStreamTest.cs" />
<Compile Include="SslStreamMutualAuthenticationTest.cs" />
<Compile Include="SslStreamAppDataTests.cs" />
<Compile Include="TransportContextTest.cs" />
<!-- NegotiateAuthentication Tests -->
<Compile Include="NegotiateAuthenticationKerberosTest.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ jclass g_SSLEngineResult;
jmethodID g_SSLEngineResultGetStatus;
jmethodID g_SSLEngineResultGetHandshakeStatus;
bool g_SSLEngineResultStatusLegacyOrder;
jmethodID g_SSLEngineResultBytesConsumed;

// javax/crypto/KeyAgreement
jclass g_KeyAgreementClass;
Expand Down Expand Up @@ -1096,6 +1097,7 @@ JNI_OnLoad(JavaVM *vm, void *reserved)
g_SSLEngineResult = GetClassGRef(env, "javax/net/ssl/SSLEngineResult");
g_SSLEngineResultGetStatus = GetMethod(env, false, g_SSLEngineResult, "getStatus", "()Ljavax/net/ssl/SSLEngineResult$Status;");
g_SSLEngineResultGetHandshakeStatus = GetMethod(env, false, g_SSLEngineResult, "getHandshakeStatus", "()Ljavax/net/ssl/SSLEngineResult$HandshakeStatus;");
g_SSLEngineResultBytesConsumed = GetMethod(env, false, g_SSLEngineResult, "bytesConsumed", "()I");
g_SSLEngineResultStatusLegacyOrder = android_get_device_api_level() < 24;

g_KeyAgreementClass = GetClassGRef(env, "javax/crypto/KeyAgreement");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ extern jclass g_SSLEngineResult;
extern jmethodID g_SSLEngineResultGetStatus;
extern jmethodID g_SSLEngineResultGetHandshakeStatus;
extern bool g_SSLEngineResultStatusLegacyOrder;
extern jmethodID g_SSLEngineResultBytesConsumed;

// javax/crypto/KeyAgreement
extern jclass g_KeyAgreementClass;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct ApplicationProtocolData_t
ARGS_NON_NULL(1) static uint16_t* AllocateString(JNIEnv* env, jstring source);

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream* sslStream);
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus);
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed);
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus);

ARGS_NON_NULL_ALL static int GetHandshakeStatus(JNIEnv* env, SSLStream* sslStream)
Expand Down Expand Up @@ -112,15 +112,15 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Close(JNIEnv* env, SSLStream* sslSt
{
// Call wrap to clear any remaining data before closing
int unused;
PAL_SSLStreamStatus ret = DoWrap(env, sslStream, &unused);
PAL_SSLStreamStatus ret = DoWrap(env, sslStream, &unused, &unused);

// sslEngine.closeOutbound();
(*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineCloseOutbound);
if (ret != SSLStreamStatus_OK)
return ret;

// Flush any remaining data (e.g. sending close notification)
return DoWrap(env, sslStream, &unused);
return DoWrap(env, sslStream, &unused, &unused);
}

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Flush(JNIEnv* env, SSLStream* sslStream)
Expand Down Expand Up @@ -172,10 +172,14 @@ ARGS_NON_NULL_ALL static jobject ExpandBuffer(JNIEnv* env, jobject oldBuffer, in

ARGS_NON_NULL_ALL static jobject EnsureRemaining(JNIEnv* env, jobject oldBuffer, int32_t newRemaining)
{
IGNORE_RETURN((*env)->CallObjectMethod(env, oldBuffer, g_ByteBufferCompact));
int32_t oldPosition = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferPosition);
int32_t oldRemaining = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferRemaining);
if (oldRemaining < newRemaining)
{
return ExpandBuffer(env, oldBuffer, oldRemaining + newRemaining);
// After compacting the oldBuffer, the oldPosition is equal to the number of bytes in the buffer at the moment
// we need to change the capacity to the oldPosition + newRemaining
return ExpandBuffer(env, oldBuffer, oldPosition + newRemaining);
}
else
{
Expand Down Expand Up @@ -204,22 +208,19 @@ static int MapLegacySSLEngineResultStatus(int legacyStatus)
}
}

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus)
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus WrapAndProcessResult(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed, bool* repeat)
{
// appOutBuffer.flip();
// SSLEngineResult result = sslEngine.wrap(appOutBuffer, netOutBuffer);
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlip));
jobject result = (*env)->CallObjectMethod(
env, sslStream->sslEngine, g_SSLEngineWrap, sslStream->appOutBuffer, sslStream->netOutBuffer);
if (CheckJNIExceptions(env))
return SSLStreamStatus_Error;

// appOutBuffer.compact();
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact));

// handshakeStatus = result.getHandshakeStatus();
// bytesConsumed = result.bytesConsumed();
// SSLEngineResult.Status status = result.getStatus();
*handshakeStatus = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetHandshakeStatus));
*bytesConsumed = (*env)->CallIntMethod(env, result, g_SSLEngineResultBytesConsumed);
int status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetStatus));
(*env)->DeleteLocalRef(env, result);

Expand All @@ -242,11 +243,10 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslS
}
case STATUS__BUFFER_OVERFLOW:
{
// Expand buffer
// int newCapacity = sslSession.getPacketBufferSize() + netOutBuffer.remaining();
int32_t newCapacity = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize) +
(*env)->CallIntMethod(env, sslStream->netOutBuffer, g_ByteBufferRemaining);
sslStream->netOutBuffer = ExpandBuffer(env, sslStream->netOutBuffer, newCapacity);
// Expand buffer and repeat the wrap
int32_t packetBufferSize = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize);
sslStream->netOutBuffer = ExpandBuffer(env, sslStream->netOutBuffer, packetBufferSize);
*repeat = true;
return SSLStreamStatus_OK;
}
default:
Expand All @@ -257,6 +257,32 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslS
}
}

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed)
{
// appOutBuffer.flip();
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlip));

bool repeat = false;
PAL_SSLStreamStatus status = WrapAndProcessResult(env, sslStream, handshakeStatus, bytesConsumed, &repeat);

if (repeat)
{
repeat = false;
status = WrapAndProcessResult(env, sslStream, handshakeStatus, bytesConsumed, &repeat);

if (repeat)
{
LOG_ERROR("Unexpected repeat in DoWrap");
return SSLStreamStatus_Error;
}
}

// appOutBuffer.compact();
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact));

return status;
}

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus)
{
// if (netInBuffer.position() == 0)
Expand Down Expand Up @@ -350,13 +376,14 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream*
PAL_SSLStreamStatus status = SSLStreamStatus_OK;
int handshakeStatus = GetHandshakeStatus(env, sslStream);
assert(handshakeStatus >= 0);
int bytesConsumed;

while (IsHandshaking(handshakeStatus) && status == SSLStreamStatus_OK)
{
switch (handshakeStatus)
{
case HANDSHAKE_STATUS__NEED_WRAP:
status = DoWrap(env, sslStream, &handshakeStatus);
status = DoWrap(env, sslStream, &handshakeStatus, &bytesConsumed);
break;
case HANDSHAKE_STATUS__NEED_UNWRAP:
status = DoUnwrap(env, sslStream, &handshakeStatus);
Expand Down Expand Up @@ -858,26 +885,24 @@ PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uin
JNIEnv* env = GetJNIEnv();
PAL_SSLStreamStatus ret = SSLStreamStatus_Error;

// int remaining = appOutBuffer.remaining();
// int arraySize = length > remaining ? remaining : length;
// byte[] data = new byte[arraySize];
int32_t remaining = (*env)->CallIntMethod(env, sslStream->appOutBuffer, g_ByteBufferRemaining);
int32_t arraySize = length > remaining ? remaining : length;
jbyteArray data = make_java_byte_array(env, arraySize);
// data.setByteArrayRegion(0, length, buffer);
jbyteArray data = make_java_byte_array(env, length);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about capping the size of the data buffer to the max payload size of one TLS frame and keep copying just a subset of the data as we did previously, but the SslStream is already chunking the data (StreamSizes.Default == 32,768).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to allocate a java array here or can we use NIO? appOutBuffer is a ByteBuffer. ByteBuffer has a put that accepts another ByteBuffer.

We can create a ByteBuffer over buffer with jobject bufferByteBuffer = (*env)->NewDirectByteBuffer(env, buffer, length);

Then we can put that in to appOutBuffer. That remove an allocation and a copy from a handshake.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought, can be done as a follow up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I'll give it a try.

(*env)->SetByteArrayRegion(env, data, 0, length, (jbyte*)buffer);

// appOutBuffer.compact();
// appOutBuffer.put(data, 0, length);
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact));
sslStream->appOutBuffer = EnsureRemaining(env, sslStream->appOutBuffer, length);
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPutByteArrayWithLength, data, 0, length));
ON_EXCEPTION_PRINT_AND_GOTO(cleanup);

int32_t written = 0;
while (written < length)
{
int32_t toWrite = length - written > arraySize ? arraySize : length - written;
(*env)->SetByteArrayRegion(env, data, 0, toWrite, (jbyte*)(buffer + written));

// appOutBuffer.put(data, 0, toWrite);
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPutByteArrayWithLength, data, 0, toWrite));
ON_EXCEPTION_PRINT_AND_GOTO(cleanup);
written += toWrite;

int handshakeStatus;
ret = DoWrap(env, sslStream, &handshakeStatus);
int bytesConsumed;
ret = DoWrap(env, sslStream, &handshakeStatus, &bytesConsumed);
if (ret != SSLStreamStatus_OK)
{
goto cleanup;
Expand All @@ -887,6 +912,8 @@ PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uin
ret = SSLStreamStatus_Renegotiate;
goto cleanup;
}

written += bytesConsumed;
}

cleanup:
Expand Down
Loading