Skip to content

Commit

Permalink
Add a require! macro and use it. (#204)
Browse files Browse the repository at this point in the history
Add a macro that simplifies the case of checking a condition and
returning an error if it doesn't hold.

Closes: #148
  • Loading branch information
nicklan authored May 16, 2024
1 parent 663585d commit e49bf51
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 51 deletions.
34 changes: 18 additions & 16 deletions kernel/src/actions/deletion_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use delta_kernel_derive::Schema;
use roaring::RoaringTreemap;
use url::Url;

use crate::utils::require;
use crate::{DeltaResult, Error, FileSystemClient};

#[derive(Debug, Clone, PartialEq, Eq, Schema)]
Expand Down Expand Up @@ -55,11 +56,10 @@ impl DeletionVectorDescriptor {
match self.storage_type.as_str() {
"u" => {
let path_len = self.path_or_inline_dv.len();
if path_len < 20 {
return Err(Error::deletion_vector(
"Invalid length {path_len}, must be >20",
));
}
require!(
path_len >= 20,
Error::deletion_vector("Invalid length {path_len}, must be >= 20",)
);
let prefix_len = path_len - 20;
let decoded = z85::decode(&self.path_or_inline_dv[prefix_len..])
.map_err(|_| Error::deletion_vector("Failed to decode DV uuid"))?;
Expand Down Expand Up @@ -128,24 +128,26 @@ impl DeletionVectorDescriptor {
.read(&mut version_buf)
.map_err(|err| Error::DeletionVector(err.to_string()))?;
let version = u8::from_be_bytes(version_buf);
if version != 1 {
return Err(Error::DeletionVector(format!("Invalid version: {version}")));
}
require!(
version == 1,
Error::DeletionVector(format!("Invalid version: {version}"))
);

if let Some(offset) = offset {
cursor.set_position(offset as u64);
}
let dv_size = read_u32(&mut cursor, Endian::Big)?;
if dv_size != size_in_bytes as u32 {
return Err(Error::DeletionVector(format!(
require!(
dv_size == size_in_bytes as u32,
Error::DeletionVector(format!(
"DV size mismatch. Log indicates {size_in_bytes}, file says: {dv_size}"
)));
}
))
);
let magic = read_u32(&mut cursor, Endian::Little)?;

if magic != 1681511377 {
return Err(Error::DeletionVector(format!("Invalid magic: {magic}")));
}
require!(
magic == 1681511377,
Error::DeletionVector(format!("Invalid magic: {magic}"))
);

// get the Bytes back out and limit it to dv_size
let position = cursor.position();
Expand Down
19 changes: 11 additions & 8 deletions kernel/src/engine/arrow_data.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::engine_data::{EngineData, EngineList, EngineMap, GetData};
use crate::schema::{DataType, PrimitiveType, Schema, SchemaRef, StructField};
use crate::utils::require;
use crate::{DataVisitor, DeltaResult, Error};

use arrow_array::cast::AsArray;
Expand Down Expand Up @@ -271,20 +272,22 @@ impl ArrowEngineData {
if let ArrowDataType::Struct(fields) = map_field.data_type() {
let mut fcount = 0;
for field in fields {
if field.data_type() != &ArrowDataType::Utf8 {
return Err(Error::UnexpectedColumnType(format!(
require!(
field.data_type() == &ArrowDataType::Utf8,
Error::UnexpectedColumnType(format!(
"On {}: Only support maps of String->String",
field.name()
)));
}
))
);
fcount += 1;
}
if fcount != 2 {
return Err(Error::UnexpectedColumnType(format!(
require!(
fcount == 2,
Error::UnexpectedColumnType(format!(
"On {}: Expect map field struct to have two fields",
field.name()
)));
}
))
);
debug!("Pushing map for {}", field.name);
out_col_array.push(col.as_map());
} else {
Expand Down
11 changes: 5 additions & 6 deletions kernel/src/engine/arrow_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::sync::Arc;

use crate::{schema::SchemaRef, DeltaResult, Error};
use crate::{schema::SchemaRef, utils::require, DeltaResult, Error};

use arrow_array::RecordBatch;
use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef};
Expand Down Expand Up @@ -33,11 +33,10 @@ pub(crate) fn get_requested_indices(
.map(|index| (parquet_index, index))
})
.unzip();
if mask_indicies.len() != requested_schema.fields.len() {
return Err(Error::generic(
"Didn't find all requested columns in parquet schema",
));
}
require!(
mask_indicies.len() == requested_schema.fields.len(),
Error::generic("Didn't find all requested columns in parquet schema")
);
Ok((mask_indicies, reorder_indicies))
}

Expand Down
11 changes: 6 additions & 5 deletions kernel/src/engine/sync/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use std::{
};

use crate::{
schema::SchemaRef, DeltaResult, EngineData, Error, Expression, FileDataReadResultIterator,
FileMeta, JsonHandler,
schema::SchemaRef, utils::require, DeltaResult, EngineData, Error, Expression,
FileDataReadResultIterator, FileMeta, JsonHandler,
};
use arrow_array::{cast::AsArray, RecordBatch};
use arrow_json::ReaderBuilder;
Expand Down Expand Up @@ -65,9 +65,10 @@ impl JsonHandler for SyncJsonHandler {
// TODO: This is taken from the default engine as it's the same. We should share an
// implementation at some point
let json_strings: RecordBatch = ArrowEngineData::try_from_engine_data(json_strings)?.into();
if json_strings.num_columns() != 1 {
return Err(Error::missing_column("Expected single column"));
}
require!(
json_strings.num_columns() == 1,
Error::missing_column("Expected single column")
);
let json_strings =
json_strings
.column(0)
Expand Down
9 changes: 3 additions & 6 deletions kernel/src/expressions/scalars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::fmt::{Display, Formatter};
use chrono::{DateTime, NaiveDate, NaiveDateTime, TimeZone, Utc};

use crate::schema::{DataType, PrimitiveType};
use crate::utils::require;
use crate::Error;

/// A single value, which can be null. Used for representing literal values
Expand Down Expand Up @@ -225,9 +226,7 @@ impl PrimitiveType {
(base, exp[1..].parse()?)
}
};
if base.is_empty() {
return Err(self.parse_error(raw));
}
require!(!base.is_empty(), self.parse_error(raw));

// now split on any '.' and parse
let (int_part, frac_part, frac_digits) = match base.find('.') {
Expand All @@ -249,9 +248,7 @@ impl PrimitiveType {
// most i128::MAX, and 0-i128::MAX doesn't underflow
let scale = frac_digits - exp;
let scale: i8 = scale.try_into().map_err(|_| self.parse_error(raw))?;
if scale != expected_scale {
return Err(self.parse_error(raw));
}
require!(scale == expected_scale, self.parse_error(raw));

let int: i128 = match frac_part {
None => int_part.parse()?,
Expand Down
1 change: 1 addition & 0 deletions kernel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ pub mod schema;
pub mod snapshot;
pub mod table;
pub mod transaction;
pub(crate) mod utils;

pub use engine_data::{DataVisitor, EngineData};
pub use error::{DeltaResult, Error};
Expand Down
11 changes: 5 additions & 6 deletions kernel/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use indexmap::IndexMap;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use crate::utils::require;
use crate::{DeltaResult, Error};

pub type Schema = StructType;
Expand Down Expand Up @@ -342,12 +343,10 @@ where
D: serde::Deserializer<'de>,
{
let str_value = String::deserialize(deserializer)?;
if !str_value.starts_with("decimal(") || !str_value.ends_with(')') {
return Err(serde::de::Error::custom(format!(
"Invalid decimal: {}",
str_value
)));
}
require!(
str_value.starts_with("decimal(") && str_value.ends_with(')'),
serde::de::Error::custom(format!("Invalid decimal: {}", str_value))
);

let mut parts = str_value[8..str_value.len() - 1].split(',');
let precision = parts
Expand Down
9 changes: 5 additions & 4 deletions kernel/src/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use url::Url;
use crate::actions::{get_log_schema, Metadata, Protocol, METADATA_NAME, PROTOCOL_NAME};
use crate::path::{version_from_location, LogPath};
use crate::schema::{Schema, SchemaRef};
use crate::utils::require;
use crate::{DeltaResult, Engine, Error, FileMeta, FileSystemClient, Version};
use crate::{EngineData, Expression};

Expand Down Expand Up @@ -163,10 +164,10 @@ impl Snapshot {
.ok_or(Error::MissingVersion)?; // TODO: A more descriptive error

if let Some(v) = version {
if version_eff != v {
// TODO more descriptive error
return Err(Error::MissingVersion);
}
require!(
version_eff == v,
Error::MissingVersion // TODO more descriptive error
);
}

let log_segment = LogSegment {
Expand Down
12 changes: 12 additions & 0 deletions kernel/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//! Various utility functions/macros used throughout the kernel
/// convenient way to return an error if a condition isn't true
macro_rules! require {
( $cond:expr, $err:expr ) => {
if !($cond) {
return Err($err);
}
};
}

pub(crate) use require;

0 comments on commit e49bf51

Please sign in to comment.