Skip to content

Commit

Permalink
Add db-specific quoting for non-simple names in AbstractJbdcSource (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cgardens authored Jan 28, 2021
1 parent d488588 commit c34bcd2
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 14 deletions.
34 changes: 34 additions & 0 deletions airbyte-db/src/main/java/io/airbyte/db/jdbc/JdbcUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.Field.JsonSchemaPrimitive;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.Date;
import java.sql.JDBCType;
import java.sql.PreparedStatement;
Expand All @@ -41,8 +42,10 @@
import java.text.SimpleDateFormat;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.StringJoiner;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
Expand Down Expand Up @@ -237,4 +240,35 @@ private interface SQLSupplier<O> {

}

/**
* Given a database connection and identifier, adds db-specific quoting.
*
* @param connection database connection
* @param identifier identifier to quote
* @return quoted identifier
* @throws SQLException throws if there are any issues fulling the quoting metadata from the db.
*/
public static String enquoteIdentifier(Connection connection, String identifier) throws SQLException {
final String identifierQuoteString = connection.getMetaData().getIdentifierQuoteString();

return identifierQuoteString + identifier + identifierQuoteString;
}

/**
* Given a database connection and identifiers, adds db-specific quoting to each identifier.
*
* @param connection database connection
* @param identifiers identifiers to quote
* @return quoted identifiers
* @throws SQLException throws if there are any issues fulling the quoting metadata from the db.
*/
public static String enquoteIdentifierList(Connection connection, List<String> identifiers) throws SQLException {
final StringJoiner joiner = new StringJoiner(",");
for (String col : identifiers) {
String s = JdbcUtils.enquoteIdentifier(connection, col);
joiner.add(s);
}
return joiner.toString();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import io.airbyte.commons.lang.Exceptions;
import io.airbyte.commons.resources.MoreResources;
import io.airbyte.commons.stream.MoreStreams;
import io.airbyte.commons.string.Strings;
import io.airbyte.db.Databases;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.db.jdbc.JdbcStreamingQueryConfiguration;
Expand All @@ -53,6 +52,7 @@
import io.airbyte.protocol.models.Field.JsonSchemaPrimitive;
import io.airbyte.protocol.models.SyncMode;
import java.io.IOException;
import java.sql.Connection;
import java.sql.JDBCType;
import java.sql.PreparedStatement;
import java.sql.SQLException;
Expand Down Expand Up @@ -280,6 +280,7 @@ private List<TableInfo> getTables(final JdbcDatabase database,
final Optional<String> databaseOptional,
final Optional<String> schemaOptional)
throws Exception {

return discoverInternal(database, databaseOptional, schemaOptional).stream()
.map(t -> {
// some databases return multiple copies of the same record for a column (e.g. redshift) because
Expand Down Expand Up @@ -313,6 +314,11 @@ private static void assertColumnsWithSameNameAreSame(String schemaName, String t
});
}

private static String getFullyQualifiedTableNameWithQuoting(Connection connection, String schemaName, String tableName) throws SQLException {
final String quotedTableName = JdbcUtils.enquoteIdentifier(connection, tableName);
return schemaName != null ? JdbcUtils.enquoteIdentifier(connection, schemaName) + "." + quotedTableName : quotedTableName;
}

private static String getFullyQualifiedTableName(String schemaName, String tableName) {
return schemaName != null ? schemaName + "." + tableName : tableName;
}
Expand Down Expand Up @@ -376,7 +382,9 @@ public static Stream<JsonNode> queryTableFullRefresh(JdbcDatabase database, List
throws SQLException {
return database.query(
connection -> {
final String sql = String.format("SELECT %s FROM %s", Strings.join(columnNames, ","), getFullyQualifiedTableName(schemaName, tableName));
final String sql = String.format("SELECT %s FROM %s",
JdbcUtils.enquoteIdentifierList(connection, columnNames),
getFullyQualifiedTableNameWithQuoting(connection, schemaName, tableName));
return connection.prepareStatement(sql);
},
JdbcUtils::rowToJson);
Expand All @@ -390,12 +398,14 @@ public static Stream<JsonNode> queryTableIncremental(JdbcDatabase database,
JDBCType cursorFieldType,
String cursor)
throws SQLException {
final String sql = String.format("SELECT %s FROM %s WHERE %s > ?",
Strings.join(columnNames, ","),
getFullyQualifiedTableName(schemaName, tableName), cursorField);

return database.query(
connection -> {
final String sql = String.format("SELECT %s FROM %s WHERE %s > ?",
JdbcUtils.enquoteIdentifierList(connection, columnNames),
getFullyQualifiedTableNameWithQuoting(connection, schemaName, tableName),
JdbcUtils.enquoteIdentifier(connection, cursorField));

final PreparedStatement preparedStatement = connection.prepareStatement(sql);
JdbcUtils.setStatementField(preparedStatement, 1, cursorFieldType, cursor);
return preparedStatement;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.airbyte.commons.resources.MoreResources;
import io.airbyte.db.Databases;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.db.jdbc.JdbcUtils;
import io.airbyte.integrations.source.jdbc.AbstractJdbcSource;
import io.airbyte.integrations.source.jdbc.models.JdbcState;
import io.airbyte.integrations.source.jdbc.models.JdbcStreamState;
Expand All @@ -57,6 +58,7 @@
import io.airbyte.protocol.models.Field;
import io.airbyte.protocol.models.Field.JsonSchemaPrimitive;
import io.airbyte.protocol.models.SyncMode;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
Expand Down Expand Up @@ -266,6 +268,55 @@ void testReadMultipleTables() throws Exception {
assertEquals(expectedMessages, actualMessages);
}

private ConfiguredAirbyteStream createTableWithSpaces() throws SQLException {
// test table name with space.
final String tableNameWithSpaces = "id and name2";
final String streamName2 = getDefaultSchemaName().map(val -> val + "." + tableNameWithSpaces).orElse(tableNameWithSpaces);;
// test column name with space.
final String lastNameField = "last name";
database.execute(connection -> {
connection.createStatement().execute(String.format("CREATE TABLE %s (id INTEGER, %s VARCHAR(200));",
JdbcUtils.enquoteIdentifier(connection, tableNameWithSpaces), JdbcUtils.enquoteIdentifier(connection, lastNameField)));
connection.createStatement().execute(String.format("INSERT INTO %s (id, %s) VALUES (1,'picard'), (2, 'crusher'), (3, 'vash');",
JdbcUtils.enquoteIdentifier(connection, tableNameWithSpaces), JdbcUtils.enquoteIdentifier(connection, lastNameField)));
});

return CatalogHelpers.createConfiguredAirbyteStream(
streamName2,
Field.of("id", JsonSchemaPrimitive.NUMBER),
Field.of(lastNameField, JsonSchemaPrimitive.STRING));
}

@Test
void testTablesWithQuoting() throws Exception {
final ConfiguredAirbyteStream streamForTableWithSpaces = createTableWithSpaces();

final ConfiguredAirbyteCatalog catalog = new ConfiguredAirbyteCatalog().withStreams(Lists.newArrayList(
getConfiguredCatalog().getStreams().get(0),
streamForTableWithSpaces));
final List<AirbyteMessage> actualMessages = source.read(config, catalog, null).collect(Collectors.toList());

actualMessages.forEach(r -> {
if (r.getRecord() != null) {
r.getRecord().setEmittedAt(null);
}
});

final List<AirbyteMessage> secondStreamExpectedMessages = getTestMessages()
.stream()
.map(Jsons::clone)
.peek(m -> {
m.getRecord().setStream(streamForTableWithSpaces.getStream().getName());
((ObjectNode) m.getRecord().getData()).set("last name", ((ObjectNode) m.getRecord().getData()).remove("name"));
((ObjectNode) m.getRecord().getData()).remove("updated_at");
})
.collect(Collectors.toList());
final List<AirbyteMessage> expectedMessages = new ArrayList<>(getTestMessages());
expectedMessages.addAll(secondStreamExpectedMessages);

assertEquals(expectedMessages, actualMessages);
}

@SuppressWarnings("ResultOfMethodCallIgnored")
@Test
void testReadFailure() {
Expand Down Expand Up @@ -303,6 +354,31 @@ void testIncrementalStringCheckCursor() throws Exception {
Lists.newArrayList(getTestMessages().get(0), getTestMessages().get(2)));
}

@Test
void testIncrementalStringCheckCursorSpaceInColumnName() throws Exception {
final ConfiguredAirbyteStream streamWithSpaces = createTableWithSpaces();

final AirbyteMessage firstMessage = getTestMessages().get(0);
firstMessage.getRecord().setStream(streamWithSpaces.getStream().getName());
((ObjectNode) firstMessage.getRecord().getData()).remove("updated_at");
((ObjectNode) firstMessage.getRecord().getData()).set("last name", ((ObjectNode) firstMessage.getRecord().getData()).remove("name"));

final AirbyteMessage secondMessage = getTestMessages().get(2);
secondMessage.getRecord().setStream(streamWithSpaces.getStream().getName());
((ObjectNode) secondMessage.getRecord().getData()).remove("updated_at");
((ObjectNode) secondMessage.getRecord().getData()).set("last name", ((ObjectNode) secondMessage.getRecord().getData()).remove("name"));

Lists.newArrayList(getTestMessages().get(0), getTestMessages().get(2));

incrementalCursorCheck(
"last name",
"last name",
"patent",
"vash",
Lists.newArrayList(firstMessage, secondMessage),
streamWithSpaces);
}

@Test
void testIncrementalTimestampCheckCursor() throws Exception {
incrementalCursorCheck(
Expand Down Expand Up @@ -462,18 +538,29 @@ private void incrementalCursorCheck(
String endCursorValue,
List<AirbyteMessage> expectedRecordMessages)
throws Exception {
final ConfiguredAirbyteCatalog configuredCatalog = getConfiguredCatalog();
configuredCatalog.getStreams().forEach(airbyteStream -> {
airbyteStream.setSyncMode(SyncMode.INCREMENTAL);
airbyteStream.setCursorField(Lists.newArrayList(cursorField));
});
incrementalCursorCheck(initialCursorField, cursorField, initialCursorValue, endCursorValue, expectedRecordMessages,
getConfiguredCatalog().getStreams().get(0));
}

private void incrementalCursorCheck(
String initialCursorField,
String cursorField,
String initialCursorValue,
String endCursorValue,
List<AirbyteMessage> expectedRecordMessages,
ConfiguredAirbyteStream airbyteStream)
throws Exception {
airbyteStream.setSyncMode(SyncMode.INCREMENTAL);
airbyteStream.setCursorField(Lists.newArrayList(cursorField));

final JdbcState state = new JdbcState()
.withStreams(Lists.newArrayList(new JdbcStreamState()
.withStreamName(streamName)
.withStreamName(airbyteStream.getStream().getName())
.withCursorField(ImmutableList.of(initialCursorField))
.withCursor(initialCursorValue)));

final ConfiguredAirbyteCatalog configuredCatalog = new ConfiguredAirbyteCatalog().withStreams(ImmutableList.of(airbyteStream));

final List<AirbyteMessage> actualMessages = source.read(config, configuredCatalog, Jsons.jsonNode(state)).collect(Collectors.toList());

actualMessages.forEach(r -> {
Expand All @@ -488,7 +575,7 @@ private void incrementalCursorCheck(
.withState(new AirbyteStateMessage()
.withData(Jsons.jsonNode(new JdbcState()
.withStreams(Lists.newArrayList(new JdbcStreamState()
.withStreamName(streamName)
.withStreamName(airbyteStream.getStream().getName())
.withCursorField(ImmutableList.of(cursorField))
.withCursor(endCursorValue)))))));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import org.jooq.SQLDialect;
import org.testcontainers.containers.MySQLContainer;

public class MySqlSourceStandardTest extends StandardSourceTest {
public class MySqlStandardTest extends StandardSourceTest {

private static final String STREAM_NAME = "id_and_name";
private static final String STREAM_NAME2 = "public.starships";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.testcontainers.containers.MySQLContainer;

class MySqlStandardSourceTest extends JdbcSourceStandardTest {
class MySqlJdbcStandardTest extends JdbcSourceStandardTest {

private static final String TEST_USER = "test";
private static final String TEST_PASSWORD = "test";
Expand Down

0 comments on commit c34bcd2

Please sign in to comment.