Skip to content

Commit

Permalink
fix: pass transaction tests
Browse files Browse the repository at this point in the history
  • Loading branch information
roeap authored and rtyler committed May 26, 2024
1 parent 1b1e431 commit a02ca5e
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 77 deletions.
2 changes: 1 addition & 1 deletion crates/core/src/kernel/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod schema;
pub use actions::*;
pub use schema::*;

#[derive(Debug, Hash, PartialEq, Eq, Clone)]
#[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)]
/// The type of action that was performed on the table
pub enum ActionType {
/// modify the data in a table by adding individual logical files
Expand Down
101 changes: 84 additions & 17 deletions crates/core/src/kernel/snapshot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
//!
//!
use std::collections::{HashMap, HashSet};
use std::io::{BufRead, BufReader, Cursor};
use std::sync::Arc;

use ::serde::{Deserialize, Serialize};
use arrow_array::RecordBatch;
use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt};
use hashbrown::HashSet;
use object_store::path::Path;
use object_store::ObjectStore;

Expand All @@ -31,6 +31,7 @@ use self::parse::{read_adds, read_removes};
use self::replay::{LogMapper, LogReplayScanner, ReplayStream};
use super::{
Action, Add, AddCDCFile, CommitInfo, DataType, Metadata, Protocol, Remove, StructField,
Transaction,
};
use crate::kernel::parse::read_cdf_adds;
use crate::kernel::{ActionType, StructType};
Expand Down Expand Up @@ -206,7 +207,7 @@ impl Snapshot {
pub fn files<'a>(
&self,
store: Arc<dyn ObjectStore>,
visitors: Vec<&'a mut dyn ReplayVisitor>,
visitors: &'a mut Vec<Box<dyn ReplayVisitor>>,
) -> DeltaResult<ReplayStream<'a, BoxStream<'_, DeltaResult<RecordBatch>>>> {
let mut schema_actions: HashSet<_> =
visitors.iter().flat_map(|v| v.required_actions()).collect();
Expand Down Expand Up @@ -361,6 +362,11 @@ impl Snapshot {
#[derive(Debug, Clone, PartialEq)]
pub struct EagerSnapshot {
snapshot: Snapshot,
// additional actions that should be tracked during log replay.
tracked_actions: HashSet<ActionType>,

transactions: Option<HashMap<String, Transaction>>,

// NOTE: this is a Vec of RecordBatch instead of a single RecordBatch because
// we do not yet enforce a consistent schema across all batches we read from the log.
files: Vec<RecordBatch>,
Expand All @@ -374,7 +380,7 @@ impl EagerSnapshot {
config: DeltaTableConfig,
version: Option<i64>,
) -> DeltaResult<Self> {
Self::try_new_with_visitor(table_root, store, config, version, vec![]).await
Self::try_new_with_visitor(table_root, store, config, version, Default::default()).await
}

/// Create a new [`EagerSnapshot`] instance
Expand All @@ -383,11 +389,42 @@ impl EagerSnapshot {
store: Arc<dyn ObjectStore>,
config: DeltaTableConfig,
version: Option<i64>,
visitors: Vec<&mut dyn ReplayVisitor>,
tracked_actions: HashSet<ActionType>,
) -> DeltaResult<Self> {
let mut visitors = tracked_actions
.iter()
.flat_map(|a| get_visitor(a))
.collect::<Vec<_>>();
let snapshot = Snapshot::try_new(table_root, store.clone(), config, version).await?;
let files = snapshot.files(store, visitors)?.try_collect().await?;
Ok(Self { snapshot, files })
let files = snapshot.files(store, &mut visitors)?.try_collect().await?;

let mut sn = Self {
snapshot,
files,
tracked_actions,
transactions: None,
};

sn.process_visitors(visitors)?;

Ok(sn)
}

fn process_visitors(&mut self, visitors: Vec<Box<dyn ReplayVisitor>>) -> DeltaResult<()> {
for visitor in visitors {
if let Some(tv) = visitor
.as_ref()
.as_any()
.downcast_ref::<AppTransactionVisitor>()
{
if self.transactions.is_none() {
self.transactions = Some(tv.app_transaction_version.clone());
} else {
self.transactions = Some(tv.merge(self.transactions.as_ref().unwrap()));
}
}
}
Ok(())
}

#[cfg(test)]
Expand All @@ -401,15 +438,19 @@ impl EagerSnapshot {
.into_iter()
.map(|b| mapper.map_batch(b))
.collect::<DeltaResult<Vec<_>>>()?;
Ok(Self { snapshot, files })
Ok(Self {
snapshot,
files,
tracked_actions: Default::default(),
transactions: None,
})
}

/// Update the snapshot to the given version
pub async fn update<'a>(
&mut self,
log_store: Arc<dyn LogStore>,
target_version: Option<i64>,
visitors: Vec<&'a mut dyn ReplayVisitor>,
) -> DeltaResult<()> {
if Some(self.version()) == target_version {
return Ok(());
Expand Down Expand Up @@ -438,13 +479,23 @@ impl EagerSnapshot {
.boxed()
};
let mapper = LogMapper::try_new(&self.snapshot, None)?;
let files =
ReplayStream::try_new(log_stream, checkpoint_stream, &self.snapshot, visitors)?
.map(|batch| batch.and_then(|b| mapper.map_batch(b)))
.try_collect()
.await?;
let mut visitors = self
.tracked_actions
.iter()
.flat_map(|a| get_visitor(a))
.collect::<Vec<_>>();
let files = ReplayStream::try_new(
log_stream,
checkpoint_stream,
&self.snapshot,
&mut visitors,
)?
.map(|batch| batch.and_then(|b| mapper.map_batch(b)))
.try_collect()
.await?;

self.files = files;
self.process_visitors(visitors)?;
}
Ok(())
}
Expand Down Expand Up @@ -517,11 +568,21 @@ impl EagerSnapshot {
Ok(self.files.iter().flat_map(|b| read_cdf_adds(b)).flatten())
}

/// Iterate over all latest app transactions
pub fn transactions(&self) -> DeltaResult<impl Iterator<Item = Transaction> + '_> {
self.transactions
.as_ref()
.map(|t| t.values().cloned())
.ok_or(DeltaTableError::Generic(
"Transactions are not available. Please enable tracking of transactions."
.to_string(),
))
}

/// Advance the snapshot based on the given commit actions
pub fn advance<'a>(
&mut self,
commits: impl IntoIterator<Item = &'a CommitData>,
mut visitors: Vec<&'a mut dyn ReplayVisitor>,
) -> DeltaResult<i64> {
let mut metadata = None;
let mut protocol = None;
Expand Down Expand Up @@ -550,6 +611,11 @@ impl EagerSnapshot {

let mut files = Vec::new();
let mut scanner = LogReplayScanner::new();
let mut visitors = self
.tracked_actions
.iter()
.flat_map(|a| get_visitor(a))
.collect::<Vec<_>>();

for batch in actions {
let batch = batch?;
Expand Down Expand Up @@ -583,6 +649,7 @@ impl EagerSnapshot {
if let Some(protocol) = protocol {
self.snapshot.protocol = protocol;
}
self.process_visitors(visitors)?;

Ok(self.snapshot.version())
}
Expand Down Expand Up @@ -715,7 +782,7 @@ mod tests {
assert_eq!(tombstones.len(), 31);

let batches = snapshot
.files(store.clone(), vec![])?
.files(store.clone(), &mut vec![])?
.try_collect::<Vec<_>>()
.await?;
let expected = [
Expand Down Expand Up @@ -745,7 +812,7 @@ mod tests {
)
.await?;
let batches = snapshot
.files(store.clone(), vec![])?
.files(store.clone(), &mut vec![])?
.try_collect::<Vec<_>>()
.await?;
let num_files = batches.iter().map(|b| b.num_rows() as i64).sum::<i64>();
Expand Down Expand Up @@ -848,7 +915,7 @@ mod tests {
Vec::new(),
)];

let new_version = snapshot.advance(&actions, vec![])?;
let new_version = snapshot.advance(&actions)?;
assert_eq!(new_version, version + 1);

let new_files = snapshot.file_actions()?.map(|f| f.path).collect::<Vec<_>>();
Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/kernel/snapshot/replay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pin_project! {

mapper: Arc<LogMapper>,

visitors: Vec<&'a mut dyn ReplayVisitor>,
visitors: &'a mut Vec<Box<dyn ReplayVisitor>>,

#[pin]
commits: S,
Expand All @@ -48,7 +48,7 @@ impl<'a, S> ReplayStream<'a, S> {
commits: S,
checkpoint: S,
snapshot: &Snapshot,
visitors: Vec<&'a mut dyn ReplayVisitor>,
visitors: &'a mut Vec<Box<dyn ReplayVisitor>>,
) -> DeltaResult<Self> {
let stats_schema = Arc::new((&snapshot.stats_schema(None)?).try_into()?);
let mapper = Arc::new(LogMapper {
Expand Down
15 changes: 14 additions & 1 deletion crates/core/src/kernel/snapshot/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ impl Serialize for EagerSnapshot {
{
let mut seq = serializer.serialize_seq(None)?;
seq.serialize_element(&self.snapshot)?;
seq.serialize_element(&self.tracked_actions)?;
seq.serialize_element(&self.transactions)?;
for batch in self.files.iter() {
let mut buffer = vec![];
let mut writer = FileWriter::try_new(&mut buffer, batch.schema().as_ref())
Expand Down Expand Up @@ -155,6 +157,12 @@ impl<'de> Visitor<'de> for EagerSnapshotVisitor {
let snapshot = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
let tracked_actions = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
let transactions = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(2, &self))?;
let mut files = Vec::new();
while let Some(elem) = seq.next_element::<Vec<u8>>()? {
let mut reader =
Expand All @@ -169,7 +177,12 @@ impl<'de> Visitor<'de> for EagerSnapshotVisitor {
})?;
files.push(rb);
}
Ok(EagerSnapshot { snapshot, files })
Ok(EagerSnapshot {
snapshot,
files,
tracked_actions,
transactions,
})
}
}

Expand Down
21 changes: 17 additions & 4 deletions crates/core/src/kernel/snapshot/visitors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,23 @@ use crate::kernel::Transaction;

/// Allows hooking into the reading of commit files and checkpoints whenever a table is loaded or updated.
pub trait ReplayVisitor: std::fmt::Debug + Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;

/// Process a batch
fn visit_batch(&mut self, batch: &RecordBatch) -> DeltaResult<()>;

/// return all relevant actions for the visitor
fn required_actions(&self) -> Vec<ActionType>;
}

/// Get the relevant visitor for the given action type
pub fn get_visitor(action: &ActionType) -> Option<Box<dyn ReplayVisitor>> {
match action {
ActionType::Txn => Some(Box::new(AppTransactionVisitor::new())),
_ => None,
}
}

#[derive(Debug, Default)]
pub(crate) struct AppTransactionVisitor {
pub(crate) app_transaction_version: HashMap<String, Transaction>,
Expand All @@ -35,17 +45,20 @@ impl AppTransactionVisitor {
}

impl AppTransactionVisitor {
pub fn merge(self, map: &HashMap<String, Transaction>) -> HashMap<String, Transaction> {
pub fn merge(&self, map: &HashMap<String, Transaction>) -> HashMap<String, Transaction> {
let mut clone = map.clone();
for (key, value) in self.app_transaction_version {
clone.insert(key, value);
for (key, value) in &self.app_transaction_version {
clone.insert(key.clone(), value.clone());
}

return clone;
}
}

impl ReplayVisitor for AppTransactionVisitor {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn visit_batch(&mut self, batch: &arrow_array::RecordBatch) -> DeltaResult<()> {
if batch.column_by_name("txn").is_none() {
return Ok(());
Expand Down
19 changes: 12 additions & 7 deletions crates/core/src/operations/transaction/application.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#[cfg(test)]
mod tests {
use crate::{
checkpoints, kernel::Txn, operations::transaction::CommitProperties, protocol::SaveMode,
writer::test_utils::get_record_batch, DeltaOps, DeltaTableBuilder,
checkpoints, kernel::Transaction, operations::transaction::CommitProperties,
protocol::SaveMode, writer::test_utils::get_record_batch, DeltaOps, DeltaTableBuilder,
};

#[tokio::test]
Expand All @@ -24,7 +24,8 @@ mod tests {
.with_save_mode(SaveMode::ErrorIfExists)
.with_partition_columns(["modified"])
.with_commit_properties(
CommitProperties::default().with_application_transaction(Txn::new("my-app", 1)),
CommitProperties::default()
.with_application_transaction(Transaction::new("my-app", 1)),
)
.await
.unwrap();
Expand All @@ -51,7 +52,8 @@ mod tests {
let table = DeltaOps::from(table)
.write(vec![get_record_batch(None, false)])
.with_commit_properties(
CommitProperties::default().with_application_transaction(Txn::new("my-app", 3)),
CommitProperties::default()
.with_application_transaction(Transaction::new("my-app", 3)),
)
.await
.unwrap();
Expand Down Expand Up @@ -94,7 +96,8 @@ mod tests {
.with_save_mode(SaveMode::ErrorIfExists)
.with_partition_columns(["modified"])
.with_commit_properties(
CommitProperties::default().with_application_transaction(Txn::new(&"my-app", 1)),
CommitProperties::default()
.with_application_transaction(Transaction::new(&"my-app", 1)),
)
.await
.unwrap();
Expand All @@ -109,7 +112,8 @@ mod tests {
let table = DeltaOps::from(table)
.write(vec![get_record_batch(None, false)])
.with_commit_properties(
CommitProperties::default().with_application_transaction(Txn::new(&"my-app", 2)),
CommitProperties::default()
.with_application_transaction(Transaction::new(&"my-app", 2)),
)
.await
.unwrap();
Expand All @@ -118,7 +122,8 @@ mod tests {
let res = DeltaOps::from(table2)
.write(vec![get_record_batch(None, false)])
.with_commit_properties(
CommitProperties::default().with_application_transaction(Txn::new(&"my-app", 3)),
CommitProperties::default()
.with_application_transaction(Transaction::new(&"my-app", 3)),
)
.await;

Expand Down
Loading

0 comments on commit a02ca5e

Please sign in to comment.