From 8948ec8dc0fbeec85122f154f11881ec3c4b23a5 Mon Sep 17 00:00:00 2001 From: Dmitriy Suponitskiy Date: Mon, 25 Nov 2024 07:10:19 -0500 Subject: [PATCH] Fixed the PrintValue() function family for plaintext and its derived classes --- src/pke/include/encoding/ckkspackedencoding.h | 26 +++------ src/pke/include/encoding/coefpackedencoding.h | 46 +++++++++------- src/pke/include/encoding/packedencoding.h | 32 +++++++---- src/pke/include/encoding/plaintext.h | 55 ++++++++++--------- src/pke/include/encoding/stringencoding.h | 4 +- 5 files changed, 88 insertions(+), 75 deletions(-) diff --git a/src/pke/include/encoding/ckkspackedencoding.h b/src/pke/include/encoding/ckkspackedencoding.h index 4a3237028..c75f70660 100644 --- a/src/pke/include/encoding/ckkspackedencoding.h +++ b/src/pke/include/encoding/ckkspackedencoding.h @@ -232,23 +232,11 @@ class CKKSPackedEncoding : public PlaintextImpl { */ static void Destroy(); - void PrintValue(std::ostream& out) const override { - // for sanity's sake, trailing zeros get elided into "..." - // out.precision(15); - out << "("; - size_t i = value.size(); - while (--i > 0) - if (value[i] != std::complex(0, 0)) - break; - - for (size_t j = 0; j <= i; j++) { - out << value[j].real() << ", "; - } - - out << " ... ); "; - out << "Estimated precision: " << encodingParams->GetPlaintextModulus() - m_logError << " bits" << std::endl; - } - + /** + * @brief GetFormattedValues() is called by operator<< and requires a precision as an argument + * @param precision number of decimal digits of precision to print + * @return string with all values and "estimated precision" + */ std::string GetFormattedValues(int64_t precision) const override { std::stringstream ss; ss << "("; @@ -279,6 +267,10 @@ class CKKSPackedEncoding : public PlaintextImpl { double m_logError = 0; protected: + void PrintValue(std::ostream& out) const override { + out << GetFormattedValues(8) << std::endl; + } + usint GetDefaultSlotSize() { auto batchSize = GetEncodingParams()->GetBatchSize(); return (0 == batchSize) ? GetElementRingDimension() / 2 : batchSize; diff --git a/src/pke/include/encoding/coefpackedencoding.h b/src/pke/include/encoding/coefpackedencoding.h index 089bbe1a8..235274dd8 100644 --- a/src/pke/include/encoding/coefpackedencoding.h +++ b/src/pke/include/encoding/coefpackedencoding.h @@ -48,6 +48,32 @@ namespace lbcrypto { class CoefPackedEncoding : public PlaintextImpl { std::vector value; +protected: + /** + * @brief PrintValue() is called by operator<< + * @param out stream to print to + */ + void PrintValue(std::ostream& out) const override { + out << "("; + + // for sanity's sake: get rid of all trailing zeroes and print "..." instead + size_t i = value.size(); + bool allZeroes = true; + while (i > 0) { + --i; + if (value[i] != 0) { + allZeroes = false; + break; + } + } + + if (allZeroes == false) { + for (size_t j = 0; j <= i; ++j) + out << value[j] << ", "; + } + out << "... )"; + } + public: template ::value || std::is_same::value || @@ -64,7 +90,7 @@ class CoefPackedEncoding : public PlaintextImpl { SCHEME schemeId = SCHEME::INVALID_SCHEME) : PlaintextImpl(vp, ep, schemeId), value(coeffs) {} - virtual ~CoefPackedEncoding() = default; + ~CoefPackedEncoding() = default; /** * GetCoeffsValue @@ -130,24 +156,6 @@ class CoefPackedEncoding : public PlaintextImpl { const auto& oth = static_cast(other); return oth.value == this->value; } - - /** - * PrintValue - used by operator<< for this object - * @param out - */ - void PrintValue(std::ostream& out) const { - // for sanity's sake, trailing zeros get elided into "..." - out << "("; - size_t i = value.size(); - while (--i > 0) - if (value[i] != 0) - break; - - for (size_t j = 0; j <= i; j++) - out << ' ' << value[j]; - - out << " ... )"; - } }; } /* namespace lbcrypto */ diff --git a/src/pke/include/encoding/packedencoding.h b/src/pke/include/encoding/packedencoding.h index 97a02c4a0..16365c2fe 100644 --- a/src/pke/include/encoding/packedencoding.h +++ b/src/pke/include/encoding/packedencoding.h @@ -185,18 +185,30 @@ class PackedEncoding : public PlaintextImpl { */ static void Destroy(); - void PrintValue(std::ostream& out) const { - // for sanity's sake, trailing zeros get elided into "..." +protected: + /** + * @brief PrintValue() is called by operator<< + * @param out stream to print to + */ + void PrintValue(std::ostream& out) const override { out << "("; - size_t i = value.size(); - while (--i > 0) - if (value[i] != 0) - break; - for (size_t j = 0; j <= i; j++) - out << ' ' << value[j]; - - out << " ... )"; + // for sanity's sake: get rid of all trailing zeroes and print "..." instead + size_t i = value.size(); + bool allZeroes = true; + while (i > 0) { + --i; + if (value[i] != 0) { + allZeroes = false; + break; + } + } + + if (allZeroes == false) { + for (size_t j = 0; j <= i; ++j) + out << value[j] << ", "; + } + out << "... )"; } private: diff --git a/src/pke/include/encoding/plaintext.h b/src/pke/include/encoding/plaintext.h index 9e419c599..82097b41c 100644 --- a/src/pke/include/encoding/plaintext.h +++ b/src/pke/include/encoding/plaintext.h @@ -85,6 +85,13 @@ class PlaintextImpl { usint slots = 0; SCHEME schemeID; +protected: + /** + * @brief PrintValue() is called by operator<< + * @param out + */ + virtual void PrintValue(std::ostream& out) const = 0; + public: PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, SCHEME schemeTag = SCHEME::INVALID_SCHEME, bool isEncoded = false) @@ -137,7 +144,7 @@ class PlaintextImpl { slots(rhs.slots), schemeID(rhs.schemeID) {} - virtual ~PlaintextImpl() {} + virtual ~PlaintextImpl() = default; /** * GetEncodingType @@ -398,39 +405,33 @@ class PlaintextImpl { } /** - * operator<< for ostream integration - calls PrintValue - * @param out - * @param item - * @return - */ - friend std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item); - - /** - * PrintValue is called by operator<< - * @param out - */ - virtual void PrintValue(std::ostream& out) const = 0; + * @brief operator<< for ostream integration - calls PrintValue() + * @param out + * @param item + * @return + */ + friend std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item) { + item.PrintValue(out); + return out; + } + friend std::ostream& operator<<(std::ostream& out, const Plaintext& item) { + if (item) + out << *item; // Call the non-pointer version + else + OPENFHE_THROW("Cannot de-reference nullptr for printing"); + return out; + } /** - * GetFormattedValues() has a logic similar to PrintValue(), but requires a precision as an argument - * @param precision number of decimal digits of precision to print - * @return string with all values and "estimated precision" - */ + * @brief GetFormattedValues() is similar to PrintValue() and requires a precision as an argument + * @param precision number of decimal digits of precision to print + * @return string with all values + */ virtual std::string GetFormattedValues(int64_t precision) const { OPENFHE_THROW("not implemented"); } }; -inline std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item) { - item.PrintValue(out); - return out; -} - -inline std::ostream& operator<<(std::ostream& out, const Plaintext& item) { - item->PrintValue(out); - return out; -} - inline bool operator==(const Plaintext& p1, const Plaintext& p2) { return *p1 == *p2; } diff --git a/src/pke/include/encoding/stringencoding.h b/src/pke/include/encoding/stringencoding.h index 8ab1fbe27..2cd16d634 100644 --- a/src/pke/include/encoding/stringencoding.h +++ b/src/pke/include/encoding/stringencoding.h @@ -65,7 +65,7 @@ class StringEncoding : public PlaintextImpl { // TODO provide wide-character version (for unicode); right now this class // only supports strings of 7-bit ASCII characters - virtual ~StringEncoding() {} + ~StringEncoding() = default; /** * GetStringValue @@ -128,7 +128,7 @@ class StringEncoding : public PlaintextImpl { * PrintValue - used by operator<< for this object * @param out */ - void PrintValue(std::ostream& out) const { + void PrintValue(std::ostream& out) const override { out << ptx; } };