Skip to content

Commit

Permalink
Add init_circuit_data. (starkware-libs#5481)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware authored Apr 30, 2024
1 parent 5c95ac1 commit ce3e4a7
Show file tree
Hide file tree
Showing 17 changed files with 347 additions and 21 deletions.
15 changes: 14 additions & 1 deletion corelib/src/circuit.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
pub type u96 = core::internal::BoundedInt<0, 79228162514264337593543950335>;

pub extern type RangeCheck96;


// Defines an input for a circuit.
#[phantom]
pub extern type CircuitInput<const N: usize>;


// Initializes the input data for running an instance of the circuit.
extern fn init_circuit_data<C>() -> CircuitInputAccumulator<C> implicits(RangeCheck96) nopanic;

// Type for accumulating inputs into the circuit instance's data.
extern type CircuitInputAccumulator<C>;

impl CircuitInputAccumulatorDrop<C> of Drop<CircuitInputAccumulator<C>>;
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use cairo_lang_sierra::extensions::bounded_int::{
use cairo_lang_sierra::extensions::boxing::BoxConcreteLibfunc;
use cairo_lang_sierra::extensions::bytes31::Bytes31ConcreteLibfunc;
use cairo_lang_sierra::extensions::casts::{CastConcreteLibfunc, CastType};
use cairo_lang_sierra::extensions::circuit::CircuitConcreteLibfunc;
use cairo_lang_sierra::extensions::const_type::ConstConcreteLibfunc;
use cairo_lang_sierra::extensions::core::CoreConcreteLibfunc::{self, *};
use cairo_lang_sierra::extensions::coupon::CouponConcreteLibfunc;
Expand Down Expand Up @@ -380,6 +381,7 @@ pub fn core_libfunc_ap_change<InfoProvider: InvocationApChangeInfoProvider>(
]
}
},
Circuit(CircuitConcreteLibfunc::InitCircuitData(_)) => vec![ApChange::Known(0)],
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use cairo_lang_sierra::extensions::bounded_int::{
use cairo_lang_sierra::extensions::boxing::BoxConcreteLibfunc;
use cairo_lang_sierra::extensions::bytes31::Bytes31ConcreteLibfunc;
use cairo_lang_sierra::extensions::casts::{CastConcreteLibfunc, CastType};
use cairo_lang_sierra::extensions::circuit::CircuitConcreteLibfunc;
use cairo_lang_sierra::extensions::const_type::ConstConcreteLibfunc;
use cairo_lang_sierra::extensions::core::CoreConcreteLibfunc::{self, *};
use cairo_lang_sierra::extensions::coupon::CouponConcreteLibfunc;
Expand Down Expand Up @@ -470,6 +471,7 @@ pub fn core_libfunc_cost(
]
}
},
Circuit(CircuitConcreteLibfunc::InitCircuitData(_)) => vec![ConstCost::steps(0).into()],
}
}

Expand Down
46 changes: 46 additions & 0 deletions crates/cairo-lang-sierra-to-casm/src/invocations/circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use cairo_lang_casm::builder::CasmBuilder;
use cairo_lang_casm::casm_build_extend;
use cairo_lang_sierra::extensions::circuit::CircuitConcreteLibfunc;

use super::{CompiledInvocation, CompiledInvocationBuilder, InvocationError};
use crate::invocations::add_input_variables;

/// Builds instructions for Sierra array operations.
pub fn build(
libfunc: &CircuitConcreteLibfunc,
builder: CompiledInvocationBuilder<'_>,
) -> Result<CompiledInvocation, InvocationError> {
match libfunc {
CircuitConcreteLibfunc::InitCircuitData(_) => build_init_circuit_data(builder),
}
}

/// Handles a Sierra statement for initializing circuit data.
fn build_init_circuit_data(
builder: CompiledInvocationBuilder<'_>,
) -> Result<CompiledInvocation, InvocationError> {
let [expr_rc96] = builder.try_get_refs()?;
let rc96 = expr_rc96.try_unpack_single()?;

// TODO(ilya): get n_inputs and n_vals from the libfunc.
let n_inputs = 1;
let n_vals = 2;

let mut casm_builder = CasmBuilder::default();

add_input_variables! {casm_builder,
buffer(1) rc96;
};
casm_build_extend! {casm_builder,
const n_inputs = n_inputs;
let inputs_end = rc96 + n_inputs;
const n_vals = n_vals;
let vals_end = rc96 + n_vals;
};

Ok(builder.build_from_casm_builder(
casm_builder,
[("Fallthrough", &[&[vals_end], &[rc96, inputs_end]], None)],
Default::default(),
))
}
2 changes: 2 additions & 0 deletions crates/cairo-lang-sierra-to-casm/src/invocations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ mod boolean;
mod boxing;
mod bytes31;
mod casts;
mod circuit;
mod const_type;
mod debug;
mod ec;
Expand Down Expand Up @@ -689,6 +690,7 @@ pub fn compile_invocation(
}
},
BoundedInt(libfunc) => int::bounded::build(libfunc, builder),
Circuit(libfunc) => circuit::build(libfunc, builder),
}
}

