diff --git a/ogion/backup_targets/postgresql.py b/ogion/backup_targets/postgresql.py index f238306..2fe80a6 100644 --- a/ogion/backup_targets/postgresql.py +++ b/ogion/backup_targets/postgresql.py @@ -62,10 +62,20 @@ def _get_escaped_conn_uri(self) -> str: pgpass_file = self._init_pgpass_file() encoded_user = urllib.parse.quote_plus(self.target_model.user) encoded_db = urllib.parse.quote_plus(self.target_model.db) + + params = {"passfile": pgpass_file} + if self.target_model.model_extra is not None: + for param, value in self.target_model.model_extra.items(): + params[param] = value + + log.debug("psql connection params: %s", params) + uri = ( f"postgresql://{encoded_user}@{self.target_model.host}:{self.target_model.port}/{encoded_db}?" - f"passfile={pgpass_file}" - ) + ) + urllib.parse.urlencode(params) + + log.debug("psql connection url: %s", uri) + escaped_uri = shlex.quote(uri) return escaped_uri diff --git a/ogion/core.py b/ogion/core.py index 83cdd9e..c3d0050 100644 --- a/ogion/core.py +++ b/ogion/core.py @@ -21,6 +21,7 @@ SAFE_LETTER_PATTERN = re.compile(r"[^A-Za-z0-9_]*") DATETIME_BACKUP_FILE_PATTERN = re.compile(r"_[0-9]{8}_[0-9]{4}_") +MODEL_SPLIT_EQUATION_PATTERN = re.compile(r"( \w*\=|^\w*\=)") _BM = TypeVar("_BM", bound=BaseModel) @@ -116,26 +117,24 @@ def _validate_model( env_name: str, env_value: str, target: type[_BM], - value_whitespace_split: bool = False, ) -> _BM: target_name: str = target.__name__.lower() log.info("validating %s variable: `%s`", target_name, env_name) log.debug("%s=%s", target_name, env_value) try: env_value_parts = env_value.strip() + fields_matches = [ + match.group() + for match in MODEL_SPLIT_EQUATION_PATTERN.finditer(env_value_parts) + ] target_kwargs: dict[str, Any] = {"env_name": env_name} - for field_name in target.model_fields.keys(): - if env_value_parts.startswith(f"{field_name}="): - f = f"{field_name}=" - else: - f = f" {field_name}=" - if f in env_value_parts: - _, val = env_value_parts.split(f, maxsplit=1) - for other_field in target.model_fields.keys(): - val = val.split(f" {other_field}=")[0] - if value_whitespace_split: - val = val.split()[0] - target_kwargs[field_name] = val + + while fields_matches: + field_match = fields_matches.pop() + rest, value = env_value_parts.split(field_match, maxsplit=1) + env_value_parts = rest.rstrip() + target_kwargs[field_match.removesuffix("=").strip()] = value + log.debug("calculated arguments: %s", target_kwargs) validated_target = target.model_validate(target_kwargs) except Exception: @@ -173,7 +172,6 @@ def create_provider_model() -> upload_provider_models.ProviderModel: "backup_provider", config.options.BACKUP_PROVIDER, upload_provider_models.ProviderModel, - value_whitespace_split=True, ) target_model_cls = provider_map[base_provider.name] return _validate_model( diff --git a/ogion/models/backup_target_models.py b/ogion/models/backup_target_models.py index af10c4d..66accca 100644 --- a/ogion/models/backup_target_models.py +++ b/ogion/models/backup_target_models.py @@ -45,6 +45,10 @@ class PostgreSQLTargetModel(TargetModel): db: str = "postgres" password: SecretStr + model_config = ConfigDict( + extra="allow", + ) + class MariaDBTargetModel(TargetModel): name: config.BackupTargetEnum = config.BackupTargetEnum.MARIADB @@ -54,6 +58,10 @@ class MariaDBTargetModel(TargetModel): db: str = "mariadb" password: SecretStr + model_config = ConfigDict( + extra="allow", + ) + class SingleFileTargetModel(TargetModel): name: config.BackupTargetEnum = config.BackupTargetEnum.FILE diff --git a/tests/test_core.py b/tests/test_core.py index d526348..f024a05 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -195,6 +195,16 @@ def test_run_create_age_archive_can_be_decrypted( ], True, ), + ( + [ + ( + "POSTGRESQL_FIRST_DB", + "host=localhost port5432 password=secret " + "cron_rule=* * * * * ssl_mode=require", + ), + ], + True, + ), ]