Skip to content

Commit

Permalink
feat(derive): Test Utilities (#62)
Browse files Browse the repository at this point in the history
* feat(derive): trait test utilities

* feat(derive): trait test utilities

* fix(derive): test utils

* feat(derive): l1 traversal tests

* fix(derive): lint fixes

* fix(derive): data availability test utils

* feat(derive): channel bank test utils
  • Loading branch information
refcell authored Mar 30, 2024
1 parent 7348b2e commit 7ee85a3
Show file tree
Hide file tree
Showing 17 changed files with 555 additions and 23 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ unsigned-varint = "0.8.0"
# Optional
serde = { version = "1.0.197", default-features = false, features = ["derive"], optional = true }

[dev-dependencies]
tokio = { version = "1.36", features = ["full"] }

[features]
serde = ["dep:serde", "alloy-primitives/serde"]
83 changes: 76 additions & 7 deletions crates/derive/src/stages/channel_bank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,15 @@ where
self.prev.origin()
}

/// Returns the size of the channel bank by accumulating over all channels.
pub fn size(&self) -> usize {
self.channels.iter().fold(0, |acc, (_, c)| acc + c.size())
}

/// Prunes the Channel bank, until it is below [MAX_CHANNEL_BANK_SIZE].
/// Prunes from the high-priority channel since it failed to be read.
pub fn prune(&mut self) -> StageResult<()> {
// Check total size
let mut total_size = self.channels.iter().fold(0, |acc, (_, c)| acc + c.size());
// Prune until it is reasonable again. The high-priority channel failed to be read,
// so we prune from there.
let mut total_size = self.size();
while total_size > MAX_CHANNEL_BANK_SIZE {
let id = self
.channel_queue
Expand Down Expand Up @@ -122,16 +125,17 @@ where
.ok_or(anyhow!("Channel not found"))?;
let origin = self.origin().ok_or(anyhow!("No origin present"))?;

// Remove all timed out channels from the front of the `channel_queue`.
if channel.open_block_number() + self.cfg.channel_timeout < origin.number {
self.channels.remove(&first);
self.channel_queue.pop_front();
return Ok(None);
}

// At the point we have removed all timed out channels from the front of the `channel_queue`.
// At this point we have removed all timed out channels from the front of the `channel_queue`.
// Pre-Canyon we simply check the first index.
// Post-Canyon we read the entire channelQueue for the first ready channel. If no channel is
// available, we return `nil, io.EOF`.
// Post-Canyon we read the entire channelQueue for the first ready channel.
// If no channel is available, we return StageError::Eof.
// Canyon is activated when the first L1 block whose time >= CanyonTime, not on the L2 timestamp.
if !self.cfg.is_canyon_active(origin.timestamp) {
return self.try_read_channel_at_index(0).map(Some);
Expand Down Expand Up @@ -201,3 +205,68 @@ where
Err(StageError::Eof)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::stages::frame_queue::tests::new_test_frames;
use crate::stages::l1_retrieval::L1Retrieval;
use crate::stages::l1_traversal::tests::new_test_traversal;
use crate::traits::test_utils::TestDAP;
use alloc::vec;

#[test]
fn test_ingest_empty_origin() {
let mut traversal = new_test_traversal(false, false);
traversal.block = None;
let dap = TestDAP::default();
let retrieval = L1Retrieval::new(traversal, dap);
let frame_queue = FrameQueue::new(retrieval);
let mut channel_bank = ChannelBank::new(RollupConfig::default(), frame_queue);
let frame = Frame::default();
let err = channel_bank.ingest_frame(frame).unwrap_err();
assert_eq!(err, StageError::Custom(anyhow!("No origin")));
}

#[test]
fn test_ingest_and_prune_channel_bank() {
let traversal = new_test_traversal(true, true);
let results = vec![Ok(Bytes::from(vec![0x00]))];
let dap = TestDAP { results };
let retrieval = L1Retrieval::new(traversal, dap);
let frame_queue = FrameQueue::new(retrieval);
let mut channel_bank = ChannelBank::new(RollupConfig::default(), frame_queue);
let mut frames = new_test_frames(100000);
// Ingest frames until the channel bank is full and it stops increasing in size
let mut current_size = 0;
let next_frame = frames.pop().unwrap();
channel_bank.ingest_frame(next_frame).unwrap();
while channel_bank.size() > current_size {
current_size = channel_bank.size();
let next_frame = frames.pop().unwrap();
channel_bank.ingest_frame(next_frame).unwrap();
assert!(channel_bank.size() <= MAX_CHANNEL_BANK_SIZE);
}
// There should be a bunch of frames leftover
assert!(!frames.is_empty());
// If we ingest one more frame, the channel bank should prune
// and the size should be the same
let next_frame = frames.pop().unwrap();
channel_bank.ingest_frame(next_frame).unwrap();
assert_eq!(channel_bank.size(), current_size);
}

#[tokio::test]
async fn test_read_empty_channel_bank() {
let traversal = new_test_traversal(true, true);
let results = vec![Ok(Bytes::from(vec![0x00]))];
let dap = TestDAP { results };
let retrieval = L1Retrieval::new(traversal, dap);
let frame_queue = FrameQueue::new(retrieval);
let mut channel_bank = ChannelBank::new(RollupConfig::default(), frame_queue);
let err = channel_bank.read().unwrap_err();
assert_eq!(err, StageError::Eof);
let err = channel_bank.next_data().await.unwrap_err();
assert_eq!(err, StageError::Custom(anyhow!("Not Enough Data")));
}
}
110 changes: 110 additions & 0 deletions crates/derive/src/stages/frame_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ where
if self.queue.is_empty() {
match self.prev.next_data().await {
Ok(data) => {
// TODO: what do we do with frame parsing errors?
if let Ok(frames) = Frame::parse_frames(data.as_ref()) {
self.queue.extend(frames);
}
Expand Down Expand Up @@ -78,3 +79,112 @@ where
Err(StageError::Eof)
}
}

#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::stages::l1_traversal::tests::new_test_traversal;
use crate::traits::test_utils::TestDAP;
use crate::DERIVATION_VERSION_0;
use alloc::vec;
use alloc::vec::Vec;
use alloy_primitives::Bytes;

pub(crate) fn new_test_frames(count: usize) -> Vec<Frame> {
(0..count)
.map(|i| Frame {
id: [0xFF; 16],
number: i as u16,
data: vec![0xDD; 50],
is_last: i == count - 1,
})
.collect()
}

pub(crate) fn new_encoded_test_frames(count: usize) -> Bytes {
let frames = new_test_frames(count);
let mut bytes = Vec::new();
bytes.extend_from_slice(&[DERIVATION_VERSION_0]);
for frame in frames.iter() {
bytes.extend_from_slice(&frame.encode());
}
Bytes::from(bytes)
}

#[tokio::test]
async fn test_frame_queue_empty_bytes() {
let traversal = new_test_traversal(true, true);
let results = vec![Ok(Bytes::from(vec![0x00]))];
let dap = TestDAP { results };
let retrieval = L1Retrieval::new(traversal, dap);
let mut frame_queue = FrameQueue::new(retrieval);
let err = frame_queue.next_frame().await.unwrap_err();
assert_eq!(err, anyhow!("Not enough data").into());
}

#[tokio::test]
async fn test_frame_queue_no_frames_decoded() {
let traversal = new_test_traversal(true, true);
let results = vec![Err(StageError::Eof), Ok(Bytes::default())];
let dap = TestDAP { results };
let retrieval = L1Retrieval::new(traversal, dap);
let mut frame_queue = FrameQueue::new(retrieval);
let err = frame_queue.next_frame().await.unwrap_err();
assert_eq!(err, anyhow!("Not enough data").into());
}

#[tokio::test]
async fn test_frame_queue_wrong_derivation_version() {
let traversal = new_test_traversal(true, true);
let results = vec![Ok(Bytes::from(vec![0x01]))];
let dap = TestDAP { results };
let retrieval = L1Retrieval::new(traversal, dap);
let mut frame_queue = FrameQueue::new(retrieval);
let err = frame_queue.next_frame().await.unwrap_err();
assert_eq!(err, anyhow!("Unsupported derivation version").into());
}

#[tokio::test]
async fn test_frame_queue_frame_too_short() {
let traversal = new_test_traversal(true, true);
let results = vec![Ok(Bytes::from(vec![0x00, 0x01]))];
let dap = TestDAP { results };
let retrieval = L1Retrieval::new(traversal, dap);
let mut frame_queue = FrameQueue::new(retrieval);
let err = frame_queue.next_frame().await.unwrap_err();
assert_eq!(err, anyhow!("Frame too short to decode").into());
}

#[tokio::test]
async fn test_frame_queue_single_frame() {
let data = new_encoded_test_frames(1);
let traversal = new_test_traversal(true, true);
let dap = TestDAP {
results: vec![Ok(data)],
};
let retrieval = L1Retrieval::new(traversal, dap);
let mut frame_queue = FrameQueue::new(retrieval);
let frame_decoded = frame_queue.next_frame().await.unwrap();
let frame = new_test_frames(1);
assert_eq!(frame[0], frame_decoded);
let err = frame_queue.next_frame().await.unwrap_err();
assert_eq!(err, anyhow!("Not enough data").into());
}

#[tokio::test]
async fn test_frame_queue_multiple_frames() {
let data = new_encoded_test_frames(3);
let traversal = new_test_traversal(true, true);
let dap = TestDAP {
results: vec![Ok(data)],
};
let retrieval = L1Retrieval::new(traversal, dap);
let mut frame_queue = FrameQueue::new(retrieval);
for i in 0..3 {
let frame_decoded = frame_queue.next_frame().await.unwrap();
assert_eq!(frame_decoded.number, i);
}
let err = frame_queue.next_frame().await.unwrap_err();
assert_eq!(err, anyhow!("Not enough data").into());
}
}
81 changes: 80 additions & 1 deletion crates/derive/src/stages/l1_retrieval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ where
/// The data availability provider to use for the L1 retrieval stage.
pub provider: DAP,
/// The current data iterator.
data: Option<DAP::DataIter<Bytes>>,
pub(crate) data: Option<DAP::DataIter>,
}

impl<DAP, CP> L1Retrieval<DAP, CP>
Expand Down Expand Up @@ -83,3 +83,82 @@ where
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::stages::l1_traversal::tests::new_test_traversal;
use crate::traits::test_utils::{TestDAP, TestIter};
use alloc::vec;
use alloy_primitives::Address;

#[tokio::test]
async fn test_l1_retrieval_origin() {
let traversal = new_test_traversal(true, true);
let dap = TestDAP { results: vec![] };
let retrieval = L1Retrieval::new(traversal, dap);
let expected = BlockInfo::default();
assert_eq!(retrieval.origin(), Some(&expected));
}

#[tokio::test]
async fn test_l1_retrieval_next_data() {
let traversal = new_test_traversal(true, true);
let results = vec![Err(StageError::Eof), Ok(Bytes::default())];
let dap = TestDAP { results };
let mut retrieval = L1Retrieval::new(traversal, dap);
assert_eq!(retrieval.data, None);
let data = retrieval.next_data().await.unwrap();
assert_eq!(data, Bytes::default());
assert!(retrieval.data.is_some());
let retrieval_data = retrieval.data.as_ref().unwrap();
assert_eq!(retrieval_data.open_data_calls.len(), 1);
assert_eq!(retrieval_data.open_data_calls[0].0, BlockInfo::default());
assert_eq!(retrieval_data.open_data_calls[0].1, Address::default());
// Data should be reset to none and the error should be bubbled up.
let data = retrieval.next_data().await.unwrap_err();
assert_eq!(data, StageError::Eof);
assert!(retrieval.data.is_none());
}

#[tokio::test]
async fn test_l1_retrieval_existing_data_is_respected() {
let data = TestIter {
open_data_calls: vec![(BlockInfo::default(), Address::default())],
results: vec![Ok(Bytes::default())],
};
// Create a new traversal with no blocks or receipts.
// This would bubble up an error if the prev stage
// (traversal) is called in the retrieval stage.
let traversal = new_test_traversal(false, false);
let dap = TestDAP { results: vec![] };
let mut retrieval = L1Retrieval {
prev: traversal,
provider: dap,
data: Some(data),
};
let data = retrieval.next_data().await.unwrap();
assert_eq!(data, Bytes::default());
assert!(retrieval.data.is_some());
let retrieval_data = retrieval.data.as_ref().unwrap();
assert_eq!(retrieval_data.open_data_calls.len(), 1);
}

#[tokio::test]
async fn test_l1_retrieval_existing_data_errors() {
let data = TestIter {
open_data_calls: vec![(BlockInfo::default(), Address::default())],
results: vec![Err(StageError::Eof)],
};
let traversal = new_test_traversal(true, true);
let dap = TestDAP { results: vec![] };
let mut retrieval = L1Retrieval {
prev: traversal,
provider: dap,
data: Some(data),
};
let data = retrieval.next_data().await.unwrap_err();
assert_eq!(data, StageError::Eof);
assert!(retrieval.data.is_none());
}
}
Loading

0 comments on commit 7ee85a3

Please sign in to comment.