Expand Down
12 changes: 10 additions & 2 deletions crates/cairo-lang-sierra-type-size/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use cairo_lang_sierra::extensions::circuit::CircuitTypeConcrete;
use cairo_lang_sierra::extensions::core::{CoreLibfunc, CoreType, CoreTypeConcrete};
use cairo_lang_sierra::extensions::starknet::StarkNetTypeConcrete;
use cairo_lang_sierra::ids::ConcreteTypeId;
Expand Down Expand Up @@ -68,14 +69,21 @@ pub fn get_type_size_map(
Some(size)
}
CoreTypeConcrete::Struct(struct_type) => {
if !struct_type.info.storable {
// If the struct is not storable, it should not have a size.
continue;
}
let mut size = 0;
for member in &struct_type.members {
size += type_sizes.get(member).cloned()?;
}
Some(size)
}
// Const types are not moved around and should not have a size.
CoreTypeConcrete::Const(_) => continue,

CoreTypeConcrete::Circuit(CircuitTypeConcrete::CircuitInputAccumulator(_)) => Some(2),

// Const and circuit types are not moved around and should not have a size.
CoreTypeConcrete::Const(_) | CoreTypeConcrete::Circuit(_) => continue,
}?;
type_sizes.insert(declaration.id.clone(), size);
}
Expand Down
3 changes: 3 additions & 0 deletions crates/cairo-lang-sierra/src/extensions/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::bounded_int::{BoundedIntLibfunc, BoundedIntType};
use super::branch_align::BranchAlignLibfunc;
use super::bytes31::{Bytes31Libfunc, Bytes31Type};
use super::casts::CastLibfunc;
use super::circuit::{CircuitLibFunc, CircuitType};
use super::const_type::{ConstLibfunc, ConstType};
use super::coupon::{CouponLibfunc, CouponType};
use super::debug::DebugLibfunc;
Expand Down Expand Up @@ -56,6 +57,7 @@ define_type_hierarchy! {
Coupon(CouponType),
Bitwise(BitwiseType),
Box(BoxType),
Circuit(CircuitType),
Const(ConstType),
EcOp(EcOpType),
EcPoint(EcPointType),
Expand Down Expand Up @@ -103,6 +105,7 @@ define_libfunc_hierarchy! {
Bool(BoolLibfunc),
Box(BoxLibfunc),
Cast(CastLibfunc),
Circuit(CircuitLibFunc),
Coupon(CouponLibfunc),
CouponCall(CouponCallLibfunc),
Drop(DropLibfunc),
Expand Down
14 changes: 14 additions & 0 deletions crates/cairo-lang-sierra/src/extensions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub use self::lib_func::{
OutputVarReferenceInfo, SignatureBasedConcreteLibfunc,
};
pub use self::modules::*;
use self::type_specialization_context::TypeSpecializationContext;
pub use self::types::{
ConcreteType, GenericType, GenericTypeEx, NamedType, NoGenericArgsGenericType,
};
Expand Down Expand Up @@ -59,5 +60,18 @@ fn args_as_single_user_func(args: &[GenericArg]) -> Result<FunctionId, Specializ
}
}

/// Extracts the generic args of `ty`, additionally validates it is of generic type `T`.
fn extract_type_generic_args<T: NamedType>(
context: &dyn TypeSpecializationContext,
ty: &ConcreteTypeId,
) -> Result<Vec<GenericArg>, SpecializationError> {
let long_id = context.get_type_info(ty.clone())?.long_id;
if long_id.generic_id != T::ID {
Err(SpecializationError::UnsupportedGenericArg)
} else {
Ok(long_id.generic_args)
}
}

#[cfg(test)]
mod test;
206 changes: 206 additions & 0 deletions crates/cairo-lang-sierra/src/extensions/modules/circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
use num_bigint::BigInt;

use super::range_check::RangeCheck96Type;
use super::structure::StructType;
use crate::extensions::lib_func::{
DeferredOutputKind, LibfuncSignature, OutputVarInfo, ParamSignature, SierraApChange,
SignatureOnlyGenericLibfunc, SignatureSpecializationContext,
};
use crate::extensions::type_specialization_context::TypeSpecializationContext;
use crate::extensions::types::TypeInfo;
use crate::extensions::{
args_as_single_type, args_as_single_value, extract_type_generic_args, ConcreteType, NamedType,
OutputVarReferenceInfo, SpecializationError,
};
use crate::ids::{ConcreteTypeId, GenericTypeId, UserTypeId};
use crate::program::{ConcreteTypeLongId, GenericArg};
use crate::{define_libfunc_hierarchy, define_type_hierarchy};

define_type_hierarchy! {
pub enum CircuitType {
CircuitInput(CircuitInput),
CircuitInputAccumulator(CircuitInputAccumulator),
}, CircuitTypeConcrete
}

define_libfunc_hierarchy! {
pub enum CircuitLibFunc {
InitCircuitData(InitCircuitData),
}, CircuitConcreteLibfunc
}

