Skip to content

Commit

Permalink
Extend SqlQueryMetadataCache to include enclave-required keys (#1062)
Browse files Browse the repository at this point in the history
  • Loading branch information
Johnny Pham authored Jun 3, 2021
1 parent 688b931 commit 89e15c4
Show file tree
Hide file tree
Showing 15 changed files with 235 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data;
Expand Down Expand Up @@ -149,8 +150,17 @@ private enum EXECTYPE
// cached metadata
private _SqlMetaDataSet _cachedMetaData;

private Dictionary<int, SqlTceCipherInfoEntry> keysToBeSentToEnclave;
private bool requiresEnclaveComputations = false;
internal ConcurrentDictionary<int, SqlTceCipherInfoEntry> keysToBeSentToEnclave;
internal bool requiresEnclaveComputations = false;

private bool ShouldCacheEncryptionMetadata
{
get
{
return !requiresEnclaveComputations || _activeConnection.Parser.AreEnclaveRetriesSupported;
}
}

internal EnclavePackage enclavePackage = null;
private SqlEnclaveAttestationParameters enclaveAttestationParameters = null;
private byte[] customData = null;
Expand Down Expand Up @@ -3435,10 +3445,7 @@ private void ResetEncryptionState()
}
}

if (keysToBeSentToEnclave != null)
{
keysToBeSentToEnclave.Clear();
}
keysToBeSentToEnclave?.Clear();
enclavePackage = null;
requiresEnclaveComputations = false;
enclaveAttestationParameters = null;
Expand Down Expand Up @@ -4143,7 +4150,6 @@ private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDi
enclaveMetadataExists = false;
}


