diff --git a/kernel/src/actions/deletion_vector.rs b/kernel/src/actions/deletion_vector.rs index 201bdc877..8f5079106 100644 --- a/kernel/src/actions/deletion_vector.rs +++ b/kernel/src/actions/deletion_vector.rs @@ -8,6 +8,7 @@ use delta_kernel_derive::Schema; use roaring::RoaringTreemap; use url::Url; +use crate::utils::require; use crate::{DeltaResult, Error, FileSystemClient}; #[derive(Debug, Clone, PartialEq, Eq, Schema)] @@ -55,11 +56,10 @@ impl DeletionVectorDescriptor { match self.storage_type.as_str() { "u" => { let path_len = self.path_or_inline_dv.len(); - if path_len < 20 { - return Err(Error::deletion_vector( - "Invalid length {path_len}, must be >20", - )); - } + require!( + path_len >= 20, + Error::deletion_vector("Invalid length {path_len}, must be >= 20",) + ); let prefix_len = path_len - 20; let decoded = z85::decode(&self.path_or_inline_dv[prefix_len..]) .map_err(|_| Error::deletion_vector("Failed to decode DV uuid"))?; @@ -128,24 +128,26 @@ impl DeletionVectorDescriptor { .read(&mut version_buf) .map_err(|err| Error::DeletionVector(err.to_string()))?; let version = u8::from_be_bytes(version_buf); - if version != 1 { - return Err(Error::DeletionVector(format!("Invalid version: {version}"))); - } + require!( + version == 1, + Error::DeletionVector(format!("Invalid version: {version}")) + ); if let Some(offset) = offset { cursor.set_position(offset as u64); } let dv_size = read_u32(&mut cursor, Endian::Big)?; - if dv_size != size_in_bytes as u32 { - return Err(Error::DeletionVector(format!( + require!( + dv_size == size_in_bytes as u32, + Error::DeletionVector(format!( "DV size mismatch. Log indicates {size_in_bytes}, file says: {dv_size}" - ))); - } + )) + ); let magic = read_u32(&mut cursor, Endian::Little)?; - - if magic != 1681511377 { - return Err(Error::DeletionVector(format!("Invalid magic: {magic}"))); - } + require!( + magic == 1681511377, + Error::DeletionVector(format!("Invalid magic: {magic}")) + ); // get the Bytes back out and limit it to dv_size let position = cursor.position(); diff --git a/kernel/src/engine/arrow_data.rs b/kernel/src/engine/arrow_data.rs index 91b3f7555..bc6bf8e3a 100644 --- a/kernel/src/engine/arrow_data.rs +++ b/kernel/src/engine/arrow_data.rs @@ -1,5 +1,6 @@ use crate::engine_data::{EngineData, EngineList, EngineMap, GetData}; use crate::schema::{DataType, PrimitiveType, Schema, SchemaRef, StructField}; +use crate::utils::require; use crate::{DataVisitor, DeltaResult, Error}; use arrow_array::cast::AsArray; @@ -271,20 +272,22 @@ impl ArrowEngineData { if let ArrowDataType::Struct(fields) = map_field.data_type() { let mut fcount = 0; for field in fields { - if field.data_type() != &ArrowDataType::Utf8 { - return Err(Error::UnexpectedColumnType(format!( + require!( + field.data_type() == &ArrowDataType::Utf8, + Error::UnexpectedColumnType(format!( "On {}: Only support maps of String->String", field.name() - ))); - } + )) + ); fcount += 1; } - if fcount != 2 { - return Err(Error::UnexpectedColumnType(format!( + require!( + fcount == 2, + Error::UnexpectedColumnType(format!( "On {}: Expect map field struct to have two fields", field.name() - ))); - } + )) + ); debug!("Pushing map for {}", field.name); out_col_array.push(col.as_map()); } else { diff --git a/kernel/src/engine/arrow_utils.rs b/kernel/src/engine/arrow_utils.rs index 629e2ebf0..081ae1756 100644 --- a/kernel/src/engine/arrow_utils.rs +++ b/kernel/src/engine/arrow_utils.rs @@ -2,7 +2,7 @@ use std::sync::Arc; -use crate::{schema::SchemaRef, DeltaResult, Error}; +use crate::{schema::SchemaRef, utils::require, DeltaResult, Error}; use arrow_array::RecordBatch; use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; @@ -33,11 +33,10 @@ pub(crate) fn get_requested_indices( .map(|index| (parquet_index, index)) }) .unzip(); - if mask_indicies.len() != requested_schema.fields.len() { - return Err(Error::generic( - "Didn't find all requested columns in parquet schema", - )); - } + require!( + mask_indicies.len() == requested_schema.fields.len(), + Error::generic("Didn't find all requested columns in parquet schema") + ); Ok((mask_indicies, reorder_indicies)) } diff --git a/kernel/src/engine/sync/json.rs b/kernel/src/engine/sync/json.rs index 1b1c8fe90..855f8ccbe 100644 --- a/kernel/src/engine/sync/json.rs +++ b/kernel/src/engine/sync/json.rs @@ -5,8 +5,8 @@ use std::{ }; use crate::{ - schema::SchemaRef, DeltaResult, EngineData, Error, Expression, FileDataReadResultIterator, - FileMeta, JsonHandler, + schema::SchemaRef, utils::require, DeltaResult, EngineData, Error, Expression, + FileDataReadResultIterator, FileMeta, JsonHandler, }; use arrow_array::{cast::AsArray, RecordBatch}; use arrow_json::ReaderBuilder; @@ -65,9 +65,10 @@ impl JsonHandler for SyncJsonHandler { // TODO: This is taken from the default engine as it's the same. We should share an // implementation at some point let json_strings: RecordBatch = ArrowEngineData::try_from_engine_data(json_strings)?.into(); - if json_strings.num_columns() != 1 { - return Err(Error::missing_column("Expected single column")); - } + require!( + json_strings.num_columns() == 1, + Error::missing_column("Expected single column") + ); let json_strings = json_strings .column(0) diff --git a/kernel/src/expressions/scalars.rs b/kernel/src/expressions/scalars.rs index 5165e7ad3..5f8b9a55d 100644 --- a/kernel/src/expressions/scalars.rs +++ b/kernel/src/expressions/scalars.rs @@ -4,6 +4,7 @@ use std::fmt::{Display, Formatter}; use chrono::{DateTime, NaiveDate, NaiveDateTime, TimeZone, Utc}; use crate::schema::{DataType, PrimitiveType}; +use crate::utils::require; use crate::Error; /// A single value, which can be null. Used for representing literal values @@ -225,9 +226,7 @@ impl PrimitiveType { (base, exp[1..].parse()?) } }; - if base.is_empty() { - return Err(self.parse_error(raw)); - } + require!(!base.is_empty(), self.parse_error(raw)); // now split on any '.' and parse let (int_part, frac_part, frac_digits) = match base.find('.') { @@ -249,9 +248,7 @@ impl PrimitiveType { // most i128::MAX, and 0-i128::MAX doesn't underflow let scale = frac_digits - exp; let scale: i8 = scale.try_into().map_err(|_| self.parse_error(raw))?; - if scale != expected_scale { - return Err(self.parse_error(raw)); - } + require!(scale == expected_scale, self.parse_error(raw)); let int: i128 = match frac_part { None => int_part.parse()?, diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index 1050d7668..7e297b5cb 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -64,6 +64,7 @@ pub mod schema; pub mod snapshot; pub mod table; pub mod transaction; +pub(crate) mod utils; pub use engine_data::{DataVisitor, EngineData}; pub use error::{DeltaResult, Error}; diff --git a/kernel/src/schema.rs b/kernel/src/schema.rs index faba56c6e..ee5f755dd 100644 --- a/kernel/src/schema.rs +++ b/kernel/src/schema.rs @@ -8,6 +8,7 @@ use indexmap::IndexMap; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use crate::utils::require; use crate::{DeltaResult, Error}; pub type Schema = StructType; @@ -342,12 +343,10 @@ where D: serde::Deserializer<'de>, { let str_value = String::deserialize(deserializer)?; - if !str_value.starts_with("decimal(") || !str_value.ends_with(')') { - return Err(serde::de::Error::custom(format!( - "Invalid decimal: {}", - str_value - ))); - } + require!( + str_value.starts_with("decimal(") && str_value.ends_with(')'), + serde::de::Error::custom(format!("Invalid decimal: {}", str_value)) + ); let mut parts = str_value[8..str_value.len() - 1].split(','); let precision = parts diff --git a/kernel/src/snapshot.rs b/kernel/src/snapshot.rs index 15003bcc3..d22850d7a 100644 --- a/kernel/src/snapshot.rs +++ b/kernel/src/snapshot.rs @@ -12,6 +12,7 @@ use url::Url; use crate::actions::{get_log_schema, Metadata, Protocol, METADATA_NAME, PROTOCOL_NAME}; use crate::path::{version_from_location, LogPath}; use crate::schema::{Schema, SchemaRef}; +use crate::utils::require; use crate::{DeltaResult, Engine, Error, FileMeta, FileSystemClient, Version}; use crate::{EngineData, Expression}; @@ -163,10 +164,10 @@ impl Snapshot { .ok_or(Error::MissingVersion)?; // TODO: A more descriptive error if let Some(v) = version { - if version_eff != v { - // TODO more descriptive error - return Err(Error::MissingVersion); - } + require!( + version_eff == v, + Error::MissingVersion // TODO more descriptive error + ); } let log_segment = LogSegment { diff --git a/kernel/src/utils.rs b/kernel/src/utils.rs new file mode 100644 index 000000000..3dbef02c3 --- /dev/null +++ b/kernel/src/utils.rs @@ -0,0 +1,12 @@ +//! Various utility functions/macros used throughout the kernel + +/// convenient way to return an error if a condition isn't true +macro_rules! require { + ( $cond:expr, $err:expr ) => { + if !($cond) { + return Err($err); + } + }; +} + +pub(crate) use require;