Skip to content

Commit

Permalink
fix(cheatcodes): Fix expectCall behavior (#4912)
Browse files Browse the repository at this point in the history
* chore: add tests to test proper behavior

* fix(cheatcodes): properly handle all cases for expectCall

* chore: allow too many arguments

* chore: store calldata as a vec instead of bytes to avoid interior mutability lint

* chore: more clippy

* chore: add more docs and abstract signature
Evalir authored May 12, 2023

Verified

This commit was signed with the committer’s verified signature. The key has expired.
miri64 Martine Lenders
1 parent cbad9c9 commit bd4b290
Showing 3 changed files with 333 additions and 145 deletions.
267 changes: 165 additions & 102 deletions evm/src/executor/inspector/cheatcodes/expect.rs
Original file line number Diff line number Diff line change
@@ -168,16 +168,26 @@ pub fn handle_expect_emit(state: &mut Cheatcodes, log: RawLog, address: &Address

#[derive(Clone, Debug, Default)]
pub struct ExpectedCallData {
/// The expected calldata
pub calldata: Bytes,
/// The expected value sent in the call
pub value: Option<U256>,
/// The expected gas supplied to the call
pub gas: Option<u64>,
/// The expected *minimum* gas supplied to the call
pub min_gas: Option<u64>,
/// The number of times the call is expected to be made
pub count: Option<u64>,
/// The number of times the call is expected to be made.
/// If the type of call is `NonCount`, this is the lower bound for the number of calls
/// that must be seen.
/// If the type of call is `Count`, this is the exact number of calls that must be seen.
pub count: u64,
/// The type of call
pub call_type: ExpectedCallType,
}

#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub enum ExpectedCallType {
#[default]
Count,
NonCount,
}

#[derive(Clone, Debug, Default, PartialEq, Eq)]
@@ -220,6 +230,71 @@ fn expect_safe_memory(state: &mut Cheatcodes, start: u64, end: u64, depth: u64)
Ok(Bytes::new())
}

/// Handles expected calls specified by the `vm.expectCall` cheatcode.
///
/// It can handle calls in two ways:
/// - If the cheatcode was used with a `count` argument, it will expect the call to be made exactly
/// `count` times.
/// e.g. `vm.expectCall(address(0xc4f3), abi.encodeWithSelector(0xd34db33f), 4)` will expect the
/// call to address(0xc4f3) with selector `0xd34db33f` to be made exactly 4 times. If the amount of
/// calls is less or more than 4, the test will fail. Note that the `count` argument cannot be
/// overwritten with another `vm.expectCall`. If this is attempted, `expectCall` will revert.
/// - If the cheatcode was used without a `count` argument, it will expect the call to be made at
/// least the amount of times the cheatcode
/// was called. This means that `vm.expectCall` without a count argument can be called many times,
/// but cannot be called with a `count` argument after it was called without one. If the latter
/// happens, `expectCall` will revert. e.g `vm.expectCall(address(0xc4f3),
/// abi.encodeWithSelector(0xd34db33f))` will expect the call to address(0xc4f3) and selector
/// `0xd34db33f` to be made at least once. If the amount of calls is 0, the test will fail. If the
/// call is made more than once, the test will pass.
#[allow(clippy::too_many_arguments)]
fn expect_call(
state: &mut Cheatcodes,
target: H160,
calldata: Vec<u8>,
value: Option<U256>,
gas: Option<u64>,
min_gas: Option<u64>,
count: u64,
call_type: ExpectedCallType,
) -> Result {
match call_type {
ExpectedCallType::Count => {
// Get the expected calls for this target.
let expecteds = state.expected_calls.entry(target).or_default();
// In this case, as we're using counted expectCalls, we should not be able to set them
// more than once.
ensure!(
!expecteds.contains_key(&calldata),
"Counted expected calls can only bet set once."
);
expecteds
.insert(calldata, (ExpectedCallData { value, gas, min_gas, count, call_type }, 0));
Ok(Bytes::new())
}
ExpectedCallType::NonCount => {
let expecteds = state.expected_calls.entry(target).or_default();
// Check if the expected calldata exists.
// If it does, increment the count by one as we expect to see it one more time.
if let Some(expected) = expecteds.get_mut(&calldata) {
// Ensure we're not overwriting a counted expectCall.
ensure!(
expected.0.call_type == ExpectedCallType::NonCount,
"Cannot overwrite a counted expectCall with a non-counted expectCall."
);
expected.0.count += 1;
} else {
// If it does not exist, then create it.
expecteds.insert(
calldata,
(ExpectedCallData { value, gas, min_gas, count, call_type }, 0),
);
}
Ok(Bytes::new())
}
}
}

