diff --git a/crates/circuits/sha256-air/src/lib.rs b/crates/circuits/sha256-air/src/lib.rs index e5616b5921..48bdaee5f9 100644 --- a/crates/circuits/sha256-air/src/lib.rs +++ b/crates/circuits/sha256-air/src/lib.rs @@ -8,6 +8,7 @@ mod utils; pub use air::*; pub use columns::*; +pub use trace::*; pub use utils::*; #[cfg(test)] diff --git a/crates/circuits/sha256-air/src/tests.rs b/crates/circuits/sha256-air/src/tests.rs index 43721f8589..804d226a52 100644 --- a/crates/circuits/sha256-air/src/tests.rs +++ b/crates/circuits/sha256-air/src/tests.rs @@ -1,10 +1,7 @@ -use std::{array, borrow::BorrowMut, cmp::max, sync::Arc}; +use std::{array, cmp::max, sync::Arc}; -use openvm_circuit::{ - arch::{ - instructions::riscv::RV32_CELL_BITS, testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, - }, - utils::next_power_of_two_or_zero, +use openvm_circuit::arch::{ + instructions::riscv::RV32_CELL_BITS, testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip}, @@ -14,9 +11,7 @@ use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, BaseAir}, - p3_field::{AbstractField, Field, PrimeField32}, - p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut}, + p3_field::{Field, PrimeField32}, prover::types::AirProofInput, rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, Chip, ChipUsageGetter, @@ -25,8 +20,7 @@ use openvm_stark_sdk::utils::create_seeded_rng; use rand::Rng; use crate::{ - limbs_into_u32, Sha256Air, Sha256RoundCols, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, SHA256_H, - SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S, + Sha256Air, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK, }; // A wrapper AIR purely for testing purposes @@ -67,85 +61,12 @@ where fn generate_air_proof_input(self) -> AirProofInput { let air = self.air(); - let non_padded_height = self.current_trace_height(); - let height = next_power_of_two_or_zero(non_padded_height); - let width = self.trace_width(); - let mut values = Val::::zero_vec(height * width); - - struct BlockContext { - prev_hash: [u32; 8], - local_block_idx: u32, - global_block_idx: u32, - input: [u8; SHA256_BLOCK_U8S], - is_last_block: bool, - } - let mut block_ctx: Vec = Vec::with_capacity(self.records.len()); - let mut prev_hash = SHA256_H; - let mut local_block_idx = 0; - let mut global_block_idx = 1; - for (input, is_last_block) in self.records { - block_ctx.push(BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - }); - global_block_idx += 1; - if is_last_block { - local_block_idx = 0; - prev_hash = SHA256_H; - } else { - local_block_idx += 1; - prev_hash = Sha256Air::get_block_hash(&prev_hash, input); - } - } - // first pass - values - .par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK) - .zip(block_ctx) - .for_each(|(block, ctx)| { - let BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - } = ctx; - let input_words = array::from_fn(|i| { - limbs_into_u32::(array::from_fn(|j| { - input[i * SHA256_WORD_U8S + j] as u32 - })) - }); - self.air.sub_air.generate_block_trace( - block, - width, - 0, - &input_words, - self.bitwise_lookup_chip.as_ref(), - &prev_hash, - is_last_block, - global_block_idx, - local_block_idx, - &[[Val::::ZERO; 16]; 4], - ); - }); - // second pass: padding rows - values[width * non_padded_height..] - .par_chunks_mut(width) - .for_each(|row| { - let cols: &mut Sha256RoundCols> = row.borrow_mut(); - self.air.sub_air.generate_default_row(cols); - }); - // second pass: non-padding rows - values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - self.air.sub_air.generate_missing_cells(chunk, width, 0); - }); - - AirProofInput::simple(air, RowMajorMatrix::new(values, width), vec![]) + let trace = crate::generate_trace::>( + &self.air.sub_air, + &self.bitwise_lookup_chip, + self.records, + ); + AirProofInput::simple(air, trace, vec![]) } } diff --git a/crates/circuits/sha256-air/src/trace.rs b/crates/circuits/sha256-air/src/trace.rs index d5dfa6a256..c2d4f74b85 100644 --- a/crates/circuits/sha256-air/src/trace.rs +++ b/crates/circuits/sha256-air/src/trace.rs @@ -1,7 +1,12 @@ use std::{array, borrow::BorrowMut, ops::Range}; -use openvm_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupChip; -use openvm_stark_backend::p3_field::PrimeField32; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupChip, utils::next_power_of_two_or_zero, +}; +use openvm_stark_backend::{ + p3_air::BaseAir, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, + p3_maybe_rayon::prelude::*, +}; use sha2::{compress256, digest::generic_array::GenericArray}; use super::{ @@ -20,7 +25,7 @@ use crate::{ /// The first pass should do `get_block_trace` for every block and generate the invalid rows through `get_default_row` /// The second pass should go through all the blocks and call `generate_missing_values` impl Sha256Air { - /// This function takes the input_massage (should be already padded), the previous hash, + /// This function takes the input_message (padding not handled), the previous hash, /// and returns the new hash after processing the block input pub fn get_block_hash( prev_hash: &[u32; SHA256_HASH_WORDS], @@ -454,3 +459,89 @@ impl Sha256Air { } } } + +/// `records` consists of pairs of `(input_block, is_last_block)`. +pub fn generate_trace( + sub_air: &Sha256Air, + bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, + records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, +) -> RowMajorMatrix { + let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK; + let height = next_power_of_two_or_zero(non_padded_height); + let width = >::width(sub_air); + let mut values = F::zero_vec(height * width); + + struct BlockContext { + prev_hash: [u32; 8], + local_block_idx: u32, + global_block_idx: u32, + input: [u8; SHA256_BLOCK_U8S], + is_last_block: bool, + } + let mut block_ctx: Vec = Vec::with_capacity(records.len()); + let mut prev_hash = SHA256_H; + let mut local_block_idx = 0; + let mut global_block_idx = 1; + for (input, is_last_block) in records { + block_ctx.push(BlockContext { + prev_hash, + local_block_idx, + global_block_idx, + input, + is_last_block, + }); + global_block_idx += 1; + if is_last_block { + local_block_idx = 0; + prev_hash = SHA256_H; + } else { + local_block_idx += 1; + prev_hash = Sha256Air::get_block_hash(&prev_hash, input); + } + } + // first pass + values + .par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK) + .zip(block_ctx) + .for_each(|(block, ctx)| { + let BlockContext { + prev_hash, + local_block_idx, + global_block_idx, + input, + is_last_block, + } = ctx; + let input_words = array::from_fn(|i| { + limbs_into_u32::(array::from_fn(|j| { + input[i * SHA256_WORD_U8S + j] as u32 + })) + }); + sub_air.generate_block_trace( + block, + width, + 0, + &input_words, + bitwise_lookup_chip, + &prev_hash, + is_last_block, + global_block_idx, + local_block_idx, + &[[F::ZERO; 16]; 4], + ); + }); + // second pass: padding rows + values[width * non_padded_height..] + .par_chunks_mut(width) + .for_each(|row| { + let cols: &mut Sha256RoundCols = row.borrow_mut(); + sub_air.generate_default_row(cols); + }); + // second pass: non-padding rows + values[width..] + .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) + .take(non_padded_height / SHA256_ROWS_PER_BLOCK) + .for_each(|chunk| { + sub_air.generate_missing_cells(chunk, width, 0); + }); + RowMajorMatrix::new(values, width) +}