Skip to content

Commit

Permalink
feat: Implement automatic execution of rollback scripts
Browse files Browse the repository at this point in the history
See #37 and #6 (comment) for details

Tests -
- Added test cases for rollbacker.py

Refactor -
- Split the rollbacker.rollback() function

Style -
- Formatted main.py, os_utils.py, rollbacker.py, and rollbacker_test.py using isort and black
  - `isort --python-version 310 --profile google --line-length 80 --src-path "src/code"`
  - `black --line-length 80 --target-version py310 --target-version py311 --target-version py312`
  • Loading branch information
Tql-ws1 committed Jan 16, 2025
1 parent 7d00a87 commit 9c44eb8
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 27 deletions.
82 changes: 64 additions & 18 deletions src/code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
from typing import Iterable


from . import batch_deleter
from . import colored_logs
from . import configs
Expand All @@ -38,7 +37,9 @@ def _parse_args() -> argparse.Namespace:
action="store_true",
)
parser.add_argument("--config-file", help="Specify a config file to use.")
parser.add_argument("--source", help="Restrict to config with this source path.")
parser.add_argument(
"--source", help="Restrict to config with this source path."
)
parser.add_argument(
"--dry-run",
help="If passed, will disable all snapshot creation and deletion.",
Expand All @@ -50,23 +51,29 @@ def _parse_args() -> argparse.Namespace:

# title - Shows as [title]: before commands are listed.
# metavar - The string is printed below the title. If None, all commands including hidden ones are printed.
subparsers = parser.add_subparsers(dest="command", title="command", metavar="")
subparsers = parser.add_subparsers(
dest="command", title="command", metavar=""
)

# A message to be added to help for all commands that honor --source or --config-file.
source_message = " Optionally use with --source or --config-file."

# Creates a new config by NAME.
create_config = subparsers.add_parser(
"create-config", help="Bootstrap a config for new filesystem to snapshot."
"create-config",
help="Bootstrap a config for new filesystem to snapshot.",
)
create_config.add_argument(
"config_name", help='Name to be given to config file, e.g. "home".'
)

# User commands.
subparsers.add_parser("list", help="List all managed snaps." + source_message)
subparsers.add_parser(
"list-json", help="Machine readable list of all managed snaps." + source_message
"list", help="List all managed snaps." + source_message
)
subparsers.add_parser(
"list-json",
help="Machine readable list of all managed snaps." + source_message,
)

# Creates an user snapshot.
Expand Down Expand Up @@ -121,11 +128,27 @@ def _parse_args() -> argparse.Namespace:
"rollback-gen",
help="Generate script to rollback one or more snaps." + source_message,
)
rollback.add_argument(
"--execute",
action="store_true",
help="Generate rollback script and execute.",
)

execute_rollback = subparsers.add_parser(
"rollback",
help="Generate rollback script and run. Equivalent to `rollback-gen --execute`",
)
execute_rollback.add_argument(
"--noconfirm",
action="store_true",
help="Execute the rollback script without confirmation.",
)

for command_with_target in [delete, rollback, set_ttl]:
for command_with_target in [delete, rollback, execute_rollback, set_ttl]:
command_with_target.add_argument(
"target_suffix",
help="Datetime string, or full path of a snapshot." + source_message,
help="Datetime string, or full path of a snapshot."
+ source_message,
)

# Internal commands used in scheduling and pacman hook.
Expand All @@ -138,23 +161,27 @@ def _parse_args() -> argparse.Namespace:


