Skip to content

Commit

Permalink
Cleanup the code
Browse files Browse the repository at this point in the history
Use conditional compilation for psa and mbedtls code (MBEDTLS_USE_PSA_CRYPTO).

Signed-off-by: Przemyslaw Stekiel <[email protected]>
  • Loading branch information
mprse committed Jan 31, 2022
1 parent d4eab57 commit 6be9cf5
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 89 deletions.
5 changes: 3 additions & 2 deletions library/ssl_misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -937,14 +937,15 @@ struct mbedtls_ssl_transform

#endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */

mbedtls_cipher_context_t cipher_ctx_enc; /*!< encryption context */
mbedtls_cipher_context_t cipher_ctx_dec; /*!< decryption context */
int minor_ver;

#if defined(MBEDTLS_USE_PSA_CRYPTO)
mbedtls_svc_key_id_t psa_key_enc; /*!< psa encryption key */
mbedtls_svc_key_id_t psa_key_dec; /*!< psa decryption key */
psa_algorithm_t psa_alg; /*!< psa algorithm */
#else
mbedtls_cipher_context_t cipher_ctx_enc; /*!< encryption context */
mbedtls_cipher_context_t cipher_ctx_dec; /*!< decryption context */
#endif /* MBEDTLS_USE_PSA_CRYPTO */

#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
Expand Down
106 changes: 102 additions & 4 deletions library/ssl_msg.c
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,9 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl,
int (*f_rng)(void *, unsigned char *, size_t),
void *p_rng )
{
#if !defined(MBEDTLS_USE_PSA_CRYPTO)
mbedtls_cipher_mode_t mode;
#endif /* MBEDTLS_USE_PSA_CRYPTO */
int auth_done = 0;
unsigned char * data;
unsigned char add_data[13 + 1 + MBEDTLS_SSL_CID_OUT_LEN_MAX ];
Expand Down Expand Up @@ -568,7 +570,9 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl,
MBEDTLS_SSL_DEBUG_BUF( 4, "before encrypt: output payload",
data, rec->data_len );

#if !defined(MBEDTLS_USE_PSA_CRYPTO)
mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc );
#endif /* MBEDTLS_USE_PSA_CRYPTO */

