Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improve][Connector-v2] Use regex to match filedName placeholders in jdbc sink #8222

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.seatunnel.connectors.seatunnel.jdbc.internal.executor;

import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

Expand Down Expand Up @@ -47,6 +49,8 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.apache.seatunnel.shade.com.google.common.base.Preconditions.checkArgument;
import static org.apache.seatunnel.shade.com.google.common.base.Preconditions.checkNotNull;
Expand Down Expand Up @@ -669,29 +673,26 @@ public static FieldNamedPreparedStatement prepareStatement(
connection.prepareStatement(parsedSQL), indexMapping);
}

private static String parseNamedStatement(String sql, Map<String, List<Integer>> paramMap) {
StringBuilder parsedSql = new StringBuilder();
int fieldIndex = 1; // SQL statement parameter index starts from 1
int length = sql.length();
for (int i = 0; i < length; i++) {
char c = sql.charAt(i);
if (':' == c) {
int j = i + 1;
while (j < length && Character.isJavaIdentifierPart(sql.charAt(j))) {
j++;
}
String parameterName = sql.substring(i + 1, j);
checkArgument(
!parameterName.isEmpty(),
"Named parameters in SQL statement must not be empty.");
paramMap.computeIfAbsent(parameterName, n -> new ArrayList<>()).add(fieldIndex);
fieldIndex++;
i = j - 1;
parsedSql.append('?');
} else {
parsedSql.append(c);
}
@VisibleForTesting
public static String parseNamedStatement(String sql, Map<String, List<Integer>> paramMap) {
dailai marked this conversation as resolved.
Show resolved Hide resolved
Pattern pattern =
Pattern.compile(":([\\p{L}\\p{Nl}\\p{Nd}\\p{Pc}\\$\\-\\.@%&*#~!?^+=<>|]+)");
dailai marked this conversation as resolved.
Show resolved Hide resolved
Matcher matcher = pattern.matcher(sql);

StringBuffer result = new StringBuffer();
int fieldIndex = 1;

while (matcher.find()) {
String parameterName = matcher.group(1);
checkArgument(
!parameterName.isEmpty(),
"Named parameters in SQL statement must not be empty.");
paramMap.computeIfAbsent(parameterName, n -> new ArrayList<>()).add(fieldIndex++);
matcher.appendReplacement(result, "?");
}
return parsedSql.toString();

matcher.appendTail(result);

return result.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.connectors.seatunnel.jdbc.internal.executor;

import org.junit.jupiter.api.Test;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class FieldNamedPreparedStatementTest {

private static final String[] SPECIAL_FILEDNAMES =
new String[] {
"USER@TOKEN",
"字段%名称",
"field_name",
"field.name",
"field-name",
"$fieldName",
"field&key",
"field*value",
"field#1",
"field~test",
"field!data",
"field?question",
"field^caret",
"field+add",
"field=value",
"fieldmax",
"field|pipe"
};

@Test
public void testParseNamedStatementWithSpecialCharacters() {
String sql =
"INSERT INTO `nhp_emr_ws`.`cm_prescriptiondetails_cs` (`USER@TOKEN`, `字段%名称`, `field_name`, `field.name`, `field-name`, `$fieldName`, `field&key`, `field*value`, `field#1`, `field~test`, `field!data`, `field?question`, `field^caret`, `field+add`, `field=value`, `fieldmax`, `field|pipe`) VALUES (:USER@TOKEN, :字段%名称, :field_name, :field.name, :field-name, :$fieldName, :field&key, :field*value, :field#1, :field~test, :field!data, :field?question, :field^caret, :field+add, :field=value, :fieldmax, :field|pipe) ON DUPLICATE KEY UPDATE `USER@TOKEN`=VALUES(`USER@TOKEN`), `字段%名称`=VALUES(`字段%名称`), `field_name`=VALUES(`field_name`), `field.name`=VALUES(`field.name`), `field-name`=VALUES(`field-name`), `$fieldName`=VALUES(`$fieldName`), `field&key`=VALUES(`field&key`), `field*value`=VALUES(`field*value`), `field#1`=VALUES(`field#1`), `field~test`=VALUES(`field~test`), `field!data`=VALUES(`field!data`), `field?question`=VALUES(`field?question`), `field^caret`=VALUES(`field^caret`), `field+add`=VALUES(`field+add`), `field=value`=VALUES(`field=value`), `fieldmax`=VALUES(`fieldmax`), `field|pipe`=VALUES(`field|pipe`)";

String exceptPreparedstatement =
"INSERT INTO `nhp_emr_ws`.`cm_prescriptiondetails_cs` (`USER@TOKEN`, `字段%名称`, `field_name`, `field.name`, `field-name`, `$fieldName`, `field&key`, `field*value`, `field#1`, `field~test`, `field!data`, `field?question`, `field^caret`, `field+add`, `field=value`, `fieldmax`, `field|pipe`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE `USER@TOKEN`=VALUES(`USER@TOKEN`), `字段%名称`=VALUES(`字段%名称`), `field_name`=VALUES(`field_name`), `field.name`=VALUES(`field.name`), `field-name`=VALUES(`field-name`), `$fieldName`=VALUES(`$fieldName`), `field&key`=VALUES(`field&key`), `field*value`=VALUES(`field*value`), `field#1`=VALUES(`field#1`), `field~test`=VALUES(`field~test`), `field!data`=VALUES(`field!data`), `field?question`=VALUES(`field?question`), `field^caret`=VALUES(`field^caret`), `field+add`=VALUES(`field+add`), `field=value`=VALUES(`field=value`), `fieldmax`=VALUES(`fieldmax`), `field|pipe`=VALUES(`field|pipe`)";

Map<String, List<Integer>> paramMap = new HashMap<>();
String actualSQL = FieldNamedPreparedStatement.parseNamedStatement(sql, paramMap);
assertEquals(exceptPreparedstatement, actualSQL);
for (int i = 0; i < SPECIAL_FILEDNAMES.length; i++) {
assertTrue(paramMap.containsKey(SPECIAL_FILEDNAMES[i]));
assertEquals(i + 1, paramMap.get(SPECIAL_FILEDNAMES[i]).get(0));
}
}

@Test
public void testParseNamedStatement() {
String sql = "UPDATE table SET col1 = :param1, col2 = :param1 WHERE col3 = :param2";
Map<String, List<Integer>> paramMap = new HashMap<>();
String expectedSQL = "UPDATE table SET col1 = ?, col2 = ? WHERE col3 = ?";

String actualSQL = FieldNamedPreparedStatement.parseNamedStatement(sql, paramMap);

assertEquals(expectedSQL, actualSQL);
assertTrue(paramMap.containsKey("param1"));
assertTrue(paramMap.containsKey("param2"));
assertEquals(1, paramMap.get("param1").get(0).intValue());
assertEquals(2, paramMap.get("param1").get(1).intValue());
assertEquals(3, paramMap.get("param2").get(0).intValue());
}

@Test
public void testParseNamedStatementWithNoNamedParameters() {
String sql = "SELECT * FROM table";
Map<String, List<Integer>> paramMap = new HashMap<>();
String expectedSQL = "SELECT * FROM table";

String actualSQL = FieldNamedPreparedStatement.parseNamedStatement(sql, paramMap);

assertEquals(expectedSQL, actualSQL);
assertTrue(paramMap.isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public class JdbcMysqlIT extends AbstractJdbcIT {
private static final String CREATE_SQL =
"CREATE TABLE IF NOT EXISTS %s\n"
+ "(\n"
+ " `c_bit_1` bit(1) DEFAULT NULL,\n"
+ " `c-bit_1` bit(1) DEFAULT NULL,\n"
+ " `c_bit_8` bit(8) DEFAULT NULL,\n"
+ " `c_bit_16` bit(16) DEFAULT NULL,\n"
+ " `c_bit_32` bit(32) DEFAULT NULL,\n"
Expand Down Expand Up @@ -191,7 +191,7 @@ protected void checkResult(
String executeKey, TestContainer container, Container.ExecResult execResult) {
String[] fieldNames =
new String[] {
"c_bit_1",
"c-bit_1",
"c_bit_8",
"c_bit_16",
"c_bit_32",
Expand Down Expand Up @@ -249,7 +249,7 @@ String driverUrl() {
Pair<String[], List<SeaTunnelRow>> initTestData() {
String[] fieldNames =
new String[] {
"c_bit_1",
"c-bit_1",
"c_bit_8",
"c_bit_16",
"c_bit_32",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ sink {
user = "root"
password = "Abc!@#135_seatunnel"

query = """insert into sink (c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
query = """insert into sink (`c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ CREATE TABLE sink_table WITH (


INSERT INTO sink_table
SELECT c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
SELECT `c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ sink {
user = "root"
password = "Abc!@#135_seatunnel"
connection_check_timeout_sec = 100
query = """insert into sink (c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
query = """insert into sink (`c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ CREATE TABLE sink_table WITH (


CREATE TABLE temp1 AS
SELECT c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
SELECT `c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand All @@ -58,4 +58,4 @@ CREATE TABLE temp1 AS


INSERT INTO sink_table SELECT * FROM temp1;

Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ sink {
user = "root"
password = "Abc!@#135_seatunnel"
connection_check_timeout_sec = 100
query = """insert into sink (c_bit_1, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
query = """insert into sink (`c-bit_1`, c_bit_8, c_bit_16, c_bit_32, c_bit_64, c_boolean, c_tinyint, c_tinyint_unsigned, c_smallint, c_smallint_unsigned,
c_mediumint, c_mediumint_unsigned, c_int, c_integer, c_bigint, c_bigint_unsigned,
c_decimal, c_decimal_unsigned, c_float, c_float_unsigned, c_double, c_double_unsigned,
c_char, c_tinytext, c_mediumtext, c_text, c_varchar, c_json, c_longtext, c_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public SeaTunnelRowType typeMapping(List<String> inputColumnsMapping) {
for (SelectItem selectItem : selectItems) {
if (selectItem.getExpression() instanceof AllColumns) {
for (int i = 0; i < inputRowType.getFieldNames().length; i++) {
fieldNames[idx] = inputRowType.getFieldName(i);
fieldNames[idx] = cleanEscape(inputRowType.getFieldName(i));
seaTunnelDataTypes[idx] = inputRowType.getFieldType(i);
if (inputColumnsMapping != null) {
inputColumnsMapping.set(idx, inputRowType.getFieldName(i));
Expand All @@ -194,16 +194,12 @@ public SeaTunnelRowType typeMapping(List<String> inputColumnsMapping) {
Expression expression = selectItem.getExpression();
if (selectItem.getAlias() != null) {
String aliasName = selectItem.getAlias().getName();
if (aliasName.startsWith(ESCAPE_IDENTIFIER)
&& aliasName.endsWith(ESCAPE_IDENTIFIER)) {
aliasName = aliasName.substring(1, aliasName.length() - 1);
}
fieldNames[idx] = aliasName;
fieldNames[idx] = cleanEscape(aliasName);
} else {
if (expression instanceof Column) {
fieldNames[idx] = ((Column) expression).getColumnName();
fieldNames[idx] = cleanEscape(((Column) expression).getColumnName());
} else {
fieldNames[idx] = expression.toString();
fieldNames[idx] = cleanEscape(expression.toString());
}
}

Expand All @@ -225,6 +221,13 @@ public SeaTunnelRowType typeMapping(List<String> inputColumnsMapping) {
fieldNames, seaTunnelDataTypes, lateralViews, inputColumnsMapping);
}

private static String cleanEscape(String columnName) {
if (columnName.startsWith(ESCAPE_IDENTIFIER) && columnName.endsWith(ESCAPE_IDENTIFIER)) {
columnName = columnName.substring(1, columnName.length() - 1);
}
return columnName;
}

@Override
public List<SeaTunnelRow> transformBySQL(SeaTunnelRow inputRow, SeaTunnelRowType outRowType) {
// ------Physical Query Plan Execution------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,13 @@ public void testEscapeIdentifier() {
ReadonlyConfig.fromMap(
Collections.singletonMap(
"query",
"select id, trim(`apply`) as `apply` from test where `apply` = 'a'"));
"select `id`, trim(`apply`) as `apply` from test where `apply` = 'a'"));
SQLTransform sqlTransform = new SQLTransform(config, table);
TableSchema tableSchema = sqlTransform.transformTableSchema();
List<SeaTunnelRow> result =
sqlTransform.transformRow(
new SeaTunnelRow(new Object[] {Integer.valueOf(1), String.valueOf("a")}));
Assertions.assertEquals("id", tableSchema.getFieldNames()[0]);
Assertions.assertEquals("apply", tableSchema.getFieldNames()[1]);
Assertions.assertEquals("a", result.get(0).getField(1));
result =
Expand Down
Loading