def _sync(configs_to_sync: list[configs.Config]):
paths_to_sync: dict[snap_mechanisms.SnapType, set[str]] = collections.defaultdict(
set
paths_to_sync: dict[snap_mechanisms.SnapType, set[str]] = (
collections.defaultdict(set)
)
for config in configs_to_sync:
paths_to_sync[config.snap_type].add(config.mount_path)
for snap_type, paths in sorted(paths_to_sync.items()):
snap_mechanisms.get(snap_type).sync_paths(paths)


def _set_ttl(configs_iter: Iterable[configs.Config], path_suffix: str, ttl_str: str):
def _set_ttl(
configs_iter: Iterable[configs.Config], path_suffix: str, ttl_str: str
):
for config in configs_iter:
snap = snap_operator.find_target(config, path_suffix)
if snap:
snap.set_ttl(ttl_str, now=datetime.datetime.now())


def _delete_snap(configs_iter: Iterable[configs.Config], path_suffix: str, sync: bool):
def _delete_snap(
configs_iter: Iterable[configs.Config], path_suffix: str, sync: bool
):
to_sync: list[configs.Config] = []
for config in configs_iter:
snap = snap_operator.find_target(config, path_suffix)
Expand Down Expand Up @@ -184,20 +211,27 @@ def _batch_delete_snaps(
filters = batch_deleter.get_filters(args_as_dict)

targets = list(
batch_deleter.apply_snapshot_filters(config_snaps_mapping_tuple, *filters)
batch_deleter.apply_snapshot_filters(
config_snaps_mapping_tuple, *filters
)
)
if sum(len(mapping.snaps) for mapping in targets) == 0:
os_utils.eprint("No snapshots matching the criteria were found.")
return

batch_deleter.show_snapshots_to_be_deleted(targets)

if batch_deleter.interactive_confirm():
snaps = itertools.chain.from_iterable(mapping.snaps for mapping in targets)
msg = "Are you sure you want to delete the above snapshots? [y/N] "
if os_utils.interactive_confirm(msg):
snaps = itertools.chain.from_iterable(
mapping.snaps for mapping in targets
)
batch_deleter.delete_snapshots(snaps)

if sync:
to_sync = batch_deleter.get_to_sync_list(mapping.config for mapping in targets)
to_sync = batch_deleter.get_to_sync_list(
mapping.config for mapping in targets
)
_sync(to_sync)


Expand Down Expand Up @@ -247,7 +281,9 @@ def main():
global_flags.FLAGS.dryrun = True
configs.USER_CONFIG_FILE = args.config_file

colored_logs.setup_logging(level=logging.INFO if args.verbose else logging.WARNING)
colored_logs.setup_logging(
level=logging.INFO if args.verbose else logging.WARNING
)

if configs.is_schedule_enabled() and not os_utils.timer_enabled():
os_utils.eprint(
Expand Down Expand Up @@ -282,10 +318,20 @@ def main():
args=args,
sync=args.sync,
)
elif command == "rollback":
rollbacker.rollback(
configs.iterate_configs(source=args.source),
args.target_suffix,
execute=True,
no_confirm=args.noconfirm,
)
elif command == "rollback-gen":
rollbacker.rollback(
configs.iterate_configs(source=args.source), args.target_suffix
configs.iterate_configs(source=args.source),
args.target_suffix,
execute=args.execute,
)
# Does `rollback-gen` subcommand require the optional parameter `--nocofirm` ?
else:
comment = getattr(args, "comment", "")
_config_operation(
Expand Down
15 changes: 13 additions & 2 deletions src/code/os_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import re
import subprocess
import sys

from typing import Any


Expand Down Expand Up @@ -49,7 +48,9 @@ def run_user_script(script_name: str, args: list[str]) -> bool:
logging.warning(f"User script {script_name=} does not exist.")
return False
except subprocess.CalledProcessError:
logging.warning(f"User script {script_name=} with {args=} resulted in error.")
logging.warning(
f"User script {script_name=} with {args=} resulted in error."
)
return False
return True

Expand Down Expand Up @@ -80,3 +81,13 @@ def timer_enabled() -> bool:
def eprint(*args: Any, **kwargs: Any) -> None:
"""Notifications meant for user, but not for redirection to any file."""
print(*args, file=sys.stderr, **kwargs)


def interactive_confirm(msg: str) -> bool:
user_choice = input(msg)
match user_choice:
case "y" | "Y" | "yes" | "Yes" | "YES":
return True
case _:
print("Aborted.")
return False
84 changes: 77 additions & 7 deletions src/code/rollbacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,93 @@
# limitations under the License.

import collections
import logging
import os
import subprocess
from typing import Iterable

from . import configs
from . import os_utils
from . import snap_operator
from .mechanisms import snap_mechanisms

from typing import Iterable

def _find_and_categorize_snaps(
configs_iter: Iterable[configs.Config], path_suffix: str
) -> dict[snap_mechanisms.SnapType, list[tuple[str, str]]]:
"""Find snapshots in the configuration file and categorize them based on their type."""
source_dests_by_snaptype: dict[
snap_mechanisms.SnapType, list[tuple[str, str]]
] = collections.defaultdict(list)

def rollback(configs_iter: Iterable[configs.Config], path_suffix: str):
source_dests_by_snaptype: dict[snap_mechanisms.SnapType, list[tuple[str, str]]] = (
collections.defaultdict(list)
)
for config in configs_iter:
snap = snap_operator.find_target(config, path_suffix)
if snap:
source_dests_by_snaptype[config.snap_type].append(
(snap.metadata.source, snap.target)
)
for snap_type, source_dests in sorted(source_dests_by_snaptype.items()):
print("\n".join(snap_mechanisms.get(snap_type).rollback_gen(source_dests)))

return source_dests_by_snaptype


_ROLLBACK_SCRIPT_FILEPATH = "/tmp/rollback.sh"


def rollback(
configs_iter: Iterable[configs.Config],
path_suffix: str,
*,
execute: bool = False,
no_confirm: bool = False,
) -> None:
source_dests_by_snaptype = _find_and_categorize_snaps(
configs_iter, path_suffix
)
contents = _show_and_return_rollback_gen(source_dests_by_snaptype)
_create_and_chmod_script(contents)

if execute and no_confirm:
subprocess.run([".", _ROLLBACK_SCRIPT_FILEPATH])
return

if execute is True:
msg = "Review the code and enter 'y' to confirm execution. [y/N] "
confirm = os_utils.interactive_confirm(msg)
if confirm:
subprocess.run([".", _ROLLBACK_SCRIPT_FILEPATH])


def _create_and_chmod_script(contents: list[str]) -> None:
with open(_ROLLBACK_SCRIPT_FILEPATH, mode="w", encoding="utf_8") as fp:
fp.writelines(contents)
logging.info(
f"The rollback script is saved in {_ROLLBACK_SCRIPT_FILEPATH} ."
)

os.chmod(_ROLLBACK_SCRIPT_FILEPATH, mode=0o700)
if os.access(_ROLLBACK_SCRIPT_FILEPATH, mode=os.X_OK) is True:
logging.info(
f"Execution permissions have been granted for {_ROLLBACK_SCRIPT_FILEPATH} ."
)
else:
logging.warning(
f"Manual execution permissions need to be added for {_ROLLBACK_SCRIPT_FILEPATH} ."
)


def _show_and_return_rollback_gen(
source_dests_by_snaptype: dict[
snap_mechanisms.SnapType, list[tuple[str, str]]
],
) -> list[str]:
contents = [
"\n".join(snap_mechanisms.get(snap_type).rollback_gen(source_dests))
for snap_type, source_dests in sorted(source_dests_by_snaptype.items())
]

print("=== THE FOLLOWING IS THE ROLLBACK CODE ===")
for content in contents:
print(content)
print("=== THE ABOVE IS THE ROLLBACK CODE ===")

return contents
22 changes: 22 additions & 0 deletions src/code/rollbacker_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import unittest

from . import rollbacker

# For testing, we can access private methods.
# pyright: reportPrivateUsage=false


class TestCreateAndChmodScript(unittest.TestCase):
def test_create_and_chmod(self):
contents = [
"This file is used for unit testing the yabsnap rollbacker module, generated from yabsnap."
]
rollbacker._create_and_chmod_script(contents)
self.assertTrue(os.path.exists(rollbacker._ROLLBACK_SCRIPT_FILEPATH))
self.assertTrue(
os.access(rollbacker._ROLLBACK_SCRIPT_FILEPATH, mode=os.X_OK)
)

def tearDown(self):
os.remove(rollbacker._ROLLBACK_SCRIPT_FILEPATH)

0 comments on commit 9c44eb8

Please sign in to comment.