-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Android] Improve SslStream PAL buffer resizing #104726
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
|
||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Security.Cryptography.X509Certificates; | ||
using System.Security.Authentication; | ||
using System.Threading.Tasks; | ||
using Xunit; | ||
using System.Threading; | ||
using System.Linq; | ||
|
||
namespace System.Net.Security.Tests | ||
{ | ||
using Configuration = System.Net.Test.Common.Configuration; | ||
|
||
public sealed class SslStreamAppDataTests | ||
{ | ||
[Fact] | ||
public async Task UtilizeFullSizeOfTlsFrames() | ||
{ | ||
(Stream client, Stream server) = TestHelper.GetConnectedTcpStreams(); | ||
using (client) | ||
using (server) | ||
{ | ||
using var clientInterceptingStream = new TlsFrameInterceptingStream(client); | ||
using var serverInterceptingStream = new TlsFrameInterceptingStream(server); | ||
|
||
using var clientStream = new SslStream(clientInterceptingStream, leaveInnerStreamOpen: true, (sender, cert, chain, errors) => true); | ||
using var serverStream = new SslStream(serverInterceptingStream, leaveInnerStreamOpen: true); | ||
|
||
using var serverCertificate = Configuration.Certificates.GetServerCertificate(); | ||
var hostName = serverCertificate.GetNameInfo(X509NameType.SimpleName, forIssuer: false); | ||
|
||
Task t1 = clientStream.AuthenticateAsClientAsync(hostName, [], SslProtocols.None, checkCertificateRevocation: false); | ||
Task t2 = serverStream.AuthenticateAsServerAsync(serverCertificate, clientCertificateRequired: false, checkCertificateRevocation: false); | ||
await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2); | ||
|
||
// Clear the intercepted frames of the handshake | ||
clientInterceptingStream.InterceptedTlsFrameHeaders.Clear(); | ||
serverInterceptingStream.InterceptedTlsFrameHeaders.Clear(); | ||
|
||
var cts = new CancellationTokenSource(TestConfiguration.PassingTestTimeoutMilliseconds); | ||
var bytesToSend = 123_456; | ||
var sentData = new byte[bytesToSend]; | ||
Random.Shared.NextBytes(sentData); | ||
var receivedData = new byte[bytesToSend]; | ||
|
||
await clientStream.WriteAsync(sentData, cts.Token); | ||
|
||
int receivedBytes = 0; | ||
while (receivedBytes < bytesToSend) | ||
{ | ||
receivedBytes += await serverStream.ReadAsync(receivedData.AsMemory(receivedBytes), cts.Token); | ||
} | ||
|
||
Assert.Equal(sentData, receivedData); | ||
|
||
Assert.Equal(8, clientInterceptingStream.InterceptedTlsFrameHeaders.Count); | ||
Assert.All(clientInterceptingStream.InterceptedTlsFrameHeaders, static frameHeader => Assert.Equal(TlsContentType.AppData, frameHeader.Type)); | ||
|
||
for (int i = 0; i < 7; i++) | ||
{ | ||
// The first 7 frames should contain 16384 bytes of data + TLS frame overhead | ||
Assert.True(clientInterceptingStream.InterceptedTlsFrameHeaders[i].Length > 16384 | ||
&& clientInterceptingStream.InterceptedTlsFrameHeaders[i].Length <= 16709); | ||
} | ||
|
||
// The last frame should contain less data than the previous ones | ||
Assert.True(clientInterceptingStream.InterceptedTlsFrameHeaders[7].Length < 16384); | ||
} | ||
} | ||
|
||
private sealed class TlsFrameInterceptingStream(Stream innerStream) : Stream | ||
{ | ||
public List<TlsFrameHeader> InterceptedTlsFrameHeaders { get; } = new(); | ||
public List<byte[]> InterceptedTlsFramePayloads { get; } = new(); | ||
|
||
private readonly Stream _innerStream = innerStream; | ||
private readonly List<byte> _tlsFrameBuffer = new(); | ||
|
||
public override bool CanRead => true; | ||
public override bool CanWrite => true; | ||
public override bool CanSeek => false; | ||
public override long Length => throw new NotSupportedException(); | ||
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } | ||
|
||
public override void Flush() => _innerStream.Flush(); | ||
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); | ||
public override void SetLength(long value) { } | ||
|
||
public override int Read(byte[] buffer, int offset, int count) | ||
{ | ||
return _innerStream.Read(buffer, offset, count); | ||
} | ||
|
||
public override void Write(byte[] buffer, int offset, int count) | ||
{ | ||
_tlsFrameBuffer.AddRange(buffer[offset..(offset+count)]); | ||
_innerStream.Write(buffer, offset, count); | ||
|
||
TlsFrameHeader header = default; | ||
while (TlsFrameHelper.TryGetFrameHeader(_tlsFrameBuffer.ToArray(), ref header)) | ||
{ | ||
if (header.Length <= _tlsFrameBuffer.Count) | ||
{ | ||
InterceptedTlsFrameHeaders.Add(header); | ||
InterceptedTlsFramePayloads.Add(_tlsFrameBuffer.Take(header.Length).ToArray()); | ||
_tlsFrameBuffer.RemoveRange(0, header.Length); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,7 @@ struct ApplicationProtocolData_t | |
ARGS_NON_NULL(1) static uint16_t* AllocateString(JNIEnv* env, jstring source); | ||
|
||
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream* sslStream); | ||
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus); | ||
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed); | ||
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus); | ||
|
||
ARGS_NON_NULL_ALL static int GetHandshakeStatus(JNIEnv* env, SSLStream* sslStream) | ||
|
@@ -112,15 +112,15 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Close(JNIEnv* env, SSLStream* sslSt | |
{ | ||
// Call wrap to clear any remaining data before closing | ||
int unused; | ||
PAL_SSLStreamStatus ret = DoWrap(env, sslStream, &unused); | ||
PAL_SSLStreamStatus ret = DoWrap(env, sslStream, &unused, &unused); | ||
|
||
// sslEngine.closeOutbound(); | ||
(*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineCloseOutbound); | ||
if (ret != SSLStreamStatus_OK) | ||
return ret; | ||
|
||
// Flush any remaining data (e.g. sending close notification) | ||
return DoWrap(env, sslStream, &unused); | ||
return DoWrap(env, sslStream, &unused, &unused); | ||
} | ||
|
||
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Flush(JNIEnv* env, SSLStream* sslStream) | ||
|
@@ -172,10 +172,14 @@ ARGS_NON_NULL_ALL static jobject ExpandBuffer(JNIEnv* env, jobject oldBuffer, in | |
|
||
ARGS_NON_NULL_ALL static jobject EnsureRemaining(JNIEnv* env, jobject oldBuffer, int32_t newRemaining) | ||
{ | ||
IGNORE_RETURN((*env)->CallObjectMethod(env, oldBuffer, g_ByteBufferCompact)); | ||
int32_t oldPosition = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferPosition); | ||
int32_t oldRemaining = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferRemaining); | ||
if (oldRemaining < newRemaining) | ||
{ | ||
return ExpandBuffer(env, oldBuffer, oldRemaining + newRemaining); | ||
// After compacting the oldBuffer, the oldPosition is equal to the number of bytes in the buffer at the moment | ||
// we need to change the capacity to the oldPosition + newRemaining | ||
return ExpandBuffer(env, oldBuffer, oldPosition + newRemaining); | ||
} | ||
else | ||
{ | ||
|
@@ -204,22 +208,19 @@ static int MapLegacySSLEngineResultStatus(int legacyStatus) | |
} | ||
} | ||
|
||
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus) | ||
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus WrapAndProcessResult(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed, bool* repeat) | ||
{ | ||
// appOutBuffer.flip(); | ||
// SSLEngineResult result = sslEngine.wrap(appOutBuffer, netOutBuffer); | ||
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlip)); | ||
jobject result = (*env)->CallObjectMethod( | ||
env, sslStream->sslEngine, g_SSLEngineWrap, sslStream->appOutBuffer, sslStream->netOutBuffer); | ||
if (CheckJNIExceptions(env)) | ||
return SSLStreamStatus_Error; | ||
|
||
// appOutBuffer.compact(); | ||
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact)); | ||
|
||
// handshakeStatus = result.getHandshakeStatus(); | ||
// bytesConsumed = result.bytesConsumed(); | ||
// SSLEngineResult.Status status = result.getStatus(); | ||
*handshakeStatus = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetHandshakeStatus)); | ||
*bytesConsumed = (*env)->CallIntMethod(env, result, g_SSLEngineResultBytesConsumed); | ||
int status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetStatus)); | ||
(*env)->DeleteLocalRef(env, result); | ||
|
||
|
@@ -242,11 +243,10 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslS | |
} | ||
case STATUS__BUFFER_OVERFLOW: | ||
{ | ||
// Expand buffer | ||
// int newCapacity = sslSession.getPacketBufferSize() + netOutBuffer.remaining(); | ||
int32_t newCapacity = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize) + | ||
(*env)->CallIntMethod(env, sslStream->netOutBuffer, g_ByteBufferRemaining); | ||
sslStream->netOutBuffer = ExpandBuffer(env, sslStream->netOutBuffer, newCapacity); | ||
// Expand buffer and repeat the wrap | ||
int32_t packetBufferSize = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize); | ||
sslStream->netOutBuffer = ExpandBuffer(env, sslStream->netOutBuffer, packetBufferSize); | ||
*repeat = true; | ||
return SSLStreamStatus_OK; | ||
} | ||
default: | ||
|
@@ -257,6 +257,32 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslS | |
} | ||
} | ||
|
||
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed) | ||
{ | ||
// appOutBuffer.flip(); | ||
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlip)); | ||
|
||
bool repeat = false; | ||
PAL_SSLStreamStatus status = WrapAndProcessResult(env, sslStream, handshakeStatus, bytesConsumed, &repeat); | ||
|
||
if (repeat) | ||
{ | ||
repeat = false; | ||
status = WrapAndProcessResult(env, sslStream, handshakeStatus, bytesConsumed, &repeat); | ||
|
||
if (repeat) | ||
{ | ||
LOG_ERROR("Unexpected repeat in DoWrap"); | ||
return SSLStreamStatus_Error; | ||
} | ||
} | ||
|
||
// appOutBuffer.compact(); | ||
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact)); | ||
|
||
return status; | ||
} | ||
|
||
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus) | ||
{ | ||
// if (netInBuffer.position() == 0) | ||
|
@@ -350,13 +376,14 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream* | |
PAL_SSLStreamStatus status = SSLStreamStatus_OK; | ||
int handshakeStatus = GetHandshakeStatus(env, sslStream); | ||
assert(handshakeStatus >= 0); | ||
int bytesConsumed; | ||
|
||
while (IsHandshaking(handshakeStatus) && status == SSLStreamStatus_OK) | ||
{ | ||
switch (handshakeStatus) | ||
{ | ||
case HANDSHAKE_STATUS__NEED_WRAP: | ||
status = DoWrap(env, sslStream, &handshakeStatus); | ||
status = DoWrap(env, sslStream, &handshakeStatus, &bytesConsumed); | ||
break; | ||
case HANDSHAKE_STATUS__NEED_UNWRAP: | ||
status = DoUnwrap(env, sslStream, &handshakeStatus); | ||
|
@@ -858,26 +885,24 @@ PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uin | |
JNIEnv* env = GetJNIEnv(); | ||
PAL_SSLStreamStatus ret = SSLStreamStatus_Error; | ||
|
||
// int remaining = appOutBuffer.remaining(); | ||
// int arraySize = length > remaining ? remaining : length; | ||
// byte[] data = new byte[arraySize]; | ||
int32_t remaining = (*env)->CallIntMethod(env, sslStream->appOutBuffer, g_ByteBufferRemaining); | ||
int32_t arraySize = length > remaining ? remaining : length; | ||
jbyteArray data = make_java_byte_array(env, arraySize); | ||
// data.setByteArrayRegion(0, length, buffer); | ||
jbyteArray data = make_java_byte_array(env, length); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to allocate a java array here or can we use NIO? We can create a Then we can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a thought, can be done as a follow up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion, I'll give it a try. |
||
(*env)->SetByteArrayRegion(env, data, 0, length, (jbyte*)buffer); | ||
|
||
// appOutBuffer.compact(); | ||
// appOutBuffer.put(data, 0, length); | ||
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact)); | ||
sslStream->appOutBuffer = EnsureRemaining(env, sslStream->appOutBuffer, length); | ||
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPutByteArrayWithLength, data, 0, length)); | ||
ON_EXCEPTION_PRINT_AND_GOTO(cleanup); | ||
|
||
int32_t written = 0; | ||
while (written < length) | ||
{ | ||
int32_t toWrite = length - written > arraySize ? arraySize : length - written; | ||
(*env)->SetByteArrayRegion(env, data, 0, toWrite, (jbyte*)(buffer + written)); | ||
|
||
// appOutBuffer.put(data, 0, toWrite); | ||
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPutByteArrayWithLength, data, 0, toWrite)); | ||
ON_EXCEPTION_PRINT_AND_GOTO(cleanup); | ||
written += toWrite; | ||
|
||
int handshakeStatus; | ||
ret = DoWrap(env, sslStream, &handshakeStatus); | ||
int bytesConsumed; | ||
ret = DoWrap(env, sslStream, &handshakeStatus, &bytesConsumed); | ||
if (ret != SSLStreamStatus_OK) | ||
{ | ||
goto cleanup; | ||
|
@@ -887,6 +912,8 @@ PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uin | |
ret = SSLStreamStatus_Renegotiate; | ||
goto cleanup; | ||
} | ||
|
||
written += bytesConsumed; | ||
} | ||
|
||
cleanup: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about capping the size of the
data
buffer to the max payload size of one TLS frame and keep copying just a subset of the data as we did previously, but the SslStream is already chunking the data (StreamSizes.Default == 32,768
).