Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the PrintValue() function family for plaintext and its derived classes #915

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions src/pke/include/encoding/ckkspackedencoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,23 +232,11 @@ class CKKSPackedEncoding : public PlaintextImpl {
*/
static void Destroy();

void PrintValue(std::ostream& out) const override {
// for sanity's sake, trailing zeros get elided into "..."
// out.precision(15);
out << "(";
size_t i = value.size();
while (--i > 0)
if (value[i] != std::complex<double>(0, 0))
break;

for (size_t j = 0; j <= i; j++) {
out << value[j].real() << ", ";
}

out << " ... ); ";
out << "Estimated precision: " << encodingParams->GetPlaintextModulus() - m_logError << " bits" << std::endl;
}

/**
* @brief GetFormattedValues() is called by operator<< and requires a precision as an argument
* @param precision number of decimal digits of precision to print
* @return string with all values and "estimated precision"
*/
std::string GetFormattedValues(int64_t precision) const override {
std::stringstream ss;
ss << "(";
Expand Down Expand Up @@ -279,6 +267,10 @@ class CKKSPackedEncoding : public PlaintextImpl {
double m_logError = 0;

protected:
void PrintValue(std::ostream& out) const override {
out << GetFormattedValues(8) << std::endl;
}

usint GetDefaultSlotSize() {
auto batchSize = GetEncodingParams()->GetBatchSize();
return (0 == batchSize) ? GetElementRingDimension() / 2 : batchSize;
Expand Down
46 changes: 27 additions & 19 deletions src/pke/include/encoding/coefpackedencoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,32 @@ namespace lbcrypto {
class CoefPackedEncoding : public PlaintextImpl {
std::vector<int64_t> value;

protected:
/**
* @brief PrintValue() is called by operator<<
* @param out stream to print to
*/
void PrintValue(std::ostream& out) const override {
out << "(";

// for sanity's sake: get rid of all trailing zeroes and print "..." instead
size_t i = value.size();
bool allZeroes = true;
while (i > 0) {
--i;
if (value[i] != 0) {
allZeroes = false;
break;
}
}

if (allZeroes == false) {
for (size_t j = 0; j <= i; ++j)
out << value[j] << ", ";
}
out << "... )";
}

public:
template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
std::is_same<T, NativePoly::Params>::value ||
Expand All @@ -64,7 +90,7 @@ class CoefPackedEncoding : public PlaintextImpl {
SCHEME schemeId = SCHEME::INVALID_SCHEME)
: PlaintextImpl(vp, ep, schemeId), value(coeffs) {}

virtual ~CoefPackedEncoding() = default;
~CoefPackedEncoding() = default;

/**
* GetCoeffsValue
Expand Down Expand Up @@ -130,24 +156,6 @@ class CoefPackedEncoding : public PlaintextImpl {
const auto& oth = static_cast<const CoefPackedEncoding&>(other);
return oth.value == this->value;
}

/**
* PrintValue - used by operator<< for this object
* @param out
*/
void PrintValue(std::ostream& out) const {
// for sanity's sake, trailing zeros get elided into "..."
out << "(";
size_t i = value.size();
while (--i > 0)
if (value[i] != 0)
break;

for (size_t j = 0; j <= i; j++)
out << ' ' << value[j];

out << " ... )";
}
};

} /* namespace lbcrypto */
Expand Down
32 changes: 22 additions & 10 deletions src/pke/include/encoding/packedencoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,30 @@ class PackedEncoding : public PlaintextImpl {
*/
static void Destroy();

void PrintValue(std::ostream& out) const {
// for sanity's sake, trailing zeros get elided into "..."
protected:
/**
* @brief PrintValue() is called by operator<<
* @param out stream to print to
*/
void PrintValue(std::ostream& out) const override {
out << "(";
size_t i = value.size();
while (--i > 0)
if (value[i] != 0)
break;

for (size_t j = 0; j <= i; j++)
out << ' ' << value[j];

out << " ... )";
// for sanity's sake: get rid of all trailing zeroes and print "..." instead
size_t i = value.size();
bool allZeroes = true;
while (i > 0) {
--i;
if (value[i] != 0) {
allZeroes = false;
break;
}
}

if (allZeroes == false) {
for (size_t j = 0; j <= i; ++j)
out << value[j] << ", ";
}
out << "... )";
}

private:
Expand Down
55 changes: 28 additions & 27 deletions src/pke/include/encoding/plaintext.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ class PlaintextImpl {
usint slots = 0;
SCHEME schemeID;

protected:
/**
* @brief PrintValue() is called by operator<<
* @param out
*/
virtual void PrintValue(std::ostream& out) const = 0;

public:
PlaintextImpl(const std::shared_ptr<Poly::Params>& vp, EncodingParams ep, SCHEME schemeTag = SCHEME::INVALID_SCHEME,
bool isEncoded = false)
Expand Down Expand Up @@ -137,7 +144,7 @@ class PlaintextImpl {
slots(rhs.slots),
schemeID(rhs.schemeID) {}

virtual ~PlaintextImpl() {}
virtual ~PlaintextImpl() = default;

/**
* GetEncodingType
Expand Down Expand Up @@ -398,39 +405,33 @@ class PlaintextImpl {
}

/**
* operator<< for ostream integration - calls PrintValue
* @param out
* @param item
* @return
*/
friend std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item);

/**
* PrintValue is called by operator<<
* @param out
*/
virtual void PrintValue(std::ostream& out) const = 0;
* @brief operator<< for ostream integration - calls PrintValue()
* @param out
* @param item
* @return
*/
friend std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item) {
item.PrintValue(out);
return out;
}
friend std::ostream& operator<<(std::ostream& out, const Plaintext& item) {
if (item)
out << *item; // Call the non-pointer version
else
OPENFHE_THROW("Cannot de-reference nullptr for printing");
return out;
}

/**
* GetFormattedValues() has a logic similar to PrintValue(), but requires a precision as an argument
* @param precision number of decimal digits of precision to print
* @return string with all values and "estimated precision"
*/
* @brief GetFormattedValues() is similar to PrintValue() and requires a precision as an argument
* @param precision number of decimal digits of precision to print
* @return string with all values
*/
virtual std::string GetFormattedValues(int64_t precision) const {
OPENFHE_THROW("not implemented");
}
};

inline std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item) {
item.PrintValue(out);
return out;
}

inline std::ostream& operator<<(std::ostream& out, const Plaintext& item) {
item->PrintValue(out);
return out;
}

inline bool operator==(const Plaintext& p1, const Plaintext& p2) {
return *p1 == *p2;
}
Expand Down
4 changes: 2 additions & 2 deletions src/pke/include/encoding/stringencoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class StringEncoding : public PlaintextImpl {
// TODO provide wide-character version (for unicode); right now this class
// only supports strings of 7-bit ASCII characters

virtual ~StringEncoding() {}
~StringEncoding() = default;

/**
* GetStringValue
Expand Down Expand Up @@ -128,7 +128,7 @@ class StringEncoding : public PlaintextImpl {
* PrintValue - used by operator<< for this object
* @param out
*/
void PrintValue(std::ostream& out) const {
void PrintValue(std::ostream& out) const override {
out << ptx;
}
};
Expand Down
Loading