Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generated columns [WIP] #3123

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ use crate::kernel::{Add, DataCheck, EagerSnapshot, Invariant, Snapshot, StructTy
use crate::logstore::LogStoreRef;
use crate::table::builder::ensure_table_uri;
use crate::table::state::DeltaTableState;
use crate::table::Constraint;
use crate::table::{Constraint, GeneratedColumn};
use crate::{open_table, open_table_with_storage_options, DeltaTable};

pub(crate) const PATH_COLUMN: &str = "__delta_rs_path";
Expand Down Expand Up @@ -1159,6 +1159,7 @@ pub(crate) async fn execute_plan_to_batch(
pub struct DeltaDataChecker {
constraints: Vec<Constraint>,
invariants: Vec<Invariant>,
generated_columns: Vec<GeneratedColumn>,
non_nullable_columns: Vec<String>,
ctx: SessionContext,
}
Expand All @@ -1169,6 +1170,7 @@ impl DeltaDataChecker {
Self {
invariants: vec![],
constraints: vec![],
generated_columns: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
Expand All @@ -1179,6 +1181,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints: vec![],
generated_columns: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
Expand All @@ -1189,6 +1192,7 @@ impl DeltaDataChecker {
Self {
constraints,
invariants: vec![],
generated_columns: vec![],
non_nullable_columns: vec![],
ctx: DeltaSessionContext::default().into(),
}
Expand All @@ -1209,6 +1213,10 @@ impl DeltaDataChecker {
/// Create a new DeltaDataChecker
pub fn new(snapshot: &DeltaTableState) -> Self {
let invariants = snapshot.schema().get_invariants().unwrap_or_default();
let generated_columns = snapshot
.schema()
.get_generated_columns()
.unwrap_or_default();
let constraints = snapshot.table_config().get_constraints();
let non_nullable_columns = snapshot
.schema()
Expand All @@ -1224,6 +1232,7 @@ impl DeltaDataChecker {
Self {
invariants,
constraints,
generated_columns,
non_nullable_columns,
ctx: DeltaSessionContext::default().into(),
}
Expand All @@ -1236,7 +1245,9 @@ impl DeltaDataChecker {
pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> {
self.check_nullability(record_batch)?;
self.enforce_checks(record_batch, &self.invariants).await?;
self.enforce_checks(record_batch, &self.constraints).await
self.enforce_checks(record_batch, &self.constraints).await?;
self.enforce_checks(record_batch, &self.generated_columns)
.await
}

/// Return true if all the nullability checks are valid
Expand Down
9 changes: 9 additions & 0 deletions crates/core/src/kernel/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ pub enum Error {
line: String,
},

/// Error returned when the log contains invalid stats JSON.
#[error("Invalid JSON in generation expression, line=`{line}`, err=`{json_err}`")]
InvalidGenerationExpressionJson {
/// JSON error details returned when parsing the generation expression JSON.
json_err: serde_json::error::Error,
/// Generation expression.
line: String,
},

#[error("Table metadata is invalid: {0}")]
MetadataError(String),

Expand Down
101 changes: 90 additions & 11 deletions crates/core/src/kernel/models/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use std::collections::{HashMap, HashSet};
use std::fmt::{self, Display};
use std::str::FromStr;

use delta_kernel::schema::{DataType, StructField};
use maplit::hashset;
use serde::{Deserialize, Serialize};
use tracing::warn;
use url::Url;

use super::schema::StructType;
use super::StructTypeExt;
use crate::kernel::{error::Error, DeltaResult};
use crate::TableProperty;
use delta_kernel::table_features::{ReaderFeatures, WriterFeatures};
Expand Down Expand Up @@ -115,6 +117,19 @@ impl Metadata {
}
}

/// checks if table contains timestamp_ntz in any field including nested fields.
pub fn contains_timestampntz<'a>(mut fields: impl Iterator<Item = &'a StructField>) -> bool {
fn _check_type(dtype: &DataType) -> bool {
match dtype {
&DataType::TIMESTAMP_NTZ => true,
DataType::Array(inner) => _check_type(inner.element_type()),
DataType::Struct(inner) => inner.fields().any(|f| _check_type(f.data_type())),
_ => false,
}
}
fields.any(|f| _check_type(f.data_type()))
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
#[serde(rename_all = "camelCase")]
/// Defines a protocol action
Expand Down Expand Up @@ -146,8 +161,8 @@ impl Protocol {
}
}

/// set the reader features in the protocol action, automatically bumps min_reader_version
pub fn with_reader_features(
/// Append the reader features in the protocol action, automatically bumps min_reader_version
pub fn append_reader_features(
mut self,
reader_features: impl IntoIterator<Item = impl Into<ReaderFeatures>>,
) -> Self {
Expand All @@ -156,14 +171,20 @@ impl Protocol {
.map(Into::into)
.collect::<HashSet<_>>();
if !all_reader_features.is_empty() {
self.min_reader_version = 3
self.min_reader_version = 3;
match self.reader_features {
Some(mut features) => {
features.extend(all_reader_features);
self.reader_features = Some(features);
}
None => self.reader_features = Some(all_reader_features),
};
}
self.reader_features = Some(all_reader_features);
self
}

/// set the writer features in the protocol action, automatically bumps min_writer_version
pub fn with_writer_features(
/// Append the writer features in the protocol action, automatically bumps min_writer_version
pub fn append_writer_features(
mut self,
writer_features: impl IntoIterator<Item = impl Into<WriterFeatures>>,
) -> Self {
Expand All @@ -172,9 +193,16 @@ impl Protocol {
.map(|c| c.into())
.collect::<HashSet<_>>();
if !all_writer_feautures.is_empty() {
self.min_writer_version = 7
self.min_writer_version = 7;

match self.writer_features {
Some(mut features) => {
features.extend(all_writer_feautures);
self.writer_features = Some(features);
}
None => self.writer_features = Some(all_writer_feautures),
};
}
self.writer_features = Some(all_writer_feautures);
self
}

Expand Down Expand Up @@ -255,6 +283,32 @@ impl Protocol {
}
self
}

/// Will apply the column metadata to the protocol by either bumping the version or setting
/// features
pub fn apply_column_metadata_to_protocol(
mut self,
schema: &StructType,
) -> DeltaResult<Protocol> {
let generated_cols = schema.get_generated_columns()?;
let invariants = schema.get_invariants()?;
let contains_timestamp_ntz = self.contains_timestampntz(schema.fields());

if contains_timestamp_ntz {
self = self.enable_timestamp_ntz()
}

if !generated_cols.is_empty() {
self = self.enable_generated_columns()
}

if !invariants.is_empty() {
self = self.enable_invariants()
}

Ok(self)
}

/// Will apply the properties to the protocol by either bumping the version or setting
/// features
pub fn apply_properties_to_protocol(
Expand Down Expand Up @@ -391,10 +445,35 @@ impl Protocol {
}
Ok(self)
}

/// checks if table contains timestamp_ntz in any field including nested fields.
fn contains_timestampntz<'a>(&self, fields: impl Iterator<Item = &'a StructField>) -> bool {
contains_timestampntz(fields)
}

/// Enable timestamp_ntz in the protocol
pub fn enable_timestamp_ntz(mut self) -> Protocol {
self = self.with_reader_features(vec![ReaderFeatures::TimestampWithoutTimezone]);
self = self.with_writer_features(vec![WriterFeatures::TimestampWithoutTimezone]);
fn enable_timestamp_ntz(mut self) -> Self {
self = self.append_reader_features([ReaderFeatures::TimestampWithoutTimezone]);
self = self.append_writer_features([WriterFeatures::TimestampWithoutTimezone]);
self
}

/// Enabled generated columns
fn enable_generated_columns(mut self) -> Self {
if self.min_writer_version < 4 {
self.min_writer_version = 4;
}
if self.min_writer_version >= 7 {
self = self.append_writer_features([WriterFeatures::GeneratedColumns]);
}
self
}

/// Enabled generated columns
fn enable_invariants(mut self) -> Self {
if self.min_writer_version >= 7 {
self = self.append_writer_features([WriterFeatures::Invariants]);
}
self
}
}
Expand Down
31 changes: 31 additions & 0 deletions crates/core/src/kernel/models/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use serde_json::Value;

use crate::kernel::error::Error;
use crate::kernel::DataCheck;
use crate::table::GeneratedColumn;

/// Type alias for a top level schema
pub type Schema = StructType;
Expand Down Expand Up @@ -49,9 +50,39 @@ impl DataCheck for Invariant {
pub trait StructTypeExt {
/// Get all invariants in the schemas
fn get_invariants(&self) -> Result<Vec<Invariant>, Error>;

/// Get all generated column expressions
fn get_generated_columns(&self) -> Result<Vec<GeneratedColumn>, Error>;
}

impl StructTypeExt for StructType {
/// Get all get_generated_columns in the schemas
fn get_generated_columns(&self) -> Result<Vec<GeneratedColumn>, Error> {
let mut remaining_fields: Vec<(String, StructField)> = self
.fields()
.map(|field| (field.name.clone(), field.clone()))
.collect();
let mut generated_cols: Vec<GeneratedColumn> = Vec::new();

while let Some((field_path, field)) = remaining_fields.pop() {
if let Some(MetadataValue::String(generated_col_string)) = field
.metadata
.get(ColumnMetadataKey::GenerationExpression.as_ref())
{
let json: Value = serde_json::from_str(generated_col_string).map_err(|e| {
Error::InvalidGenerationExpressionJson {
json_err: e,
line: generated_col_string.to_string(),
}
})?;
if let Value::String(sql) = json {
generated_cols.push(GeneratedColumn::new(&field_path, &sql, field.data_type()));
}
}
}
Ok(generated_cols)
}

/// Get all invariants in the schemas
fn get_invariants(&self) -> Result<Vec<Invariant>, Error> {
let mut remaining_fields: Vec<(String, StructField)> = self
Expand Down
39 changes: 19 additions & 20 deletions crates/core/src/operations/add_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use itertools::Itertools;

use super::transaction::{CommitBuilder, CommitProperties, PROTOCOL};
use super::{CustomExecuteHandler, Operation};
use crate::kernel::StructField;
use crate::kernel::{StructField, StructTypeExt};
use crate::logstore::LogStoreRef;
use crate::operations::cast::merge_schema::merge_delta_struct;
use crate::protocol::DeltaOperation;
Expand Down Expand Up @@ -85,27 +85,26 @@ impl std::future::IntoFuture for AddColumnBuilder {
this.pre_execute(operation_id).await?;

let fields_right = &StructType::new(fields.clone());

if !fields_right
.get_generated_columns()
.unwrap_or_default()
.is_empty()
{
return Err(DeltaTableError::Generic(
"New columns cannot be a generated column".to_string(),
));
}

let table_schema = this.snapshot.schema();
let new_table_schema = merge_delta_struct(table_schema, fields_right)?;

// TODO(ion): Think of a way how we can simply this checking through the API or centralize some checks.
let contains_timestampntz = PROTOCOL.contains_timestampntz(fields.iter());
let protocol = this.snapshot.protocol();

let maybe_new_protocol = if contains_timestampntz {
let updated_protocol = protocol.clone().enable_timestamp_ntz();
if !(protocol.min_reader_version == 3 && protocol.min_writer_version == 7) {
// Convert existing properties to features since we advanced the protocol to v3,7
Some(
updated_protocol
.move_table_properties_into_features(&metadata.configuration),
)
} else {
Some(updated_protocol)
}
} else {
None
};
let current_protocol = this.snapshot.protocol();

let new_protocol = current_protocol
.clone()
.apply_column_metadata_to_protocol(&new_table_schema)?
.move_table_properties_into_features(&metadata.configuration);

let operation = DeltaOperation::AddColumn {
fields: fields.into_iter().collect_vec(),
Expand All @@ -115,7 +114,7 @@ impl std::future::IntoFuture for AddColumnBuilder {

let mut actions = vec![metadata.into()];

if let Some(new_protocol) = maybe_new_protocol {
if current_protocol != &new_protocol {
actions.push(new_protocol.into())
}

Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/operations/add_feature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ impl std::future::IntoFuture for AddTableFeatureBuilder {
}
}

protocol = protocol.with_reader_features(reader_features);
protocol = protocol.with_writer_features(writer_features);
protocol = protocol.append_reader_features(reader_features);
protocol = protocol.append_writer_features(writer_features);

let operation = DeltaOperation::AddFeature {
name: name.to_vec(),
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/operations/cdc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ mod tests {
#[tokio::test]
async fn test_should_write_cdc_v7_table_with_writer_feature() {
let protocol =
Protocol::new(1, 7).with_writer_features(vec![WriterFeatures::ChangeDataFeed]);
Protocol::new(1, 7).append_writer_features(vec![WriterFeatures::ChangeDataFeed]);
let actions = vec![Action::Protocol(protocol)];
let mut table: DeltaTable = DeltaOps::new_in_memory()
.create()
Expand Down
Loading
Loading