diff --git a/airbyte-integrations/connectors/source-jdbc/src/main/java/io/airbyte/integrations/source/jdbc/AbstractJdbcSource.java b/airbyte-integrations/connectors/source-jdbc/src/main/java/io/airbyte/integrations/source/jdbc/AbstractJdbcSource.java index ba4f620d5b40a..774fb1e447ec2 100644 --- a/airbyte-integrations/connectors/source-jdbc/src/main/java/io/airbyte/integrations/source/jdbc/AbstractJdbcSource.java +++ b/airbyte-integrations/connectors/source-jdbc/src/main/java/io/airbyte/integrations/source/jdbc/AbstractJdbcSource.java @@ -55,8 +55,10 @@ import java.sql.PreparedStatement; import java.sql.SQLException; import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -137,8 +139,11 @@ public AirbyteCatalog discover(JsonNode config) throws Exception { .stream() .map(t -> CatalogHelpers.createAirbyteStream(t.getName(), t.getFields()) .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey( - t.getPrimaryKeys().stream().filter(Objects::nonNull).map(Collections::singletonList).collect(Collectors.toList()))) + .withSourceDefinedPrimaryKey(t.getPrimaryKeys() + .stream() + .filter(Objects::nonNull) + .map(Collections::singletonList) + .collect(Collectors.toList()))) .collect(Collectors.toList())); } } @@ -280,8 +285,9 @@ private List getTables(final JdbcDatabase database, final Optional databaseOptional, final Optional schemaOptional) throws Exception { - - return discoverInternal(database, databaseOptional, schemaOptional).stream() + final List tableInfos = discoverInternal(database, databaseOptional, schemaOptional); + final Map> tablePrimaryKeys = discoverPrimaryKeys(database, databaseOptional, schemaOptional, tableInfos); + return tableInfos.stream() .map(t -> { // some databases return multiple copies of the same record for a column (e.g. redshift) because // they have at least once delivery guarantees. we want to dedupe these, but first we check that the @@ -292,12 +298,76 @@ private List getTables(final JdbcDatabase database, .map(f -> Field.of(f.getColumnName(), JdbcUtils.getType(f.getColumnType()))) .distinct() .collect(Collectors.toList()); - - return new TableInfo(JdbcUtils.getFullyQualifiedTableName(t.getSchemaName(), t.getName()), fields, t.getPrimaryKeys()); + final String streamName = JdbcUtils.getFullyQualifiedTableName(t.getSchemaName(), t.getName()); + final List primaryKeys = tablePrimaryKeys.getOrDefault(streamName, Collections.emptyList()); + return new TableInfo(streamName, fields, primaryKeys); }) .collect(Collectors.toList()); } + /** + * Discover Primary keys for each table and @return a map of schema.table name to their associated + * list of primary key fields. + * + * When invoking the conn.getMetaData().getPrimaryKeys() function without a table name, it may fail + * on some databases (for example MySql) but works on others (for instance Postgres). To avoid + * making repeated queries to the DB, we try to get all primary keys without specifying a table + * first, if it doesn't work, we retry one table at a time. + */ + private Map> discoverPrimaryKeys(JdbcDatabase database, + Optional databaseOptional, + Optional schemaOptional, + List tableInfos) { + try { + // Get all primary keys without specifying a table name + final Map> tablePrimaryKeys = aggregatePrimateKeys(database.bufferedResultSetQuery( + conn -> conn.getMetaData().getPrimaryKeys(databaseOptional.orElse(null), schemaOptional.orElse(null), null), + r -> { + final String schemaName = + r.getObject(JDBC_COLUMN_SCHEMA_NAME) != null ? r.getString(JDBC_COLUMN_SCHEMA_NAME) : r.getString(JDBC_COLUMN_DATABASE_NAME); + final String streamName = JdbcUtils.getFullyQualifiedTableName(schemaName, r.getString(JDBC_COLUMN_TABLE_NAME)); + final String primaryKey = r.getString(JDBC_COLUMN_COLUMN_NAME); + return new SimpleImmutableEntry<>(streamName, primaryKey); + })); + if (!tablePrimaryKeys.isEmpty()) { + return tablePrimaryKeys; + } + } catch (SQLException e) { + LOGGER.debug(String.format("Could not retrieve primary keys without a table name (%s), retrying", e)); + } + // Get primary keys one table at a time + return tableInfos.stream() + .collect(Collectors.toMap( + tableInfo -> JdbcUtils.getFullyQualifiedTableName(tableInfo.getSchemaName(), tableInfo.getName()), + tableInfo -> { + final String streamName = JdbcUtils.getFullyQualifiedTableName(tableInfo.getSchemaName(), tableInfo.getName()); + try { + final Map> primaryKeys = aggregatePrimateKeys(database.bufferedResultSetQuery( + conn -> conn.getMetaData().getPrimaryKeys(databaseOptional.orElse(null), tableInfo.getSchemaName(), tableInfo.getName()), + r -> new SimpleImmutableEntry<>(streamName, r.getString(JDBC_COLUMN_COLUMN_NAME)))); + return primaryKeys.getOrDefault(streamName, Collections.emptyList()); + } catch (SQLException e) { + LOGGER.error(String.format("Could not retrieve primary keys for %s: %s", streamName, e)); + return Collections.emptyList(); + } + })); + } + + /** + * Aggregate list of @param entries of StreamName and PrimaryKey and + * @return a map by StreamName to associated list of primary keys + */ + private static Map> aggregatePrimateKeys(List> entries) { + final Map> result = new HashMap<>(); + entries.forEach(entry -> { + if (!result.containsKey(entry.getKey())) { + result.put(entry.getKey(), new ArrayList<>()); + } + result.get(entry.getKey()).add(entry.getValue()); + }); + return result; + } + private static void assertColumnsWithSameNameAreSame(String schemaName, String tableName, List columns) { columns.stream() .collect(Collectors.groupingBy(ColumnInfo::getColumnName)) @@ -320,7 +390,7 @@ private List discoverInternal(final JdbcDatabase database, final Optional schemaOptional) throws Exception { final Set internalSchemas = new HashSet<>(getExcludedInternalSchemas()); - final List result = database.bufferedResultSetQuery( + return database.bufferedResultSetQuery( conn -> conn.getMetaData().getColumns(databaseOptional.orElse(null), schemaOptional.orElse(null), null, null), resultSet -> Jsons.jsonNode(ImmutableMap.builder() // we always want a namespace, if we cannot get a schema, use db name. @@ -358,17 +428,6 @@ private List discoverInternal(final JdbcDatabase database, }) .collect(Collectors.toList()))) .collect(Collectors.toList()); - result.forEach(t -> { - try { - final List primaryKeys = database.bufferedResultSetQuery( - conn -> conn.getMetaData().getPrimaryKeys(databaseOptional.orElse(null), t.getSchemaName(), t.getName()), - resultSet -> resultSet.getString(JDBC_COLUMN_COLUMN_NAME)); - t.addPrimaryKeys(primaryKeys); - } catch (SQLException e) { - LOGGER.warn(String.format("Could not find primary keys for %s.%s: %s", t.getSchemaName(), t.getName(), e)); - } - }); - return result; } private static AutoCloseableIterator getMessageIterator(AutoCloseableIterator recordIterator, @@ -481,13 +540,11 @@ protected static class TableInfoInternal { private final String schemaName; private final String name; private final List fields; - private final List primaryKeys; public TableInfoInternal(String schemaName, String tableName, List fields) { this.schemaName = schemaName; this.name = tableName; this.fields = fields; - this.primaryKeys = new ArrayList<>(); } public String getSchemaName() { @@ -502,14 +559,6 @@ public List getFields() { return fields; } - public void addPrimaryKeys(List primaryKeys) { - this.primaryKeys.addAll(primaryKeys); - } - - public List getPrimaryKeys() { - return primaryKeys; - } - } protected static class ColumnInfo { diff --git a/airbyte-integrations/connectors/source-jdbc/src/testFixtures/java/io/airbyte/integrations/source/jdbc/test/JdbcSourceStandardTest.java b/airbyte-integrations/connectors/source-jdbc/src/testFixtures/java/io/airbyte/integrations/source/jdbc/test/JdbcSourceStandardTest.java index 81752a9bb79cb..17b17fbcf9aab 100644 --- a/airbyte-integrations/connectors/source-jdbc/src/testFixtures/java/io/airbyte/integrations/source/jdbc/test/JdbcSourceStandardTest.java +++ b/airbyte-integrations/connectors/source-jdbc/src/testFixtures/java/io/airbyte/integrations/source/jdbc/test/JdbcSourceStandardTest.java @@ -89,7 +89,7 @@ public abstract class JdbcSourceStandardTest { private static final String TABLE_NAME = "id_and_name"; private static final String TABLE_NAME_WITHOUT_PK = "id_and_name_without_pk"; - private static final String TABLE_NAME_FULL_NAMES = "full_names"; + private static final String TABLE_NAME_COMPOSITE_PK = "full_name_composite_pk"; private JsonNode config; private JdbcDatabase database; @@ -164,11 +164,11 @@ public void setup() throws Exception { connection.createStatement() .execute( String.format("CREATE TABLE %s(first_name VARCHAR(200), last_name VARCHAR(200), updated_at DATE, PRIMARY KEY (first_name, last_name));", - getFullyQualifiedTableName(TABLE_NAME_FULL_NAMES))); + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK))); connection.createStatement().execute( String.format( "INSERT INTO %s(first_name, last_name, updated_at) VALUES ('first' ,'picard', '2004-10-19'), ('second', 'crusher', '2005-10-19'), ('third', 'vash', '2006-10-19');", - getFullyQualifiedTableName(TABLE_NAME_FULL_NAMES))); + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK))); }); } @@ -627,7 +627,7 @@ private static AirbyteCatalog getCatalog(final String defaultNamespace) { .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) .withSourceDefinedPrimaryKey(Collections.emptyList()), CatalogHelpers.createAirbyteStream( - defaultNamespace + "." + TABLE_NAME_FULL_NAMES, + defaultNamespace + "." + TABLE_NAME_COMPOSITE_PK, Field.of("first_name", JsonSchemaPrimitive.STRING), Field.of("last_name", JsonSchemaPrimitive.STRING), Field.of("updated_at", JsonSchemaPrimitive.STRING))