Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HackerOne specific changes #1

Merged
merged 11 commits into from
Jun 11, 2024
35 changes: 29 additions & 6 deletions tap_postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def sync_method_for_streams(streams, state, default_replication_method):
continue

if replication_method == 'LOG_BASED' and stream_metadata.get((), {}).get('is-view'):
raise Exception(f'Logical Replication is NOT supported for views. ' \
f'Please change the replication method for {stream["tap_stream_id"]}')
continue
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a quick fix for not having to specify each entity that's a view in the Meltano select. Long term we can look at downloading the catalog, modifying all streams marked as views and use a custom catalog for the extractor.


if replication_method == 'FULL_TABLE':
lookup[stream['tap_stream_id']] = 'full'
Expand Down Expand Up @@ -194,7 +193,7 @@ def sync_traditional_stream(conn_config, stream, state, sync_method, end_lsn):
return state


def sync_logical_streams(conn_config, logical_streams, state, end_lsn, state_file):
def sync_logical_streams(conn_config, logical_streams, traditional_streams, state, end_lsn, state_file):
"""
Sync streams that use LOG_BASED method
"""
Expand All @@ -212,10 +211,20 @@ def sync_logical_streams(conn_config, logical_streams, state, end_lsn, state_fil
selected_streams.add("{}".format(stream['tap_stream_id']))

new_state = dict(currently_syncing=state['currently_syncing'], bookmarks={})
traditional_stream_ids = [s['tap_stream_id'] for s in traditional_streams]

for stream, bookmark in state['bookmarks'].items():
if bookmark == {} or bookmark['last_replication_method'] != 'LOG_BASED' or stream in selected_streams:
if (
bookmark == {}
or bookmark['last_replication_method'] != 'LOG_BASED'
or stream in selected_streams
# The first time a LOG_BASED stream runs it needs to do an
# initial full table sync, and so will be treated as a
# traditional stream.
or (stream in traditional_stream_ids and bookmark['last_replication_method'] == 'LOG_BASED')
):
new_state['bookmarks'][stream] = bookmark

state = new_state

state = logical_replication.sync_tables(conn_config, logical_streams, state, end_lsn, state_file)
Expand Down Expand Up @@ -319,7 +328,7 @@ def do_sync(conn_config, catalog, default_replication_method, state, state_file=
for dbname, streams in itertools.groupby(logical_streams,
lambda s: metadata.to_map(s['metadata']).get(()).get('database-name')):
conn_config['dbname'] = dbname
state = sync_logical_streams(conn_config, list(streams), state, end_lsn, state_file)
state = sync_logical_streams(conn_config, list(streams), traditional_streams, state, end_lsn, state_file)
return state


Expand Down Expand Up @@ -405,9 +414,23 @@ def main_impl():
'debug_lsn': args.config.get('debug_lsn') == 'true',
'max_run_seconds': args.config.get('max_run_seconds', 43200),
'break_at_end_lsn': args.config.get('break_at_end_lsn', True),
'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0))
'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0)),
'use_replica': args.config.get('use_replica', False),
'resync_with_commit_timestamp': args.config.get('resync_with_commit_timestamp', False),
}

if conn_config['use_replica']:
replica_config = {
# Required replica config keys
'replica_host': args.config['replica_host'],
'replica_user': args.config['replica_user'],
'replica_password': args.config['replica_password'],
'replica_port': args.config['replica_port'],
'replica_dbname': args.config['replica_dbname'],
}

conn_config = { **conn_config, **replica_config }

if args.config.get('ssl') == 'true':
conn_config['sslmode'] = 'require'

Expand Down
25 changes: 19 additions & 6 deletions tap_postgres/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,27 @@ def fully_qualified_table_name(schema, table):
return '"{}"."{}"'.format(canonicalize_identifier(schema), canonicalize_identifier(table))


def open_connection(conn_config, logical_replication=False):
def open_connection(conn_config, logical_replication=False, primary_connection=False):
if not primary_connection and conn_config['use_replica']:
host_key = "replica_host"
dbname_key = "replica_dbname"
user_key = "replica_user"
password_key = "replica_password"
port_key = "replica_port"
else:
host_key = "host"
dbname_key = "dbname"
user_key = "user"
password_key = "password"
port_key = "port"

cfg = {
'application_name': 'pipelinewise',
'host': conn_config['host'],
'dbname': conn_config['dbname'],
'user': conn_config['user'],
'password': conn_config['password'],
'port': conn_config['port'],
'host': conn_config[host_key],
'dbname': conn_config[dbname_key],
'user': conn_config[user_key],
'password': conn_config[password_key],
'port': conn_config[port_key],
'connect_timeout': 30
}

Expand Down
53 changes: 40 additions & 13 deletions tap_postgres/sync_strategies/full_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,53 @@ def sync_table(conn_info, stream, state, desired_columns, md_map):

fq_table_name = post_db.fully_qualified_table_name(schema_name, stream['table_name'])
xmin = singer.get_bookmark(state, stream['tap_stream_id'], 'xmin')
if xmin:
LOGGER.info("Resuming Full Table replication %s from xmin %s", nascent_stream_version, xmin)
select_sql = """SELECT {}, xmin::text::bigint
FROM {} where age(xmin::xid) <= age('{}'::xid)
ORDER BY xmin::text ASC""".format(','.join(escaped_columns),
fq_table_name,
xmin)

if conn_info['resync_with_commit_timestamp']:
if xmin:
LOGGER.info("Resuming Full Table replication %s from commit timestamp %s", nascent_stream_version, xmin)
select_sql = """
SELECT {},
pg_xact_commit_timestamp(xmin) as xmin
FROM {}
WHERE pg_xact_commit_timestamp(xmin) >= '{}'
ORDER BY pg_xact_commit_timestamp(xmin) asc""".format(','.join(escaped_columns),
fq_table_name,
xmin)
else:
LOGGER.info("Beginning new Full Table replication %s", nascent_stream_version)
select_sql = """
SELECT {},
pg_xact_commit_timestamp(xmin) as xmin
FROM {}
ORDER BY pg_xact_commit_timestamp(xmin) asc""".format(','.join(escaped_columns),
fq_table_name)

else:
LOGGER.info("Beginning new Full Table replication %s", nascent_stream_version)
select_sql = """SELECT {}, xmin::text::bigint
FROM {}
ORDER BY xmin::text ASC""".format(','.join(escaped_columns),
fq_table_name)
if xmin:
LOGGER.info("Resuming Full Table replication %s from xmin %s", nascent_stream_version, xmin)
select_sql = """SELECT {}, xmin::text::bigint
FROM {} where xmin::text::bigint >= '{}'::text::bigint
ORDER BY xmin::text::bigint ASC""".format(','.join(escaped_columns),
fq_table_name,
xmin)
else:
LOGGER.info("Beginning new Full Table replication %s", nascent_stream_version)
select_sql = """SELECT {}, xmin::text::bigint
FROM {}
ORDER BY xmin::text::bigint ASC""".format(','.join(escaped_columns),
fq_table_name)

LOGGER.info("select %s with itersize %s", select_sql, cur.itersize)
cur.execute(select_sql)

rows_saved = 0
for rec in cur:
xmin = rec['xmin']

if rec['xmin'] and conn_info['resync_with_commit_timestamp']:
xmin = rec['xmin'].strftime("%Y-%m-%d")
else:
xmin = rec['xmin']

rec = rec[:-1]
record_message = post_db.selected_row_to_singer_message(stream,
rec,
Expand Down
31 changes: 25 additions & 6 deletions tap_postgres/sync_strategies/logical_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class UnsupportedPayloadKindError(Exception):

# pylint: disable=invalid-name,missing-function-docstring,too-many-branches,too-many-statements,too-many-arguments
def get_pg_version(conn_info):
with post_db.open_connection(conn_info, False) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
cur.execute("SELECT setting::int AS version FROM pg_settings WHERE name='server_version_num'")
version = cur.fetchone()[0]
Expand Down Expand Up @@ -92,7 +92,7 @@ def fetch_current_lsn(conn_config):
if version < 90400:
raise Exception('Logical replication not supported before PostgreSQL 9.4')

with post_db.open_connection(conn_config, False) as conn:
with post_db.open_connection(conn_config, False, True) as conn:
with conn.cursor() as cur:
# Use version specific lsn command
if version >= 100000:
Expand Down Expand Up @@ -137,7 +137,7 @@ def create_hstore_elem_query(elem):


def create_hstore_elem(conn_info, elem):
with post_db.open_connection(conn_info) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
query = create_hstore_elem_query(elem)
cur.execute(query)
Expand All @@ -150,7 +150,7 @@ def create_array_elem(elem, sql_datatype, conn_info):
if elem is None:
return None

with post_db.open_connection(conn_info) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
if sql_datatype == 'bit[]':
cast_datatype = 'boolean[]'
Expand Down Expand Up @@ -469,6 +469,25 @@ def consume_message(streams, state, msg, time_extracted, conn_info):
stream_md_map,
conn_info)

toast_columns = set()

if payload["kind"] in {"update"}:
desired_columns_without_automatic_columns = [
column for column in desired_columns if not column.startswith("_sdc")
]

toast_columns = set(desired_columns_without_automatic_columns).difference(
payload["columnnames"]
)

if toast_columns:
# TODO: also log the row ID that's affected, requires knowing the primary key per stream
LOGGER.info(
"Found toast columns %s for stream %s",
toast_columns,
target_stream["tap_stream_id"],
)

singer.write_message(record_message)
state = singer.write_bookmark(state, target_stream['tap_stream_id'], 'lsn', lsn)

Expand Down Expand Up @@ -516,7 +535,7 @@ def locate_replication_slot_by_cur(cursor, dbname, tap_id=None):


def locate_replication_slot(conn_info):
with post_db.open_connection(conn_info, False) as conn:
with post_db.open_connection(conn_info, False, True) as conn:
with conn.cursor() as cur:
return locate_replication_slot_by_cur(cur, conn_info['dbname'], conn_info['tap_id'])

Expand Down Expand Up @@ -575,7 +594,7 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file):
version = get_pg_version(conn_info)

# Create replication connection and cursor
conn = post_db.open_connection(conn_info, True)
conn = post_db.open_connection(conn_info, True, True)
cur = conn.cursor()

# Set session wal_sender_timeout for PG12 and above
Expand Down