diff --git a/src/native/libs/System.Security.Cryptography.Native.Android/pal_jni.c b/src/native/libs/System.Security.Cryptography.Native.Android/pal_jni.c index 243dbd1d9466d1..e14ef08fac0947 100644 --- a/src/native/libs/System.Security.Cryptography.Native.Android/pal_jni.c +++ b/src/native/libs/System.Security.Cryptography.Native.Android/pal_jni.c @@ -451,8 +451,6 @@ jmethodID g_ByteBufferGet; jmethodID g_ByteBufferLimit; jmethodID g_ByteBufferPosition; jmethodID g_ByteBufferPutBuffer; -jmethodID g_ByteBufferPutByteArray; -jmethodID g_ByteBufferPutByteArrayWithLength; jmethodID g_ByteBufferRemaining; // javax/net/ssl/SSLContext @@ -477,6 +475,7 @@ jclass g_SSLEngineResult; jmethodID g_SSLEngineResultGetStatus; jmethodID g_SSLEngineResultGetHandshakeStatus; bool g_SSLEngineResultStatusLegacyOrder; +jmethodID g_SSLEngineResultBytesConsumed; // javax/crypto/KeyAgreement jclass g_KeyAgreementClass; @@ -1074,8 +1073,6 @@ JNI_OnLoad(JavaVM *vm, void *reserved) g_ByteBufferLimit = GetMethod(env, false, g_ByteBuffer, "limit", "()I"); g_ByteBufferPosition = GetMethod(env, false, g_ByteBuffer, "position", "()I"); g_ByteBufferPutBuffer = GetMethod(env, false, g_ByteBuffer, "put", "(Ljava/nio/ByteBuffer;)Ljava/nio/ByteBuffer;"); - g_ByteBufferPutByteArray = GetMethod(env, false, g_ByteBuffer, "put", "([B)Ljava/nio/ByteBuffer;"); - g_ByteBufferPutByteArrayWithLength = GetMethod(env, false, g_ByteBuffer, "put", "([BII)Ljava/nio/ByteBuffer;"); g_ByteBufferRemaining = GetMethod(env, false, g_ByteBuffer, "remaining", "()I"); g_SSLContext = GetClassGRef(env, "javax/net/ssl/SSLContext"); @@ -1096,6 +1093,7 @@ JNI_OnLoad(JavaVM *vm, void *reserved) g_SSLEngineResult = GetClassGRef(env, "javax/net/ssl/SSLEngineResult"); g_SSLEngineResultGetStatus = GetMethod(env, false, g_SSLEngineResult, "getStatus", "()Ljavax/net/ssl/SSLEngineResult$Status;"); g_SSLEngineResultGetHandshakeStatus = GetMethod(env, false, g_SSLEngineResult, "getHandshakeStatus", "()Ljavax/net/ssl/SSLEngineResult$HandshakeStatus;"); + g_SSLEngineResultBytesConsumed = GetMethod(env, false, g_SSLEngineResult, "bytesConsumed", "()I"); g_SSLEngineResultStatusLegacyOrder = android_get_device_api_level() < 24; g_KeyAgreementClass = GetClassGRef(env, "javax/crypto/KeyAgreement"); diff --git a/src/native/libs/System.Security.Cryptography.Native.Android/pal_jni.h b/src/native/libs/System.Security.Cryptography.Native.Android/pal_jni.h index 79bc888224629f..e0e0abbab1874a 100644 --- a/src/native/libs/System.Security.Cryptography.Native.Android/pal_jni.h +++ b/src/native/libs/System.Security.Cryptography.Native.Android/pal_jni.h @@ -465,8 +465,6 @@ extern jmethodID g_ByteBufferGet; extern jmethodID g_ByteBufferLimit; extern jmethodID g_ByteBufferPosition; extern jmethodID g_ByteBufferPutBuffer; -extern jmethodID g_ByteBufferPutByteArray; -extern jmethodID g_ByteBufferPutByteArrayWithLength; extern jmethodID g_ByteBufferRemaining; // javax/net/ssl/SSLContext @@ -491,6 +489,7 @@ extern jclass g_SSLEngineResult; extern jmethodID g_SSLEngineResultGetStatus; extern jmethodID g_SSLEngineResultGetHandshakeStatus; extern bool g_SSLEngineResultStatusLegacyOrder; +extern jmethodID g_SSLEngineResultBytesConsumed; // javax/crypto/KeyAgreement extern jclass g_KeyAgreementClass; diff --git a/src/native/libs/System.Security.Cryptography.Native.Android/pal_sslstream.c b/src/native/libs/System.Security.Cryptography.Native.Android/pal_sslstream.c index 9aa7444e391bcc..c7933f73aaee29 100644 --- a/src/native/libs/System.Security.Cryptography.Native.Android/pal_sslstream.c +++ b/src/native/libs/System.Security.Cryptography.Native.Android/pal_sslstream.c @@ -44,7 +44,7 @@ struct ApplicationProtocolData_t ARGS_NON_NULL(1) static uint16_t* AllocateString(JNIEnv* env, jstring source); ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream* sslStream); -ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus); +ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed); ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus); ARGS_NON_NULL_ALL static int GetHandshakeStatus(JNIEnv* env, SSLStream* sslStream) @@ -112,7 +112,7 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Close(JNIEnv* env, SSLStream* sslSt { // Call wrap to clear any remaining data before closing int unused; - PAL_SSLStreamStatus ret = DoWrap(env, sslStream, &unused); + PAL_SSLStreamStatus ret = DoWrap(env, sslStream, &unused, &unused); // sslEngine.closeOutbound(); (*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineCloseOutbound); @@ -120,7 +120,7 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Close(JNIEnv* env, SSLStream* sslSt return ret; // Flush any remaining data (e.g. sending close notification) - return DoWrap(env, sslStream, &unused); + return DoWrap(env, sslStream, &unused, &unused); } ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Flush(JNIEnv* env, SSLStream* sslStream) @@ -172,10 +172,14 @@ ARGS_NON_NULL_ALL static jobject ExpandBuffer(JNIEnv* env, jobject oldBuffer, in ARGS_NON_NULL_ALL static jobject EnsureRemaining(JNIEnv* env, jobject oldBuffer, int32_t newRemaining) { + IGNORE_RETURN((*env)->CallObjectMethod(env, oldBuffer, g_ByteBufferCompact)); + int32_t oldPosition = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferPosition); int32_t oldRemaining = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferRemaining); if (oldRemaining < newRemaining) { - return ExpandBuffer(env, oldBuffer, oldRemaining + newRemaining); + // After compacting the oldBuffer, the oldPosition is equal to the number of bytes in the buffer at the moment + // we need to change the capacity to the oldPosition + newRemaining + return ExpandBuffer(env, oldBuffer, oldPosition + newRemaining); } else { @@ -204,22 +208,19 @@ static int MapLegacySSLEngineResultStatus(int legacyStatus) } } -ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus) +ARGS_NON_NULL_ALL static PAL_SSLStreamStatus WrapAndProcessResult(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed, bool* repeat) { - // appOutBuffer.flip(); // SSLEngineResult result = sslEngine.wrap(appOutBuffer, netOutBuffer); - IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlip)); jobject result = (*env)->CallObjectMethod( env, sslStream->sslEngine, g_SSLEngineWrap, sslStream->appOutBuffer, sslStream->netOutBuffer); if (CheckJNIExceptions(env)) return SSLStreamStatus_Error; - // appOutBuffer.compact(); - IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact)); - // handshakeStatus = result.getHandshakeStatus(); + // bytesConsumed = result.bytesConsumed(); // SSLEngineResult.Status status = result.getStatus(); *handshakeStatus = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetHandshakeStatus)); + *bytesConsumed = (*env)->CallIntMethod(env, result, g_SSLEngineResultBytesConsumed); int status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetStatus)); (*env)->DeleteLocalRef(env, result); @@ -242,11 +243,10 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslS } case STATUS__BUFFER_OVERFLOW: { - // Expand buffer - // int newCapacity = sslSession.getPacketBufferSize() + netOutBuffer.remaining(); - int32_t newCapacity = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize) + - (*env)->CallIntMethod(env, sslStream->netOutBuffer, g_ByteBufferRemaining); - sslStream->netOutBuffer = ExpandBuffer(env, sslStream->netOutBuffer, newCapacity); + // Expand buffer and repeat the wrap + int32_t packetBufferSize = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize); + sslStream->netOutBuffer = ExpandBuffer(env, sslStream->netOutBuffer, packetBufferSize); + *repeat = true; return SSLStreamStatus_OK; } default: @@ -257,18 +257,44 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslS } } +ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed) +{ + // appOutBuffer.flip(); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlip)); + + bool repeat = false; + PAL_SSLStreamStatus status = WrapAndProcessResult(env, sslStream, handshakeStatus, bytesConsumed, &repeat); + + if (repeat) + { + repeat = false; + status = WrapAndProcessResult(env, sslStream, handshakeStatus, bytesConsumed, &repeat); + + if (repeat) + { + LOG_ERROR("Unexpected repeat in DoWrap"); + return SSLStreamStatus_Error; + } + } + + // appOutBuffer.compact(); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact)); + + return status; +} + ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus) { // if (netInBuffer.position() == 0) // { - // byte[] tmp = new byte[netInBuffer.limit()]; - // int count = streamReader(tmp, 0, tmp.length); - // netInBuffer.put(tmp, 0, count); + // int netInBufferLimit = netInBuffer.limit(); + // ByteBuffer tmp = ByteBuffer.allocateDirect(netInBufferLimit); + // int count = streamReader(tmp, 0, netInBufferLimit); + // netInBuffer.put(tmp); // } if ((*env)->CallIntMethod(env, sslStream->netInBuffer, g_ByteBufferPosition) == 0) { int netInBufferLimit = (*env)->CallIntMethod(env, sslStream->netInBuffer, g_ByteBufferLimit); - jbyteArray tmp = make_java_byte_array(env, netInBufferLimit); uint8_t* tmpNative = (uint8_t*)xmalloc((size_t)netInBufferLimit); int count = netInBufferLimit; // todo assert streamReader != 0 ? @@ -276,13 +302,15 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* ss if (status != SSLStreamStatus_OK) { free(tmpNative); - (*env)->DeleteLocalRef(env, tmp); return status; } - (*env)->SetByteArrayRegion(env, tmp, 0, count, (jbyte*)(tmpNative)); + jobject tmp = (*env)->NewDirectByteBuffer(env, tmpNative, count); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + IGNORE_RETURN( - (*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferPutByteArrayWithLength, tmp, 0, count)); + (*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferPutBuffer, tmp)); +cleanup: free(tmpNative); (*env)->DeleteLocalRef(env, tmp); } @@ -350,13 +378,14 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream* PAL_SSLStreamStatus status = SSLStreamStatus_OK; int handshakeStatus = GetHandshakeStatus(env, sslStream); assert(handshakeStatus >= 0); + int bytesConsumed; while (IsHandshaking(handshakeStatus) && status == SSLStreamStatus_OK) { switch (handshakeStatus) { case HANDSHAKE_STATUS__NEED_WRAP: - status = DoWrap(env, sslStream, &handshakeStatus); + status = DoWrap(env, sslStream, &handshakeStatus, &bytesConsumed); break; case HANDSHAKE_STATUS__NEED_UNWRAP: status = DoUnwrap(env, sslStream, &handshakeStatus); @@ -858,26 +887,24 @@ PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uin JNIEnv* env = GetJNIEnv(); PAL_SSLStreamStatus ret = SSLStreamStatus_Error; - // int remaining = appOutBuffer.remaining(); - // int arraySize = length > remaining ? remaining : length; - // byte[] data = new byte[arraySize]; - int32_t remaining = (*env)->CallIntMethod(env, sslStream->appOutBuffer, g_ByteBufferRemaining); - int32_t arraySize = length > remaining ? remaining : length; - jbyteArray data = make_java_byte_array(env, arraySize); + // ByteBuffer bufferByteBuffer = ...; + jobject bufferByteBuffer = (*env)->NewDirectByteBuffer(env, buffer, length); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); + + // appOutBuffer.compact(); + // appOutBuffer = EnsureRemaining(appOutBuffer, length); + // appOutBuffer.put(bufferByteBuffer); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact)); + sslStream->appOutBuffer = EnsureRemaining(env, sslStream->appOutBuffer, length); + IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPutBuffer, bufferByteBuffer)); + ON_EXCEPTION_PRINT_AND_GOTO(cleanup); int32_t written = 0; while (written < length) { - int32_t toWrite = length - written > arraySize ? arraySize : length - written; - (*env)->SetByteArrayRegion(env, data, 0, toWrite, (jbyte*)(buffer + written)); - - // appOutBuffer.put(data, 0, toWrite); - IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPutByteArrayWithLength, data, 0, toWrite)); - ON_EXCEPTION_PRINT_AND_GOTO(cleanup); - written += toWrite; - int handshakeStatus; - ret = DoWrap(env, sslStream, &handshakeStatus); + int bytesConsumed; + ret = DoWrap(env, sslStream, &handshakeStatus, &bytesConsumed); if (ret != SSLStreamStatus_OK) { goto cleanup; @@ -887,10 +914,12 @@ PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uin ret = SSLStreamStatus_Renegotiate; goto cleanup; } + + written += bytesConsumed; } cleanup: - (*env)->DeleteLocalRef(env, data); + (*env)->DeleteLocalRef(env, bufferByteBuffer); return ret; }