Skip to content

Commit

Permalink
Corrected operator== for the family of cryptoparameters classes
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuponitskiy-duality committed Nov 25, 2024
1 parent 62e8fa4 commit 8a5a5d7
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 54 deletions.
22 changes: 16 additions & 6 deletions src/pke/include/schemebase/base-cryptoparameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ class CryptoParametersBase : public Serializable {
using TugType = typename Element::TugType;

public:
CryptoParametersBase() {}
CryptoParametersBase() = default;

virtual ~CryptoParametersBase() {}
virtual ~CryptoParametersBase() = default;

/**
* Returns the value of plaintext modulus p
Expand Down Expand Up @@ -99,11 +99,11 @@ class CryptoParametersBase : public Serializable {
m_encodingParams->SetPlaintextModulus(plaintextModulus);
}

virtual bool operator==(const CryptoParametersBase<Element>& cmp) const {
return *m_encodingParams == *(cmp.GetEncodingParams()) && *m_params == *(cmp.GetElementParams());
bool operator==(const CryptoParametersBase<Element>& rhs) const {
return CompareTo(rhs);
}
virtual bool operator!=(const CryptoParametersBase<Element>& cmp) const {
return !(*this == cmp);
bool operator!=(const CryptoParametersBase<Element>& rhs) const {
return !(*this == rhs);
}

/**
Expand Down Expand Up @@ -194,6 +194,16 @@ class CryptoParametersBase : public Serializable {
m_params = newElemParms;
}

/**
* @brief CompareTo() is a method to compare two CryptoParametersBase objects. It is called by operator==()
*
* @param rhs - the other CryptoParametersBase object to compare to.
* @return whether the two CryptoParametersBase objects are equivalent.
*/
virtual bool CompareTo(const CryptoParametersBase<Element>& rhs) const {
return (*m_encodingParams == *(rhs.GetEncodingParams()) && *m_params == *(rhs.GetElementParams()));
}

virtual void PrintParameters(std::ostream& out) const {
out << "Element Parameters: " << *m_params << std::endl;
out << "Encoding Parameters: " << *m_encodingParams << std::endl;
Expand Down
58 changes: 29 additions & 29 deletions src/pke/include/schemebase/rlwe-cryptoparameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ class CryptoParametersRLWE : public CryptoParametersBase<Element> {
}

/**
* Destructor
* Virtual Destructor
*/
virtual ~CryptoParametersRLWE() {}
~CryptoParametersRLWE() = default;

/**
* Returns the value of standard deviation r for discrete Gaussian
Expand Down Expand Up @@ -414,33 +414,6 @@ class CryptoParametersRLWE : public CryptoParametersBase<Element> {
m_thresholdNumOfParties = thresholdNumOfParties;
}

/**
* == operator to compare to this instance of CryptoParametersRLWE object.
*
* @param &rhs CryptoParameters to check equality against.
*/
bool operator==(const CryptoParametersBase<Element>& rhs) const {
const auto* el = dynamic_cast<const CryptoParametersRLWE<Element>*>(&rhs);

if (el == nullptr)
return false;

return CryptoParametersBase<Element>::operator==(*el) &&
this->GetPlaintextModulus() == el->GetPlaintextModulus() &&
*this->GetElementParams() == *el->GetElementParams() &&
*this->GetEncodingParams() == *el->GetEncodingParams() &&
m_distributionParameter == el->GetDistributionParameter() &&
m_assuranceMeasure == el->GetAssuranceMeasure() && m_noiseScale == el->GetNoiseScale() &&
m_digitSize == el->GetDigitSize() && m_secretKeyDist == el->GetSecretKeyDist() &&
m_stdLevel == el->GetStdLevel() && m_maxRelinSkDeg == el->GetMaxRelinSkDeg() &&
m_PREMode == el->GetPREMode() && m_multipartyMode == el->GetMultipartyMode() &&
m_executionMode == el->GetExecutionMode() &&
m_floodingDistributionParameter == el->GetFloodingDistributionParameter() &&
m_statisticalSecurity == el->GetStatisticalSecurity() &&
m_numAdversarialQueries == el->GetNumAdversarialQueries() &&
m_thresholdNumOfParties == el->GetThresholdNumOfParties();
}

void PrintParameters(std::ostream& os) const {
CryptoParametersBase<Element>::PrintParameters(os);

Expand Down Expand Up @@ -541,6 +514,33 @@ class CryptoParametersRLWE : public CryptoParametersBase<Element> {
double m_numAdversarialQueries = 1;

usint m_thresholdNumOfParties = 1;

/**
* @brief CompareTo() is a method to compare two CryptoParametersRLWE objects.
* It is called by CryptoParametersBase::operator==()
* @param rhs - the other CryptoParametersRLWE object to compare to.
* @return whether the two CryptoParametersRLWE objects are equivalent.
*/
bool CompareTo(const CryptoParametersBase<Element>& rhs) const override {
const auto* el = dynamic_cast<const CryptoParametersRLWE<Element>*>(&rhs);
if (el == nullptr)
return false;

return CryptoParametersBase<Element>::CompareTo(*el) &&
this->GetPlaintextModulus() == el->GetPlaintextModulus() &&
*this->GetElementParams() == *el->GetElementParams() &&
*this->GetEncodingParams() == *el->GetEncodingParams() &&
m_distributionParameter == el->GetDistributionParameter() &&
m_assuranceMeasure == el->GetAssuranceMeasure() && m_noiseScale == el->GetNoiseScale() &&
m_digitSize == el->GetDigitSize() && m_secretKeyDist == el->GetSecretKeyDist() &&
m_stdLevel == el->GetStdLevel() && m_maxRelinSkDeg == el->GetMaxRelinSkDeg() &&
m_PREMode == el->GetPREMode() && m_multipartyMode == el->GetMultipartyMode() &&
m_executionMode == el->GetExecutionMode() &&
m_floodingDistributionParameter == el->GetFloodingDistributionParameter() &&
m_statisticalSecurity == el->GetStatisticalSecurity() &&
m_numAdversarialQueries == el->GetNumAdversarialQueries() &&
m_thresholdNumOfParties == el->GetThresholdNumOfParties();
}
};

} // namespace lbcrypto
Expand Down
38 changes: 19 additions & 19 deletions src/pke/include/schemerns/rns-cryptoparameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,25 @@ class CryptoParametersRNS : public CryptoParametersRLWE<DCRTPoly> {
m_MPIntBootCiphertextCompressionLevel = mPIntBootCiphertextCompressionLevel;
}

virtual ~CryptoParametersRNS() {}
~CryptoParametersRNS() = default;

/**
* @brief CompareTo() is a method to compare two CryptoParametersRNS objects.
* It is called by CryptoParametersBase::operator==()
* @param rhs - the other CryptoParametersRNS object to compare to.
* @return whether the two CryptoParametersRNS objects are equivalent.
*/
bool CompareTo(const CryptoParametersBase<DCRTPoly>& rhs) const override {
const auto* el = dynamic_cast<const CryptoParametersRNS*>(&rhs);
if (el == nullptr)
return false;

return CryptoParametersRLWE<DCRTPoly>::CompareTo(rhs) && m_scalTechnique == el->GetScalingTechnique() &&
m_ksTechnique == el->GetKeySwitchTechnique() && m_multTechnique == el->GetMultiplicationTechnique() &&
m_encTechnique == el->GetEncryptionTechnique() && m_numPartQ == el->GetNumPartQ() &&
m_auxBits == el->GetAuxBits() && m_extraBits == el->GetExtraBits() && m_PREMode == el->GetPREMode() &&
m_multipartyMode == el->GetMultipartyMode() && m_executionMode == el->GetExecutionMode();
}

public:
/**
Expand Down Expand Up @@ -183,24 +201,6 @@ class CryptoParametersRNS : public CryptoParametersRLWE<DCRTPoly> {
return static_cast<double>(NoiseFlooding::MULTIPARTY_MOD_SIZE * NoiseFlooding::NUM_MODULI_MULTIPARTY);
}

/**
* == operator to compare to this instance of CryptoParametersBase object.
*
* @param &rhs CryptoParameters to check equality against.
*/
bool operator==(const CryptoParametersBase<DCRTPoly>& rhs) const override {
const auto* el = dynamic_cast<const CryptoParametersRNS*>(&rhs);

if (el == nullptr)
return false;

return CryptoParametersBase<DCRTPoly>::operator==(rhs) && m_scalTechnique == el->GetScalingTechnique() &&
m_ksTechnique == el->GetKeySwitchTechnique() && m_multTechnique == el->GetMultiplicationTechnique() &&
m_encTechnique == el->GetEncryptionTechnique() && m_numPartQ == el->GetNumPartQ() &&
m_auxBits == el->GetAuxBits() && m_extraBits == el->GetExtraBits() && m_PREMode == el->GetPREMode() &&
m_multipartyMode == el->GetMultipartyMode() && m_executionMode == el->GetExecutionMode();
}

void PrintParameters(std::ostream& os) const override {
CryptoParametersBase<DCRTPoly>::PrintParameters(os);
}
Expand Down

0 comments on commit 8a5a5d7

Please sign in to comment.