pub fn apply<DB: DatabaseExt>(
state: &mut Cheatcodes,
data: &mut EVMData<'_, DB>,
@@ -267,125 +342,113 @@ pub fn apply<DB: DatabaseExt>(
});
Ok(Bytes::new())
}
HEVMCalls::ExpectCall0(inner) => {
state.expected_calls.entry(inner.0).or_default().push((
ExpectedCallData {
calldata: inner.1.to_vec().into(),
value: None,
gas: None,
min_gas: None,
count: None,
},
0,
));
Ok(Bytes::new())
}
HEVMCalls::ExpectCall1(inner) => {
state.expected_calls.entry(inner.0).or_default().push((
ExpectedCallData {
calldata: inner.1.to_vec().into(),
value: None,
gas: None,
min_gas: None,
count: Some(inner.2),
},
0,
));
Ok(Bytes::new())
}
HEVMCalls::ExpectCall2(inner) => {
state.expected_calls.entry(inner.0).or_default().push((
ExpectedCallData {
calldata: inner.2.to_vec().into(),
value: Some(inner.1),
gas: None,
min_gas: None,
count: None,
},
0,
));
Ok(Bytes::new())
}
HEVMCalls::ExpectCall3(inner) => {
state.expected_calls.entry(inner.0).or_default().push((
ExpectedCallData {
calldata: inner.2.to_vec().into(),
value: Some(inner.1),
gas: None,
min_gas: None,
count: Some(inner.3),
},
0,
));
Ok(Bytes::new())
}
HEVMCalls::ExpectCall0(inner) => expect_call(
state,
inner.0,
inner.1.to_vec(),
None,
None,
None,
1,
ExpectedCallType::NonCount,
),
HEVMCalls::ExpectCall1(inner) => expect_call(
state,
inner.0,
inner.1.to_vec(),
None,
None,
None,
inner.2,
ExpectedCallType::Count,
),
HEVMCalls::ExpectCall2(inner) => expect_call(
state,
inner.0,
inner.2.to_vec(),
Some(inner.1),
None,
None,
1,
ExpectedCallType::NonCount,
),
HEVMCalls::ExpectCall3(inner) => expect_call(
state,
inner.0,
inner.2.to_vec(),
Some(inner.1),
None,
None,
inner.3,
ExpectedCallType::Count,
),
HEVMCalls::ExpectCall4(inner) => {
let value = inner.1;

// If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas
// to ensure that the basic fallback function can be called.
let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 };

state.expected_calls.entry(inner.0).or_default().push((
ExpectedCallData {
calldata: inner.3.to_vec().into(),
value: Some(value),
gas: Some(inner.2 + positive_value_cost_stipend),
min_gas: None,
count: None,
},
0,
));
Ok(Bytes::new())
expect_call(
state,
inner.0,
inner.3.to_vec(),
Some(value),
Some(inner.2 + positive_value_cost_stipend),
None,
1,
ExpectedCallType::NonCount,
)
}
HEVMCalls::ExpectCall5(inner) => {
let value = inner.1;
// If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas
// to ensure that the basic fallback function can be called.
let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 };
state.expected_calls.entry(inner.0).or_default().push((
ExpectedCallData {
calldata: inner.3.to_vec().into(),
value: Some(value),
gas: Some(inner.2 + positive_value_cost_stipend),
min_gas: None,
count: Some(inner.4),
},
0,
));
Ok(Bytes::new())

expect_call(
state,
inner.0,
inner.3.to_vec(),
Some(value),
Some(inner.2 + positive_value_cost_stipend),
None,
inner.4,
ExpectedCallType::Count,
)
}
HEVMCalls::ExpectCallMinGas0(inner) => {
let value = inner.1;

// If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas
// to ensure that the basic fallback function can be called.
let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 };

state.expected_calls.entry(inner.0).or_default().push((
ExpectedCallData {
calldata: inner.3.to_vec().into(),
value: Some(value),
gas: None,
min_gas: Some(inner.2 + positive_value_cost_stipend),
count: None,
},
0,
));
Ok(Bytes::new())
expect_call(
state,
inner.0,
inner.3.to_vec(),
Some(value),
None,
Some(inner.2 + positive_value_cost_stipend),
1,
ExpectedCallType::NonCount,
)
}
HEVMCalls::ExpectCallMinGas1(inner) => {
let value = inner.1;
// If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas
// to ensure that the basic fallback function can be called.
let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 };
state.expected_calls.entry(inner.0).or_default().push((
ExpectedCallData {
calldata: inner.3.to_vec().into(),
value: Some(value),
gas: None,
min_gas: Some(inner.2 + positive_value_cost_stipend),
count: Some(inner.4),
},
0,
));
Ok(Bytes::new())

