From b84edbbd5f1febb8fe9c742ccaa281f7b00fd342 Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Mon, 4 Dec 2023 18:03:00 +0000 Subject: [PATCH 01/13] Implement Keccak sponge --- src/lib/keccak.ts | 385 ++++++++++++++++++++++++++++++++++-- src/lib/keccak.unit-test.ts | 133 +++++-------- 2 files changed, 416 insertions(+), 102 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index dc01bebc2d..d01f0cfd97 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -1,5 +1,11 @@ import { Field } from './field.js'; import { Gadgets } from './gadgets/gadgets.js'; +import { assert } from './errors.js'; +import { existsOne, exists } from './gadgets/common.js'; +import { TupleN } from './util/types.js'; +import { rangeCheck8 } from './gadgets/range-check.js'; + +export { preNist, nistSha3, ethereum }; // KECCAK CONSTANTS @@ -64,14 +70,153 @@ const ROUND_CONSTANTS = [ 0x8000000000008080n, 0x0000000080000001n, 0x8000000080008008n, -].map(Field.from); +]; + +function checkBytesToWord(word: Field, wordBytes: Field[]): void { + let composition = wordBytes.reduce((acc, x, i) => { + const shift = Field.from(2n ** BigInt(8 * i)); + return acc.add(x.mul(shift)); + }, Field.from(0)); + + word.assertEquals(composition); +} // Return a keccak state where all lanes are equal to 0 const getKeccakStateZeros = (): Field[][] => Array.from(Array(KECCAK_DIM), (_) => Array(KECCAK_DIM).fill(Field.from(0))); +// Converts a list of bytes to a matrix of Field elements +function getKeccakStateOfBytes(bytestring: Field[]): Field[][] { + assert(bytestring.length === 200, 'improper bytestring length'); + + const bytestringArray = Array.from(bytestring); + const state: Field[][] = getKeccakStateZeros(); + + for (let y = 0; y < KECCAK_DIM; y++) { + for (let x = 0; x < KECCAK_DIM; x++) { + const idx = BYTES_PER_WORD * (KECCAK_DIM * y + x); + // Create an array containing the 8 bytes starting on idx that correspond to the word in [x,y] + const wordBytes = bytestringArray.slice(idx, idx + BYTES_PER_WORD); + + for (let z = 0; z < BYTES_PER_WORD; z++) { + // Field element containing value 2^(8*z) + const shift = Field.from(2n ** BigInt(8 * z)); + state[x][y] = state[x][y].add(shift.mul(wordBytes[z])); + } + } + } + + return state; +} + +// Converts a state of cvars to a list of bytes as cvars and creates constraints for it +function keccakStateToBytes(state: Field[][]): Field[] { + const stateLengthInBytes = KECCAK_STATE_LENGTH / 8; + const bytestring: Field[] = Array.from( + { length: stateLengthInBytes }, + (_, idx) => + existsOne(() => { + // idx = z + 8 * ((dim * y) + x) + const z = idx % BYTES_PER_WORD; + const x = Math.floor(idx / BYTES_PER_WORD) % KECCAK_DIM; + const y = Math.floor(idx / BYTES_PER_WORD / KECCAK_DIM); + // [7 6 5 4 3 2 1 0] [x=0,y=1] [x=0,y=2] [x=0,y=3] [x=0,y=4] + // [x=1,y=0] [x=1,y=1] [x=1,y=2] [x=1,y=3] [x=1,y=4] + // [x=2,y=0] [x=2,y=1] [x=2,y=2] [x=2,y=3] [x=2,y=4] + // [x=3,y=0] [x=3,y=1] [x=3,y=2] [x=3,y=3] [x=3,y=4] + // [x=4,y=0] [x=4,y=1] [x=4,y=0] [x=4,y=3] [x=4,y=4] + const word = state[x][y].toBigInt(); + const byte = (word >> BigInt(8 * z)) & BigInt('0xff'); + return byte; + }) + ); + + // Check all words are composed correctly from bytes + for (let y = 0; y < KECCAK_DIM; y++) { + for (let x = 0; x < KECCAK_DIM; x++) { + const idx = BYTES_PER_WORD * (KECCAK_DIM * y + x); + // Create an array containing the 8 bytes starting on idx that correspond to the word in [x,y] + const word_bytes = bytestring.slice(idx, idx + BYTES_PER_WORD); + // Assert correct decomposition of bytes from state + checkBytesToWord(state[x][y], word_bytes); + } + } + + return bytestring; +} + +function keccakStateXor(a: Field[][], b: Field[][]): Field[][] { + assert( + a.length === KECCAK_DIM && a[0].length === KECCAK_DIM, + 'Invalid input1 dimensions' + ); + assert( + b.length === KECCAK_DIM && b[0].length === KECCAK_DIM, + 'Invalid input2 dimensions' + ); + + return a.map((row, rowIndex) => + row.map((element, columnIndex) => + Gadgets.xor(element, b[rowIndex][columnIndex], 64) + ) + ); +} + // KECCAK HASH FUNCTION +// Computes the number of required extra bytes to pad a message of length bytes +function bytesToPad(rate: number, length: number): number { + return Math.floor(rate / 8) - (length % Math.floor(rate / 8)); +} + +// Pads a message M as: +// M || pad[x](|M|) +// Padding rule 0x06 ..0*..1. +// The padded message vector will start with the message vector +// followed by the 0*1 rule to fulfill a length that is a multiple of rate (in bytes) +// (This means a 0110 sequence, followed with as many 0s as needed, and a final 1 bit) +function padNist(message: Field[], rate: number): Field[] { + // Find out desired length of the padding in bytes + // If message is already rate bits, need to pad full rate again + const extraBytes = bytesToPad(rate, message.length); + + // 0x06 0x00 ... 0x00 0x80 or 0x86 + const lastField = BigInt(2) ** BigInt(7); + const last = Field.from(lastField); + + // Create the padding vector + const pad = Array(extraBytes).fill(Field.from(0)); + pad[0] = Field.from(6); + pad[extraBytes - 1] = pad[extraBytes - 1].add(last); + + // Return the padded message + return [...message, ...pad]; +} + +// Pads a message M as: +// M || pad[x](|M|) +// Padding rule 10*1. +// The padded message vector will start with the message vector +// followed by the 10*1 rule to fulfill a length that is a multiple of rate (in bytes) +// (This means a 1 bit, followed with as many 0s as needed, and a final 1 bit) +function pad101(message: Field[], rate: number): Field[] { + // Find out desired length of the padding in bytes + // If message is already rate bits, need to pad full rate again + const extraBytes = bytesToPad(rate, message.length); + + // 0x01 0x00 ... 0x00 0x80 or 0x81 + const lastField = BigInt(2) ** BigInt(7); + const last = Field.from(lastField); + + // Create the padding vector + const pad = Array(extraBytes).fill(Field.from(0)); + pad[0] = Field.from(1); + pad[extraBytes - 1] = pad[extraBytes - 1].add(last); + + // Return the padded message + return [...message, ...pad]; +} + // ROUND TRANSFORMATION // First algorithm in the compression step of Keccak for 64-bit words. @@ -209,18 +354,230 @@ function permutation(state: Field[][], rc: Field[]): Field[][] { ); } -// TESTING +// Absorb padded message into a keccak state with given rate and capacity +function absorb( + paddedMessage: Field[], + capacity: number, + rate: number, + rc: Field[] +): Field[][] { + let state = getKeccakStateZeros(); -const blockTransformation = (state: Field[][]): Field[][] => - permutation(state, ROUND_CONSTANTS); + // split into blocks of rate bits + // for each block of rate bits in the padded message -> this is rate/8 bytes + const chunks = []; + // (capacity / 8) zero bytes + const zeros = Array(capacity / 8).fill(Field.from(0)); -export { - KECCAK_DIM, - ROUND_CONSTANTS, - theta, - piRho, - chi, - iota, - round, - blockTransformation, -}; + for (let i = 0; i < paddedMessage.length; i += rate / 8) { + const block = paddedMessage.slice(i, i + rate / 8); + // pad the block with 0s to up to 1600 bits + const paddedBlock = block.concat(zeros); + // padded with zeros each block until they are 1600 bit long + assert( + paddedBlock.length * 8 === KECCAK_STATE_LENGTH, + 'improper Keccak block length' + ); + const blockState = getKeccakStateOfBytes(paddedBlock); + // xor the state with the padded block + const stateXor = keccakStateXor(state, blockState); + // apply the permutation function to the xored state + const statePerm = permutation(stateXor, rc); + state = statePerm; + } + + return state; +} + +// Squeeze state until it has a desired length in bits +function squeeze( + state: Field[][], + length: number, + rate: number, + rc: Field[] +): Field[] { + const copy = ( + bytestring: Field[], + outputArray: Field[], + start: number, + length: number + ) => { + for (let i = 0; i < length; i++) { + outputArray[start + i] = bytestring[i]; + } + }; + + let newState = state; + + // bytes per squeeze + const bytesPerSqueeze = rate / 8; + // number of squeezes + const squeezes = Math.floor(length / rate) + 1; + // multiple of rate that is larger than output_length, in bytes + const outputLength = squeezes * bytesPerSqueeze; + // array with sufficient space to store the output + const outputArray = Array(outputLength).fill(Field.from(0)); + // first state to be squeezed + const bytestring = keccakStateToBytes(state); + const outputBytes = bytestring.slice(0, bytesPerSqueeze); + copy(outputBytes, outputArray, 0, bytesPerSqueeze); + // for the rest of squeezes + for (let i = 1; i < squeezes; i++) { + // apply the permutation function to the state + newState = permutation(newState, rc); + // append the output of the permutation function to the output + const bytestringI = keccakStateToBytes(state); + const outputBytesI = bytestringI.slice(0, bytesPerSqueeze); + copy(outputBytesI, outputArray, bytesPerSqueeze * i, bytesPerSqueeze); + } + // Obtain the hash selecting the first bitlength/8 bytes of the output array + const hashed = outputArray.slice(0, length / 8); + + return hashed; +} + +// Keccak sponge function for 1600 bits of state width +// Need to split the message into blocks of 1088 bits. +function sponge( + paddedMessage: Field[], + length: number, + capacity: number, + rate: number +): Field[] { + // check that the padded message is a multiple of rate + if ((paddedMessage.length * 8) % rate !== 0) { + throw new Error('Invalid padded message length'); + } + + // setup cvars for round constants + let rc = exists(24, () => TupleN.fromArray(24, ROUND_CONSTANTS)); + + // absorb + const state = absorb(paddedMessage, capacity, rate, rc); + + // squeeze + const hashed = squeeze(state, length, rate, rc); + + return hashed; +} + +// TODO(jackryanservia): Use lookup argument once issue is resolved +// Checks in the circuit that a list of cvars are at most 8 bits each +function checkBytes(inputs: Field[]): void { + inputs.map(rangeCheck8); +} + +// Keccak hash function with input message passed as list of Cvar bytes. +// The message will be parsed as follows: +// - the first byte of the message will be the least significant byte of the first word of the state (A[0][0]) +// - the 10*1 pad will take place after the message, until reaching the bit length rate. +// - then, {0} pad will take place to finish the 1600 bits of the state. +function hash( + inpEndian: 'Big' | 'Little' = 'Big', + outEndian: 'Big' | 'Little' = 'Big', + byteChecks: boolean = false, + message: Field[] = [], + length: number, + capacity: number, + nistVersion: boolean +): Field[] { + assert(capacity > 0, 'capacity must be positive'); + assert(capacity < KECCAK_STATE_LENGTH, 'capacity must be less than 1600'); + assert(length > 0, 'length must be positive'); + assert(length % 8 === 0, 'length must be a multiple of 8'); + + // Set input to Big Endian format + let messageFormatted = inpEndian === 'Big' ? message : message.reverse(); + + // Check each cvar input is 8 bits at most if it was not done before at creation time + if (byteChecks) { + checkBytes(messageFormatted); + } + + const rate = KECCAK_STATE_LENGTH - capacity; + + let padded; + if (nistVersion) { + padded = padNist(messageFormatted, rate); + } else { + padded = pad101(messageFormatted, rate); + } + + const hash = sponge(padded, length, capacity, rate); + + // Check each cvar output is 8 bits at most. Always because they are created here + checkBytes(hash); + + // Set input to desired endianness + const hashFormatted = outEndian === 'Big' ? hash : hash.reverse(); + + // Check each cvar output is 8 bits at most + return hashFormatted; +} + +// Gadget for NIST SHA-3 function for output lengths 224/256/384/512. +// Input and output endianness can be specified. Default is big endian. +// Note that when calling with output length 256 this is equivalent to the ethereum function +function nistSha3( + len: number, + message: Field[], + inpEndian: 'Big' | 'Little' = 'Big', + outEndian: 'Big' | 'Little' = 'Big', + byteChecks: boolean = false +): Field[] { + let output: Field[]; + + switch (len) { + case 224: + output = hash(inpEndian, outEndian, byteChecks, message, 224, 448, true); + break; + case 256: + output = hash(inpEndian, outEndian, byteChecks, message, 256, 512, true); + break; + case 384: + output = hash(inpEndian, outEndian, byteChecks, message, 384, 768, true); + break; + case 512: + output = hash(inpEndian, outEndian, byteChecks, message, 512, 1024, true); + break; + default: + throw new Error('Invalid length'); + } + + return output; +} + +// Gadget for Keccak hash function for the parameters used in Ethereum. +// Input and output endianness can be specified. Default is big endian. +function ethereum( + inpEndian: 'Big' | 'Little' = 'Big', + outEndian: 'Big' | 'Little' = 'Big', + byteChecks: boolean = false, + message: Field[] = [] +): Field[] { + return hash(inpEndian, outEndian, byteChecks, message, 256, 512, false); +} + +// Gadget for pre-NIST SHA-3 function for output lengths 224/256/384/512. +// Input and output endianness can be specified. Default is big endian. +// Note that when calling with output length 256 this is equivalent to the ethereum function +function preNist( + len: number, + message: Field[], + inpEndian: 'Big' | 'Little' = 'Big', + outEndian: 'Big' | 'Little' = 'Big', + byteChecks: boolean = false +): Field[] { + switch (len) { + case 224: + return hash(inpEndian, outEndian, byteChecks, message, 224, 448, false); + case 256: + return ethereum(inpEndian, outEndian, byteChecks, message); + case 384: + return hash(inpEndian, outEndian, byteChecks, message, 384, 768, false); + case 512: + return hash(inpEndian, outEndian, byteChecks, message, 512, 1024, false); + default: + throw new Error('Invalid length'); + } +} diff --git a/src/lib/keccak.unit-test.ts b/src/lib/keccak.unit-test.ts index 1c20b2ec6f..cd477ddb5c 100644 --- a/src/lib/keccak.unit-test.ts +++ b/src/lib/keccak.unit-test.ts @@ -1,106 +1,63 @@ import { Field } from './field.js'; import { Provable } from './provable.js'; +import { preNist, nistSha3, ethereum } from './keccak.js'; +import { sha3_256, keccak_256 } from '@noble/hashes/sha3'; import { ZkProgram } from './proof_system.js'; -import { constraintSystem, print } from './testing/constraint-system.js'; -import { - KECCAK_DIM, - ROUND_CONSTANTS, - theta, - piRho, - chi, - iota, - round, - blockTransformation, -} from './keccak.js'; -const KECCAK_TEST_STATE = [ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [1, 0, 0, 0, 0], - [0, 0, 0, 1, 0], - [0, 1, 0, 0, 0], -].map((row) => row.map((elem) => Field.from(elem))); - -let KeccakBlockTransformation = ZkProgram({ - name: 'KeccakBlockTransformation', - publicInput: Provable.Array(Provable.Array(Field, KECCAK_DIM), KECCAK_DIM), - publicOutput: Provable.Array(Provable.Array(Field, KECCAK_DIM), KECCAK_DIM), +let Keccak = ZkProgram({ + name: 'keccak', + publicInput: Provable.Array(Field, 100), + publicOutput: Provable.Array(Field, 32), methods: { - Theta: { - privateInputs: [], - method(input: Field[][]) { - return theta(input); - }, - }, - PiRho: { + preNist: { privateInputs: [], - method(input: Field[][]) { - return piRho(input); + method(preImage) { + return preNist(256, preImage); }, }, - Chi: { + nistSha3: { privateInputs: [], - method(input: Field[][]) { - return chi(input); + method(preImage) { + return nistSha3(256, preImage); }, }, - Iota: { + ethereum: { privateInputs: [], - method(input: Field[][]) { - return iota(input, ROUND_CONSTANTS[0]); - }, - }, - Round: { - privateInputs: [], - method(input: Field[][]) { - return round(input, ROUND_CONSTANTS[0]); - }, - }, - BlockTransformation: { - privateInputs: [], - method(input: Field[][]) { - return blockTransformation(input); + method(preImage) { + return ethereum(preImage); }, }, }, }); -// constraintSystem.fromZkProgram( -// KeccakBlockTransformation, -// 'BlockTransformation', -// print -// ); - -console.log('KECCAK_TEST_STATE: ', KECCAK_TEST_STATE.toString()); - -console.log('Compiling...'); -await KeccakBlockTransformation.compile(); -console.log('Done!'); -console.log('Generating proof...'); -let proof0 = await KeccakBlockTransformation.BlockTransformation( - KECCAK_TEST_STATE -); -console.log('Done!'); -console.log('Output:', proof0.publicOutput.toString()); -console.log('Verifying...'); -proof0.verify(); -console.log('Done!'); - -/* -[RUST IMPLEMENTATION OUTPUT](https://github.com/BaldyAsh/keccak-rust) - -INPUT: -[[0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [1, 0, 0, 0, 0], - [0, 0, 0, 1, 0], - [0, 1, 0, 0, 0]] - -OUTPUT: -[[8771753707458093707, 14139250443469741764, 11827767624278131459, 2757454755833177578, 5758014717183214102], -[3389583698920935946, 1287099063347104936, 15030403046357116816, 17185756281681305858, 9708367831595350450], -[1416127551095004411, 16037937966823201128, 9518790688640222300, 1997971396112921437, 4893561083608951508], -[8048617297177300085, 10306645194383020789, 2789881727527423094, 7603160281577405588, 12935834807086847890], -[9476112750389234330, 13193683191463706918, 4460519148532423021, 7183125267124224670, 1393214916959060614]] -*/ +console.log("compiling keccak"); +await Keccak.compile(); +console.log("done compiling keccak"); + +const runs = 2; + +let preImage = [ + 236, 185, 24, 61, 138, 249, 61, 13, 226, 103, 152, 232, 104, 234, 170, 26, + 46, 54, 157, 146, 17, 240, 10, 193, 214, 110, 134, 47, 97, 241, 172, 198, + 80, 95, 136, 185, 62, 156, 246, 210, 207, 129, 93, 162, 215, 77, 3, 38, + 194, 86, 75, 100, 64, 87, 6, 18, 4, 159, 235, 53, 87, 124, 216, 241, 179, + 201, 111, 168, 72, 181, 28, 65, 142, 243, 224, 69, 58, 178, 114, 3, 112, + 23, 15, 208, 103, 231, 114, 64, 89, 172, 240, 81, 27, 215, 129, 3, 16, + 173, 133, 160, +] + +let preNistProof = await Keccak.preNist(preImage.map(Field.from)); +console.log(preNistProof.publicOutput.toString()); +console.log(keccak_256(new Uint8Array(preImage))); +let nistSha3Proof = await Keccak.nistSha3(preImage.map(Field.from)); +console.log(nistSha3Proof.publicOutput.toString()); +console.log(sha3_256(new Uint8Array(preImage))); +let ethereumProof = await Keccak.ethereum(preImage.map(Field.from)); +console.log(ethereumProof.publicOutput.toString()); + +console.log('verifying'); +preNistProof.verify(); +nistSha3Proof.verify(); +ethereumProof.verify(); +console.log('done verifying'); From b5b0ee444f71390d2590d3ac0fff3dc26cee3a43 Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:29:31 +0000 Subject: [PATCH 02/13] Improved test and cleaned up implementation --- package-lock.json | 13 +++ package.json | 1 + src/lib/keccak.ts | 92 +++++++++++++++------ src/lib/keccak.unit-test.ts | 161 ++++++++++++++++++++++++++---------- 4 files changed, 196 insertions(+), 71 deletions(-) diff --git a/package-lock.json b/package-lock.json index eff466f3fb..a518898026 100644 --- a/package-lock.json +++ b/package-lock.json @@ -21,6 +21,7 @@ "snarky-run": "src/build/run.js" }, "devDependencies": { + "@noble/hashes": "^1.3.2", "@playwright/test": "^1.25.2", "@types/isomorphic-fetch": "^0.0.36", "@types/jest": "^27.0.0", @@ -1486,6 +1487,18 @@ "@jridgewell/sourcemap-codec": "^1.4.10" } }, + "node_modules/@noble/hashes": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.3.2.tgz", + "integrity": "sha512-MVC8EAQp7MvEcm30KWENFjgR+Mkmf+D189XJTkFIlwohU5hcBbn1ZkKq7KVTi2Hme3PMGF390DaL52beVrIihQ==", + "dev": true, + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://paulmillr.com/funding/" + } + }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", diff --git a/package.json b/package.json index b890fd3dba..5b6a3f46f9 100644 --- a/package.json +++ b/package.json @@ -72,6 +72,7 @@ }, "author": "O(1) Labs", "devDependencies": { + "@noble/hashes": "^1.3.2", "@playwright/test": "^1.25.2", "@types/isomorphic-fetch": "^0.0.36", "@types/jest": "^27.0.0", diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index d01f0cfd97..c6b6db9ad3 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -5,7 +5,39 @@ import { existsOne, exists } from './gadgets/common.js'; import { TupleN } from './util/types.js'; import { rangeCheck8 } from './gadgets/range-check.js'; -export { preNist, nistSha3, ethereum }; +export { Keccak }; + +const Keccak = { + /** TODO */ + preNist( + len: number, + message: Field[], + inpEndian: 'Big' | 'Little' = 'Big', + outEndian: 'Big' | 'Little' = 'Big', + byteChecks: boolean = false + ) { + return preNist(len, message, inpEndian, outEndian, byteChecks); + }, + /** TODO */ + nistSha3( + len: number, + message: Field[], + inpEndian: 'Big' | 'Little' = 'Big', + outEndian: 'Big' | 'Little' = 'Big', + byteChecks: boolean = false + ) { + return nistSha3(len, message, inpEndian, outEndian, byteChecks); + }, + /** TODO */ + ethereum( + message: Field[], + inpEndian: 'Big' | 'Little' = 'Big', + outEndian: 'Big' | 'Little' = 'Big', + byteChecks: boolean = false + ) { + return ethereum(message, inpEndian, outEndian, byteChecks); + }, +}; // KECCAK CONSTANTS @@ -72,6 +104,9 @@ const ROUND_CONSTANTS = [ 0x8000000080008008n, ]; +// AUXILARY FUNCTIONS + +// Auxiliary function to check composition of 8 bytes into a 64-bit word function checkBytesToWord(word: Field, wordBytes: Field[]): void { let composition = wordBytes.reduce((acc, x, i) => { const shift = Field.from(2n ** BigInt(8 * i)); @@ -81,6 +116,8 @@ function checkBytesToWord(word: Field, wordBytes: Field[]): void { word.assertEquals(composition); } +// KECCAK STATE FUNCTIONS + // Return a keccak state where all lanes are equal to 0 const getKeccakStateZeros = (): Field[][] => Array.from(Array(KECCAK_DIM), (_) => Array(KECCAK_DIM).fill(Field.from(0))); @@ -105,11 +142,10 @@ function getKeccakStateOfBytes(bytestring: Field[]): Field[][] { } } } - return state; } -// Converts a state of cvars to a list of bytes as cvars and creates constraints for it +// Converts a state of Fields to a list of bytes as Fields and creates constraints for it function keccakStateToBytes(state: Field[][]): Field[] { const stateLengthInBytes = KECCAK_STATE_LENGTH / 8; const bytestring: Field[] = Array.from( @@ -141,20 +177,21 @@ function keccakStateToBytes(state: Field[][]): Field[] { checkBytesToWord(state[x][y], word_bytes); } } - return bytestring; } +// XOR two states together and return the result function keccakStateXor(a: Field[][], b: Field[][]): Field[][] { assert( a.length === KECCAK_DIM && a[0].length === KECCAK_DIM, - 'Invalid input1 dimensions' + 'Invalid a dimensions' ); assert( b.length === KECCAK_DIM && b[0].length === KECCAK_DIM, - 'Invalid input2 dimensions' + 'Invalid b dimensions' ); + // Calls Gadgets.xor on each pair (x,y) of the states input1 and input2 and outputs the output Fields as a new matrix return a.map((row, rowIndex) => row.map((element, columnIndex) => Gadgets.xor(element, b[rowIndex][columnIndex], 64) @@ -181,8 +218,7 @@ function padNist(message: Field[], rate: number): Field[] { const extraBytes = bytesToPad(rate, message.length); // 0x06 0x00 ... 0x00 0x80 or 0x86 - const lastField = BigInt(2) ** BigInt(7); - const last = Field.from(lastField); + const last = Field.from(BigInt(2) ** BigInt(7)); // Create the padding vector const pad = Array(extraBytes).fill(Field.from(0)); @@ -205,8 +241,7 @@ function pad101(message: Field[], rate: number): Field[] { const extraBytes = bytesToPad(rate, message.length); // 0x01 0x00 ... 0x00 0x80 or 0x81 - const lastField = BigInt(2) ** BigInt(7); - const last = Field.from(lastField); + const last = Field.from(BigInt(2) ** BigInt(7)); // Create the padding vector const pad = Array(extraBytes).fill(Field.from(0)); @@ -354,6 +389,8 @@ function permutation(state: Field[][], rc: Field[]): Field[][] { ); } +// KECCAK SPONGE + // Absorb padded message into a keccak state with given rate and capacity function absorb( paddedMessage: Field[], @@ -363,13 +400,12 @@ function absorb( ): Field[][] { let state = getKeccakStateZeros(); - // split into blocks of rate bits - // for each block of rate bits in the padded message -> this is rate/8 bytes - const chunks = []; // (capacity / 8) zero bytes const zeros = Array(capacity / 8).fill(Field.from(0)); for (let i = 0; i < paddedMessage.length; i += rate / 8) { + // split into blocks of rate bits + // for each block of rate bits in the padded message -> this is rate/8 bytes const block = paddedMessage.slice(i, i + rate / 8); // pad the block with 0s to up to 1600 bits const paddedBlock = block.concat(zeros); @@ -385,7 +421,6 @@ function absorb( const statePerm = permutation(stateXor, rc); state = statePerm; } - return state; } @@ -396,6 +431,7 @@ function squeeze( rate: number, rc: Field[] ): Field[] { + // Copies a section of bytes in the bytestring into the output array const copy = ( bytestring: Field[], outputArray: Field[], @@ -421,6 +457,7 @@ function squeeze( const bytestring = keccakStateToBytes(state); const outputBytes = bytestring.slice(0, bytesPerSqueeze); copy(outputBytes, outputArray, 0, bytesPerSqueeze); + // for the rest of squeezes for (let i = 1; i < squeezes; i++) { // apply the permutation function to the state @@ -430,9 +467,9 @@ function squeeze( const outputBytesI = bytestringI.slice(0, bytesPerSqueeze); copy(outputBytesI, outputArray, bytesPerSqueeze * i, bytesPerSqueeze); } + // Obtain the hash selecting the first bitlength/8 bytes of the output array const hashed = outputArray.slice(0, length / 8); - return hashed; } @@ -449,7 +486,7 @@ function sponge( throw new Error('Invalid padded message length'); } - // setup cvars for round constants + // load round constants into Fields let rc = exists(24, () => TupleN.fromArray(24, ROUND_CONSTANTS)); // absorb @@ -462,12 +499,12 @@ function sponge( } // TODO(jackryanservia): Use lookup argument once issue is resolved -// Checks in the circuit that a list of cvars are at most 8 bits each +// Checks in the circuit that a list of Fields are at most 8 bits each function checkBytes(inputs: Field[]): void { inputs.map(rangeCheck8); } -// Keccak hash function with input message passed as list of Cvar bytes. +// Keccak hash function with input message passed as list of Field bytes. // The message will be parsed as follows: // - the first byte of the message will be the least significant byte of the first word of the state (A[0][0]) // - the 10*1 pad will take place after the message, until reaching the bit length rate. @@ -482,14 +519,17 @@ function hash( nistVersion: boolean ): Field[] { assert(capacity > 0, 'capacity must be positive'); - assert(capacity < KECCAK_STATE_LENGTH, 'capacity must be less than 1600'); + assert( + capacity < KECCAK_STATE_LENGTH, + 'capacity must be less than KECCAK_STATE_LENGTH' + ); assert(length > 0, 'length must be positive'); assert(length % 8 === 0, 'length must be a multiple of 8'); - // Set input to Big Endian format + // Input endianness conversion let messageFormatted = inpEndian === 'Big' ? message : message.reverse(); - // Check each cvar input is 8 bits at most if it was not done before at creation time + // Check each Field input is 8 bits at most if it was not done before at creation time if (byteChecks) { checkBytes(messageFormatted); } @@ -505,19 +545,17 @@ function hash( const hash = sponge(padded, length, capacity, rate); - // Check each cvar output is 8 bits at most. Always because they are created here + // Always check each Field output is 8 bits at most because they are created here checkBytes(hash); // Set input to desired endianness const hashFormatted = outEndian === 'Big' ? hash : hash.reverse(); - // Check each cvar output is 8 bits at most return hashFormatted; } // Gadget for NIST SHA-3 function for output lengths 224/256/384/512. // Input and output endianness can be specified. Default is big endian. -// Note that when calling with output length 256 this is equivalent to the ethereum function function nistSha3( len: number, message: Field[], @@ -550,10 +588,10 @@ function nistSha3( // Gadget for Keccak hash function for the parameters used in Ethereum. // Input and output endianness can be specified. Default is big endian. function ethereum( + message: Field[] = [], inpEndian: 'Big' | 'Little' = 'Big', outEndian: 'Big' | 'Little' = 'Big', - byteChecks: boolean = false, - message: Field[] = [] + byteChecks: boolean = false ): Field[] { return hash(inpEndian, outEndian, byteChecks, message, 256, 512, false); } @@ -572,7 +610,7 @@ function preNist( case 224: return hash(inpEndian, outEndian, byteChecks, message, 224, 448, false); case 256: - return ethereum(inpEndian, outEndian, byteChecks, message); + return ethereum(message, inpEndian, outEndian, byteChecks); case 384: return hash(inpEndian, outEndian, byteChecks, message, 384, 768, false); case 512: diff --git a/src/lib/keccak.unit-test.ts b/src/lib/keccak.unit-test.ts index cd477ddb5c..e72e655093 100644 --- a/src/lib/keccak.unit-test.ts +++ b/src/lib/keccak.unit-test.ts @@ -1,63 +1,136 @@ import { Field } from './field.js'; import { Provable } from './provable.js'; -import { preNist, nistSha3, ethereum } from './keccak.js'; -import { sha3_256, keccak_256 } from '@noble/hashes/sha3'; +import { Keccak } from './keccak.js'; +import { keccak_256, sha3_256, keccak_512, sha3_512 } from '@noble/hashes/sha3'; import { ZkProgram } from './proof_system.js'; +import { Random } from './testing/random.js'; +import { array, equivalentAsync, fieldWithRng } from './testing/equivalent.js'; +import { constraintSystem, contains } from './testing/constraint-system.js'; +const PREIMAGE_LENGTH = 75; +const RUNS = 1; -let Keccak = ZkProgram({ - name: 'keccak', - publicInput: Provable.Array(Field, 100), +let uint = (length: number) => fieldWithRng(Random.biguint(length)); + +let Keccak256 = ZkProgram({ + name: 'keccak256', + publicInput: Provable.Array(Field, PREIMAGE_LENGTH), publicOutput: Provable.Array(Field, 32), methods: { - preNist: { + ethereum: { privateInputs: [], method(preImage) { - return preNist(256, preImage); + return Keccak.ethereum(preImage); }, }, + // No need for preNist Keccak_256 because it's identical to ethereum nistSha3: { privateInputs: [], method(preImage) { - return nistSha3(256, preImage); - }, - }, - ethereum: { - privateInputs: [], - method(preImage) { - return ethereum(preImage); + return Keccak.nistSha3(256, preImage); }, }, }, }); -console.log("compiling keccak"); -await Keccak.compile(); -console.log("done compiling keccak"); - -const runs = 2; - -let preImage = [ - 236, 185, 24, 61, 138, 249, 61, 13, 226, 103, 152, 232, 104, 234, 170, 26, - 46, 54, 157, 146, 17, 240, 10, 193, 214, 110, 134, 47, 97, 241, 172, 198, - 80, 95, 136, 185, 62, 156, 246, 210, 207, 129, 93, 162, 215, 77, 3, 38, - 194, 86, 75, 100, 64, 87, 6, 18, 4, 159, 235, 53, 87, 124, 216, 241, 179, - 201, 111, 168, 72, 181, 28, 65, 142, 243, 224, 69, 58, 178, 114, 3, 112, - 23, 15, 208, 103, 231, 114, 64, 89, 172, 240, 81, 27, 215, 129, 3, 16, - 173, 133, 160, -] - -let preNistProof = await Keccak.preNist(preImage.map(Field.from)); -console.log(preNistProof.publicOutput.toString()); -console.log(keccak_256(new Uint8Array(preImage))); -let nistSha3Proof = await Keccak.nistSha3(preImage.map(Field.from)); -console.log(nistSha3Proof.publicOutput.toString()); -console.log(sha3_256(new Uint8Array(preImage))); -let ethereumProof = await Keccak.ethereum(preImage.map(Field.from)); -console.log(ethereumProof.publicOutput.toString()); - -console.log('verifying'); -preNistProof.verify(); -nistSha3Proof.verify(); -ethereumProof.verify(); -console.log('done verifying'); +await Keccak256.compile(); + +await equivalentAsync( + { + from: [array(uint(8), PREIMAGE_LENGTH)], + to: array(uint(8), 32), + }, + { runs: RUNS } +)( + (x) => { + let uint8Array = new Uint8Array(x.map(Number)); + let result = keccak_256(uint8Array); + return Array.from(result).map(BigInt); + }, + async (x) => { + let proof = await Keccak256.ethereum(x); + return proof.publicOutput; + } +); + +await equivalentAsync( + { + from: [array(uint(8), PREIMAGE_LENGTH)], + to: array(uint(8), 32), + }, + { runs: RUNS } +)( + (x) => { + let thing = x.map(Number); + let result = sha3_256(new Uint8Array(thing)); + return Array.from(result).map(BigInt); + }, + async (x) => { + let proof = await Keccak256.nistSha3(x); + return proof.publicOutput; + } +); + +// let Keccak512 = ZkProgram({ +// name: 'keccak512', +// publicInput: Provable.Array(Field, PREIMAGE_LENGTH), +// publicOutput: Provable.Array(Field, 64), +// methods: { +// preNist: { +// privateInputs: [], +// method(preImage) { +// return Keccak.preNist(512, preImage, 'Big', 'Big', true); +// }, +// }, +// nistSha3: { +// privateInputs: [], +// method(preImage) { +// return Keccak.nistSha3(512, preImage, 'Big', 'Big', true); +// }, +// }, +// }, +// }); + +// await Keccak512.compile(); + +// await equivalentAsync( +// { +// from: [array(uint(8), PREIMAGE_LENGTH)], +// to: array(uint(8), 64), +// }, +// { runs: RUNS } +// )( +// (x) => { +// let uint8Array = new Uint8Array(x.map(Number)); +// let result = keccak_512(uint8Array); +// return Array.from(result).map(BigInt); +// }, +// async (x) => { +// let proof = await Keccak512.preNist(x); +// return proof.publicOutput; +// } +// ); + +// await equivalentAsync( +// { +// from: [array(uint(8), PREIMAGE_LENGTH)], +// to: array(uint(8), 64), +// }, +// { runs: RUNS } +// )( +// (x) => { +// let thing = x.map(Number); +// let result = sha3_512(new Uint8Array(thing)); +// return Array.from(result).map(BigInt); +// }, +// async (x) => { +// let proof = await Keccak512.nistSha3(x); +// return proof.publicOutput; +// } +// ); + +constraintSystem.fromZkProgram( + Keccak256, + 'ethereum', + contains([['Generic'], ['Xor16'], ['Zero'], ['Rot64'], ['RangeCheck0']]) +); From 991008b8e56440bc28172223ac1a7424c108f0f0 Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Thu, 7 Dec 2023 10:50:51 +0000 Subject: [PATCH 03/13] Cleaned up and made variable names consistent --- src/lib/keccak.ts | 225 +++++++++++++++++------------------- src/lib/keccak.unit-test.ts | 33 +++--- 2 files changed, 122 insertions(+), 136 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index c6b6db9ad3..813b53d576 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -9,33 +9,33 @@ export { Keccak }; const Keccak = { /** TODO */ - preNist( + nistSha3( len: number, message: Field[], inpEndian: 'Big' | 'Little' = 'Big', outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false - ) { - return preNist(len, message, inpEndian, outEndian, byteChecks); + ): Field[] { + return nistSha3(len, message, inpEndian, outEndian, byteChecks); }, /** TODO */ - nistSha3( - len: number, + ethereum( message: Field[], inpEndian: 'Big' | 'Little' = 'Big', outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false - ) { - return nistSha3(len, message, inpEndian, outEndian, byteChecks); + ): Field[] { + return ethereum(message, inpEndian, outEndian, byteChecks); }, /** TODO */ - ethereum( + preNist( + len: number, message: Field[], inpEndian: 'Big' | 'Little' = 'Big', outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false - ) { - return ethereum(message, inpEndian, outEndian, byteChecks); + ): Field[] { + return preNist(len, message, inpEndian, outEndian, byteChecks); }, }; @@ -60,7 +60,7 @@ const KECCAK_STATE_LENGTH = KECCAK_DIM ** 2 * KECCAK_WORD; const KECCAK_ROUNDS = 12 + 2 * KECCAK_ELL; // Creates the 5x5 table of rotation offset for Keccak modulo 64 -// | x \ y | 0 | 1 | 2 | 3 | 4 | +// | i \ j | 0 | 1 | 2 | 3 | 4 | // | ----- | -- | -- | -- | -- | -- | // | 0 | 0 | 36 | 3 | 41 | 18 | // | 1 | 1 | 44 | 10 | 45 | 2 | @@ -106,13 +106,14 @@ const ROUND_CONSTANTS = [ // AUXILARY FUNCTIONS -// Auxiliary function to check composition of 8 bytes into a 64-bit word -function checkBytesToWord(word: Field, wordBytes: Field[]): void { - let composition = wordBytes.reduce((acc, x, i) => { - const shift = Field.from(2n ** BigInt(8 * i)); - return acc.add(x.mul(shift)); +// Auxiliary function to check the composition of 8 byte values (LE) into a 64-bit word and create constraints for it +function checkBytesToWord(wordBytes: Field[], word: Field): void { + const composition = wordBytes.reduce((acc, value, idx) => { + const shift = 1n << BigInt(8 * idx); + return acc.add(value.mul(shift)); }, Field.from(0)); + // Create constraints to check that the word is composed correctly from bytes word.assertEquals(composition); } @@ -124,21 +125,24 @@ const getKeccakStateZeros = (): Field[][] => // Converts a list of bytes to a matrix of Field elements function getKeccakStateOfBytes(bytestring: Field[]): Field[][] { - assert(bytestring.length === 200, 'improper bytestring length'); + assert( + bytestring.length === 200, + 'improper bytestring length (should be 200)' + ); const bytestringArray = Array.from(bytestring); const state: Field[][] = getKeccakStateZeros(); - for (let y = 0; y < KECCAK_DIM; y++) { - for (let x = 0; x < KECCAK_DIM; x++) { - const idx = BYTES_PER_WORD * (KECCAK_DIM * y + x); - // Create an array containing the 8 bytes starting on idx that correspond to the word in [x,y] + for (let j = 0; j < KECCAK_DIM; j++) { + for (let i = 0; i < KECCAK_DIM; i++) { + const idx = BYTES_PER_WORD * (KECCAK_DIM * j + i); + // Create an array containing the 8 bytes starting on idx that correspond to the word in [i,j] const wordBytes = bytestringArray.slice(idx, idx + BYTES_PER_WORD); - for (let z = 0; z < BYTES_PER_WORD; z++) { - // Field element containing value 2^(8*z) - const shift = Field.from(2n ** BigInt(8 * z)); - state[x][y] = state[x][y].add(shift.mul(wordBytes[z])); + for (let k = 0; k < BYTES_PER_WORD; k++) { + // Field element containing value 2^(8*k) + const shift = 1n << BigInt(8 * k); + state[i][j] = state[i][j].add(wordBytes[k].mul(shift)); } } } @@ -152,29 +156,24 @@ function keccakStateToBytes(state: Field[][]): Field[] { { length: stateLengthInBytes }, (_, idx) => existsOne(() => { - // idx = z + 8 * ((dim * y) + x) - const z = idx % BYTES_PER_WORD; - const x = Math.floor(idx / BYTES_PER_WORD) % KECCAK_DIM; - const y = Math.floor(idx / BYTES_PER_WORD / KECCAK_DIM); - // [7 6 5 4 3 2 1 0] [x=0,y=1] [x=0,y=2] [x=0,y=3] [x=0,y=4] - // [x=1,y=0] [x=1,y=1] [x=1,y=2] [x=1,y=3] [x=1,y=4] - // [x=2,y=0] [x=2,y=1] [x=2,y=2] [x=2,y=3] [x=2,y=4] - // [x=3,y=0] [x=3,y=1] [x=3,y=2] [x=3,y=3] [x=3,y=4] - // [x=4,y=0] [x=4,y=1] [x=4,y=0] [x=4,y=3] [x=4,y=4] - const word = state[x][y].toBigInt(); - const byte = (word >> BigInt(8 * z)) & BigInt('0xff'); + // idx = k + 8 * ((dim * j) + i) + const i = Math.floor(idx / BYTES_PER_WORD) % KECCAK_DIM; + const j = Math.floor(idx / BYTES_PER_WORD / KECCAK_DIM); + const k = idx % BYTES_PER_WORD; + const word = state[i][j].toBigInt(); + const byte = (word >> BigInt(8 * k)) & 0xffn; return byte; }) ); // Check all words are composed correctly from bytes - for (let y = 0; y < KECCAK_DIM; y++) { - for (let x = 0; x < KECCAK_DIM; x++) { - const idx = BYTES_PER_WORD * (KECCAK_DIM * y + x); - // Create an array containing the 8 bytes starting on idx that correspond to the word in [x,y] + for (let j = 0; j < KECCAK_DIM; j++) { + for (let i = 0; i < KECCAK_DIM; i++) { + const idx = BYTES_PER_WORD * (KECCAK_DIM * j + i); + // Create an array containing the 8 bytes starting on idx that correspond to the word in [i,j] const word_bytes = bytestring.slice(idx, idx + BYTES_PER_WORD); // Assert correct decomposition of bytes from state - checkBytesToWord(state[x][y], word_bytes); + checkBytesToWord(word_bytes, state[i][j]); } } return bytestring; @@ -184,18 +183,16 @@ function keccakStateToBytes(state: Field[][]): Field[] { function keccakStateXor(a: Field[][], b: Field[][]): Field[][] { assert( a.length === KECCAK_DIM && a[0].length === KECCAK_DIM, - 'Invalid a dimensions' + `invalid \`a\` dimensions (should be ${KECCAK_DIM})` ); assert( b.length === KECCAK_DIM && b[0].length === KECCAK_DIM, - 'Invalid b dimensions' + `invalid \`b\` dimensions (should be ${KECCAK_DIM})` ); - // Calls Gadgets.xor on each pair (x,y) of the states input1 and input2 and outputs the output Fields as a new matrix - return a.map((row, rowIndex) => - row.map((element, columnIndex) => - Gadgets.xor(element, b[rowIndex][columnIndex], 64) - ) + // Calls Gadgets.xor on each pair (i,j) of the states input1 and input2 and outputs the output Fields as a new matrix + return a.map((row, i) => + row.map((value, j) => Gadgets.xor(value, b[i][j], 64)) ); } @@ -255,21 +252,21 @@ function pad101(message: Field[], rate: number): Field[] { // ROUND TRANSFORMATION // First algorithm in the compression step of Keccak for 64-bit words. -// C[x] = A[x,0] xor A[x,1] xor A[x,2] xor A[x,3] xor A[x,4] -// D[x] = C[x-1] xor ROT(C[x+1], 1) -// E[x,y] = A[x,y] xor D[x] +// C[i] = A[i,0] xor A[i,1] xor A[i,2] xor A[i,3] xor A[i,4] +// D[i] = C[i-1] xor ROT(C[i+1], 1) +// E[i,j] = A[i,j] xor D[i] // In the Keccak reference, it corresponds to the `theta` algorithm. -// We use the first index of the state array as the x coordinate and the second index as the y coordinate. +// We use the first index of the state array as the i coordinate and the second index as the j coordinate. const theta = (state: Field[][]): Field[][] => { const stateA = state; // XOR the elements of each row together - // for all x in {0..4}: C[x] = A[x,0] xor A[x,1] xor A[x,2] xor A[x,3] xor A[x,4] + // for all i in {0..4}: C[i] = A[i,0] xor A[i,1] xor A[i,2] xor A[i,3] xor A[i,4] const stateC = stateA.map((row) => - row.reduce((acc, next) => Gadgets.xor(acc, next, KECCAK_WORD)) + row.reduce((acc, value) => Gadgets.xor(acc, value, KECCAK_WORD)) ); - // for all x in {0..4}: D[x] = C[x-1] xor ROT(C[x+1], 1) + // for all i in {0..4}: D[i] = C[i-1] xor ROT(C[i+1], 1) const stateD = Array.from({ length: KECCAK_DIM }, (_, x) => Gadgets.xor( stateC[(x + KECCAK_DIM - 1) % KECCAK_DIM], @@ -278,7 +275,7 @@ const theta = (state: Field[][]): Field[][] => { ) ); - // for all x in {0..4} and y in {0..4}: E[x,y] = A[x,y] xor D[x] + // for all i in {0..4} and j in {0..4}: E[i,j] = A[i,j] xor D[i] const stateE = stateA.map((row, index) => row.map((elem) => Gadgets.xor(elem, stateD[index], KECCAK_WORD)) ); @@ -287,40 +284,40 @@ const theta = (state: Field[][]): Field[][] => { }; // Second and third steps in the compression step of Keccak for 64-bit words. -// pi: A[x,y] = ROT(E[x,y], r[x,y]) -// rho: A[x,y] = A'[y, 2x+3y mod KECCAK_DIM] -// piRho: B[y,2x+3y] = ROT(E[x,y], r[x,y]) +// pi: A[i,j] = ROT(E[i,j], r[i,j]) +// rho: A[i,j] = A'[j, 2i+3j mod KECCAK_DIM] +// piRho: B[j,2i+3j] = ROT(E[i,j], r[i,j]) // which is equivalent to the `rho` algorithm followed by the `pi` algorithm in the Keccak reference as follows: // rho: // A[0,0] = a[0,0] -// | x | = | 1 | -// | y | = | 0 | +// | i | = | 1 | +// | j | = | 0 | // for t = 0 to 23 do -// A[x,y] = ROT(a[x,y], (t+1)(t+2)/2 mod 64))) -// | x | = | 0 1 | | x | +// A[i,j] = ROT(a[i,j], (t+1)(t+2)/2 mod 64))) +// | i | = | 0 1 | | i | // | | = | | * | | -// | y | = | 2 3 | | y | +// | j | = | 2 3 | | j | // end for // pi: -// for x = 0 to 4 do -// for y = 0 to 4 do -// | X | = | 0 1 | | x | +// for i = 0 to 4 do +// for j = 0 to 4 do +// | I | = | 0 1 | | i | // | | = | | * | | -// | Y | = | 2 3 | | y | -// A[X,Y] = a[x,y] +// | J | = | 2 3 | | j | +// A[I,J] = a[i,j] // end for // end for -// We use the first index of the state array as the x coordinate and the second index as the y coordinate. +// We use the first index of the state array as the i coordinate and the second index as the j coordinate. function piRho(state: Field[][]): Field[][] { const stateE = state; const stateB: Field[][] = getKeccakStateZeros(); - // for all x in {0..4} and y in {0..4}: B[y,2x+3y] = ROT(E[x,y], r[x,y]) - for (let x = 0; x < KECCAK_DIM; x++) { - for (let y = 0; y < KECCAK_DIM; y++) { - stateB[y][(2 * x + 3 * y) % KECCAK_DIM] = Gadgets.rotate( - stateE[x][y], - ROT_TABLE[x][y], + // for all i in {0..4} and j in {0..4}: B[y,2x+3y] = ROT(E[i,j], r[i,j]) + for (let i = 0; i < KECCAK_DIM; i++) { + for (let j = 0; j < KECCAK_DIM; j++) { + stateB[j][(2 * i + 3 * j) % KECCAK_DIM] = Gadgets.rotate( + stateE[i][j], + ROT_TABLE[i][j], 'left' ); } @@ -330,26 +327,26 @@ function piRho(state: Field[][]): Field[][] { } // Fourth step of the compression function of Keccak for 64-bit words. -// F[x,y] = B[x,y] xor ((not B[x+1,y]) and B[x+2,y]) +// F[i,j] = B[i,j] xor ((not B[i+1,j]) and B[i+2,j]) // It corresponds to the chi algorithm in the Keccak reference. -// for y = 0 to 4 do -// for x = 0 to 4 do -// A[x,y] = a[x,y] xor ((not a[x+1,y]) and a[x+2,y]) +// for j = 0 to 4 do +// for i = 0 to 4 do +// A[i,j] = a[i,j] xor ((not a[i+1,j]) and a[i+2,j]) // end for // end for function chi(state: Field[][]): Field[][] { const stateB = state; const stateF = getKeccakStateZeros(); - // for all x in {0..4} and y in {0..4}: F[x,y] = B[x,y] xor ((not B[x+1,y]) and B[x+2,y]) - for (let x = 0; x < KECCAK_DIM; x++) { - for (let y = 0; y < KECCAK_DIM; y++) { - stateF[x][y] = Gadgets.xor( - stateB[x][y], + // for all i in {0..4} and j in {0..4}: F[i,j] = B[i,j] xor ((not B[i+1,j]) and B[i+2,j]) + for (let i = 0; i < KECCAK_DIM; i++) { + for (let j = 0; j < KECCAK_DIM; j++) { + stateF[i][j] = Gadgets.xor( + stateB[i][j], Gadgets.and( // We can use unchecked NOT because the length of the input is constrained to be 64 bits thanks to the fact that it is the output of a previous Xor64 - Gadgets.not(stateB[(x + 1) % KECCAK_DIM][y], KECCAK_WORD, false), - stateB[(x + 2) % KECCAK_DIM][y], + Gadgets.not(stateB[(i + 1) % KECCAK_DIM][j], KECCAK_WORD, false), + stateB[(i + 2) % KECCAK_DIM][j], KECCAK_WORD ), KECCAK_WORD @@ -383,10 +380,7 @@ function round(state: Field[][], rc: Field): Field[][] { // Keccak permutation function with a constant number of rounds function permutation(state: Field[][], rc: Field[]): Field[][] { - return rc.reduce( - (currentState, rcValue) => round(currentState, rcValue), - state - ); + return rc.reduce((acc, value) => round(acc, value), state); } // KECCAK SPONGE @@ -403,16 +397,16 @@ function absorb( // (capacity / 8) zero bytes const zeros = Array(capacity / 8).fill(Field.from(0)); - for (let i = 0; i < paddedMessage.length; i += rate / 8) { + for (let idx = 0; idx < paddedMessage.length; idx += rate / 8) { // split into blocks of rate bits // for each block of rate bits in the padded message -> this is rate/8 bytes - const block = paddedMessage.slice(i, i + rate / 8); + const block = paddedMessage.slice(idx, idx + rate / 8); // pad the block with 0s to up to 1600 bits const paddedBlock = block.concat(zeros); // padded with zeros each block until they are 1600 bit long assert( paddedBlock.length * 8 === KECCAK_STATE_LENGTH, - 'improper Keccak block length' + `improper Keccak block length (should be ${KECCAK_STATE_LENGTH})` ); const blockState = getKeccakStateOfBytes(paddedBlock); // xor the state with the padded block @@ -438,8 +432,8 @@ function squeeze( start: number, length: number ) => { - for (let i = 0; i < length; i++) { - outputArray[start + i] = bytestring[i]; + for (let idx = 0; idx < length; idx++) { + outputArray[start + idx] = bytestring[idx]; } }; @@ -487,7 +481,7 @@ function sponge( } // load round constants into Fields - let rc = exists(24, () => TupleN.fromArray(24, ROUND_CONSTANTS)); + const rc = exists(24, () => TupleN.fromArray(24, ROUND_CONSTANTS)); // absorb const state = absorb(paddedMessage, capacity, rate, rc); @@ -518,37 +512,34 @@ function hash( capacity: number, nistVersion: boolean ): Field[] { + // Throw errors if used improperly assert(capacity > 0, 'capacity must be positive'); assert( capacity < KECCAK_STATE_LENGTH, - 'capacity must be less than KECCAK_STATE_LENGTH' + `capacity must be less than ${KECCAK_STATE_LENGTH}` ); assert(length > 0, 'length must be positive'); assert(length % 8 === 0, 'length must be a multiple of 8'); // Input endianness conversion - let messageFormatted = inpEndian === 'Big' ? message : message.reverse(); + const messageFormatted = inpEndian === 'Big' ? message : message.reverse(); // Check each Field input is 8 bits at most if it was not done before at creation time - if (byteChecks) { - checkBytes(messageFormatted); - } + byteChecks && checkBytes(messageFormatted); const rate = KECCAK_STATE_LENGTH - capacity; - let padded; - if (nistVersion) { - padded = padNist(messageFormatted, rate); - } else { - padded = pad101(messageFormatted, rate); - } + const padded = + nistVersion === true + ? padNist(messageFormatted, rate) + : pad101(messageFormatted, rate); const hash = sponge(padded, length, capacity, rate); // Always check each Field output is 8 bits at most because they are created here checkBytes(hash); - // Set input to desired endianness + // Output endianness conversion const hashFormatted = outEndian === 'Big' ? hash : hash.reverse(); return hashFormatted; @@ -563,26 +554,18 @@ function nistSha3( outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false ): Field[] { - let output: Field[]; - switch (len) { case 224: - output = hash(inpEndian, outEndian, byteChecks, message, 224, 448, true); - break; + return hash(inpEndian, outEndian, byteChecks, message, 224, 448, true); case 256: - output = hash(inpEndian, outEndian, byteChecks, message, 256, 512, true); - break; + return hash(inpEndian, outEndian, byteChecks, message, 256, 512, true); case 384: - output = hash(inpEndian, outEndian, byteChecks, message, 384, 768, true); - break; + return hash(inpEndian, outEndian, byteChecks, message, 384, 768, true); case 512: - output = hash(inpEndian, outEndian, byteChecks, message, 512, 1024, true); - break; + return hash(inpEndian, outEndian, byteChecks, message, 512, 1024, true); default: throw new Error('Invalid length'); } - - return output; } // Gadget for Keccak hash function for the parameters used in Ethereum. diff --git a/src/lib/keccak.unit-test.ts b/src/lib/keccak.unit-test.ts index e72e655093..34e497804e 100644 --- a/src/lib/keccak.unit-test.ts +++ b/src/lib/keccak.unit-test.ts @@ -7,12 +7,15 @@ import { Random } from './testing/random.js'; import { array, equivalentAsync, fieldWithRng } from './testing/equivalent.js'; import { constraintSystem, contains } from './testing/constraint-system.js'; +// TODO(jackryanservia): Add test to assert fail for byte that's larger than 255 +// TODO(jackryanservia): Add random length with three runs + const PREIMAGE_LENGTH = 75; const RUNS = 1; -let uint = (length: number) => fieldWithRng(Random.biguint(length)); +const uint = (length: number) => fieldWithRng(Random.biguint(length)); -let Keccak256 = ZkProgram({ +const Keccak256 = ZkProgram({ name: 'keccak256', publicInput: Provable.Array(Field, PREIMAGE_LENGTH), publicOutput: Provable.Array(Field, 32), @@ -43,12 +46,12 @@ await equivalentAsync( { runs: RUNS } )( (x) => { - let uint8Array = new Uint8Array(x.map(Number)); - let result = keccak_256(uint8Array); + const uint8Array = new Uint8Array(x.map(Number)); + const result = keccak_256(uint8Array); return Array.from(result).map(BigInt); }, async (x) => { - let proof = await Keccak256.ethereum(x); + const proof = await Keccak256.ethereum(x); return proof.publicOutput; } ); @@ -61,17 +64,17 @@ await equivalentAsync( { runs: RUNS } )( (x) => { - let thing = x.map(Number); - let result = sha3_256(new Uint8Array(thing)); + const thing = x.map(Number); + const result = sha3_256(new Uint8Array(thing)); return Array.from(result).map(BigInt); }, async (x) => { - let proof = await Keccak256.nistSha3(x); + const proof = await Keccak256.nistSha3(x); return proof.publicOutput; } ); -// let Keccak512 = ZkProgram({ +// const Keccak512 = ZkProgram({ // name: 'keccak512', // publicInput: Provable.Array(Field, PREIMAGE_LENGTH), // publicOutput: Provable.Array(Field, 64), @@ -101,12 +104,12 @@ await equivalentAsync( // { runs: RUNS } // )( // (x) => { -// let uint8Array = new Uint8Array(x.map(Number)); -// let result = keccak_512(uint8Array); +// const uint8Array = new Uint8Array(x.map(Number)); +// const result = keccak_512(uint8Array); // return Array.from(result).map(BigInt); // }, // async (x) => { -// let proof = await Keccak512.preNist(x); +// const proof = await Keccak512.preNist(x); // return proof.publicOutput; // } // ); @@ -119,12 +122,12 @@ await equivalentAsync( // { runs: RUNS } // )( // (x) => { -// let thing = x.map(Number); -// let result = sha3_512(new Uint8Array(thing)); +// const thing = x.map(Number); +// const result = sha3_512(new Uint8Array(thing)); // return Array.from(result).map(BigInt); // }, // async (x) => { -// let proof = await Keccak512.nistSha3(x); +// const proof = await Keccak512.nistSha3(x); // return proof.publicOutput; // } // ); From 6e477e260e9dc27ec3be4ee534695cfc76204a87 Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Mon, 11 Dec 2023 18:48:44 +0000 Subject: [PATCH 04/13] Fix round constants not constrained --- src/lib/keccak.ts | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index 813b53d576..48e71658d1 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -2,7 +2,6 @@ import { Field } from './field.js'; import { Gadgets } from './gadgets/gadgets.js'; import { assert } from './errors.js'; import { existsOne, exists } from './gadgets/common.js'; -import { TupleN } from './util/types.js'; import { rangeCheck8 } from './gadgets/range-check.js'; export { Keccak }; @@ -359,17 +358,17 @@ function chi(state: Field[][]): Field[][] { // Fifth step of the permutation function of Keccak for 64-bit words. // It takes the word located at the position (0,0) of the state and XORs it with the round constant. -function iota(state: Field[][], rc: Field): Field[][] { +function iota(state: Field[][], rc: bigint): Field[][] { const stateG = state; - stateG[0][0] = Gadgets.xor(stateG[0][0], rc, KECCAK_WORD); + stateG[0][0] = Gadgets.xor(stateG[0][0], Field.from(rc), KECCAK_WORD); return stateG; } // One round of the Keccak permutation function. // iota o chi o pi o rho o theta -function round(state: Field[][], rc: Field): Field[][] { +function round(state: Field[][], rc: bigint): Field[][] { const stateA = state; const stateE = theta(stateA); const stateB = piRho(stateE); @@ -379,7 +378,7 @@ function round(state: Field[][], rc: Field): Field[][] { } // Keccak permutation function with a constant number of rounds -function permutation(state: Field[][], rc: Field[]): Field[][] { +function permutation(state: Field[][], rc: bigint[]): Field[][] { return rc.reduce((acc, value) => round(acc, value), state); } @@ -390,7 +389,7 @@ function absorb( paddedMessage: Field[], capacity: number, rate: number, - rc: Field[] + rc: bigint[] ): Field[][] { let state = getKeccakStateZeros(); @@ -423,7 +422,7 @@ function squeeze( state: Field[][], length: number, rate: number, - rc: Field[] + rc: bigint[] ): Field[] { // Copies a section of bytes in the bytestring into the output array const copy = ( @@ -480,8 +479,8 @@ function sponge( throw new Error('Invalid padded message length'); } - // load round constants into Fields - const rc = exists(24, () => TupleN.fromArray(24, ROUND_CONSTANTS)); + // load round constants + const rc = ROUND_CONSTANTS; // absorb const state = absorb(paddedMessage, capacity, rate, rc); From 601e07aed2e1f23dbfdb4cbf3bfd53e3f259e83a Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Tue, 12 Dec 2023 02:21:51 +0000 Subject: [PATCH 05/13] Remove switch statements for length --- src/lib/keccak.ts | 34 ++++++---------------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index 48e71658d1..d7890353dd 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -9,7 +9,7 @@ export { Keccak }; const Keccak = { /** TODO */ nistSha3( - len: number, + len: 224 | 256 | 384 | 512, message: Field[], inpEndian: 'Big' | 'Little' = 'Big', outEndian: 'Big' | 'Little' = 'Big', @@ -28,7 +28,7 @@ const Keccak = { }, /** TODO */ preNist( - len: number, + len: 224 | 256 | 384 | 512, message: Field[], inpEndian: 'Big' | 'Little' = 'Big', outEndian: 'Big' | 'Little' = 'Big', @@ -547,24 +547,13 @@ function hash( // Gadget for NIST SHA-3 function for output lengths 224/256/384/512. // Input and output endianness can be specified. Default is big endian. function nistSha3( - len: number, + len: 224 | 256 | 384 | 512, message: Field[], inpEndian: 'Big' | 'Little' = 'Big', outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false ): Field[] { - switch (len) { - case 224: - return hash(inpEndian, outEndian, byteChecks, message, 224, 448, true); - case 256: - return hash(inpEndian, outEndian, byteChecks, message, 256, 512, true); - case 384: - return hash(inpEndian, outEndian, byteChecks, message, 384, 768, true); - case 512: - return hash(inpEndian, outEndian, byteChecks, message, 512, 1024, true); - default: - throw new Error('Invalid length'); - } + return hash(inpEndian, outEndian, byteChecks, message, len, 2 * len, true); } // Gadget for Keccak hash function for the parameters used in Ethereum. @@ -582,22 +571,11 @@ function ethereum( // Input and output endianness can be specified. Default is big endian. // Note that when calling with output length 256 this is equivalent to the ethereum function function preNist( - len: number, + len: 224 | 256 | 384 | 512, message: Field[], inpEndian: 'Big' | 'Little' = 'Big', outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false ): Field[] { - switch (len) { - case 224: - return hash(inpEndian, outEndian, byteChecks, message, 224, 448, false); - case 256: - return ethereum(message, inpEndian, outEndian, byteChecks); - case 384: - return hash(inpEndian, outEndian, byteChecks, message, 384, 768, false); - case 512: - return hash(inpEndian, outEndian, byteChecks, message, 512, 1024, false); - default: - throw new Error('Invalid length'); - } + return hash(inpEndian, outEndian, byteChecks, message, len, 2 * len, false); } From 8fc218b96f6b8bf858d6d0b234667a58ea7752ce Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Tue, 12 Dec 2023 02:42:46 +0000 Subject: [PATCH 06/13] Remove endian conversion --- src/lib/keccak.ts | 55 +++++++++++------------------------------------ 1 file changed, 12 insertions(+), 43 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index d7890353dd..585348c51a 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -1,7 +1,7 @@ import { Field } from './field.js'; import { Gadgets } from './gadgets/gadgets.js'; import { assert } from './errors.js'; -import { existsOne, exists } from './gadgets/common.js'; +import { existsOne } from './gadgets/common.js'; import { rangeCheck8 } from './gadgets/range-check.js'; export { Keccak }; @@ -11,30 +11,21 @@ const Keccak = { nistSha3( len: 224 | 256 | 384 | 512, message: Field[], - inpEndian: 'Big' | 'Little' = 'Big', - outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false ): Field[] { - return nistSha3(len, message, inpEndian, outEndian, byteChecks); + return nistSha3(len, message, byteChecks); }, /** TODO */ - ethereum( - message: Field[], - inpEndian: 'Big' | 'Little' = 'Big', - outEndian: 'Big' | 'Little' = 'Big', - byteChecks: boolean = false - ): Field[] { - return ethereum(message, inpEndian, outEndian, byteChecks); + ethereum(message: Field[], byteChecks: boolean = false): Field[] { + return ethereum(message, byteChecks); }, /** TODO */ preNist( len: 224 | 256 | 384 | 512, message: Field[], - inpEndian: 'Big' | 'Little' = 'Big', - outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false ): Field[] { - return preNist(len, message, inpEndian, outEndian, byteChecks); + return preNist(len, message, byteChecks); }, }; @@ -503,8 +494,6 @@ function checkBytes(inputs: Field[]): void { // - the 10*1 pad will take place after the message, until reaching the bit length rate. // - then, {0} pad will take place to finish the 1600 bits of the state. function hash( - inpEndian: 'Big' | 'Little' = 'Big', - outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false, message: Field[] = [], length: number, @@ -520,62 +509,42 @@ function hash( assert(length > 0, 'length must be positive'); assert(length % 8 === 0, 'length must be a multiple of 8'); - // Input endianness conversion - const messageFormatted = inpEndian === 'Big' ? message : message.reverse(); - // Check each Field input is 8 bits at most if it was not done before at creation time - byteChecks && checkBytes(messageFormatted); + byteChecks && checkBytes(message); const rate = KECCAK_STATE_LENGTH - capacity; const padded = - nistVersion === true - ? padNist(messageFormatted, rate) - : pad101(messageFormatted, rate); + nistVersion === true ? padNist(message, rate) : pad101(message, rate); const hash = sponge(padded, length, capacity, rate); // Always check each Field output is 8 bits at most because they are created here checkBytes(hash); - // Output endianness conversion - const hashFormatted = outEndian === 'Big' ? hash : hash.reverse(); - - return hashFormatted; + return hash; } // Gadget for NIST SHA-3 function for output lengths 224/256/384/512. -// Input and output endianness can be specified. Default is big endian. function nistSha3( len: 224 | 256 | 384 | 512, message: Field[], - inpEndian: 'Big' | 'Little' = 'Big', - outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false ): Field[] { - return hash(inpEndian, outEndian, byteChecks, message, len, 2 * len, true); + return hash(byteChecks, message, len, 2 * len, true); } // Gadget for Keccak hash function for the parameters used in Ethereum. -// Input and output endianness can be specified. Default is big endian. -function ethereum( - message: Field[] = [], - inpEndian: 'Big' | 'Little' = 'Big', - outEndian: 'Big' | 'Little' = 'Big', - byteChecks: boolean = false -): Field[] { - return hash(inpEndian, outEndian, byteChecks, message, 256, 512, false); +function ethereum(message: Field[] = [], byteChecks: boolean = false): Field[] { + return hash(byteChecks, message, 256, 512, false); } // Gadget for pre-NIST SHA-3 function for output lengths 224/256/384/512. -// Input and output endianness can be specified. Default is big endian. // Note that when calling with output length 256 this is equivalent to the ethereum function function preNist( len: 224 | 256 | 384 | 512, message: Field[], - inpEndian: 'Big' | 'Little' = 'Big', - outEndian: 'Big' | 'Little' = 'Big', byteChecks: boolean = false ): Field[] { - return hash(inpEndian, outEndian, byteChecks, message, len, 2 * len, false); + return hash(byteChecks, message, len, 2 * len, false); } From 86ccad775045801c02e9f7a4bae690c2bb34fdcf Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Tue, 12 Dec 2023 02:50:17 +0000 Subject: [PATCH 07/13] Replace copy in squeeze with splice builtin --- src/lib/keccak.ts | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index 585348c51a..d64963b25c 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -415,18 +415,6 @@ function squeeze( rate: number, rc: bigint[] ): Field[] { - // Copies a section of bytes in the bytestring into the output array - const copy = ( - bytestring: Field[], - outputArray: Field[], - start: number, - length: number - ) => { - for (let idx = 0; idx < length; idx++) { - outputArray[start + idx] = bytestring[idx]; - } - }; - let newState = state; // bytes per squeeze @@ -440,7 +428,8 @@ function squeeze( // first state to be squeezed const bytestring = keccakStateToBytes(state); const outputBytes = bytestring.slice(0, bytesPerSqueeze); - copy(outputBytes, outputArray, 0, bytesPerSqueeze); + // copies a section of bytes in the bytestring into the output array + outputArray.splice(0, bytesPerSqueeze, ...outputBytes); // for the rest of squeezes for (let i = 1; i < squeezes; i++) { @@ -449,7 +438,8 @@ function squeeze( // append the output of the permutation function to the output const bytestringI = keccakStateToBytes(state); const outputBytesI = bytestringI.slice(0, bytesPerSqueeze); - copy(outputBytesI, outputArray, bytesPerSqueeze * i, bytesPerSqueeze); + // copies a section of bytes in the bytestring into the output array + outputArray.splice(bytesPerSqueeze * i, bytesPerSqueeze, ...outputBytesI); } // Obtain the hash selecting the first bitlength/8 bytes of the output array From 3d3977fb28c313fa6d4d8c1578f2c620be03366b Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Tue, 12 Dec 2023 04:29:26 +0000 Subject: [PATCH 08/13] Combines pad functions and change rate argument to bytes --- src/lib/keccak.ts | 42 +++++++++--------------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index d64963b25c..bac20da5d8 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -190,49 +190,26 @@ function keccakStateXor(a: Field[][], b: Field[][]): Field[][] { // Computes the number of required extra bytes to pad a message of length bytes function bytesToPad(rate: number, length: number): number { - return Math.floor(rate / 8) - (length % Math.floor(rate / 8)); + return rate - (length % rate); } // Pads a message M as: // M || pad[x](|M|) -// Padding rule 0x06 ..0*..1. -// The padded message vector will start with the message vector -// followed by the 0*1 rule to fulfill a length that is a multiple of rate (in bytes) -// (This means a 0110 sequence, followed with as many 0s as needed, and a final 1 bit) -function padNist(message: Field[], rate: number): Field[] { +// The padded message will begin with the message and end with the padding rule (below) to fulfill a length that is a multiple of rate (in bytes). +// If nist is true, then the padding rule is 0x06 ..0*..1. +// If nist is false, then the padding rule is 10*1. +function pad(message: Field[], rate: number, nist: boolean): Field[] { // Find out desired length of the padding in bytes // If message is already rate bits, need to pad full rate again const extraBytes = bytesToPad(rate, message.length); // 0x06 0x00 ... 0x00 0x80 or 0x86 - const last = Field.from(BigInt(2) ** BigInt(7)); + const first = nist ? 0x06n : 0x01n; + const last = 0x80n; // Create the padding vector const pad = Array(extraBytes).fill(Field.from(0)); - pad[0] = Field.from(6); - pad[extraBytes - 1] = pad[extraBytes - 1].add(last); - - // Return the padded message - return [...message, ...pad]; -} - -// Pads a message M as: -// M || pad[x](|M|) -// Padding rule 10*1. -// The padded message vector will start with the message vector -// followed by the 10*1 rule to fulfill a length that is a multiple of rate (in bytes) -// (This means a 1 bit, followed with as many 0s as needed, and a final 1 bit) -function pad101(message: Field[], rate: number): Field[] { - // Find out desired length of the padding in bytes - // If message is already rate bits, need to pad full rate again - const extraBytes = bytesToPad(rate, message.length); - - // 0x01 0x00 ... 0x00 0x80 or 0x81 - const last = Field.from(BigInt(2) ** BigInt(7)); - - // Create the padding vector - const pad = Array(extraBytes).fill(Field.from(0)); - pad[0] = Field.from(1); + pad[0] = Field.from(first); pad[extraBytes - 1] = pad[extraBytes - 1].add(last); // Return the padded message @@ -504,8 +481,7 @@ function hash( const rate = KECCAK_STATE_LENGTH - capacity; - const padded = - nistVersion === true ? padNist(message, rate) : pad101(message, rate); + const padded = pad(message, rate, nistVersion); const hash = sponge(padded, length, capacity, rate); From 45fb784290f2cdce18f36e69fc431aec71924df7 Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Tue, 12 Dec 2023 07:06:41 +0000 Subject: [PATCH 09/13] Changes all rate, length, and capacity units to bytes --- src/lib/keccak.ts | 68 +++++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index bac20da5d8..55782214f3 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -46,6 +46,9 @@ const BYTES_PER_WORD = KECCAK_WORD / 8; // Length of the state in bits, meaning the 5x5 matrix of words in bits (1600) const KECCAK_STATE_LENGTH = KECCAK_DIM ** 2 * KECCAK_WORD; +// Length of the state in bytes, meaning the 5x5 matrix of words in bytes (200) +const KECCAK_STATE_LENGTH_BYTES = KECCAK_STATE_LENGTH / 8; + // Number of rounds of the Keccak permutation function depending on the value `l` (24) const KECCAK_ROUNDS = 12 + 2 * KECCAK_ELL; @@ -141,9 +144,8 @@ function getKeccakStateOfBytes(bytestring: Field[]): Field[][] { // Converts a state of Fields to a list of bytes as Fields and creates constraints for it function keccakStateToBytes(state: Field[][]): Field[] { - const stateLengthInBytes = KECCAK_STATE_LENGTH / 8; const bytestring: Field[] = Array.from( - { length: stateLengthInBytes }, + { length: KECCAK_STATE_LENGTH_BYTES }, (_, idx) => existsOne(() => { // idx = k + 8 * ((dim * j) + i) @@ -361,19 +363,19 @@ function absorb( ): Field[][] { let state = getKeccakStateZeros(); - // (capacity / 8) zero bytes - const zeros = Array(capacity / 8).fill(Field.from(0)); + // array of capacity zero bytes + const zeros = Array(capacity).fill(Field.from(0)); - for (let idx = 0; idx < paddedMessage.length; idx += rate / 8) { + for (let idx = 0; idx < paddedMessage.length; idx += rate) { // split into blocks of rate bits - // for each block of rate bits in the padded message -> this is rate/8 bytes - const block = paddedMessage.slice(idx, idx + rate / 8); - // pad the block with 0s to up to 1600 bits + // for each block of rate bits in the padded message -> this is rate bytes + const block = paddedMessage.slice(idx, idx + rate); + // pad the block with 0s to up to 200 bytes const paddedBlock = block.concat(zeros); - // padded with zeros each block until they are 1600 bit long + // padded with zeros each block until they are 200 bytes long assert( - paddedBlock.length * 8 === KECCAK_STATE_LENGTH, - `improper Keccak block length (should be ${KECCAK_STATE_LENGTH})` + paddedBlock.length === KECCAK_STATE_LENGTH_BYTES, + `improper Keccak block length (should be ${KECCAK_STATE_LENGTH_BYTES})` ); const blockState = getKeccakStateOfBytes(paddedBlock); // xor the state with the padded block @@ -394,19 +396,17 @@ function squeeze( ): Field[] { let newState = state; - // bytes per squeeze - const bytesPerSqueeze = rate / 8; // number of squeezes const squeezes = Math.floor(length / rate) + 1; // multiple of rate that is larger than output_length, in bytes - const outputLength = squeezes * bytesPerSqueeze; + const outputLength = squeezes * rate; // array with sufficient space to store the output const outputArray = Array(outputLength).fill(Field.from(0)); // first state to be squeezed const bytestring = keccakStateToBytes(state); - const outputBytes = bytestring.slice(0, bytesPerSqueeze); + const outputBytes = bytestring.slice(0, rate); // copies a section of bytes in the bytestring into the output array - outputArray.splice(0, bytesPerSqueeze, ...outputBytes); + outputArray.splice(0, rate, ...outputBytes); // for the rest of squeezes for (let i = 1; i < squeezes; i++) { @@ -414,18 +414,17 @@ function squeeze( newState = permutation(newState, rc); // append the output of the permutation function to the output const bytestringI = keccakStateToBytes(state); - const outputBytesI = bytestringI.slice(0, bytesPerSqueeze); + const outputBytesI = bytestringI.slice(0, rate); // copies a section of bytes in the bytestring into the output array - outputArray.splice(bytesPerSqueeze * i, bytesPerSqueeze, ...outputBytesI); + outputArray.splice(rate * i, rate, ...outputBytesI); } - // Obtain the hash selecting the first bitlength/8 bytes of the output array - const hashed = outputArray.slice(0, length / 8); + // Obtain the hash selecting the first bitlength bytes of the output array + const hashed = outputArray.slice(0, length); return hashed; } -// Keccak sponge function for 1600 bits of state width -// Need to split the message into blocks of 1088 bits. +// Keccak sponge function for 200 bytes of state width function sponge( paddedMessage: Field[], length: number, @@ -433,7 +432,7 @@ function sponge( rate: number ): Field[] { // check that the padded message is a multiple of rate - if ((paddedMessage.length * 8) % rate !== 0) { + if (paddedMessage.length % rate !== 0) { throw new Error('Invalid padded message length'); } @@ -459,7 +458,7 @@ function checkBytes(inputs: Field[]): void { // The message will be parsed as follows: // - the first byte of the message will be the least significant byte of the first word of the state (A[0][0]) // - the 10*1 pad will take place after the message, until reaching the bit length rate. -// - then, {0} pad will take place to finish the 1600 bits of the state. +// - then, {0} pad will take place to finish the 200 bytes of the state. function hash( byteChecks: boolean = false, message: Field[] = [], @@ -470,16 +469,15 @@ function hash( // Throw errors if used improperly assert(capacity > 0, 'capacity must be positive'); assert( - capacity < KECCAK_STATE_LENGTH, - `capacity must be less than ${KECCAK_STATE_LENGTH}` + capacity < KECCAK_STATE_LENGTH_BYTES, + `capacity must be less than ${KECCAK_STATE_LENGTH_BYTES}` ); assert(length > 0, 'length must be positive'); - assert(length % 8 === 0, 'length must be a multiple of 8'); // Check each Field input is 8 bits at most if it was not done before at creation time byteChecks && checkBytes(message); - const rate = KECCAK_STATE_LENGTH - capacity; + const rate = KECCAK_STATE_LENGTH_BYTES - capacity; const padded = pad(message, rate, nistVersion); @@ -497,12 +495,7 @@ function nistSha3( message: Field[], byteChecks: boolean = false ): Field[] { - return hash(byteChecks, message, len, 2 * len, true); -} - -// Gadget for Keccak hash function for the parameters used in Ethereum. -function ethereum(message: Field[] = [], byteChecks: boolean = false): Field[] { - return hash(byteChecks, message, 256, 512, false); + return hash(byteChecks, message, len / 8, len / 4, true); } // Gadget for pre-NIST SHA-3 function for output lengths 224/256/384/512. @@ -512,5 +505,10 @@ function preNist( message: Field[], byteChecks: boolean = false ): Field[] { - return hash(byteChecks, message, len, 2 * len, false); + return hash(byteChecks, message, len / 8, len / 4, false); +} + +// Gadget for Keccak hash function for the parameters used in Ethereum. +function ethereum(message: Field[] = [], byteChecks: boolean = false): Field[] { + return preNist(256, message, byteChecks); } From 1003cc538c65f67f7bba914cd2ae0b6132dbe407 Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Tue, 12 Dec 2023 07:50:03 +0000 Subject: [PATCH 10/13] Make Keccak test run for random preimage/digest length --- src/lib/keccak.unit-test.ts | 152 +++++++++++++++--------------------- 1 file changed, 65 insertions(+), 87 deletions(-) diff --git a/src/lib/keccak.unit-test.ts b/src/lib/keccak.unit-test.ts index 34e497804e..69da64a7cf 100644 --- a/src/lib/keccak.unit-test.ts +++ b/src/lib/keccak.unit-test.ts @@ -1,139 +1,117 @@ import { Field } from './field.js'; import { Provable } from './provable.js'; import { Keccak } from './keccak.js'; -import { keccak_256, sha3_256, keccak_512, sha3_512 } from '@noble/hashes/sha3'; import { ZkProgram } from './proof_system.js'; import { Random } from './testing/random.js'; import { array, equivalentAsync, fieldWithRng } from './testing/equivalent.js'; import { constraintSystem, contains } from './testing/constraint-system.js'; +import { + keccak_224, + keccak_256, + keccak_384, + keccak_512, + sha3_224, + sha3_256, + sha3_384, + sha3_512, +} from '@noble/hashes/sha3'; -// TODO(jackryanservia): Add test to assert fail for byte that's larger than 255 -// TODO(jackryanservia): Add random length with three runs - -const PREIMAGE_LENGTH = 75; const RUNS = 1; +const testImplementations = { + sha3: { + 224: sha3_224, + 256: sha3_256, + 384: sha3_384, + 512: sha3_512, + }, + preNist: { + 224: keccak_224, + 256: keccak_256, + 384: keccak_384, + 512: keccak_512, + }, +}; + const uint = (length: number) => fieldWithRng(Random.biguint(length)); -const Keccak256 = ZkProgram({ - name: 'keccak256', - publicInput: Provable.Array(Field, PREIMAGE_LENGTH), - publicOutput: Provable.Array(Field, 32), +// Choose a test length at random +const digestLength = [224, 256, 384, 512][Math.floor(Math.random() * 4)] as + | 224 + | 256 + | 384 + | 512; + +// Chose a random preimage length +const preImageLength = digestLength / Math.floor(Math.random() * 4 + 2); + +// No need to test Ethereum because it's just a special case of preNist +const KeccakProgram = ZkProgram({ + name: 'keccak-test', + publicInput: Provable.Array(Field, preImageLength), + publicOutput: Provable.Array(Field, digestLength / 8), methods: { - ethereum: { + nistSha3: { privateInputs: [], method(preImage) { - return Keccak.ethereum(preImage); + return Keccak.nistSha3(digestLength, preImage); }, }, - // No need for preNist Keccak_256 because it's identical to ethereum - nistSha3: { + preNist: { privateInputs: [], method(preImage) { - return Keccak.nistSha3(256, preImage); + return Keccak.preNist(digestLength, preImage); }, }, }, }); -await Keccak256.compile(); +await KeccakProgram.compile(); +// SHA-3 await equivalentAsync( { - from: [array(uint(8), PREIMAGE_LENGTH)], - to: array(uint(8), 32), + from: [array(uint(8), preImageLength)], + to: array(uint(8), digestLength / 8), }, { runs: RUNS } )( (x) => { - const uint8Array = new Uint8Array(x.map(Number)); - const result = keccak_256(uint8Array); + const byteArray = new Uint8Array(x.map(Number)); + const result = testImplementations.sha3[digestLength](byteArray); return Array.from(result).map(BigInt); }, async (x) => { - const proof = await Keccak256.ethereum(x); + const proof = await KeccakProgram.nistSha3(x); + await KeccakProgram.verify(proof); return proof.publicOutput; } ); +// PreNIST Keccak await equivalentAsync( { - from: [array(uint(8), PREIMAGE_LENGTH)], - to: array(uint(8), 32), + from: [array(uint(8), preImageLength)], + to: array(uint(8), digestLength / 8), }, { runs: RUNS } )( (x) => { - const thing = x.map(Number); - const result = sha3_256(new Uint8Array(thing)); + const byteArray = new Uint8Array(x.map(Number)); + const result = testImplementations.preNist[digestLength](byteArray); return Array.from(result).map(BigInt); }, async (x) => { - const proof = await Keccak256.nistSha3(x); + const proof = await KeccakProgram.preNist(x); + await KeccakProgram.verify(proof); return proof.publicOutput; } ); -// const Keccak512 = ZkProgram({ -// name: 'keccak512', -// publicInput: Provable.Array(Field, PREIMAGE_LENGTH), -// publicOutput: Provable.Array(Field, 64), -// methods: { -// preNist: { -// privateInputs: [], -// method(preImage) { -// return Keccak.preNist(512, preImage, 'Big', 'Big', true); -// }, -// }, -// nistSha3: { -// privateInputs: [], -// method(preImage) { -// return Keccak.nistSha3(512, preImage, 'Big', 'Big', true); -// }, -// }, -// }, -// }); - -// await Keccak512.compile(); - -// await equivalentAsync( -// { -// from: [array(uint(8), PREIMAGE_LENGTH)], -// to: array(uint(8), 64), -// }, -// { runs: RUNS } -// )( -// (x) => { -// const uint8Array = new Uint8Array(x.map(Number)); -// const result = keccak_512(uint8Array); -// return Array.from(result).map(BigInt); -// }, -// async (x) => { -// const proof = await Keccak512.preNist(x); -// return proof.publicOutput; -// } +// This takes a while and doesn't do much, so I commented it out +// Constraint system sanity check +// constraintSystem.fromZkProgram( +// KeccakTest, +// 'preNist', +// contains([['Generic'], ['Xor16'], ['Zero'], ['Rot64'], ['RangeCheck0']]) // ); - -// await equivalentAsync( -// { -// from: [array(uint(8), PREIMAGE_LENGTH)], -// to: array(uint(8), 64), -// }, -// { runs: RUNS } -// )( -// (x) => { -// const thing = x.map(Number); -// const result = sha3_512(new Uint8Array(thing)); -// return Array.from(result).map(BigInt); -// }, -// async (x) => { -// const proof = await Keccak512.nistSha3(x); -// return proof.publicOutput; -// } -// ); - -constraintSystem.fromZkProgram( - Keccak256, - 'ethereum', - contains([['Generic'], ['Xor16'], ['Zero'], ['Rot64'], ['RangeCheck0']]) -); From bfb99e25988adab75bbf5e66d1eab05990f1bac9 Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Tue, 12 Dec 2023 08:30:44 +0000 Subject: [PATCH 11/13] Fixes random preImageLength in Keccak unit test --- src/lib/keccak.unit-test.ts | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/lib/keccak.unit-test.ts b/src/lib/keccak.unit-test.ts index 69da64a7cf..70aa377724 100644 --- a/src/lib/keccak.unit-test.ts +++ b/src/lib/keccak.unit-test.ts @@ -42,14 +42,17 @@ const digestLength = [224, 256, 384, 512][Math.floor(Math.random() * 4)] as | 384 | 512; +// Digest length in bytes +const digestLengthBytes = digestLength / 8; + // Chose a random preimage length -const preImageLength = digestLength / Math.floor(Math.random() * 4 + 2); +const preImageLength = Math.floor(digestLength / (Math.random() * 4 + 2)); // No need to test Ethereum because it's just a special case of preNist const KeccakProgram = ZkProgram({ name: 'keccak-test', publicInput: Provable.Array(Field, preImageLength), - publicOutput: Provable.Array(Field, digestLength / 8), + publicOutput: Provable.Array(Field, digestLengthBytes), methods: { nistSha3: { privateInputs: [], @@ -72,7 +75,7 @@ await KeccakProgram.compile(); await equivalentAsync( { from: [array(uint(8), preImageLength)], - to: array(uint(8), digestLength / 8), + to: array(uint(8), digestLengthBytes), }, { runs: RUNS } )( @@ -92,7 +95,7 @@ await equivalentAsync( await equivalentAsync( { from: [array(uint(8), preImageLength)], - to: array(uint(8), digestLength / 8), + to: array(uint(8), digestLengthBytes), }, { runs: RUNS } )( From 898e761767650d34202c814eae48c2e23bb22cd2 Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Tue, 12 Dec 2023 09:01:57 +0000 Subject: [PATCH 12/13] Cleans up asserts --- src/lib/keccak.ts | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index 55782214f3..5c2e38b5bf 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -119,8 +119,10 @@ const getKeccakStateZeros = (): Field[][] => // Converts a list of bytes to a matrix of Field elements function getKeccakStateOfBytes(bytestring: Field[]): Field[][] { assert( - bytestring.length === 200, - 'improper bytestring length (should be 200)' + bytestring.length === KECCAK_DIM ** 2 * BYTES_PER_WORD, + `improper bytestring length (should be ${ + KECCAK_DIM ** 2 * BYTES_PER_WORD + }})` ); const bytestringArray = Array.from(bytestring); @@ -197,7 +199,7 @@ function bytesToPad(rate: number, length: number): number { // Pads a message M as: // M || pad[x](|M|) -// The padded message will begin with the message and end with the padding rule (below) to fulfill a length that is a multiple of rate (in bytes). +// The padded message will start with the message argument followed by the padding rule (below) to fulfill a length that is a multiple of rate (in bytes). // If nist is true, then the padding rule is 0x06 ..0*..1. // If nist is false, then the padding rule is 10*1. function pad(message: Field[], rate: number, nist: boolean): Field[] { @@ -361,6 +363,15 @@ function absorb( rate: number, rc: bigint[] ): Field[][] { + assert( + rate + capacity === KECCAK_STATE_LENGTH_BYTES, + `invalid rate or capacity (rate + capacity should be ${KECCAK_STATE_LENGTH_BYTES})` + ); + assert( + paddedMessage.length % rate === 0, + 'invalid padded message length (should be multiple of rate)' + ); + let state = getKeccakStateZeros(); // array of capacity zero bytes @@ -370,13 +381,9 @@ function absorb( // split into blocks of rate bits // for each block of rate bits in the padded message -> this is rate bytes const block = paddedMessage.slice(idx, idx + rate); - // pad the block with 0s to up to 200 bytes + // pad the block with 0s to up to KECCAK_STATE_LENGTH_BYTES bytes const paddedBlock = block.concat(zeros); - // padded with zeros each block until they are 200 bytes long - assert( - paddedBlock.length === KECCAK_STATE_LENGTH_BYTES, - `improper Keccak block length (should be ${KECCAK_STATE_LENGTH_BYTES})` - ); + // convert the padded block byte array to a Keccak state const blockState = getKeccakStateOfBytes(paddedBlock); // xor the state with the padded block const stateXor = keccakStateXor(state, blockState); From b45ba97022f1cd31d319df64e9a84d970782fadf Mon Sep 17 00:00:00 2001 From: jackryanservia <90076280+jackryanservia@users.noreply.github.com> Date: Tue, 12 Dec 2023 09:09:33 +0000 Subject: [PATCH 13/13] Removes loop in squeeze bc standard length+capacity only does one sequeeze --- src/lib/keccak.ts | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/src/lib/keccak.ts b/src/lib/keccak.ts index 5c2e38b5bf..7a0ccb7db1 100644 --- a/src/lib/keccak.ts +++ b/src/lib/keccak.ts @@ -395,16 +395,10 @@ function absorb( } // Squeeze state until it has a desired length in bits -function squeeze( - state: Field[][], - length: number, - rate: number, - rc: bigint[] -): Field[] { - let newState = state; - +function squeeze(state: Field[][], length: number, rate: number): Field[] { // number of squeezes const squeezes = Math.floor(length / rate) + 1; + assert(squeezes === 1, 'squeezes should be 1'); // multiple of rate that is larger than output_length, in bytes const outputLength = squeezes * rate; // array with sufficient space to store the output @@ -415,17 +409,6 @@ function squeeze( // copies a section of bytes in the bytestring into the output array outputArray.splice(0, rate, ...outputBytes); - // for the rest of squeezes - for (let i = 1; i < squeezes; i++) { - // apply the permutation function to the state - newState = permutation(newState, rc); - // append the output of the permutation function to the output - const bytestringI = keccakStateToBytes(state); - const outputBytesI = bytestringI.slice(0, rate); - // copies a section of bytes in the bytestring into the output array - outputArray.splice(rate * i, rate, ...outputBytesI); - } - // Obtain the hash selecting the first bitlength bytes of the output array const hashed = outputArray.slice(0, length); return hashed; @@ -450,7 +433,7 @@ function sponge( const state = absorb(paddedMessage, capacity, rate, rc); // squeeze - const hashed = squeeze(state, length, rate, rc); + const hashed = squeeze(state, length, rate); return hashed; }