Skip to content

Commit

Permalink
fix: add decimal expr parsing Signed-off-by: Ion Koutsouris
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Dec 29, 2024
1 parent cc863c0 commit 8bda53c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
26 changes: 21 additions & 5 deletions crates/core/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,10 @@ impl fmt::Display for ScalarValueFormat<'_> {
ScalarValue::UInt16(e) => format_option!(f, e)?,
ScalarValue::UInt32(e) => format_option!(f, e)?,
ScalarValue::UInt64(e) => format_option!(f, e)?,
ScalarValue::Decimal128(e, precision, scale) => match e {
Some(e) => write!(f, "'{e}'::decimal({precision}, {scale})",)?,
None => write!(f, "NULL")?,
},
ScalarValue::Date32(e) => match e {
Some(e) => write!(
f,
Expand Down Expand Up @@ -657,6 +661,11 @@ mod test {
DataType::Primitive(PrimitiveType::Binary),
true,
),
StructField::new(
"_decimal".to_string(),
DataType::Primitive(PrimitiveType::Decimal(2, 2)),
true,
),
StructField::new(
"_struct".to_string(),
DataType::Struct(Box::new(StructType::new(vec![
Expand Down Expand Up @@ -887,6 +896,18 @@ mod test {
)
)),
},
ParseTest {
expr: col("_decimal").eq(lit(ScalarValue::Decimal128(Some(1),2,2))),
expected: "_decimal = '1'::decimal(2, 2)".to_string(),
override_expected_expr: Some(col("_decimal").eq(
Expr::Cast(
Cast {
expr: Box::from(lit("1")),
data_type: arrow_schema::DataType::Decimal128(2, 2)
}
)
)),
},
];

let session: SessionContext = DeltaSessionContext::default().into();
Expand All @@ -908,11 +929,6 @@ mod test {
}

let unsupported_types = vec![
/* TODO: Determine proper way to display decimal values in an sql expression*/
simple!(
col("money").gt(lit(ScalarValue::Decimal128(Some(100), 12, 2))),
"money > 0.1".to_string()
),
simple!(
col("_timestamp").gt(lit(ScalarValue::TimestampMillisecond(Some(100), None))),
"".to_string()
Expand Down
42 changes: 42 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import os
import pathlib
from decimal import Decimal

import pyarrow as pa
import pyarrow.parquet as pq
Expand Down Expand Up @@ -1149,3 +1150,44 @@ def test_merge_when_wrong_but_castable_type_passed_while_merge(
tmp_path / dt.get_add_actions().column(0)[0].as_py()
).schema
assert table_schema.field("price").type == sample_table["price"].type


def test_merge_on_decimal_3033(tmp_path):
data = {
"timestamp": [datetime.datetime(2024, 3, 20, 12, 30, 0)],
"altitude": [Decimal("150.5")],
}

table = pa.Table.from_pydict(data)

schema = pa.schema(
[
("timestamp", pa.timestamp("us")),
("altitude", pa.decimal128(6, 1)),
]
)

dt = DeltaTable.create(tmp_path, schema=schema)

write_deltalake(dt, table, mode="append")

dt.merge(
source=table,
predicate="target.timestamp = source.timestamp",
source_alias="source",
target_alias="target",
).when_matched_update_all().when_not_matched_insert_all().execute()

dt.merge(
source=table,
predicate="target.timestamp = source.timestamp AND target.altitude = source.altitude",
source_alias="source",
target_alias="target",
).when_matched_update_all().when_not_matched_insert_all().execute()

string_predicate = dt.history(1)[0]["operationParameters"]["predicate"]

assert (
string_predicate
== "timestamp BETWEEN arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND altitude BETWEEN '1505'::decimal(4, 1) AND '1505'::decimal(4, 1)"
)

0 comments on commit 8bda53c

Please sign in to comment.