Skip to content

Commit

Permalink
reuse a single function to create the tpch test contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
Blizzara committed Jul 10, 2024
1 parent 585504a commit a1a68b1
Showing 1 changed file with 62 additions and 145 deletions.
207 changes: 62 additions & 145 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,151 +32,22 @@ mod tests {
use std::io::BufReader;
use substrait::proto::Plan;

async fn register_csv(
ctx: &SessionContext,
table_name: &str,
file_path: &str,
) -> Result<()> {
ctx.register_csv(table_name, file_path, CsvReadOptions::default())
.await
}

async fn create_context_tpch1() -> Result<SessionContext> {
let ctx = SessionContext::new();
register_csv(
&ctx,
"FILENAME_PLACEHOLDER_0",
"tests/testdata/tpch/lineitem.csv",
)
.await?;
Ok(ctx)
}

async fn create_context_tpch2() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"),
("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}

async fn create_context_tpch3() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}

async fn create_context_tpch4() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/orders.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}

async fn create_context_tpch5() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/supplier.csv"),
("NATION", "tests/testdata/tpch/nation.csv"),
("REGION", "tests/testdata/tpch/region.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}

async fn create_context_tpch6() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations =
vec![("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/lineitem.csv")];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}
// missing context for query 7,8,9

async fn create_context_tpch10() -> Result<SessionContext> {
async fn create_context(files: Vec<(&str, &str)>) -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
for (table_name, file_path) in files {
ctx.register_csv(table_name, file_path, CsvReadOptions::default())
.await?;
}

Ok(ctx)
}

async fn create_context_tpch11() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/nation.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}

#[tokio::test]
async fn tpch_test_1() -> Result<()> {
let ctx = create_context_tpch1().await?;
let ctx = create_context(vec![(
"FILENAME_PLACEHOLDER_0",
"tests/testdata/tpch/lineitem.csv",
)])
.await?;
let path = "tests/testdata/tpch_substrait_plans/query_1.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand All @@ -200,7 +71,18 @@ mod tests {

#[tokio::test]
async fn tpch_test_2() -> Result<()> {
let ctx = create_context_tpch2().await?;
let ctx = create_context(vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"),
("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"),
])
.await?;
let path = "tests/testdata/tpch_substrait_plans/query_2.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand Down Expand Up @@ -242,7 +124,12 @@ mod tests {

#[tokio::test]
async fn tpch_test_3() -> Result<()> {
let ctx = create_context_tpch3().await?;
let ctx = create_context(vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"),
])
.await?;
let path = "tests/testdata/tpch_substrait_plans/query_3.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand All @@ -267,7 +154,11 @@ mod tests {

#[tokio::test]
async fn tpch_test_4() -> Result<()> {
let ctx = create_context_tpch4().await?;
let ctx = create_context(vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/orders.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"),
])
.await?;
let path = "tests/testdata/tpch_substrait_plans/query_4.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand All @@ -289,7 +180,15 @@ mod tests {

#[tokio::test]
async fn tpch_test_5() -> Result<()> {
let ctx = create_context_tpch5().await?;
let ctx = create_context(vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/supplier.csv"),
("NATION", "tests/testdata/tpch/nation.csv"),
("REGION", "tests/testdata/tpch/region.csv"),
])
.await?;
let path = "tests/testdata/tpch_substrait_plans/query_5.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand Down Expand Up @@ -319,7 +218,11 @@ mod tests {

#[tokio::test]
async fn tpch_test_6() -> Result<()> {
let ctx = create_context_tpch6().await?;
let ctx = create_context(vec![(
"FILENAME_PLACEHOLDER_0",
"tests/testdata/tpch/lineitem.csv",
)])
.await?;
let path = "tests/testdata/tpch_substrait_plans/query_6.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand All @@ -338,7 +241,13 @@ mod tests {
// TODO: missing plan 7, 8, 9
#[tokio::test]
async fn tpch_test_10() -> Result<()> {
let ctx = create_context_tpch10().await?;
let ctx = create_context(vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"),
])
.await?;
let path = "tests/testdata/tpch_substrait_plans/query_10.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand All @@ -365,7 +274,15 @@ mod tests {

#[tokio::test]
async fn tpch_test_11() -> Result<()> {
let ctx = create_context_tpch11().await?;
let ctx = create_context(vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/nation.csv"),
])
.await?;
let path = "tests/testdata/tpch_substrait_plans/query_11.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand Down

0 comments on commit a1a68b1

Please sign in to comment.