Skip to content

Commit

Permalink
[#2402] test(spark-connector): Add the integration tests for CTAS and…
Browse files Browse the repository at this point in the history
… ITAS (#2669)

### What changes were proposed in this pull request?
add integrate test for `create table as select xx` and `insert table
select xx`

### Why are the changes needed?
Fix: #2402 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
IT
  • Loading branch information
FANNG1 authored Mar 26, 2024
1 parent 682e13f commit 3847159
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
public abstract class SparkCommonIT extends SparkEnvIT {

// To generate test data for write&read table.
private static final Map<DataType, String> typeConstant =
protected static final Map<DataType, String> typeConstant =
ImmutableMap.of(
DataTypes.IntegerType,
"2",
Expand Down Expand Up @@ -505,6 +505,75 @@ void testCreateSortBucketTable() {
checkTableReadWrite(tableInfo);
}

// Spark CTAS doesn't copy table properties and partition schema from source table.
@Test
void testCreateTableAsSelect() {
String tableName = "ctas_table";
dropTableIfExists(tableName);
createSimpleTable(tableName);
SparkTableInfo tableInfo = getTableInfo(tableName);
checkTableReadWrite(tableInfo);

String newTableName = "new_" + tableName;
dropTableIfExists(newTableName);
createTableAsSelect(tableName, newTableName);

SparkTableInfo newTableInfo = getTableInfo(newTableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create().withName(newTableName).withColumns(getSimpleTableColumn());
checker.check(newTableInfo);

List<String> tableData = getTableData(newTableName);
Assertions.assertTrue(tableData.size() == 1);
Assertions.assertEquals(getExpectedTableData(newTableInfo), tableData.get(0));
}

@Test
void testInsertTableAsSelect() {
String tableName = "insert_select_table";
String newTableName = "new_" + tableName;

dropTableIfExists(tableName);
createSimpleTable(tableName);
SparkTableInfo tableInfo = getTableInfo(tableName);
checkTableReadWrite(tableInfo);

dropTableIfExists(newTableName);
createSimpleTable(newTableName);
insertTableAsSelect(tableName, newTableName);

SparkTableInfo newTableInfo = getTableInfo(newTableName);
String expectedTableData = getExpectedTableData(newTableInfo);
List<String> tableData = getTableData(newTableName);
Assertions.assertTrue(tableData.size() == 1);
Assertions.assertEquals(expectedTableData, tableData.get(0));
}

@Test
void testInsertDatasourceFormatPartitionTableAsSelect() {
String tableName = "insert_select_partition_table";
String newTableName = "new_" + tableName;
dropTableIfExists(tableName);
dropTableIfExists(newTableName);

createSimpleTable(tableName);
String createTableSql = getCreateSimpleTableString(newTableName);
createTableSql += "PARTITIONED BY (name, age)";
sql(createTableSql);

SparkTableInfo tableInfo = getTableInfo(tableName);
checkTableReadWrite(tableInfo);

insertTableAsSelect(tableName, newTableName);

SparkTableInfo newTableInfo = getTableInfo(newTableName);
checkPartitionDirExists(newTableInfo);
String expectedTableData = getExpectedTableData(newTableInfo);
List<String> tableData = getTableData(newTableName);
Assertions.assertTrue(tableData.size() == 1);
Assertions.assertEquals(expectedTableData, tableData.get(0));
}

protected void checkPartitionDirExists(SparkTableInfo table) {
Assertions.assertTrue(table.isPartitionTable(), "Not a partition table");
String tableLocation = table.getTableLocation();
Expand Down Expand Up @@ -539,43 +608,46 @@ protected void checkTableReadWrite(SparkTableInfo table) {
}
sql(insertDataSQL);

// do something to match the query result:
// 1. remove "'" from values, such as 'a' is trans to a
// 2. remove "array" from values, such as array(1, 2, 3) is trans to [1, 2, 3]
// 3. remove "map" from values, such as map('a', 1, 'b', 2) is trans to {a=1, b=2}
// 4. remove "struct" from values, such as struct(1, 'a') is trans to 1,a
String checkValues =
table.getColumns().stream()
.map(columnInfo -> typeConstant.get(columnInfo.getType()))
.map(Object::toString)
.map(
s -> {
String tmp = org.apache.commons.lang3.StringUtils.remove(s, "'");
if (org.apache.commons.lang3.StringUtils.isEmpty(tmp)) {
return tmp;
} else if (tmp.startsWith("array")) {
return tmp.replace("array", "").replace("(", "[").replace(")", "]");
} else if (tmp.startsWith("map")) {
return tmp.replace("map", "")
.replace("(", "{")
.replace(")", "}")
.replace(", ", "=");
} else if (tmp.startsWith("struct")) {
return tmp.replace("struct", "")
.replace("(", "")
.replace(")", "")
.replace(", ", ",");
}
return tmp;
})
.collect(Collectors.joining(","));
String checkValues = getExpectedTableData(table);

List<String> queryResult = getTableData(name);
Assertions.assertTrue(
queryResult.size() == 1, "Should just one row, table content: " + queryResult);
Assertions.assertEquals(checkValues, queryResult.get(0));
}

protected String getExpectedTableData(SparkTableInfo table) {
// Do something to match the query result:
// 1. remove "'" from values, such as 'a' is trans to a
// 2. remove "array" from values, such as array(1, 2, 3) is trans to [1, 2, 3]
// 3. remove "map" from values, such as map('a', 1, 'b', 2) is trans to {a=1, b=2}
// 4. remove "struct" from values, such as struct(1, 'a') is trans to 1,a
return table.getColumns().stream()
.map(columnInfo -> typeConstant.get(columnInfo.getType()))
.map(Object::toString)
.map(
s -> {
String tmp = org.apache.commons.lang3.StringUtils.remove(s, "'");
if (org.apache.commons.lang3.StringUtils.isEmpty(tmp)) {
return tmp;
} else if (tmp.startsWith("array")) {
return tmp.replace("array", "").replace("(", "[").replace(")", "]");
} else if (tmp.startsWith("map")) {
return tmp.replace("map", "")
.replace("(", "{")
.replace(")", "}")
.replace(", ", "=");
} else if (tmp.startsWith("struct")) {
return tmp.replace("struct", "")
.replace("(", "")
.replace(")", "")
.replace(", ", ",");
}
return tmp;
})
.collect(Collectors.joining(","));
}

protected String getCreateSimpleTableString(String tableName) {
return String.format(
"CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', age INT)",
Expand All @@ -595,7 +667,7 @@ protected String getDefaultDatabase() {

// Helper method to create a simple table, and could use corresponding
// getSimpleTableColumn to check table column.
private void createSimpleTable(String identifier) {
protected void createSimpleTable(String identifier) {
String createTableSql = getCreateSimpleTableString(identifier);
sql(createTableSql);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ private void initSparkEnv() {
.config("spark.plugins", GravitinoSparkPlugin.class.getName())
.config(GravitinoSparkConfig.GRAVITINO_URI, gravitinoUri)
.config(GravitinoSparkConfig.GRAVITINO_METALAKE, metalakeName)
.config("hive.exec.dynamic.partition.mode", "nonstrict")
.config(
"spark.sql.warehouse.dir",
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,68 @@ public void testWriteHiveDynamicPartition() {
Path partitionPath = new Path(location, partitionExpression);
checkDirExists(partitionPath);
}

@Test
public void testInsertHiveFormatPartitionTableAsSelect() {
String tableName = "insert_hive_partition_table";
String newTableName = "new_" + tableName;

// create source table
dropTableIfExists(tableName);
createSimpleTable(tableName);
SparkTableInfo tableInfo = getTableInfo(tableName);
checkTableReadWrite(tableInfo);

// insert into partition ((name = %s, age = %s) select xx
dropTableIfExists(newTableName);
String createTableSql =
String.format(
"CREATE TABLE %s (id INT) PARTITIONED BY (name STRING, age INT)", newTableName);
sql(createTableSql);
String insertPartitionSql =
String.format(
"INSERT OVERWRITE TABLE %s PARTITION (name = %s, age = %s) SELECT id FROM %s",
newTableName,
typeConstant.get(DataTypes.StringType),
typeConstant.get(DataTypes.IntegerType),
tableName);
sql(insertPartitionSql);

SparkTableInfo newTableInfo = getTableInfo(newTableName);
checkPartitionDirExists(newTableInfo);
String expectedData = getExpectedTableData(newTableInfo);
List<String> tableData = getTableData(newTableName);
Assertions.assertTrue(tableData.size() == 1);
Assertions.assertEquals(expectedData, tableData.get(0));

// insert into partition ((name = %s, age) select xx
dropTableIfExists(newTableName);
sql(createTableSql);
insertPartitionSql =
String.format(
"INSERT OVERWRITE TABLE %s PARTITION (name = %s, age) SELECT id, age FROM %s",
newTableName, typeConstant.get(DataTypes.StringType), tableName);
sql(insertPartitionSql);

newTableInfo = getTableInfo(newTableName);
checkPartitionDirExists(newTableInfo);
tableData = getTableData(newTableName);
Assertions.assertTrue(tableData.size() == 1);
Assertions.assertEquals(expectedData, tableData.get(0));

// insert into partition ((name, age) select xx
dropTableIfExists(newTableName);
sql(createTableSql);
insertPartitionSql =
String.format(
"INSERT OVERWRITE TABLE %s PARTITION (name , age) SELECT * FROM %s",
newTableName, tableName);
sql(insertPartitionSql);

newTableInfo = getTableInfo(newTableName);
checkPartitionDirExists(newTableInfo);
tableData = getTableData(newTableName);
Assertions.assertTrue(tableData.size() == 1);
Assertions.assertEquals(expectedData, tableData.get(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ protected boolean tableExists(String tableName) {
}
}

protected void createTableAsSelect(String tableName, String newName) {
sql(String.format("CREATE TABLE %s AS SELECT * FROM %s", newName, tableName));
}

protected void insertTableAsSelect(String tableName, String newName) {
sql(String.format("INSERT INTO TABLE %s SELECT * FROM %s", newName, tableName));
}

private static String getSelectAllSql(String tableName) {
return String.format("SELECT * FROM %s", tableName);
}
Expand Down

0 comments on commit 3847159

Please sign in to comment.