diff --git a/Makefile b/Makefile index bb02e4c..e8b4ede 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ long_ver = $(shell git describe --long 2>/dev/null || echo $(short_ver)-0-unknow generated = pglookout/version.py # Only include files that have been typed. -typed = pglookout/__main__.py +typed = pglookout/pgutil.py test/test_pgutil.py # Flake8 ignores: # E722: https://www.flake8rules.com/rules/E722.html Do not use bare except, specify exception instead diff --git a/pglookout/pgutil.py b/pglookout/pgutil.py index 3b8562b..6a8c2b1 100644 --- a/pglookout/pgutil.py +++ b/pglookout/pgutil.py @@ -5,16 +5,80 @@ Copyright (c) 2015 Ohmu Ltd See LICENSE for details """ +from __future__ import annotations + +from typing import cast, Literal, TypedDict from urllib.parse import parse_qs, urlparse # pylint: disable=no-name-in-module, import-error import psycopg2.extensions -def create_connection_string(connection_info): +class DsnDictBase(TypedDict, total=False): + user: str + password: str + host: str + port: str | int + + +class DsnDict(DsnDictBase, total=False): + dbname: str + + +class DsnDictDeprecated(DsnDictBase, total=False): + database: str + + +class ConnectionParameterKeywords(TypedDict, total=False): + """Parameter Keywords for Connection. + + See: + https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS + """ + + host: str + hostaddr: str + port: str + dbname: str + user: str + password: str + passfile: str + channel_binding: Literal["require", "prefer", "disable"] + connect_timeout: str + client_encoding: str + options: str + application_name: str + fallback_application_name: str + keepalives: Literal["0", "1"] + keepalives_idle: str + keepalives_interval: str + keepalives_count: str + tcp_user_timeout: str + replication: Literal["true", "on", "yes", "1", "database", "false", "off", "no", "0"] + gssencmode: Literal["disable", "prefer", "require"] + sslmode: Literal["disable", "allow", "prefer", "require", "verify-ca", "verify-full"] + requiressl: Literal["0", "1"] + sslcompression: Literal["0", "1"] + sslcert: str + sslkey: str + sslpassword: str + sslrootcert: str + sslcrl: str + sslcrldir: str + sslsni: Literal["0", "1"] + requirepeer: str + ssl_min_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"] + ssl_max_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"] + krbsrvname: str + gsslib: str + service: str + target_session_attrs: Literal["any", "read-write", "read-only", "primary", "standby", "prefer-standby"] + + +def create_connection_string(connection_info: DsnDict | DsnDictDeprecated | ConnectionParameterKeywords) -> str: return psycopg2.extensions.make_dsn(**connection_info) -def mask_connection_info(info): +def mask_connection_info(info: str) -> str: masked_info = get_connection_info(info) password = masked_info.pop("password", None) connection_string = create_connection_string(masked_info) @@ -22,24 +86,28 @@ def mask_connection_info(info): return f"{connection_string}; {message}" -def get_connection_info_from_config_line(line): +def get_connection_info_from_config_line(line: str) -> ConnectionParameterKeywords: _, value = line.split("=", 1) value = value.strip()[1:-1].replace("''", "'") return get_connection_info(value) -def get_connection_info(info): +def get_connection_info( + info: str | DsnDict | DsnDictDeprecated | ConnectionParameterKeywords, +) -> ConnectionParameterKeywords: """turn a connection info object into a dict or return it if it was a dict already. supports both the traditional libpq format and the new url format""" if isinstance(info, dict): - return info.copy() + # Potentially, we might clean deprecated DSN dicts: `database` -> `dbname`. + # Also, psycopg2 will validate the keys and values. + return parse_connection_string_libpq(create_connection_string(info)) if info.startswith("postgres://") or info.startswith("postgresql://"): return parse_connection_string_url(info) return parse_connection_string_libpq(info) -def parse_connection_string_url(url): +def parse_connection_string_url(url: str) -> ConnectionParameterKeywords: # drop scheme from the url as some versions of urlparse don't handle # query and path properly for urls with a non-http scheme schemeless_url = url.split(":", 1)[1] @@ -57,10 +125,10 @@ def parse_connection_string_url(url): fields["dbname"] = p.path[1:] for k, v in parse_qs(p.query).items(): fields[k] = v[-1] - return fields + return cast(ConnectionParameterKeywords, fields) -def parse_connection_string_libpq(connection_string): +def parse_connection_string_libpq(connection_string: str) -> ConnectionParameterKeywords: """parse a postgresql connection string as defined in http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING""" fields = {} @@ -92,5 +160,8 @@ def parse_connection_string_libpq(connection_string): value, connection_string = res else: value, connection_string = rem, "" + # This one is case-insensitive. To continue benefiting from mypy, we make it lowercase. + if key == "replication": + value = value.lower() fields[key] = value - return fields + return cast(ConnectionParameterKeywords, fields) diff --git a/test/test_pgutil.py b/test/test_pgutil.py index d70e6f5..7dc969f 100644 --- a/test/test_pgutil.py +++ b/test/test_pgutil.py @@ -6,14 +6,14 @@ See LICENSE for details """ -from pglookout.pgutil import create_connection_string, get_connection_info, mask_connection_info +from pglookout.pgutil import ConnectionParameterKeywords, create_connection_string, get_connection_info, mask_connection_info from pytest import raises -def test_connection_info(): +def test_connection_info() -> None: url = "postgres://hannu:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require" cs = "host=dbhost.local user='hannu' dbname='abc'\nreplication=true password=secret sslmode=require port=5555" - ci = { + ci: ConnectionParameterKeywords = { "host": "dbhost.local", "port": "5555", "user": "hannu", @@ -39,7 +39,7 @@ def test_connection_info(): get_connection_info("foo=bar bar='x") -def test_mask_connection_info(): +def test_mask_connection_info() -> None: url = "postgres://michael:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require" cs = "host=dbhost.local user='michael' dbname='abc'\nreplication=true password=secret sslmode=require port=5555" ci = get_connection_info(cs)