if( rec->data_len > MBEDTLS_SSL_OUT_CONTENT_LEN )
{
Expand Down Expand Up @@ -649,8 +653,13 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl,
* Add MAC before if needed
*/
#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
#if defined(MBEDTLS_USE_PSA_CRYPTO)
if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER ||
( transform->psa_alg == PSA_ALG_CBC_NO_PADDING
#else
if( mode == MBEDTLS_MODE_STREAM ||
( mode == MBEDTLS_MODE_CBC
#endif
#if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
&& transform->encrypt_then_mac == MBEDTLS_SSL_ETM_DISABLED
#endif
Expand Down Expand Up @@ -707,7 +716,11 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl,
* Encrypt
*/
#if defined(MBEDTLS_SSL_SOME_SUITES_USE_STREAM)
#if defined(MBEDTLS_USE_PSA_CRYPTO)
if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER )
#else
if( mode == MBEDTLS_MODE_STREAM )
#endif
{
size_t olen;
#if defined(MBEDTLS_USE_PSA_CRYPTO)
Expand Down Expand Up @@ -779,9 +792,18 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl,
#if defined(MBEDTLS_GCM_C) || \
defined(MBEDTLS_CCM_C) || \
defined(MBEDTLS_CHACHAPOLY_C)
#if defined(MBEDTLS_USE_PSA_CRYPTO)
if ( transform->psa_alg == PSA_ALG_GCM ||
/* PSA_ALG_IS_AEAD( transform->psa_alg ) corresponds to
psa_alg == PSA_ALG_CCM || psa_alg == PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 8 )
in tls context (TLS only uses the default taglen or 8) */
PSA_ALG_IS_AEAD( transform->psa_alg ) ||
transform->psa_alg == PSA_ALG_CHACHA20_POLY1305 )
#else
if( mode == MBEDTLS_MODE_GCM ||
mode == MBEDTLS_MODE_CCM ||
mode == MBEDTLS_MODE_CHACHAPOLY )
#endif /* MBEDTLS_USE_PSA_CRYPTO */
{
unsigned char iv[12];
unsigned char *dynamic_iv;
Expand Down Expand Up @@ -897,7 +919,11 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl,
else
#endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C || MBEDTLS_CHACHAPOLY_C */
#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC)
#if defined(MBEDTLS_USE_PSA_CRYPTO)
if ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING )
#else
if( mode == MBEDTLS_MODE_CBC )
#endif /* MBEDTLS_USE_PSA_CRYPTO */
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
size_t padlen, i;
Expand Down Expand Up @@ -1092,7 +1118,9 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
mbedtls_record *rec )
{
size_t olen;
#if !defined(MBEDTLS_USE_PSA_CRYPTO)
mbedtls_cipher_mode_t mode;
#endif /* MBEDTLS_USE_PSA_CRYPTO */
int ret, auth_done = 0;
#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
size_t padlen = 0, correct = 1;
Expand All @@ -1117,7 +1145,9 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
}

data = rec->buf + rec->data_offset;
#if !defined(MBEDTLS_USE_PSA_CRYPTO)
mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_dec );
#endif /* MBEDTLS_USE_PSA_CRYPTO */

#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
/*
Expand All @@ -1131,7 +1161,11 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
#endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */

#if defined(MBEDTLS_SSL_SOME_SUITES_USE_STREAM)
#if defined(MBEDTLS_USE_PSA_CRYPTO)
if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER )
#else
if( mode == MBEDTLS_MODE_STREAM )
#endif /* MBEDTLS_USE_PSA_CRYPTO */
{
padlen = 0;
#if defined(MBEDTLS_USE_PSA_CRYPTO)
Expand Down Expand Up @@ -1198,9 +1232,18 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
#if defined(MBEDTLS_GCM_C) || \
defined(MBEDTLS_CCM_C) || \
defined(MBEDTLS_CHACHAPOLY_C)
#if defined(MBEDTLS_USE_PSA_CRYPTO)
if ( transform->psa_alg == PSA_ALG_GCM ||
/* PSA_ALG_IS_AEAD( transform->psa_alg ) corresponds to
psa_alg == PSA_ALG_CCM || psa_alg == PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 8 )
in tls context (TLS only uses the default taglen or 8) */
PSA_ALG_IS_AEAD( transform->psa_alg ) ||
transform->psa_alg == PSA_ALG_CHACHA20_POLY1305 )
#else
if( mode == MBEDTLS_MODE_GCM ||
mode == MBEDTLS_MODE_CCM ||
mode == MBEDTLS_MODE_CHACHAPOLY )
#endif /* MBEDTLS_USE_PSA_CRYPTO */
{
unsigned char iv[12];
unsigned char *dynamic_iv;
Expand Down Expand Up @@ -1322,7 +1365,11 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl,
else
#endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C */
#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC)
#if defined(MBEDTLS_USE_PSA_CRYPTO)
if ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING )
#else
if( mode == MBEDTLS_MODE_CBC )
#endif /* MBEDTLS_USE_PSA_CRYPTO */
{
size_t minlen = 0;
#if defined(MBEDTLS_USE_PSA_CRYPTO)
Expand Down Expand Up @@ -5047,12 +5094,62 @@ int mbedtls_ssl_get_record_expansion( const mbedtls_ssl_context *ssl )
size_t transform_expansion = 0;
const mbedtls_ssl_transform *transform = ssl->transform_out;
unsigned block_size;
#if defined(MBEDTLS_USE_PSA_CRYPTO)
psa_key_attributes_t attr = PSA_KEY_ATTRIBUTES_INIT;
psa_key_type_t key_type;
#endif /* MBEDTLS_USE_PSA_CRYPTO */

size_t out_hdr_len = mbedtls_ssl_out_hdr_len( ssl );

if( transform == NULL )
return( (int) out_hdr_len );


#if defined(MBEDTLS_USE_PSA_CRYPTO)
switch( transform->psa_alg )
{
case PSA_ALG_GCM:
case PSA_ALG_CHACHA20_POLY1305:
case MBEDTLS_SSL_NULL_CIPHER:
transform_expansion = transform->minlen;
break;

case PSA_ALG_CBC_NO_PADDING:
(void) psa_get_key_attributes( transform->psa_key_enc, &attr );
key_type = psa_get_key_type( &attr );

block_size = PSA_BLOCK_CIPHER_BLOCK_LENGTH( key_type );

/* Expansion due to the addition of the MAC. */
transform_expansion += transform->maclen;

/* Expansion due to the addition of CBC padding;
* Theoretically up to 256 bytes, but we never use
* more than the block size of the underlying cipher. */
transform_expansion += block_size;

/* For TLS 1.2 or higher, an explicit IV is added
* after the record header. */
#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
transform_expansion += block_size;
#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
break;

default:
/* Handle CCM case in default:
PSA_ALG_IS_AEAD( transform->psa_alg ) corresponds to
psa_alg == PSA_ALG_CCM || psa_alg == PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 8 )
in tls context (TLS only uses the default taglen or 8) */
if ( PSA_ALG_IS_AEAD( transform->psa_alg ) )
{
transform_expansion = transform->minlen;
break;
}

MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
}
#else
switch( mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc ) )
{
case MBEDTLS_MODE_GCM:
Expand Down Expand Up @@ -5087,6 +5184,7 @@ int mbedtls_ssl_get_record_expansion( const mbedtls_ssl_context *ssl )
MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
}
#endif /* MBEDTLS_USE_PSA_CRYPTO */

#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
if( transform->out_cid_len != 0 )
Expand Down Expand Up @@ -5591,13 +5689,13 @@ void mbedtls_ssl_transform_free( mbedtls_ssl_transform *transform )
if( transform == NULL )
return;

mbedtls_cipher_free( &transform->cipher_ctx_enc );
mbedtls_cipher_free( &transform->cipher_ctx_dec );

#if defined(MBEDTLS_USE_PSA_CRYPTO)
psa_destroy_key( transform->psa_key_enc );
psa_destroy_key( transform->psa_key_dec );
#endif
#else
mbedtls_cipher_free( &transform->cipher_ctx_enc );
mbedtls_cipher_free( &transform->cipher_ctx_dec );
#endif /* MBEDTLS_USE_PSA_CRYPTO */

#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
mbedtls_md_free( &transform->md_ctx_enc );
Expand Down
Loading

0 comments on commit 6be9cf5

Please sign in to comment.