Skip to content

Commit

Permalink
chore: provide generate_trace function in sha256-air crate
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanpwang committed Dec 31, 2024
1 parent 4f369d2 commit 64d8f06
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 93 deletions.
1 change: 1 addition & 0 deletions crates/circuits/sha256-air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod utils;

pub use air::*;
pub use columns::*;
pub use trace::*;
pub use utils::*;

#[cfg(test)]
Expand Down
101 changes: 11 additions & 90 deletions crates/circuits/sha256-air/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::{array, borrow::BorrowMut, cmp::max, sync::Arc};
use std::{array, cmp::max, sync::Arc};

use openvm_circuit::{
arch::{
instructions::riscv::RV32_CELL_BITS, testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS,
},
utils::next_power_of_two_or_zero,
use openvm_circuit::arch::{
instructions::riscv::RV32_CELL_BITS, testing::VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS,
};
use openvm_circuit_primitives::{
bitwise_op_lookup::{BitwiseOperationLookupBus, BitwiseOperationLookupChip},
Expand All @@ -14,9 +11,7 @@ use openvm_stark_backend::{
config::{StarkGenericConfig, Val},
interaction::InteractionBuilder,
p3_air::{Air, BaseAir},
p3_field::{AbstractField, Field, PrimeField32},
p3_matrix::dense::RowMajorMatrix,
p3_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut},
p3_field::{Field, PrimeField32},
prover::types::AirProofInput,
rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir},
Chip, ChipUsageGetter,
Expand All @@ -25,8 +20,7 @@ use openvm_stark_sdk::utils::create_seeded_rng;
use rand::Rng;

use crate::{
limbs_into_u32, Sha256Air, Sha256RoundCols, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, SHA256_H,
SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S,
Sha256Air, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK,
};

// A wrapper AIR purely for testing purposes
Expand Down Expand Up @@ -67,85 +61,12 @@ where

fn generate_air_proof_input(self) -> AirProofInput<SC> {
let air = self.air();
let non_padded_height = self.current_trace_height();
let height = next_power_of_two_or_zero(non_padded_height);
let width = self.trace_width();
let mut values = Val::<SC>::zero_vec(height * width);

struct BlockContext {
prev_hash: [u32; 8],
local_block_idx: u32,
global_block_idx: u32,
input: [u8; SHA256_BLOCK_U8S],
is_last_block: bool,
}
let mut block_ctx: Vec<BlockContext> = Vec::with_capacity(self.records.len());
let mut prev_hash = SHA256_H;
let mut local_block_idx = 0;
let mut global_block_idx = 1;
for (input, is_last_block) in self.records {
block_ctx.push(BlockContext {
prev_hash,
local_block_idx,
global_block_idx,
input,
is_last_block,
});
global_block_idx += 1;
if is_last_block {
local_block_idx = 0;
prev_hash = SHA256_H;
} else {
local_block_idx += 1;
prev_hash = Sha256Air::get_block_hash(&prev_hash, input);
}
}
// first pass
values
.par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK)
.zip(block_ctx)
.for_each(|(block, ctx)| {
let BlockContext {
prev_hash,
local_block_idx,
global_block_idx,
input,
is_last_block,
} = ctx;
let input_words = array::from_fn(|i| {
limbs_into_u32::<SHA256_WORD_U8S>(array::from_fn(|j| {
input[i * SHA256_WORD_U8S + j] as u32
}))
});
self.air.sub_air.generate_block_trace(
block,
width,
0,
&input_words,
self.bitwise_lookup_chip.as_ref(),
&prev_hash,
is_last_block,
global_block_idx,
local_block_idx,
&[[Val::<SC>::ZERO; 16]; 4],
);
});
// second pass: padding rows
values[width * non_padded_height..]
.par_chunks_mut(width)
.for_each(|row| {
let cols: &mut Sha256RoundCols<Val<SC>> = row.borrow_mut();
self.air.sub_air.generate_default_row(cols);
});
// second pass: non-padding rows
values[width..]
.par_chunks_mut(width * SHA256_ROWS_PER_BLOCK)
.take(non_padded_height / SHA256_ROWS_PER_BLOCK)
.for_each(|chunk| {
self.air.sub_air.generate_missing_cells(chunk, width, 0);
});

AirProofInput::simple(air, RowMajorMatrix::new(values, width), vec![])
let trace = crate::generate_trace::<Val<SC>>(
&self.air.sub_air,
&self.bitwise_lookup_chip,
self.records,
);
AirProofInput::simple(air, trace, vec![])
}
}