if (isRequestedByEnclave)
{
if (string.IsNullOrWhiteSpace(this.Connection.EnclaveAttestationUrl))
Expand Down Expand Up @@ -4173,12 +4179,12 @@ private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDi

if (keysToBeSentToEnclave == null)
{
keysToBeSentToEnclave = new Dictionary<int, SqlTceCipherInfoEntry>();
keysToBeSentToEnclave.Add(currentOrdinal, cipherInfo);
keysToBeSentToEnclave = new ConcurrentDictionary<int, SqlTceCipherInfoEntry>();
keysToBeSentToEnclave.TryAdd(currentOrdinal, cipherInfo);
}
else if (!keysToBeSentToEnclave.ContainsKey(currentOrdinal))
{
keysToBeSentToEnclave.Add(currentOrdinal, cipherInfo);
keysToBeSentToEnclave.TryAdd(currentOrdinal, cipherInfo);
}

requiresEnclaveComputations = true;
Expand Down Expand Up @@ -4315,7 +4321,6 @@ private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDi

while (ds.Read())
{

if (attestationInfoRead)
{
throw SQL.MultipleRowsReturnedForAttestationInfo();
Expand Down Expand Up @@ -4357,8 +4362,7 @@ private void ReadDescribeEncryptionParameterResults(SqlDataReader ds, ReadOnlyDi
}

// If we are not in Batch RPC mode, update the query cache with the encryption MD.
// Enclave based Always Encrypted implementation on server side does not support cache at this point. So we should not cache if the query requires keys to be sent to enclave
if (!BatchRPCMode && !requiresEnclaveComputations && (this._parameters != null && this._parameters.Count > 0))
if (!BatchRPCMode && ShouldCacheEncryptionMetadata && (_parameters is not null && _parameters.Count > 0))
{
SqlQueryMetadataCache.GetInstance().AddQueryMetadata(this, ignoreQueriesWithReturnValueParams: true);
}
Expand Down Expand Up @@ -5285,8 +5289,8 @@ internal void OnReturnStatus(int status)
// If we are not in Batch RPC mode, update the query cache with the encryption MD.
// We can do this now that we have distinguished between ReturnValue and ReturnStatus.
// Read comment in AddQueryMetadata() for more details.
// Enclave based Always Encrypted implementation on server side does not support cache at this point. So we should not cache if the query requires keys to be sent to enclave
if (!BatchRPCMode && CachingQueryMetadataPostponed && !requiresEnclaveComputations && (this._parameters != null && this._parameters.Count > 0))
if (!BatchRPCMode && CachingQueryMetadataPostponed &&
ShouldCacheEncryptionMetadata && (_parameters is not null && _parameters.Count > 0))
{
SqlQueryMetadataCache.GetInstance().AddQueryMetadata(this, ignoreQueriesWithReturnValueParams: false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2632,6 +2632,7 @@ internal void OnFeatureExtAck(int featureId, byte[] data)
Debug.Assert(_tceVersionSupported <= TdsEnums.MAX_SUPPORTED_TCE_VERSION, "Client support TCE version 2");
_parser.IsColumnEncryptionSupported = true;
_parser.TceVersionSupported = _tceVersionSupported;
_parser.AreEnclaveRetriesSupported = _tceVersionSupported == 3;

if (data.Length > 1)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
Expand All @@ -22,7 +23,7 @@ sealed internal class SqlQueryMetadataCache
const int CacheTrimThreshold = 300; // Threshold above the cache size when we start trimming.

private readonly MemoryCache _cache;
private static readonly SqlQueryMetadataCache _singletonInstance = new SqlQueryMetadataCache();
private static readonly SqlQueryMetadataCache _singletonInstance = new();
private int _inTrim = 0;
private long _cacheHits = 0;
private long _cacheMisses = 0;
Expand Down Expand Up @@ -53,17 +54,17 @@ internal bool GetQueryMetadataIfExists(SqlCommand sqlCommand)
}

// Check the cache to see if we have the MD for this query cached.
string cacheLookupKey = GetCacheLookupKeyFromSqlCommand(sqlCommand);
if (cacheLookupKey == null)
(string cacheLookupKey, string enclaveLookupKey) = GetCacheLookupKeysFromSqlCommand(sqlCommand);
if (cacheLookupKey is null)
{
IncrementCacheMisses();
return false;
}

Dictionary<string, SqlCipherMetadata> ciperMetadataDictionary = _cache.Get(cacheLookupKey) as Dictionary<string, SqlCipherMetadata>;
Dictionary<string, SqlCipherMetadata> cipherMetadataDictionary = _cache.Get(cacheLookupKey) as Dictionary<string, SqlCipherMetadata>;

// If we had a cache miss just return false.
if (ciperMetadataDictionary == null)
if (cipherMetadataDictionary is null)
{
IncrementCacheMisses();
return false;
Expand All @@ -73,7 +74,7 @@ internal bool GetQueryMetadataIfExists(SqlCommand sqlCommand)
foreach (SqlParameter param in sqlCommand.Parameters)
{
SqlCipherMetadata paramCiperMetadata;
bool found = ciperMetadataDictionary.TryGetValue(param.ParameterNameFixed, out paramCiperMetadata);
bool found = cipherMetadataDictionary.TryGetValue(param.ParameterNameFixed, out paramCiperMetadata);

// If we failed to identify the encryption for a specific parameter, clear up the cipher MD of all parameters and exit.
if (!found)
Expand All @@ -88,7 +89,7 @@ internal bool GetQueryMetadataIfExists(SqlCommand sqlCommand)
}

// Cached cipher MD should never have an initialized algorithm since this would contain the key.
Debug.Assert(paramCiperMetadata == null || !paramCiperMetadata.IsAlgorithmInitialized());
Debug.Assert(paramCiperMetadata is null || !paramCiperMetadata.IsAlgorithmInitialized());

// We were able to identify the cipher MD for this parameter, so set it on the param.
param.CipherMetadata = paramCiperMetadata;
Expand All @@ -100,7 +101,7 @@ internal bool GetQueryMetadataIfExists(SqlCommand sqlCommand)
{
SqlCipherMetadata cipherMdCopy = null;

if (param.CipherMetadata != null)
if (param.CipherMetadata is not null)
{
cipherMdCopy = new SqlCipherMetadata(
param.CipherMetadata.EncryptionInfo,
Expand All @@ -113,7 +114,7 @@ internal bool GetQueryMetadataIfExists(SqlCommand sqlCommand)

param.CipherMetadata = cipherMdCopy;

if (cipherMdCopy != null)
if (cipherMdCopy is not null)
{
// Try to get the encryption key. If the key information is stale, this might fail.
// In this case, just fail the cache lookup.
Expand Down Expand Up @@ -143,6 +144,13 @@ internal bool GetQueryMetadataIfExists(SqlCommand sqlCommand)
}
}

ConcurrentDictionary<int, SqlTceCipherInfoEntry> enclaveKeys =
_cache.Get(enclaveLookupKey) as ConcurrentDictionary<int, SqlTceCipherInfoEntry>;
if (enclaveKeys is not null)
{
sqlCommand.keysToBeSentToEnclave = CreateCopyOfEnclaveKeys(enclaveKeys);
}

IncrementCacheHits();
return true;
}
Expand Down Expand Up @@ -178,19 +186,19 @@ internal void AddQueryMetadata(SqlCommand sqlCommand, bool ignoreQueriesWithRetu
}

// Construct the entry and put it in the cache.
string cacheLookupKey = GetCacheLookupKeyFromSqlCommand(sqlCommand);
if (cacheLookupKey == null)
(string cacheLookupKey, string enclaveLookupKey) = GetCacheLookupKeysFromSqlCommand(sqlCommand);
if (cacheLookupKey is null)
{
return;
}

Dictionary<string, SqlCipherMetadata> ciperMetadataDictionary = new Dictionary<string, SqlCipherMetadata>(sqlCommand.Parameters.Count);
Dictionary<string, SqlCipherMetadata> cipherMetadataDictionary = new(sqlCommand.Parameters.Count);

// Create a copy of the cipherMD that doesn't have the algorithm and put it in the cache.
foreach (SqlParameter param in sqlCommand.Parameters)
{
SqlCipherMetadata cipherMdCopy = null;
if (param.CipherMetadata != null)
if (param.CipherMetadata is not null)
{
cipherMdCopy = new SqlCipherMetadata(
param.CipherMetadata.EncryptionInfo,
Expand All @@ -202,9 +210,9 @@ internal void AddQueryMetadata(SqlCommand sqlCommand, bool ignoreQueriesWithRetu
}

// Cached cipher MD should never have an initialized algorithm since this would contain the key.
Debug.Assert(cipherMdCopy == null || !cipherMdCopy.IsAlgorithmInitialized());
Debug.Assert(cipherMdCopy is null || !cipherMdCopy.IsAlgorithmInitialized());

ciperMetadataDictionary.Add(param.ParameterNameFixed, cipherMdCopy);
cipherMetadataDictionary.Add(param.ParameterNameFixed, cipherMdCopy);
}

// If the size of the cache exceeds the threshold, set that we are in trimming and trim the cache accordingly.
Expand All @@ -228,21 +236,27 @@ internal void AddQueryMetadata(SqlCommand sqlCommand, bool ignoreQueriesWithRetu
}

// By default evict after 10 hours.
_cache.Set(cacheLookupKey, ciperMetadataDictionary, DateTimeOffset.UtcNow.AddHours(10));
_cache.Set(cacheLookupKey, cipherMetadataDictionary, DateTimeOffset.UtcNow.AddHours(10));
if (sqlCommand.requiresEnclaveComputations)
{
ConcurrentDictionary<int, SqlTceCipherInfoEntry> keysToBeCached = CreateCopyOfEnclaveKeys(sqlCommand.keysToBeSentToEnclave);
_cache.Set(enclaveLookupKey, keysToBeCached, DateTimeOffset.UtcNow.AddHours(10));
}
}

/// <summary>
/// <para> Remove the metadata for a specific query from the cache.</para>
/// </summary>
internal void InvalidateCacheEntry(SqlCommand sqlCommand)
{
string cacheLookupKey = GetCacheLookupKeyFromSqlCommand(sqlCommand);
if (cacheLookupKey == null)
(string cacheLookupKey, string enclaveLookupKey) = GetCacheLookupKeysFromSqlCommand(sqlCommand);
if (cacheLookupKey is null)
{
return;
}

_cache.Remove(cacheLookupKey);
_cache.Remove(enclaveLookupKey);
}


Expand Down Expand Up @@ -271,26 +285,46 @@ private void ResetCacheCounts()
_cacheMisses = 0;
}

private String GetCacheLookupKeyFromSqlCommand(SqlCommand sqlCommand)
private (string, string) GetCacheLookupKeysFromSqlCommand(SqlCommand sqlCommand)
{
const int SqlIdentifierLength = 128;

SqlConnection connection = sqlCommand.Connection;

// Return null if we have no connection.
if (connection == null)
if (connection is null)
{
return null;
return (null, null);
}

StringBuilder cacheLookupKeyBuilder = new StringBuilder(connection.DataSource, capacity: connection.DataSource.Length + SqlIdentifierLength + sqlCommand.CommandText.Length + 6);
StringBuilder cacheLookupKeyBuilder = new(connection.DataSource, capacity: connection.DataSource.Length + SqlIdentifierLength + sqlCommand.CommandText.Length + 6);
cacheLookupKeyBuilder.Append(":::");
// Pad database name to 128 characters to avoid any false cache matches because of weird DB names.
cacheLookupKeyBuilder.Append(connection.Database.PadRight(SqlIdentifierLength));
cacheLookupKeyBuilder.Append(":::");
cacheLookupKeyBuilder.Append(sqlCommand.CommandText);

return cacheLookupKeyBuilder.ToString();
string cacheLookupKey = cacheLookupKeyBuilder.ToString();
string enclaveLookupKey = cacheLookupKeyBuilder.Append(":::enclaveKeys").ToString();
return (cacheLookupKey, enclaveLookupKey);
}

private ConcurrentDictionary<int, SqlTceCipherInfoEntry> CreateCopyOfEnclaveKeys(ConcurrentDictionary<int, SqlTceCipherInfoEntry> keysToBeSentToEnclave)
{
ConcurrentDictionary<int, SqlTceCipherInfoEntry> enclaveKeys = new();
foreach (KeyValuePair<int, SqlTceCipherInfoEntry> kvp in keysToBeSentToEnclave)
{
int ordinal = kvp.Key;
SqlTceCipherInfoEntry original = kvp.Value;
SqlTceCipherInfoEntry copy = new(ordinal);
foreach (SqlEncryptionKeyInfo cekInfo in original.ColumnEncryptionKeyValues)
{
copy.Add(cekInfo.encryptedKey, cekInfo.databaseId, cekInfo.cekId, cekInfo.cekVersion,
cekInfo.cekMdVersion, cekInfo.keyPath, cekInfo.keyStoreName, cekInfo.algorithmName);
}
enclaveKeys.TryAdd(ordinal, copy);
}
return enclaveKeys;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ internal static string GetSniContextEnumName(SniContext sniContext)
}

// TCE Related constants
internal const byte MAX_SUPPORTED_TCE_VERSION = 0x02; // max version
internal const byte MAX_SUPPORTED_TCE_VERSION = 0x03; // max version
internal const byte MIN_TCE_VERSION_WITH_ENCLAVE_SUPPORT = 0x02; // min version with enclave support
internal const ushort MAX_TCE_CIPHERINFO_SIZE = 2048; // max size of cipherinfo blob
internal const long MAX_TCE_CIPHERTEXT_SIZE = 2147483648; // max size of encrypted blob- currently 2GB.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ internal sealed partial class TdsParser
/// </summary>
internal byte TceVersionSupported { get; set; }

/// <summary>
/// Server supports retrying when the enclave CEKs sent by the client do not match what is needed for the query to run.
/// </summary>
internal bool AreEnclaveRetriesSupported { get; set; }

/// <summary>
/// Type of enclave being used by the server
/// </summary>
Expand Down
Loading

0 comments on commit 89e15c4

Please sign in to comment.