/// Returns true if `garg` is a type that is considered a circuit component.
fn is_circuit_component(
context: &dyn TypeSpecializationContext,
garg: &GenericArg,
) -> Result<bool, SpecializationError> {
let GenericArg::Type(ty) = garg else {
return Err(SpecializationError::UnsupportedGenericArg);
};

let long_id = context.get_type_info(ty.clone())?.long_id;
let generic_id = long_id.generic_id;
if generic_id == CircuitInput::ID {
return Ok(true);
}
Ok(false)
}

/// Circuit input type.
#[derive(Default)]
pub struct CircuitInput {}
impl NamedType for CircuitInput {
type Concrete = ConcreteCircuitInput;
const ID: GenericTypeId = GenericTypeId::new_inline("CircuitInput");

fn specialize(
&self,
context: &dyn TypeSpecializationContext,
args: &[GenericArg],
) -> Result<Self::Concrete, SpecializationError> {
Self::Concrete::new(context, args)
}
}

/// Defines an input for a circuit.
pub struct ConcreteCircuitInput {
/// The type info of the concrete type.
pub info: TypeInfo,
/// The index of the circuit input.
pub idx: BigInt,
}
impl ConcreteCircuitInput {
fn new(
_context: &dyn TypeSpecializationContext,
args: &[GenericArg],
) -> Result<Self, SpecializationError> {
let idx = args_as_single_value(args)?;
Ok(Self {
info: TypeInfo {
long_id: ConcreteTypeLongId {
generic_id: "CircuitInput".into(),
generic_args: args.to_vec(),
},
duplicatable: false,
droppable: false,
storable: false,
zero_sized: false,
},
idx,
})
}
}

impl ConcreteType for ConcreteCircuitInput {
fn info(&self) -> &TypeInfo {
&self.info
}
}

/// Type for accumulating inputs into the circuit instance's data.
#[derive(Default)]
pub struct CircuitInputAccumulator {}
impl NamedType for CircuitInputAccumulator {
type Concrete = ConcreteCircuitInputAccumulator;
const ID: GenericTypeId = GenericTypeId::new_inline("CircuitInputAccumulator");

fn specialize(
&self,
context: &dyn TypeSpecializationContext,
args: &[GenericArg],
) -> Result<Self::Concrete, SpecializationError> {
Self::Concrete::new(context, args)
}
}

pub struct ConcreteCircuitInputAccumulator {
pub info: TypeInfo,
}

impl ConcreteCircuitInputAccumulator {
fn new(
context: &dyn TypeSpecializationContext,
args: &[GenericArg],
) -> Result<Self, SpecializationError> {
let circ_ty = args_as_single_type(args)?;
validate_is_circuit(context, circ_ty)?;
Ok(Self {
info: TypeInfo {
long_id: ConcreteTypeLongId {
generic_id: "CircuitInputAccumulator".into(),
generic_args: args.to_vec(),
},
duplicatable: false,
droppable: true,
storable: true,
zero_sized: false,
},
})
}
}

impl ConcreteType for ConcreteCircuitInputAccumulator {
fn info(&self) -> &TypeInfo {
&self.info
}
}

/// Validate that `circ_ty` is a circuit type.
fn validate_is_circuit(
context: &dyn TypeSpecializationContext,
circ_ty: ConcreteTypeId,
) -> Result<(), SpecializationError> {
let struct_generic_args = extract_type_generic_args::<StructType>(context, &circ_ty)?;

let mut gargs = struct_generic_args.iter();
if !matches!(
gargs.next(),
Some(GenericArg::UserType(ut))
if (*ut == UserTypeId::from_string("Tuple"))

) {
return Err(SpecializationError::UnsupportedGenericArg);
}

for garg in gargs {
// Note that its enough to check the topmost types as they validate their children.
if !is_circuit_component(context, garg)? {
return Err(SpecializationError::UnsupportedGenericArg);
}
}

Ok(())
}

/// Libfunc for initializing the input data for running an instance of the circuit.
#[derive(Default)]
pub struct InitCircuitData {}
impl SignatureOnlyGenericLibfunc for InitCircuitData {
const STR_ID: &'static str = "init_circuit_data";

fn specialize_signature(
&self,
context: &dyn SignatureSpecializationContext,
generic_args: &[GenericArg],
) -> Result<LibfuncSignature, SpecializationError> {
let range_check96_type = context.get_concrete_type(RangeCheck96Type::id(), &[])?;
let circuit_input_accumulator_ty =
context.get_concrete_type(CircuitInputAccumulator::id(), generic_args)?;
Ok(LibfuncSignature::new_non_branch_ex(
vec![ParamSignature::new(range_check96_type.clone()).with_allow_add_const()],
vec![
OutputVarInfo {
ty: range_check96_type.clone(),
ref_info: OutputVarReferenceInfo::Deferred(DeferredOutputKind::AddConst {
param_idx: 0,
}),
},
OutputVarInfo {
ty: circuit_input_accumulator_ty.clone(),
ref_info: OutputVarReferenceInfo::Deferred(DeferredOutputKind::Generic),
},
],
SierraApChange::Known { new_vars_only: true },
))
}
}
Loading

0 comments on commit ce3e4a7

Please sign in to comment.