diff --git a/docs/tutorial/sha256_tutorial.ipynb b/docs/tutorial/sha256_tutorial.ipynb index 26cce894e9..c0b84aaead 100644 --- a/docs/tutorial/sha256_tutorial.ipynb +++ b/docs/tutorial/sha256_tutorial.ipynb @@ -539,36 +539,32 @@ "# Number of rounds must be 64 to have correct SHA256\n", "# If looking to get a faster run, reduce the number of rounds (but it will not be correct)\n", "\n", - "def get_sha256(number_of_rounds=None):\n", - " if number_of_rounds is None:\n", - " number_of_rounds=64\n", - " def sha256(data):\n", - " h_chunks = fhe.zeros((len(h_in), NUM_CHUNKS))\n", - " k_chunks = fhe.zeros((len(k_in), NUM_CHUNKS))\n", - " h_chunks += h_in\n", - " k_chunks += k_in\n", - "\n", - " num_of_iters = data.shape[0]*32//512\n", - " for chunk_iter in range(0, num_of_iters):\n", - " \n", - " # Initializing the variables\n", - " chunk = data[chunk_iter*16:(chunk_iter+1)*16]\n", - " w = [None for _ in range(number_of_rounds)]\n", - " # Starting the main loop and expansion\n", - " working_vars = h_chunks\n", - " for j in range(0, number_of_rounds):\n", - " if j<16:\n", - " w[j] = chunk[j]\n", - " else:\n", - " w[j] = add_four_32_bits(w[j-16], s0(w[j-15]), w[j-7], s1(w[j-2]))\n", - " w_i_k_i = add_two_32_bits(w[j], k_chunks[j])\n", - " working_vars = main_loop(working_vars,w_i_k_i)\n", - " \n", - " # Accumulating the results\n", - " for j in range(8):\n", - " h_chunks[j] = add_two_32_bits(h_chunks[j], working_vars[j])\n", - " return h_chunks\n", - " return sha256" + "def sha256(data, number_of_rounds=64):\n", + " h_chunks = fhe.zeros((len(h_in), NUM_CHUNKS))\n", + " k_chunks = fhe.zeros((len(k_in), NUM_CHUNKS))\n", + " h_chunks += h_in\n", + " k_chunks += k_in\n", + "\n", + " num_of_iters = data.shape[0]*32//512\n", + " for chunk_iter in range(0, num_of_iters):\n", + " \n", + " # Initializing the variables\n", + " chunk = data[chunk_iter*16:(chunk_iter+1)*16]\n", + " w = [None for _ in range(number_of_rounds)]\n", + " # Starting the main loop and expansion\n", + " working_vars = h_chunks\n", + " for j in range(0, number_of_rounds):\n", + " if j<16:\n", + " w[j] = chunk[j]\n", + " else:\n", + " w[j] = add_four_32_bits(w[j-16], s0(w[j-15]), w[j-7], s1(w[j-2]))\n", + " w_i_k_i = add_two_32_bits(w[j], k_chunks[j])\n", + " working_vars = main_loop(working_vars,w_i_k_i)\n", + " \n", + " # Accumulating the results\n", + " for j in range(8):\n", + " h_chunks[j] = add_two_32_bits(h_chunks[j], working_vars[j])\n", + " return h_chunks" ] }, { @@ -609,7 +605,7 @@ " b\"Curabitur bibendum, urna eu bibendum egestas, neque augue eleifend odio, et sagittis viverra. and more than 150\"\n", ")\n", "\n", - "result = get_sha256()(sha256_preprocess(np.frombuffer(text, dtype=np.uint8)))\n", + "result = sha256(sha256_preprocess(np.frombuffer(text, dtype=np.uint8)))\n", "\n", "m = hashlib.sha256()\n", "m.update(text)\n", @@ -638,7 +634,7 @@ " for _ in range(100)\n", " ]\n", " # Compilation of the circuit should take a few minutes\n", - " compiler = fhe.Compiler(get_sha256(self.number_of_rounds), {\"data\": \"encrypted\"})\n", + " compiler = fhe.Compiler(lambda data: sha256(data, self.number_of_rounds), {\"data\": \"encrypted\"})\n", " self.circuit = compiler.compile(\n", " inputset=inputset,\n", " configuration=fhe.Configuration(\n", @@ -655,7 +651,7 @@ " return self.circuit.encrypt_run_decrypt(sha256_preprocess(data))\n", "\n", " def getPlainSHA(self, data):\n", - " return get_sha256(self.number_of_rounds)(sha256_preprocess(data))" + " return sha256(sha256_preprocess(data), self.number_of_rounds)" ] }, {