expect_call(
state,
inner.0,
inner.3.to_vec(),
Some(value),
None,
Some(inner.2 + positive_value_cost_stipend),
inner.4,
ExpectedCallType::Count,
)
}
HEVMCalls::MockCall0(inner) => {
// TODO: Does this increase gas usage?
133 changes: 90 additions & 43 deletions evm/src/executor/inspector/cheatcodes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use self::{
env::Broadcast,
expect::{handle_expect_emit, handle_expect_revert},
expect::{handle_expect_emit, handle_expect_revert, ExpectedCallType},
util::{check_if_fixed_gas_limit, process_create, BroadcastableTransactions},
};
use crate::{
@@ -69,6 +69,15 @@ mod error;
pub(crate) use error::{bail, ensure, err};
pub use error::{Error, Result};

/// Tracks the expected calls per address.
/// For each address, we track the expected calls per call data. We track it in such manner
/// so that we don't mix together calldatas that only contain selectors and calldatas that contain
/// selector and arguments (partial and full matches).
/// This then allows us to customize the matching behavior for each call data on the
/// `ExpectedCallData` struct and track how many times we've actually seen the call on the second
/// element of the tuple.
pub type ExpectedCallTracker = BTreeMap<Address, BTreeMap<Vec<u8>, (ExpectedCallData, u64)>>;

/// An inspector that handles calls to various cheatcodes, each with their own behavior.
///
/// Cheatcodes can be called by contracts during execution to modify the VM environment, such as
@@ -126,7 +135,7 @@ pub struct Cheatcodes {
pub mocked_calls: BTreeMap<Address, BTreeMap<MockCallDataContext, MockCallReturnData>>,

/// Expected calls
pub expected_calls: BTreeMap<Address, Vec<(ExpectedCallData, u64)>>,
pub expected_calls: ExpectedCallTracker,

/// Expected emits
pub expected_emits: Vec<ExpectedEmit>,
@@ -565,15 +574,29 @@ where
}
} else if call.contract != h160_to_b160(HARDHAT_CONSOLE_ADDRESS) {
// Handle expected calls
if let Some(expecteds) = self.expected_calls.get_mut(&(b160_to_h160(call.contract))) {
if let Some((_, count)) = expecteds.iter_mut().find(|(expected, _)| {
expected.calldata.len() <= call.input.len() &&
expected.calldata == call.input[..expected.calldata.len()] &&
expected.value.map_or(true, |value| value == call.transfer.value.into()) &&

// Grab the different calldatas expected.
if let Some(expected_calls_for_target) =
self.expected_calls.get_mut(&(b160_to_h160(call.contract)))
{
// Match every partial/full calldata
for (calldata, (expected, actual_count)) in expected_calls_for_target.iter_mut() {
// Increment actual times seen if...
// The calldata is at most, as big as this call's input, and
if calldata.len() <= call.input.len() &&
// Both calldata match, taking the length of the assumed smaller one (which will have at least the selector), and
*calldata == call.input[..calldata.len()] &&
// The value matches, if provided
expected
.value
.map_or(true, |value| value == call.transfer.value.into()) &&
// The gas matches, if provided
expected.gas.map_or(true, |gas| gas == call.gas_limit) &&
// The minimum gas matches, if provided
expected.min_gas.map_or(true, |min_gas| min_gas <= call.gas_limit)
}) {
*count += 1;
{
*actual_count += 1;
}
}
}

@@ -782,43 +805,67 @@ where

// If the depth is 0, then this is the root call terminating
if data.journaled_state.depth() == 0 {
for (address, expecteds) in &self.expected_calls {
for (expected, actual_count) in expecteds {
let ExpectedCallData { calldata, gas, min_gas, value, count } = expected;
let calldata = calldata.clone();
let expected_values = [
Some(format!("data {calldata}")),
value.map(|v| format!("value {v}")),
gas.map(|g| format!("gas {g}")),
min_gas.map(|g| format!("minimum gas {g}")),
]
.into_iter()
.flatten()
.join(" and ");
if count.is_none() {
if *actual_count == 0 {
return (
InstructionResult::Revert,
remaining_gas,
format!("Expected at least one call to {address:?} with {expected_values}, but got none")
// Match expected calls
for (address, calldatas) in &self.expected_calls {
// Loop over each address, and for each address, loop over each calldata it expects.
for (calldata, (expected, actual_count)) in calldatas {
// Grab the values we expect to see
let ExpectedCallData { gas, min_gas, value, count, call_type } = expected;
let calldata = Bytes::from(calldata.clone());

// We must match differently depending on the type of call we expect.
match call_type {
// If the cheatcode was called with a `count` argument,
// we must check that the EVM performed a CALL with this calldata exactly
// `count` times.
ExpectedCallType::Count => {
if *count != *actual_count {
let expected_values = [
Some(format!("data {calldata}")),
value.map(|v| format!("value {v}")),
gas.map(|g| format!("gas {g}")),
min_gas.map(|g| format!("minimum gas {g}")),
]
.into_iter()
.flatten()
.join(" and ");
return (
InstructionResult::Revert,
remaining_gas,
format!(
"Expected call to {address:?} with {expected_values} to be called {count} time(s), but was called {actual_count} time(s)"
)
.encode()
.into(),
)
)
}
}
// If the cheatcode was called without a `count` argument,
// we must check that the EVM performed a CALL with this calldata at least
// `count` times. The amount of times to check was
// the amount of time the cheatcode was called.
ExpectedCallType::NonCount => {
if *count > *actual_count {
let expected_values = [
Some(format!("data {calldata}")),
value.map(|v| format!("value {v}")),
gas.map(|g| format!("gas {g}")),
min_gas.map(|g| format!("minimum gas {g}")),
]
.into_iter()
.flatten()
.join(" and ");
return (
InstructionResult::Revert,
remaining_gas,
format!(
"Expected call to {address:?} with {expected_values} to be called at least {count} time(s), but was called {actual_count} time(s)"
)
.encode()
.into(),
)
}
}
} else if *count != Some(*actual_count) {
return (
InstructionResult::Revert,
remaining_gas,
format!(
"Expected call to {:?} with {} to be made {} time(s), but was called {} time(s)",
address,
expected_values,
count.unwrap(),
actual_count,
)
.encode()
.into(),
)
}
}
}
78 changes: 78 additions & 0 deletions testdata/cheats/ExpectCall.t.sol
Original file line number Diff line number Diff line change
@@ -62,6 +62,33 @@ contract ExpectCallTest is DSTest {
target.add(1, 2);
}

function testExpectMultipleCallsWithDataAdditive() public {
Contract target = new Contract();
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
target.add(1, 2);
target.add(1, 2);
}

function testExpectMultipleCallsWithDataAdditiveLowerBound() public {
Contract target = new Contract();
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
target.add(1, 2);
target.add(1, 2);
target.add(1, 2);
}

function testFailExpectMultipleCallsWithDataAdditive() public {
Contract target = new Contract();
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
// Not enough calls to satisfy the additive expectCall, which expects 3 calls.
target.add(1, 2);
target.add(1, 2);
}

function testFailExpectCallWithData() public {
Contract target = new Contract();
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
@@ -312,3 +339,54 @@ contract ExpectCallCountTest is DSTest {
target.addHardGasLimit();
}
}

contract ExpectCallMixedTest is DSTest {
Cheats constant cheats = Cheats(HEVM_ADDRESS);

function testFailOverrideNoCountWithCount() public {
Contract target = new Contract();
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
// You should not be able to overwrite a expectCall that had no count with some count.
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 2);
target.add(1, 2);
target.add(1, 2);
}

function testFailOverrideCountWithCount() public {
Contract target = new Contract();
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 2);
// You should not be able to overwrite a expectCall that had a count with some count.
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 1);
target.add(1, 2);
target.add(1, 2);
}

function testFailOverrideCountWithNoCount() public {
Contract target = new Contract();
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 2);
// You should not be able to overwrite a expectCall that had a count with no count.
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
target.add(1, 2);
target.add(1, 2);
}

function testExpectMatchPartialAndFull() public {
Contract target = new Contract();
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector), 2);
// Even if a partial match is speciifed, you should still be able to look for full matches
// as one does not override the other.
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2));
target.add(1, 2);
target.add(1, 2);
}

function testExpectMatchPartialAndFullFlipped() public {
Contract target = new Contract();
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector));
// Even if a partial match is speciifed, you should still be able to look for full matches
// as one does not override the other.
cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 2);
target.add(1, 2);
target.add(1, 2);
}
}

0 comments on commit bd4b290

Please sign in to comment.