Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent results for local and collaborative LowMC evaluations. #53

Open
CBackyx opened this issue May 15, 2024 · 0 comments
Open

Comments

@CBackyx
Copy link

CBackyx commented May 15, 2024

Hello Rindal,

Thanks a lot for your great job in developing this library. It really helps me a lot!

I found that there is no unit test comparing the results of evaluating LowMC locally and evaluating the LowMC circuit jointly by three parties. So I just composed a minimal unit test to compare these two cases, in which I set the message to be encrypted as 0 and set all the round keys to 0 for simplicity.

However, I found that their encryption results are inconsistent. I wonder if this is to be expected or if I misunderstood the functionality/usage of the LowMC circuit. I would appreciate it a lot if you could answer this question.

Here is the unit test I wrote. It should compile if you append it to aby3-DB_tests/lowMC.cpp and include some necessary headers.

#include "aby3-DB/LowMC.h"
#include <iostream>
#include <cryptoTools/Common/Timer.h>
#include <cryptoTools/Common/TestCollection.h>
#include <cstdio>

#include "aby3/sh3/Sh3Encryptor.h"
#include "aby3/sh3/Sh3BinaryEvaluator.h"
#include "aby3/Circuit/CircuitLibrary.h"
#include "cryptoTools/Network/IOService.h"
#include "cryptoTools/Common/Log.h"
#include "cryptoTools/Crypto/PRNG.h"

#include <cryptoTools/Circuit/BetaLibrary.h>
#include <iomanip>
#include <atomic>
#include <vector>
#include <string>
#include <random>

std::vector<u64> bitsetToVector(const std::bitset<256>& myBitset) {
    std::vector<uint64_t> result;
    for (std::size_t i = 0; i < 256; i += 64) {
        uint64_t value = 0;
        for (std::size_t j = 0; j < 64; ++j) {
            value |= (myBitset[i + j] << j);
        }
        result.push_back(value);
    }

    return result;
}

void lowMC_CircuitEval_Simplest_test() {
    // >>>>> Eval locally
    LowMC2<> cipher2(false, 1);
    LowMC2<>::block m = 0;
    for (u64 i = 0; i < 13; ++i) cipher2.roundkeys[i] = 0;
    auto c2 = cipher2.encrypt(m);
    std::vector<u64> vec = bitsetToVector(c2);
    printf("Eval locally: ");
    for (u64 i = 0; i < 4; ++i) printf("%lu ", vec[i]);
    printf("\n");

    // >>>>> Eval jointly
    IOService ios;
    Session s01(ios, "127.0.0.1", SessionMode::Server, "01");
    Session s10(ios, "127.0.0.1", SessionMode::Client, "01");
    Session s02(ios, "127.0.0.1", SessionMode::Server, "02");
    Session s20(ios, "127.0.0.1", SessionMode::Client, "02");
    Session s12(ios, "127.0.0.1", SessionMode::Server, "12");
    Session s21(ios, "127.0.0.1", SessionMode::Client, "12");

    Channel chl01 = s01.addChannel("c");
    Channel chl10 = s10.addChannel("c");
    Channel chl02 = s02.addChannel("c");
    Channel chl20 = s20.addChannel("c");
    Channel chl12 = s12.addChannel("c");
    Channel chl21 = s21.addChannel("c");

    CommPkg comms[3], debugComm[3];
    comms[0] = { chl02, chl01 };
    comms[1] = { chl10, chl12 };
    comms[2] = { chl21, chl20 };

    BetaCircuit lowMCCir;

    LowMC2<> cipher1(false, 1);
    for (u64 i = 0; i < 13; ++i) cipher1.roundkeys[i] = 0;
    cipher1.to_enc_circuit(lowMCCir);
    lowMCCir.levelByAndDepth();

    u64 rounds = lowMCCir.mInputs.size() - 1;
    u64 blockSize = 256;
    u64 wordSize = blockSize / 64;

    std::vector<u64> serialized = {0, 0, 0, 0}; // The message to be encrypted
    u64 width = 1;
    i64Matrix kv(width, wordSize);
    std::vector<i64Matrix> keys(rounds);
    for (u64 i = 0; i < rounds; ++i) {
        keys[i].resize(1, wordSize);
        for (u64 j = 0; j < wordSize; ++j) keys[i](0, j) = 0; // The round keys
    }
    for (u64 i = 0; i < serialized.size(); ++i) {
        kv(i / wordSize, i % wordSize) = serialized[i];
    }

    bool success = true;

    auto routine = [&](int pIdx) {
        Sh3Runtime rt(pIdx, comms[pIdx]);
        Sh3Encryptor enc;
        enc.init(pIdx, toBlock(pIdx), toBlock((pIdx + 1) % 3));
        Sh3BinaryEvaluator eval;
        eval.mPrng.SetSeed(toBlock(pIdx));
        Sh3ShareGen gen;
        gen.init(toBlock(pIdx), toBlock((pIdx + 1) % 3));

        // Load the input to secret form
        sPackedBin KV(width, blockSize);
        sPackedBin encKV(width, blockSize);
        std::vector<sPackedBin> Keys(rounds);
        for (u64 i = 0; i < rounds; ++i) {
            Keys[i].reset(1, blockSize);
        }        

        auto task = rt.noDependencies();
        if (pIdx == 0) {
            task = enc.localPackedBinary(task, kv, KV);
        } else {
            task = enc.remotePackedBinary(task, KV);
        }
        if (pIdx == 0) {
            for (u64 i = 0; i < rounds; ++i) {
                task = enc.localPackedBinary(task, keys[i], Keys[i]);
            }
        } else {
            for (u64 i = 0; i < rounds; ++i) {
                task = enc.remotePackedBinary(task, Keys[i]);
            }            
        }
        task.get();

        // Eval the circuit and get the result
        eval.setCir(&lowMCCir, width, gen);
        eval.setInput(0, KV);
        for (u64 i = 0; i < rounds; ++i) {
            eval.setInput(i + 1, Keys[i]);
        }
        eval.asyncEvaluate(rt.noDependencies()).get();
        eval.getOutput(0, encKV);
        i64Matrix enckv(width, wordSize);
        enc.revealAll(rt.noDependencies(), encKV, enckv).get();

        // Compare the result
        if (pIdx == 0) {
            printf("Eval jointly: ");
            for (u64 i = 0; i < wordSize; ++i) printf("%lu ", enckv(0, i));
            printf("\n");
            for (u64 i = 0; i < wordSize; ++i) {
                if (vec[i] != enckv(0, i)) {
                    success = false;
                    break;
                }
            }
        }
    };

    auto t0 = std::thread(routine, 0);
    auto t1 = std::thread(routine, 1);
    auto t2 = std::thread(routine, 2);

    t0.join();
    t1.join();
    t2.join();    

    if (!success) {
        throw UnitTestFail();
    }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant