Skip to content

Commit

Permalink
Fix build break
Browse files Browse the repository at this point in the history
  • Loading branch information
liurenjie1024 committed Aug 26, 2024
1 parent 434214a commit e99af6e
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 107 deletions.
4 changes: 2 additions & 2 deletions crates/sqllogictest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ toml = "0.8.19"
url = {workspace = true}
iceberg-datafusion = { path = "../integrations/datafusion" }
iceberg-catalog-rest = { path = "../catalog/rest" }

[dev-dependencies]
tokio = "1.38.0"
env_logger = { workspace = true }

[dev-dependencies]
libtest-mimic = "0.7.3"
iceberg_test_utils = { path = "../test_utils", features = ["tests"] }

Expand Down
27 changes: 15 additions & 12 deletions crates/sqllogictest/src/engine/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ use async_trait::async_trait;
use datafusion::physical_plan::common::collect;
use datafusion::physical_plan::execute_stream;
use datafusion::prelude::{SessionConfig, SessionContext};
use sqllogictest::DBOutput;
use sqllogictest::{AsyncDB, DBOutput};
use std::sync::Arc;
use std::time::Duration;
use anyhow::anyhow;
use datafusion::catalog::CatalogProvider;
use toml::Table;
use iceberg_catalog_rest::RestCatalogConfig;
use iceberg_catalog_rest::{RestCatalog, RestCatalogConfig};
use iceberg_datafusion::IcebergCatalogProvider;
use crate::engine::normalize;
use crate::engine::output::{DFColumnType, DFOutput};
use crate::error::{Result, Error};

pub struct DataFusionEngine {
ctx: SessionContext,
Expand All @@ -49,12 +50,12 @@ impl Default for DataFusionEngine {
}

#[async_trait]
impl sqllogictest::AsyncDB for DataFusionEngine {
type Error = anyhow::Error;
impl AsyncDB for DataFusionEngine {
type Error = Error;
type ColumnType = DFColumnType;

async fn run(&mut self, sql: &str) -> anyhow::Result<DFOutput> {
run_query(&self.ctx, sql).await
async fn run(&mut self, sql: &str) -> Result<DFOutput> {
run_query(&self.ctx, sql).await.map_err(Box::new)
}

/// Engine name of current database.
Expand All @@ -72,7 +73,7 @@ impl sqllogictest::AsyncDB for DataFusionEngine {
}
}

async fn run_query(ctx: &SessionContext, sql: impl Into<String>) -> anyhow::Result<DFOutput> {
async fn run_query(ctx: &SessionContext, sql: impl Into<String>) -> Result<DFOutput> {
let df = ctx.sql(sql.into().as_str()).await?;
let task_ctx = Arc::new(df.task_ctx());
let plan = df.create_physical_plan().await?;
Expand All @@ -90,7 +91,7 @@ async fn run_query(ctx: &SessionContext, sql: impl Into<String>) -> anyhow::Resu
}

impl DataFusionEngine {
pub async fn new(configs: &Table) -> anyhow::Result<Self> {
pub async fn new(configs: &Table) -> Result<Self> {
let config = SessionConfig::new()
.with_target_partitions(4);

Expand All @@ -104,14 +105,16 @@ impl DataFusionEngine {

async fn create_catalog(configs: &Table) -> anyhow::Result<Arc<dyn CatalogProvider>> {
let rest_catalog_url = configs.get("url")
.ok_or_else(anyhow!("url not found datafusion engine!"))?
.ok_or_else(|| anyhow!("url not found datafusion engine!"))?
.as_str()
.ok_or_else(anyhow!("url is not str"))?;
.ok_or_else(|| anyhow!("url is not str"))?;

let rest_catalog = RestCatalogConfig::builder()
let rest_catalog_config = RestCatalogConfig::builder()
.uri(rest_catalog_url.to_string())
.build();

Ok(Arc::new(IcebergCatalogProvider::try_new(Arc::new(rest_catalog)).await?))
let rest_catalog = RestCatalog::new(rest_catalog_config);

Ok(Arc::new(IcebergCatalogProvider::try_new(Arc::new(rest_catalog)).await?))
}
}
17 changes: 10 additions & 7 deletions crates/sqllogictest/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use anyhow::anyhow;
use anyhow::{anyhow, bail};
pub use datafusion::*;
use sqllogictest::{strict_column_validator, AsyncDB, MakeConnection, Runner};
use std::sync::Arc;
Expand All @@ -30,6 +30,7 @@ pub use spark::*;

mod datafusion;
pub use datafusion::*;
use crate::error::Result;


#[derive(Clone)]
Expand All @@ -39,7 +40,7 @@ pub enum Engine {
}

impl Engine {
pub async fn new(typ: &str, configs: &Table) -> anyhow::Result<Self> {
pub async fn new(typ: &str, configs: &Table) -> Result<Self> {
let configs = Arc::new(configs.clone());
match typ {
"spark" => {
Expand All @@ -48,30 +49,32 @@ impl Engine {
"datafusion" => {
Ok(Engine::DataFusion(configs))
}
other => Err(anyhow!("Unknown engine type: {other}"))
other => Err(anyhow!("Unknown engine type: {other}").into())
}
}

pub async fn run_slt_file(self, slt_file: impl Into<String>) -> anyhow::Result<()> {
let absolute_file = format!("{}/testdata/slts/{}", env!("CARGO_MANIFEST_DIR"), slt_file);
let absolute_file = format!("{}/testdata/slts/{}", env!("CARGO_MANIFEST_DIR"), slt_file.into());

match self {
Engine::DataFusion(configs) => {
let configs = configs.clone();
let runner = Runner::new(async || DataFusionEngine::new(&*configs).await);
let runner = Runner::new(|| async {
DataFusionEngine::new(&*configs).await
});
Self::run_with_runner(runner, absolute_file).await
}
Engine::SparkSQL(configs) => {
let configs = configs.clone();
let runner = Runner::new(async || {
let runner = Runner::new(|| async {
SparkSqlEngine::new(&*configs).await
});
Self::run_with_runner(runner, absolute_file).await
}
}
}

async fn run_with_runner<D: AsyncDB, M: MakeConnection>(mut runner: Runner<D, M>,
async fn run_with_runner<D: AsyncDB, M: MakeConnection<Conn = D>>(mut runner: Runner<D, M>,
slt_file: String) -> anyhow::Result<()> {
runner.with_column_validator(strict_column_validator);
Ok(runner
Expand Down
99 changes: 22 additions & 77 deletions crates/sqllogictest/src/engine/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::Fields;
use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
use datafusion_common::DataFusionError;
use std::path::PathBuf;
use std::sync::OnceLock;
use arrow_array::{ArrayRef, RecordBatch};
use crate::engine::output::DFColumnType;
use anyhow::anyhow;
use arrow_array::{ArrayRef, BooleanArray, Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, LargeStringArray, RecordBatch, StringArray, StringViewArray};
use arrow_schema::{DataType, Fields};
use datafusion::arrow::util::display::ArrayFormatter;
use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;

use crate::engine::conversion::*;
use crate::engine::datafusion::error::{DFSqlLogicTestError, Result};

/// Converts `batches` to a result as expected by sqllogicteset.
pub(crate) fn convert_batches(batches: Vec<RecordBatch>) -> Result<Vec<Vec<String>>> {
pub(crate) fn convert_batches(batches: Vec<RecordBatch>) -> anyhow::Result<Vec<Vec<String>>> {
if batches.is_empty() {
Ok(vec![])
} else {
Expand All @@ -36,19 +34,17 @@ pub(crate) fn convert_batches(batches: Vec<RecordBatch>) -> Result<Vec<Vec<Strin
for batch in batches {
// Verify schema
if !schema.contains(&batch.schema()) {
return Err(DFSqlLogicTestError::DataFusion(DataFusionError::Internal(
format!(
return Err(anyhow!(
"Schema mismatch. Previously had\n{:#?}\n\nGot:\n{:#?}",
&schema,
batch.schema()
),
)));
);
}

let new_rows = convert_batch(batch)?
.into_iter()
.flat_map(expand_row)
.map(normalize_paths);
.flat_map(expand_row);
rows.extend(new_rows);
}
Ok(rows)
Expand Down Expand Up @@ -77,7 +73,7 @@ pub(crate) fn convert_batches(batches: Vec<RecordBatch>) -> Result<Vec<Vec<Strin
/// "|-- Projection: d.b, MAX(d.a) AS max_a",
/// ]
/// ```
fn expand_row(mut row: Vec<String>) -> impl Iterator<Item = Vec<String>> {
fn expand_row(mut row: Vec<String>) -> impl Iterator<Item=Vec<String>> {
use itertools::Either;
use std::iter::once;

Expand Down Expand Up @@ -115,65 +111,15 @@ fn expand_row(mut row: Vec<String>) -> impl Iterator<Item = Vec<String>> {
}
}

/// normalize path references
///
/// ```text
/// CsvExec: files={1 group: [[path/to/datafusion/testing/data/csv/aggregate_test_100.csv]]}, ...
/// ```
///
/// into:
///
/// ```text
/// CsvExec: files={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, ...
/// ```
fn normalize_paths(mut row: Vec<String>) -> Vec<String> {
row.iter_mut().for_each(|s| {
let workspace_root: &str = workspace_root().as_ref();
if s.contains(workspace_root) {
*s = s.replace(workspace_root, "WORKSPACE_ROOT");
}
});
row
}

/// return the location of the datafusion checkout
fn workspace_root() -> &'static object_store::path::Path {
static WORKSPACE_ROOT_LOCK: OnceLock<object_store::path::Path> = OnceLock::new();
WORKSPACE_ROOT_LOCK.get_or_init(|| {
// e.g. /Software/datafusion/datafusion/core
let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));

// e.g. /Software/datafusion/datafusion
let workspace_root = dir
.parent()
.expect("Can not find parent of datafusion/core")
// e.g. /Software/datafusion
.parent()
.expect("parent of datafusion")
.to_string_lossy();

let sanitized_workplace_root = if cfg!(windows) {
// Object store paths are delimited with `/`, e.g. `/datafusion/datafusion/testing/data/csv/aggregate_test_100.csv`.
// The default windows delimiter is `\`, so the workplace path is `datafusion\datafusion`.
workspace_root
.replace(std::path::MAIN_SEPARATOR, object_store::path::DELIMITER)
} else {
workspace_root.to_string()
};

object_store::path::Path::parse(sanitized_workplace_root).unwrap()
})
}

/// Convert a single batch to a `Vec<Vec<String>>` for comparison
fn convert_batch(batch: RecordBatch) -> Result<Vec<Vec<String>>> {
fn convert_batch(batch: RecordBatch) -> anyhow::Result<Vec<Vec<String>>> {
(0..batch.num_rows())
.map(|row| {
batch
.columns()
.iter()
.map(|col| cell_to_string(col, row))
.collect::<Result<Vec<String>>>()
.collect::<anyhow::Result<Vec<String>>>()
})
.collect()
}
Expand All @@ -196,43 +142,43 @@ macro_rules! get_row_value {
///
/// Floating numbers are rounded to have a consistent representation with the Postgres runner.
///
pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result<String> {
pub fn cell_to_string(col: &ArrayRef, row: usize) -> anyhow::Result<String> {
if !col.is_valid(row) {
// represent any null value with the string "NULL"
Ok(NULL_STR.to_string())
} else {
match col.data_type() {
DataType::Null => Ok(NULL_STR.to_string()),
DataType::Boolean => {
Ok(bool_to_str(get_row_value!(array::BooleanArray, col, row)))
Ok(bool_to_str(get_row_value!(BooleanArray, col, row)))
}
DataType::Float16 => {
Ok(f16_to_str(get_row_value!(array::Float16Array, col, row)))
Ok(f16_to_str(get_row_value!(Float16Array, col, row)))
}
DataType::Float32 => {
Ok(f32_to_str(get_row_value!(array::Float32Array, col, row)))
Ok(f32_to_str(get_row_value!(Float32Array, col, row)))
}
DataType::Float64 => {
Ok(f64_to_str(get_row_value!(array::Float64Array, col, row)))
Ok(f64_to_str(get_row_value!(Float64Array, col, row)))
}
DataType::Decimal128(precision, scale) => {
let value = get_row_value!(array::Decimal128Array, col, row);
let value = get_row_value!(Decimal128Array, col, row);
Ok(i128_to_str(value, precision, scale))
}
DataType::Decimal256(precision, scale) => {
let value = get_row_value!(array::Decimal256Array, col, row);
let value = get_row_value!(Decimal256Array, col, row);
Ok(i256_to_str(value, precision, scale))
}
DataType::LargeUtf8 => Ok(varchar_to_str(get_row_value!(
array::LargeStringArray,
LargeStringArray,
col,
row
))),
DataType::Utf8 => {
Ok(varchar_to_str(get_row_value!(array::StringArray, col, row)))
Ok(varchar_to_str(get_row_value!(StringArray, col, row)))
}
DataType::Utf8View => Ok(varchar_to_str(get_row_value!(
array::StringViewArray,
StringViewArray,
col,
row
))),
Expand All @@ -241,7 +187,6 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result<String> {
Ok(f.unwrap().value(row).to_string())
}
}
.map_err(DFSqlLogicTestError::Arrow)
}
}

Expand Down
19 changes: 14 additions & 5 deletions crates/sqllogictest/src/engine/spark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,30 @@ use itertools::Itertools;
use spark_connect_rs::{SparkSession, SparkSessionBuilder};
use sqllogictest::{AsyncDB, DBOutput};
use std::time::Duration;
use async_trait::async_trait;
use toml::Table;
use crate::error::*;

/// SparkSql engine implementation for sqllogictest.
pub struct SparkSqlEngine {
session: SparkSession,
}

#[async_trait]
impl AsyncDB for SparkSqlEngine {
type Error = anyhow::Error;
type Error = Error;
type ColumnType = DFColumnType;

async fn run(&mut self, sql: &str) -> anyhow::Result<DBOutput<DFColumnType>> {
let results = self.session.sql(sql).await?.collect()?;
async fn run(&mut self, sql: &str) -> Result<DBOutput<DFColumnType>> {
let results = self.session
.sql(sql)
.await
.map_err(Box::new)?
.collect()
.await
.map_err(Box::new)?;
let types = normalize::convert_schema_to_types(results.schema().fields());
let rows = crate::engine::normalize::convert_batches(results)?;
let rows = normalize::convert_batches(results)?;

if rows.is_empty() && types.is_empty() {
Ok(DBOutput::StatementComplete(0))
Expand All @@ -61,7 +70,7 @@ impl AsyncDB for SparkSqlEngine {
}

impl SparkSqlEngine {
pub async fn new(configs: &Table) -> anyhow::Result<Self> {
pub async fn new(configs: &Table) -> Result<Self> {
let url = configs.get("url")
.ok_or_else(|| anyhow!("url property doesn't exist for spark engine"))?
.as_str()
Expand Down
4 changes: 4 additions & 0 deletions crates/sqllogictest/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@


pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
pub type Result<T> = std::result::Result<T, Error>;
2 changes: 2 additions & 0 deletions crates/sqllogictest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@
// [Apache Datafusion](https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest)
mod engine;
pub mod schedule;
mod error;
pub use error::*;
Loading

0 comments on commit e99af6e

Please sign in to comment.