Skip to content

Commit

Permalink
SKR_TRACE_ON env flag (#48)
Browse files Browse the repository at this point in the history
* including SKR trace on environment variable

* including SKR trace on environment variable - correction

* support for SKR_TRACE_ON 1 and 2 for masked logs and complte logs respectively.

* update function name to reduct_log.
  • Loading branch information
pankajosh authored Dec 15, 2023
1 parent 69c44bc commit e672f9e
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 50 deletions.
39 changes: 21 additions & 18 deletions cvm-securekey-release-app/AttestationUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
using namespace attest;
using json = nlohmann::json;

bool Util::isTraceOn = false;
int Util::traceLevel = 1;

/// \copydoc Util::base64_to_binary()
std::vector<BYTE> Util::base64_to_binary(const std::string &base64_data)
{
Expand Down Expand Up @@ -250,12 +253,12 @@ std::string Util::GetIMDSToken(const std::string &KEKUrl)

curl_easy_cleanup(curl);

TRACE_OUT("Response: %s\n", responseStr.c_str());
TRACE_OUT("Response: %s\n", Util::reduct_log(responseStr).c_str());

json json_object = json::parse(responseStr.c_str());
std::string access_token = json_object["access_token"].get<std::string>();

TRACE_OUT("Access Token: %s\n", access_token.c_str());
TRACE_OUT("Access Token: %s\n", Util::reduct_log(access_token).c_str());

TRACE_OUT("Exiting Util::GetIMDSToken()");

Expand Down Expand Up @@ -347,7 +350,7 @@ std::string Util::GetMAAToken(const std::string &attestation_url, const std::str
}

AttestationClient *attestation_client = nullptr;
AttestationLogger *log_handle = new Logger();
AttestationLogger *log_handle = new Logger(Util::get_trace());

// Initialize attestation client
if (!Initialize(log_handle, &attestation_client))
Expand Down Expand Up @@ -562,7 +565,7 @@ std::string Util::GetKeyVaultResponse(const std::string &requestUri,
std::ostringstream bearerToken;
bearerToken << "Authorization: Bearer " << access_token;
headers = curl_slist_append(headers, bearerToken.str().c_str());
TRACE_OUT("Bearer token: %s", bearerToken.str().c_str());
TRACE_OUT("Bearer token: %s", Util::reduct_log(bearerToken.str()).c_str());
headers = curl_slist_append(headers, "Content-Type: application/json");
headers = curl_slist_append(headers, "Accept: application/json");
headers = curl_slist_append(headers, "User-Agent: AzureDiskEncryption");
Expand Down Expand Up @@ -675,7 +678,7 @@ std::string Util::GetKeyVaultResponse(const std::string &requestUri,
curl_slist_free_all(headers);
curl_easy_cleanup(curl);

TRACE_OUT("SKR response: %s", responseStr.c_str());
TRACE_OUT("SKR response: %s", Util::reduct_log(responseStr).c_str());

TRACE_OUT("Exiting Util::GetKeyVaultResponse()");
return responseStr;
Expand All @@ -692,7 +695,7 @@ bool Util::doSKR(const std::string &attestation_url,
try
{
std::string attest_token(Util::GetMAAToken(attestation_url, nonce));
TRACE_OUT("MAA Token: %s", attest_token.c_str());
TRACE_OUT("MAA Token: %s", Util::reduct_log(attest_token).c_str());

// Get Akv access token either using IMDS or Service Principal
std::string access_token;
Expand All @@ -705,15 +708,15 @@ bool Util::doSKR(const std::string &attestation_url,
access_token = std::move(Util::GetIMDSToken(KEKUrl));
}

TRACE_OUT("AkvMsiAccessToken: %s", access_token.c_str());
TRACE_OUT("AkvMsiAccessToken: %s", Util::reduct_log(access_token).c_str());

std::string requestUri = Util::GetKeyVaultSKRurl(KEKUrl);
std::string responseStr = Util::GetKeyVaultResponse(requestUri, access_token, attest_token, nonce);

// Parse the response:
json skrJson = json::parse(responseStr.c_str());
std::string skrToken = skrJson["value"];
TRACE_OUT("SKR token: %s", skrToken.c_str());
TRACE_OUT("SKR token: %s", Util::reduct_log(skrToken).c_str());
std::vector<std::string> tokenParts = Util::SplitString(skrToken, '.');
if (tokenParts.size() != 3)
{
Expand All @@ -722,18 +725,18 @@ bool Util::doSKR(const std::string &attestation_url,

std::vector<BYTE> tokenPayload(Util::base64url_to_binary(tokenParts[1]));
std::string tokenPayloadStr(tokenPayload.begin(), tokenPayload.end());
TRACE_OUT("SKR token payload: %s", tokenPayloadStr.c_str());
TRACE_OUT("SKR token payload: %s", Util::reduct_log(tokenPayloadStr).c_str());
json skrPayloadJson = json::parse(tokenPayloadStr.c_str());
std::vector<BYTE> key_hsm = Util::base64url_to_binary(skrPayloadJson["response"]["key"]["key"]["key_hsm"]);
TRACE_OUT("SKR key_hsm: %s", Util::binary_to_base64url(key_hsm).c_str());
TRACE_OUT("SKR key_hsm: %s", Util::reduct_log(Util::binary_to_base64url(key_hsm)).c_str());
json cipherTextJson = json::parse(key_hsm);
std::vector<BYTE> cipherText = Util::base64url_to_binary(cipherTextJson["ciphertext"]);
TRACE_OUT("Encrypted bytes length: %ld", cipherText.size());
std::string cipherTextStr(cipherText.begin(), cipherText.end());
TRACE_OUT("Encrypted bytes: %s", Util::binary_to_base64url(cipherText).c_str());
TRACE_OUT("Encrypted bytes: %s", Util::reduct_log(Util::binary_to_base64url(cipherText)).c_str());

AttestationClient *attestation_client = nullptr;
AttestationLogger *log_handle = new Logger();
AttestationLogger *log_handle = new Logger(Util::get_trace());

// Initialize attestation client
if (!Initialize(log_handle, &attestation_client))
Expand Down Expand Up @@ -767,7 +770,7 @@ bool Util::doSKR(const std::string &attestation_url,
else
{
std::vector<BYTE> decryptedAESBytesVec(decryptedAESBytes, decryptedAESBytes + decryptedBytesSize);
TRACE_OUT("Decrypted Transfer key: %s\n", Util::binary_to_base64url(decryptedAESBytesVec).c_str());
TRACE_OUT("Decrypted Transfer key: %s\n", Util::reduct_log(Util::binary_to_base64url(decryptedAESBytesVec)).c_str());
}

// The remaining bytes are the encrypted CMK bytes with the decrypted AES key.
Expand All @@ -787,8 +790,8 @@ bool Util::doSKR(const std::string &attestation_url,
{
TRACE_OUT("CMK private key has length=%d", private_key_len);
std::vector<BYTE> privateKeyVec(private_key, private_key + private_key_len);
TRACE_OUT("Decrypted CMK in base64url: %s", Util::binary_to_base64url(privateKeyVec).c_str());
TRACE_OUT("Decrypted CMK in hex: %s", Util::binary_to_hex(privateKeyVec).c_str());
TRACE_OUT("Decrypted CMK in base64url: %s", Util::reduct_log(Util::binary_to_base64url(privateKeyVec)).c_str());
TRACE_OUT("Decrypted CMK in hex: %s", Util::reduct_log(Util::binary_to_hex(privateKeyVec)).c_str());

// PKCS#8
BIO *bio_key = BIO_new_mem_buf(privateKeyVec.data(), (int)privateKeyVec.size());
Expand Down Expand Up @@ -970,7 +973,7 @@ std::string Util::WrapKey(const std::string &attestation_url,
}

int rsaSize = RSA_get_size(pkey);
TRACE_OUT("Wrapping: %s", sym_key.c_str());
TRACE_OUT("Wrapping: %s", Util::reduct_log(sym_key).c_str());

size_t encrypted_length = 0;
PBYTE encryptedKey;
Expand All @@ -984,7 +987,7 @@ std::string Util::WrapKey(const std::string &attestation_url,
TRACE_OUT("Wrapping the symmetric key succeeded: encrypted_length=%ld\n", encrypted_length);
std::vector<BYTE> encryptedKeyVector(encryptedKey, encryptedKey + encrypted_length);
std::string cipherText = Util::binary_to_base64(encryptedKeyVector);
TRACE_OUT("Wrapped symmetric key in base64: %s\n", cipherText.c_str());
TRACE_OUT("Wrapped symmetric key in base64: %s\n", Util::reduct_log(cipherText).c_str());

// Cleanup
OPENSSL_free(encryptedKey);
Expand Down Expand Up @@ -1035,7 +1038,7 @@ std::string Util::UnwrapKey(const std::string &attestation_url,
TRACE_OUT("Unwrapping the symmetric key succeeded: decrypted_length=%lud", decrypted_length);
std::vector<BYTE> decryptedKeyVector(decryptedKey, decryptedKey + decrypted_length);
std::string plainText = Util::binary_to_base64(decryptedKeyVector);
TRACE_OUT("Unwrapped symmetric key in base64: %s", plainText.c_str());
TRACE_OUT("Unwrapped symmetric key in base64: %s", Util::reduct_log(plainText).c_str());

TRACE_OUT("Exiting Util::UnwrapKey()");

Expand Down
81 changes: 62 additions & 19 deletions cvm-securekey-release-app/AttestationUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,70 @@ inline static void Check_HResult(std::string fileName, std::string funcName, int
exit(gle); \
} while (0);

inline static void TRACE_OUT(std::string fmt, ...)
{
#ifdef TRACE
va_list args;
va_start(args, fmt);
vfprintf(stderr, fmt.c_str(), args);
fprintf(stderr, "\n");
va_end(args);
#endif
}

inline static void OSSL_BN_TRACE_OUT(const BIGNUM *bn)
{
#ifdef TRACE
BN_print_fp(stderr, bn);
#endif
}
#define TRACE_OUT Util::trace_out
#define OSSL_BN_TRACE_OUT Util::ossl_bn_trace_out

class Util
{
class Util{
private:
static bool isTraceOn;
static int traceLevel; //1: enable Util::reduct_log, 2: do nothing.
static size_t lengthMask;
public:

static void set_trace(bool traceOn)
{
isTraceOn = traceOn;
}

static bool get_trace()
{
return isTraceOn;
}

static void set_trace_level(int trLevel)
{
traceLevel = trLevel;
}

static int get_trace_level()
{
return traceLevel;
}

inline static void trace_out(std::string fmt, ...)
{
if (isTraceOn)
{
va_list args;
va_start(args, fmt);
vfprintf(stderr, fmt.c_str(), args);
fprintf(stderr, "\n");
va_end(args);
}
}

inline static std::string reduct_log(const std::string& str)
{
std::string retStr(str);
if(traceLevel==1){
//mask 85% of string
size_t lengthMask = retStr.size()*0.15;
if(retStr.size()>lengthMask){
retStr.resize(lengthMask);
retStr.append("...");
}
}
return retStr.c_str();
}

inline static void ossl_bn_trace_out(const BIGNUM *bn)
{
if(isTraceOn)
{
BN_print_fp(stderr, bn);
}
}

enum class AkvCredentialSource
{
Imds,
Expand Down
23 changes: 11 additions & 12 deletions cvm-securekey-release-app/Logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,18 @@ void Logger::Log(const char *log_tag,
...)
{

// uncomment the below statement and rebuild if details debug logs are needed
/*
va_list args;
va_start(args, fmt);
size_t len = std::vsnprintf(NULL, 0, fmt, args);
va_end(args);
if (isTraceOn){
va_list args;
va_start(args, fmt);
size_t len = std::vsnprintf(NULL, 0, fmt, args);
va_end(args);

std::vector<char> str(len + 1);
std::vector<char> str(len + 1);

va_start(args, fmt);
std::vsnprintf(&str[0], len + 1, fmt, args);
va_end(args);
va_start(args, fmt);
std::vsnprintf(&str[0], len + 1, fmt, args);
va_end(args);

printf("Level: %s Tag: %s %s:%d:%s\n", attest::AttestationLogger::LogLevelStrings[level].c_str(), log_tag, function, line, &str[0]);
*/
printf("Level: %s Tag: %s %s:%d:%s\n", attest::AttestationLogger::LogLevelStrings[level].c_str(), log_tag, function, line, &str[0]);
}
}
8 changes: 8 additions & 0 deletions cvm-securekey-release-app/Logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@ using namespace attest;

class Logger : public AttestationLogger
{
private:
bool isTraceOn = false;
public:
Logger() = default;

Logger(bool isTraceOn){
this->isTraceOn = isTraceOn;
}

void Log(const char *log_tag,
AttestationLogger::LogLevel level,
const char *function,
Expand Down
22 changes: 21 additions & 1 deletion cvm-securekey-release-app/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,30 @@ enum class Operation
Undefined
};

// Check if tracing is to be enabled for SKR in the env.
void set_tracing(void)
{
auto skr_trace_flag = std::getenv("SKR_TRACE_ON");
if(skr_trace_flag != nullptr && strlen(skr_trace_flag) > 0)
{
if(strcmp(skr_trace_flag, "1") ==0 || strcmp(skr_trace_flag, "2") ==0)
{
std::cout<< "Tracing is enabled" <<std::endl;
Util::set_trace(true);
Util::set_trace_level(atoi(skr_trace_flag));
}
else
{
std::cerr<<"Invalid value for SKR_TRACE_ON!"<<std::endl;
exit(-1);
}
}
}

int main(int argc, char *argv[])
{
set_tracing();
TRACE_OUT("Main started");

std::string attestation_url;
std::string nonce;
std::string sym_key;
Expand Down

0 comments on commit e672f9e

Please sign in to comment.