Expand Down
97 changes: 94 additions & 3 deletions crates/circuits/sha256-air/src/trace.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use std::{array, borrow::BorrowMut, ops::Range};

use openvm_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupChip;
use openvm_stark_backend::p3_field::PrimeField32;
use openvm_circuit_primitives::{
bitwise_op_lookup::BitwiseOperationLookupChip, utils::next_power_of_two_or_zero,
};
use openvm_stark_backend::{
p3_air::BaseAir, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix,
p3_maybe_rayon::prelude::*,
};
use sha2::{compress256, digest::generic_array::GenericArray};

use super::{
Expand All @@ -20,7 +25,7 @@ use crate::{
/// The first pass should do `get_block_trace` for every block and generate the invalid rows through `get_default_row`
/// The second pass should go through all the blocks and call `generate_missing_values`
impl Sha256Air {
/// This function takes the input_massage (should be already padded), the previous hash,
/// This function takes the input_message (padding not handled), the previous hash,
/// and returns the new hash after processing the block input
pub fn get_block_hash(
prev_hash: &[u32; SHA256_HASH_WORDS],
Expand Down Expand Up @@ -454,3 +459,89 @@ impl Sha256Air {
}
}
}

/// `records` consists of pairs of `(input_block, is_last_block)`.
pub fn generate_trace<F: PrimeField32>(
sub_air: &Sha256Air,
bitwise_lookup_chip: &BitwiseOperationLookupChip<8>,
records: Vec<([u8; SHA256_BLOCK_U8S], bool)>,
) -> RowMajorMatrix<F> {
let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK;
let height = next_power_of_two_or_zero(non_padded_height);
let width = <Sha256Air as BaseAir<F>>::width(sub_air);
let mut values = F::zero_vec(height * width);

struct BlockContext {
prev_hash: [u32; 8],
local_block_idx: u32,
global_block_idx: u32,
input: [u8; SHA256_BLOCK_U8S],
is_last_block: bool,
}
let mut block_ctx: Vec<BlockContext> = Vec::with_capacity(records.len());
let mut prev_hash = SHA256_H;
let mut local_block_idx = 0;
let mut global_block_idx = 1;
for (input, is_last_block) in records {
block_ctx.push(BlockContext {
prev_hash,
local_block_idx,
global_block_idx,
input,
is_last_block,
});
global_block_idx += 1;
if is_last_block {
local_block_idx = 0;
prev_hash = SHA256_H;
} else {
local_block_idx += 1;
prev_hash = Sha256Air::get_block_hash(&prev_hash, input);
}
}
// first pass
values
.par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK)
.zip(block_ctx)
.for_each(|(block, ctx)| {
let BlockContext {
prev_hash,
local_block_idx,
global_block_idx,
input,
is_last_block,
} = ctx;
let input_words = array::from_fn(|i| {
limbs_into_u32::<SHA256_WORD_U8S>(array::from_fn(|j| {
input[i * SHA256_WORD_U8S + j] as u32
}))
});
sub_air.generate_block_trace(
block,
width,
0,
&input_words,
bitwise_lookup_chip,
&prev_hash,
is_last_block,
global_block_idx,
local_block_idx,
&[[F::ZERO; 16]; 4],
);
});
// second pass: padding rows
values[width * non_padded_height..]
.par_chunks_mut(width)
.for_each(|row| {
let cols: &mut Sha256RoundCols<F> = row.borrow_mut();
sub_air.generate_default_row(cols);
});
// second pass: non-padding rows
values[width..]
.par_chunks_mut(width * SHA256_ROWS_PER_BLOCK)
.take(non_padded_height / SHA256_ROWS_PER_BLOCK)
.for_each(|chunk| {
sub_air.generate_missing_cells(chunk, width, 0);
});
RowMajorMatrix::new(values, width)
}

0 comments on commit 64d8f06

Please sign in to comment.