Skip to content

Commit

Permalink
Separated python hint prints for addresses and non-addresses types.
Browse files Browse the repository at this point in the history
commit-id:dfa86564
  • Loading branch information
orizi committed Jul 9, 2023
1 parent b477296 commit 8c2409f
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 40 deletions.
105 changes: 69 additions & 36 deletions crates/cairo-lang-casm/src/hints/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,35 @@ impl<'a> Display for DerefOrImmediateFormatter<'a> {
}
}

struct ResOperandFormatter<'a>(&'a ResOperand);
impl<'a> Display for ResOperandFormatter<'a> {
struct ResOperandAsIntegerFormatter<'a>(&'a ResOperand);
impl<'a> Display for ResOperandAsIntegerFormatter<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.0 {
ResOperand::Deref(d) => write!(f, "memory{d}"),
ResOperand::DoubleDeref(d, i) => write!(f, "memory[memory{d} + {i}]"),
ResOperand::Immediate(i) => write!(f, "{}", i.value),
ResOperand::BinOp(bin_op) => {
write!(
f,
"(memory{} {} {}) % PRIME",
bin_op.a,
bin_op.op,
DerefOrImmediateFormatter(&bin_op.b)
)
}
}
}
}

struct ResOperandAsAddressFormatter<'a>(&'a ResOperand);
impl<'a> Display for ResOperandAsAddressFormatter<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.0 {
ResOperand::Deref(d) => write!(f, "memory{d}"),
ResOperand::DoubleDeref(d, i) => write!(f, "memory[memory{d} + {i}]"),
ResOperand::Immediate(i) => {
unreachable!("Address cannot be an immediate: {}.", i.value)
}
ResOperand::BinOp(bin_op) => {
write!(
f,
Expand Down Expand Up @@ -301,7 +323,7 @@ impl PythonicHint for CoreHint {
match self {
CoreHint::AllocSegment { dst } => format!("memory{dst} = segments.add()"),
CoreHint::AllocFelt252Dict { segment_arena_ptr } => {
let segment_arena_ptr = ResOperandFormatter(segment_arena_ptr);
let segment_arena_ptr = ResOperandAsAddressFormatter(segment_arena_ptr);
formatdoc!(
"
Expand Down Expand Up @@ -333,7 +355,8 @@ impl PythonicHint for CoreHint {
)
}
CoreHint::Felt252DictEntryInit { dict_ptr, key } => {
let (dict_ptr, key) = (ResOperandFormatter(dict_ptr), ResOperandFormatter(key));
let (dict_ptr, key) =
(ResOperandAsAddressFormatter(dict_ptr), ResOperandAsIntegerFormatter(key));
formatdoc!(
"
Expand All @@ -344,7 +367,8 @@ impl PythonicHint for CoreHint {
)
}
CoreHint::Felt252DictEntryUpdate { dict_ptr, value } => {
let (dict_ptr, value) = (ResOperandFormatter(dict_ptr), ResOperandFormatter(value));
let (dict_ptr, value) =
(ResOperandAsAddressFormatter(dict_ptr), ResOperandAsIntegerFormatter(value));
formatdoc!(
"
Expand All @@ -354,22 +378,26 @@ impl PythonicHint for CoreHint {
)
}
CoreHint::TestLessThan { lhs, rhs, dst } => {
format!("memory{dst} = {} < {}", ResOperandFormatter(lhs), ResOperandFormatter(rhs))
format!(
"memory{dst} = {} < {}",
ResOperandAsIntegerFormatter(lhs),
ResOperandAsIntegerFormatter(rhs)
)
}
CoreHint::TestLessThanOrEqual { lhs, rhs, dst } => format!(
"memory{dst} = {} <= {}",
ResOperandFormatter(lhs),
ResOperandFormatter(rhs)
ResOperandAsIntegerFormatter(lhs),
ResOperandAsIntegerFormatter(rhs)
),
CoreHint::WideMul128 { lhs, rhs, high, low } => format!(
"(memory{high}, memory{low}) = divmod({} * {}, 2**128)",
ResOperandFormatter(lhs),
ResOperandFormatter(rhs)
ResOperandAsIntegerFormatter(lhs),
ResOperandAsIntegerFormatter(rhs)
),
CoreHint::DivMod { lhs, rhs, quotient, remainder } => format!(
"(memory{quotient}, memory{remainder}) = divmod({}, {})",
ResOperandFormatter(lhs),
ResOperandFormatter(rhs)
ResOperandAsIntegerFormatter(lhs),
ResOperandAsIntegerFormatter(rhs)
),
CoreHint::Uint256DivMod {
dividend0,
Expand All @@ -382,10 +410,10 @@ impl PythonicHint for CoreHint {
remainder1,
} => {
let (dividend0, dividend1, divisor0, divisor1) = (
ResOperandFormatter(dividend0),
ResOperandFormatter(dividend1),
ResOperandFormatter(divisor0),
ResOperandFormatter(divisor1),
ResOperandAsIntegerFormatter(dividend0),
ResOperandAsIntegerFormatter(dividend1),
ResOperandAsIntegerFormatter(divisor0),
ResOperandAsIntegerFormatter(divisor1),
);
formatdoc!(
"
Expand Down Expand Up @@ -416,7 +444,7 @@ impl PythonicHint for CoreHint {
} => {
let [dividend0, dividend1, dividend2, dividend3, divisor0, divisor1] =
[dividend0, dividend1, dividend2, dividend3, divisor0, divisor1]
.map(ResOperandFormatter);
.map(ResOperandAsIntegerFormatter);
formatdoc!(
"
Expand All @@ -440,7 +468,7 @@ impl PythonicHint for CoreHint {
import math
memory{dst} = math.isqrt({})
",
ResOperandFormatter(value)
ResOperandAsIntegerFormatter(value)
)
}
CoreHint::Uint256SquareRoot {
Expand All @@ -452,8 +480,10 @@ impl PythonicHint for CoreHint {
remainder_high,
sqrt_mul_2_minus_remainder_ge_u128,
} => {
let (value_low, value_high) =
(ResOperandFormatter(value_low), ResOperandFormatter(value_high));
let (value_low, value_high) = (
ResOperandAsIntegerFormatter(value_low),
ResOperandAsIntegerFormatter(value_high),
);
formatdoc!(
"
Expand All @@ -471,9 +501,9 @@ impl PythonicHint for CoreHint {
}
CoreHint::LinearSplit { value, scalar, max_x, x, y } => {
let (value, scalar, max_x) = (
ResOperandFormatter(value),
ResOperandFormatter(scalar),
ResOperandFormatter(max_x),
ResOperandAsIntegerFormatter(value),
ResOperandAsIntegerFormatter(scalar),
ResOperandAsIntegerFormatter(max_x),
);
formatdoc!(
"
Expand Down Expand Up @@ -509,7 +539,7 @@ impl PythonicHint for CoreHint {
else:
memory{sqrt} = sqrt(val * 3, FIELD_PRIME)
",
ResOperandFormatter(val)
ResOperandAsIntegerFormatter(val)
)
}
CoreHint::GetCurrentAccessIndex { range_check_ptr } => formatdoc!(
Expand All @@ -519,7 +549,7 @@ impl PythonicHint for CoreHint {
current_access_index = current_access_indices.pop()
memory[{}] = current_access_index
",
ResOperandFormatter(range_check_ptr)
ResOperandAsAddressFormatter(range_check_ptr)
),
CoreHint::ShouldSkipSquashLoop { should_skip_loop } => {
format!("memory{should_skip_loop} = 0 if current_access_indices else 1")
Expand All @@ -542,7 +572,7 @@ impl PythonicHint for CoreHint {
"
),
CoreHint::GetSegmentArenaIndex { dict_end_ptr, dict_index } => {
let dict_end_ptr = ResOperandFormatter(dict_end_ptr);
let dict_end_ptr = ResOperandAsAddressFormatter(dict_end_ptr);
formatdoc!(
"
Expand All @@ -560,9 +590,9 @@ impl PythonicHint for CoreHint {
first_key,
} => {
let (dict_accesses, ptr_diff, n_accesses) = (
ResOperandFormatter(dict_accesses),
ResOperandFormatter(ptr_diff),
ResOperandFormatter(n_accesses),
ResOperandAsAddressFormatter(dict_accesses),
ResOperandAsIntegerFormatter(ptr_diff),
ResOperandAsIntegerFormatter(n_accesses),
);
formatdoc!(
"
Expand Down Expand Up @@ -591,9 +621,9 @@ impl PythonicHint for CoreHint {
}
CoreHint::AssertLeFindSmallArcs { range_check_ptr, a, b } => {
let (range_check_ptr, a, b) = (
ResOperandFormatter(range_check_ptr),
ResOperandFormatter(a),
ResOperandFormatter(b),
ResOperandAsAddressFormatter(range_check_ptr),
ResOperandAsIntegerFormatter(a),
ResOperandAsIntegerFormatter(b),
);
formatdoc!(
"
Expand Down Expand Up @@ -636,8 +666,8 @@ impl PythonicHint for CoreHint {
print(hex(memory[curr]))
curr += 1
",
ResOperandFormatter(start),
ResOperandFormatter(end),
ResOperandAsAddressFormatter(start),
ResOperandAsAddressFormatter(end),
),
CoreHint::AllocConstantSize { size, dst } => {
formatdoc!(
Expand All @@ -648,7 +678,7 @@ impl PythonicHint for CoreHint {
memory{dst} = __boxed_segment
__boxed_segment += {}
",
ResOperandFormatter(size)
ResOperandAsIntegerFormatter(size)
)
}
}
Expand All @@ -659,7 +689,10 @@ impl PythonicHint for StarknetHint {
fn get_pythonic_hint(&self) -> String {
match self {
StarknetHint::SystemCall { system } => {
format!("syscall_handler.syscall(syscall_ptr={})", ResOperandFormatter(system))
format!(
"syscall_handler.syscall(syscall_ptr={})",
ResOperandAsAddressFormatter(system)
)
}
StarknetHint::Cheatcode { .. } => "raise NotImplementedError".to_string(),
}
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_test_data/libfuncs/u16
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn foo(a: u16, b: u16) -> Result::<u16, u16> {
}

//! > casm
%{ memory[ap + 0] = memory[fp + -4] + memory[fp + -3] < 65536 %}
%{ memory[ap + 0] = (memory[fp + -4] + memory[fp + -3]) % PRIME < 65536 %}
jmp rel 8 if [ap + 0] != 0, ap++;
[ap + 0] = [fp + -4] + [fp + -3], ap++;
[ap + -1] = [ap + 0] + 65536, ap++;
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_test_data/libfuncs/u32
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn foo(a: u32, b: u32) -> Result::<u32, u32> {
}

//! > casm
%{ memory[ap + 0] = memory[fp + -4] + memory[fp + -3] < 4294967296 %}
%{ memory[ap + 0] = (memory[fp + -4] + memory[fp + -3]) % PRIME < 4294967296 %}
jmp rel 8 if [ap + 0] != 0, ap++;
[ap + 0] = [fp + -4] + [fp + -3], ap++;
[ap + -1] = [ap + 0] + 4294967296, ap++;
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_test_data/libfuncs/u64
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn foo(a: u64, b: u64) -> Result::<u64, u64> {
}

//! > casm
%{ memory[ap + 0] = memory[fp + -4] + memory[fp + -3] < 18446744073709551616 %}
%{ memory[ap + 0] = (memory[fp + -4] + memory[fp + -3]) % PRIME < 18446744073709551616 %}
jmp rel 8 if [ap + 0] != 0, ap++;
[ap + 0] = [fp + -4] + [fp + -3], ap++;
[ap + -1] = [ap + 0] + 18446744073709551616, ap++;
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_test_data/libfuncs/u8
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn foo(a: u8, b: u8) -> Result::<u8, u8> {
}

//! > casm
%{ memory[ap + 0] = memory[fp + -4] + memory[fp + -3] < 256 %}
%{ memory[ap + 0] = (memory[fp + -4] + memory[fp + -3]) % PRIME < 256 %}
jmp rel 8 if [ap + 0] != 0, ap++;
[ap + 0] = [fp + -4] + [fp + -3], ap++;
[ap + -1] = [ap + 0] + 256, ap++;
Expand Down

0 comments on commit 8c2409f

Please sign in to comment.