From d6c572db58cd90fdf015c99da1275e6bbb501791 Mon Sep 17 00:00:00 2001 From: thR CIrcU5 <141405263+Tql-ws1@users.noreply.github.com> Date: Thu, 16 Jan 2025 21:34:49 +0800 Subject: [PATCH] feat: Implement automatic execution of rollback scripts See #37 and https://github.com/hirak99/yabsnap/issues/6#issuecomment-2585347835 for details Tests - - Added test cases for rollbacker.py Refactor - - Split the rollbacker.rollback() function Style - - Formatted main.py, os_utils.py, rollbacker.py, and rollbacker_test.py using isort and black - `isort --python-version 310 --profile google --line-length 80 --src-path "src/code"` - `black --line-length 80 --target-version py310 --target-version py311 --target-version py312` --- src/code/main.py | 82 ++++++++++++++++++++++++++++-------- src/code/os_utils.py | 15 ++++++- src/code/rollbacker.py | 84 +++++++++++++++++++++++++++++++++---- src/code/rollbacker_test.py | 22 ++++++++++ 4 files changed, 176 insertions(+), 27 deletions(-) create mode 100644 src/code/rollbacker_test.py diff --git a/src/code/main.py b/src/code/main.py index ae56d1b..84bbeb5 100644 --- a/src/code/main.py +++ b/src/code/main.py @@ -19,7 +19,6 @@ import logging from typing import Iterable - from . import batch_deleter from . import colored_logs from . import configs @@ -38,7 +37,9 @@ def _parse_args() -> argparse.Namespace: action="store_true", ) parser.add_argument("--config-file", help="Specify a config file to use.") - parser.add_argument("--source", help="Restrict to config with this source path.") + parser.add_argument( + "--source", help="Restrict to config with this source path." + ) parser.add_argument( "--dry-run", help="If passed, will disable all snapshot creation and deletion.", @@ -50,23 +51,29 @@ def _parse_args() -> argparse.Namespace: # title - Shows as [title]: before commands are listed. # metavar - The string is printed below the title. If None, all commands including hidden ones are printed. - subparsers = parser.add_subparsers(dest="command", title="command", metavar="") + subparsers = parser.add_subparsers( + dest="command", title="command", metavar="" + ) # A message to be added to help for all commands that honor --source or --config-file. source_message = " Optionally use with --source or --config-file." # Creates a new config by NAME. create_config = subparsers.add_parser( - "create-config", help="Bootstrap a config for new filesystem to snapshot." + "create-config", + help="Bootstrap a config for new filesystem to snapshot.", ) create_config.add_argument( "config_name", help='Name to be given to config file, e.g. "home".' ) # User commands. - subparsers.add_parser("list", help="List all managed snaps." + source_message) subparsers.add_parser( - "list-json", help="Machine readable list of all managed snaps." + source_message + "list", help="List all managed snaps." + source_message + ) + subparsers.add_parser( + "list-json", + help="Machine readable list of all managed snaps." + source_message, ) # Creates an user snapshot. @@ -121,11 +128,27 @@ def _parse_args() -> argparse.Namespace: "rollback-gen", help="Generate script to rollback one or more snaps." + source_message, ) + rollback.add_argument( + "--execute", + action="store_true", + help="Generate rollback script and execute.", + ) + + execute_rollback = subparsers.add_parser( + "rollback", + help="Generate rollback script and run. Equivalent to `rollback-gen --execute`", + ) + execute_rollback.add_argument( + "--noconfirm", + action="store_true", + help="Execute the rollback script without confirmation.", + ) - for command_with_target in [delete, rollback, set_ttl]: + for command_with_target in [delete, rollback, execute_rollback, set_ttl]: command_with_target.add_argument( "target_suffix", - help="Datetime string, or full path of a snapshot." + source_message, + help="Datetime string, or full path of a snapshot." + + source_message, ) # Internal commands used in scheduling and pacman hook. @@ -138,8 +161,8 @@ def _parse_args() -> argparse.Namespace: def _sync(configs_to_sync: list[configs.Config]): - paths_to_sync: dict[snap_mechanisms.SnapType, set[str]] = collections.defaultdict( - set + paths_to_sync: dict[snap_mechanisms.SnapType, set[str]] = ( + collections.defaultdict(set) ) for config in configs_to_sync: paths_to_sync[config.snap_type].add(config.mount_path) @@ -147,14 +170,18 @@ def _sync(configs_to_sync: list[configs.Config]): snap_mechanisms.get(snap_type).sync_paths(paths) -def _set_ttl(configs_iter: Iterable[configs.Config], path_suffix: str, ttl_str: str): +def _set_ttl( + configs_iter: Iterable[configs.Config], path_suffix: str, ttl_str: str +): for config in configs_iter: snap = snap_operator.find_target(config, path_suffix) if snap: snap.set_ttl(ttl_str, now=datetime.datetime.now()) -def _delete_snap(configs_iter: Iterable[configs.Config], path_suffix: str, sync: bool): +def _delete_snap( + configs_iter: Iterable[configs.Config], path_suffix: str, sync: bool +): to_sync: list[configs.Config] = [] for config in configs_iter: snap = snap_operator.find_target(config, path_suffix) @@ -184,7 +211,9 @@ def _batch_delete_snaps( filters = batch_deleter.get_filters(args_as_dict) targets = list( - batch_deleter.apply_snapshot_filters(config_snaps_mapping_tuple, *filters) + batch_deleter.apply_snapshot_filters( + config_snaps_mapping_tuple, *filters + ) ) if sum(len(mapping.snaps) for mapping in targets) == 0: os_utils.eprint("No snapshots matching the criteria were found.") @@ -192,12 +221,17 @@ def _batch_delete_snaps( batch_deleter.show_snapshots_to_be_deleted(targets) - if batch_deleter.interactive_confirm(): - snaps = itertools.chain.from_iterable(mapping.snaps for mapping in targets) + msg = "Are you sure you want to delete the above snapshots? [y/N] " + if os_utils.interactive_confirm(msg): + snaps = itertools.chain.from_iterable( + mapping.snaps for mapping in targets + ) batch_deleter.delete_snapshots(snaps) if sync: - to_sync = batch_deleter.get_to_sync_list(mapping.config for mapping in targets) + to_sync = batch_deleter.get_to_sync_list( + mapping.config for mapping in targets + ) _sync(to_sync) @@ -247,7 +281,9 @@ def main(): global_flags.FLAGS.dryrun = True configs.USER_CONFIG_FILE = args.config_file - colored_logs.setup_logging(level=logging.INFO if args.verbose else logging.WARNING) + colored_logs.setup_logging( + level=logging.INFO if args.verbose else logging.WARNING + ) if configs.is_schedule_enabled() and not os_utils.timer_enabled(): os_utils.eprint( @@ -282,10 +318,20 @@ def main(): args=args, sync=args.sync, ) + elif command == "rollback": + rollbacker.rollback( + configs.iterate_configs(source=args.source), + args.target_suffix, + execute=True, + no_confirm=args.noconfirm, + ) elif command == "rollback-gen": rollbacker.rollback( - configs.iterate_configs(source=args.source), args.target_suffix + configs.iterate_configs(source=args.source), + args.target_suffix, + execute=args.execute, ) + # Does `rollback-gen` subcommand require the optional parameter `--nocofirm` ? else: comment = getattr(args, "comment", "") _config_operation( diff --git a/src/code/os_utils.py b/src/code/os_utils.py index 7751a5e..b2cbefe 100644 --- a/src/code/os_utils.py +++ b/src/code/os_utils.py @@ -17,7 +17,6 @@ import re import subprocess import sys - from typing import Any @@ -49,7 +48,9 @@ def run_user_script(script_name: str, args: list[str]) -> bool: logging.warning(f"User script {script_name=} does not exist.") return False except subprocess.CalledProcessError: - logging.warning(f"User script {script_name=} with {args=} resulted in error.") + logging.warning( + f"User script {script_name=} with {args=} resulted in error." + ) return False return True @@ -80,3 +81,13 @@ def timer_enabled() -> bool: def eprint(*args: Any, **kwargs: Any) -> None: """Notifications meant for user, but not for redirection to any file.""" print(*args, file=sys.stderr, **kwargs) + + +def interactive_confirm(msg: str) -> bool: + user_choice = input(msg) + match user_choice: + case "y" | "Y" | "yes" | "Yes" | "YES": + return True + case _: + print("Aborted.") + return False diff --git a/src/code/rollbacker.py b/src/code/rollbacker.py index a914c38..6eb4c8a 100644 --- a/src/code/rollbacker.py +++ b/src/code/rollbacker.py @@ -13,23 +13,93 @@ # limitations under the License. import collections +import logging +import os +import subprocess +from typing import Iterable from . import configs +from . import os_utils from . import snap_operator from .mechanisms import snap_mechanisms -from typing import Iterable +def _find_and_categorize_snaps( + configs_iter: Iterable[configs.Config], path_suffix: str +) -> dict[snap_mechanisms.SnapType, list[tuple[str, str]]]: + """Find snapshots in the configuration file and categorize them based on their type.""" + source_dests_by_snaptype: dict[ + snap_mechanisms.SnapType, list[tuple[str, str]] + ] = collections.defaultdict(list) -def rollback(configs_iter: Iterable[configs.Config], path_suffix: str): - source_dests_by_snaptype: dict[snap_mechanisms.SnapType, list[tuple[str, str]]] = ( - collections.defaultdict(list) - ) for config in configs_iter: snap = snap_operator.find_target(config, path_suffix) if snap: source_dests_by_snaptype[config.snap_type].append( (snap.metadata.source, snap.target) ) - for snap_type, source_dests in sorted(source_dests_by_snaptype.items()): - print("\n".join(snap_mechanisms.get(snap_type).rollback_gen(source_dests))) + + return source_dests_by_snaptype + + +_ROLLBACK_SCRIPT_FILEPATH = "/tmp/rollback.sh" + + +def rollback( + configs_iter: Iterable[configs.Config], + path_suffix: str, + *, + execute: bool = False, + no_confirm: bool = False, +) -> None: + source_dests_by_snaptype = _find_and_categorize_snaps( + configs_iter, path_suffix + ) + contents = _show_and_return_rollback_gen(source_dests_by_snaptype) + _create_and_chmod_script(contents) + + if execute and no_confirm: + subprocess.run([".", _ROLLBACK_SCRIPT_FILEPATH]) + return + + if execute is True: + msg = "Review the code and enter 'y' to confirm execution. [y/N] " + confirm = os_utils.interactive_confirm(msg) + if confirm: + subprocess.run([".", _ROLLBACK_SCRIPT_FILEPATH]) + + +def _create_and_chmod_script(contents: list[str]) -> None: + with open(_ROLLBACK_SCRIPT_FILEPATH, mode="w", encoding="utf_8") as fp: + fp.writelines(contents) + logging.info( + f"The rollback script is saved in {_ROLLBACK_SCRIPT_FILEPATH} ." + ) + + os.chmod(_ROLLBACK_SCRIPT_FILEPATH, mode=0o700) + if os.access(_ROLLBACK_SCRIPT_FILEPATH, mode=os.X_OK) is True: + logging.info( + f"Execution permissions have been granted for {_ROLLBACK_SCRIPT_FILEPATH} ." + ) + else: + logging.warning( + f"Manual execution permissions need to be added for {_ROLLBACK_SCRIPT_FILEPATH} ." + ) + + +def _show_and_return_rollback_gen( + source_dests_by_snaptype: dict[ + snap_mechanisms.SnapType, list[tuple[str, str]] + ], +) -> list[str]: + contents = [ + "\n".join(snap_mechanisms.get(snap_type).rollback_gen(source_dests)) + for snap_type, source_dests in sorted(source_dests_by_snaptype.items()) + ] + + print("=== THE FOLLOWING IS THE ROLLBACK CODE ===") + for content in contents: + print(content) + print("=== THE ABOVE IS THE ROLLBACK CODE ===") + + return contents diff --git a/src/code/rollbacker_test.py b/src/code/rollbacker_test.py new file mode 100644 index 0000000..fd3b819 --- /dev/null +++ b/src/code/rollbacker_test.py @@ -0,0 +1,22 @@ +import os +import unittest + +from . import rollbacker + +# For testing, we can access private methods. +# pyright: reportPrivateUsage=false + + +class TestCreateAndChmodScript(unittest.TestCase): + def test_create_and_chmod(self): + contents = [ + "This file is used for unit testing the yabsnap rollbacker module, generated from yabsnap." + ] + rollbacker._create_and_chmod_script(contents) + self.assertTrue(os.path.exists(rollbacker._ROLLBACK_SCRIPT_FILEPATH)) + self.assertTrue( + os.access(rollbacker._ROLLBACK_SCRIPT_FILEPATH, mode=os.X_OK) + ) + + def tearDown(self): + os.remove(rollbacker._ROLLBACK_SCRIPT_FILEPATH)