Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement automatic execution of rollback scripts #44

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 64 additions & 18 deletions src/code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
from typing import Iterable


from . import batch_deleter
from . import colored_logs
from . import configs
Expand All @@ -38,7 +37,9 @@ def _parse_args() -> argparse.Namespace:
action="store_true",
)
parser.add_argument("--config-file", help="Specify a config file to use.")
parser.add_argument("--source", help="Restrict to config with this source path.")
parser.add_argument(
"--source", help="Restrict to config with this source path."
)
parser.add_argument(
"--dry-run",
help="If passed, will disable all snapshot creation and deletion.",
Expand All @@ -50,23 +51,29 @@ def _parse_args() -> argparse.Namespace:

# title - Shows as [title]: before commands are listed.
# metavar - The string is printed below the title. If None, all commands including hidden ones are printed.
subparsers = parser.add_subparsers(dest="command", title="command", metavar="")
subparsers = parser.add_subparsers(
dest="command", title="command", metavar=""
)

# A message to be added to help for all commands that honor --source or --config-file.
source_message = " Optionally use with --source or --config-file."

# Creates a new config by NAME.
create_config = subparsers.add_parser(
"create-config", help="Bootstrap a config for new filesystem to snapshot."
"create-config",
help="Bootstrap a config for new filesystem to snapshot.",
)
create_config.add_argument(
"config_name", help='Name to be given to config file, e.g. "home".'
)

# User commands.
subparsers.add_parser("list", help="List all managed snaps." + source_message)
subparsers.add_parser(
"list-json", help="Machine readable list of all managed snaps." + source_message
"list", help="List all managed snaps." + source_message
)
subparsers.add_parser(
"list-json",
help="Machine readable list of all managed snaps." + source_message,
)

# Creates an user snapshot.
Expand Down Expand Up @@ -121,11 +128,27 @@ def _parse_args() -> argparse.Namespace:
"rollback-gen",
help="Generate script to rollback one or more snaps." + source_message,
)
rollback.add_argument(
"--execute",
action="store_true",
help="Generate rollback script and execute.",
)

execute_rollback = subparsers.add_parser(
"rollback",
help="Generate rollback script and run. Equivalent to `rollback-gen --execute`",
)
execute_rollback.add_argument(
"--noconfirm",
action="store_true",
help="Execute the rollback script without confirmation.",
)

for command_with_target in [delete, rollback, set_ttl]:
for command_with_target in [delete, rollback, execute_rollback, set_ttl]:
command_with_target.add_argument(
"target_suffix",
help="Datetime string, or full path of a snapshot." + source_message,
help="Datetime string, or full path of a snapshot."
+ source_message,
)

# Internal commands used in scheduling and pacman hook.
Expand All @@ -138,23 +161,27 @@ def _parse_args() -> argparse.Namespace:


def _sync(configs_to_sync: list[configs.Config]):
paths_to_sync: dict[snap_mechanisms.SnapType, set[str]] = collections.defaultdict(
set
paths_to_sync: dict[snap_mechanisms.SnapType, set[str]] = (
collections.defaultdict(set)
)
for config in configs_to_sync:
paths_to_sync[config.snap_type].add(config.mount_path)
for snap_type, paths in sorted(paths_to_sync.items()):
snap_mechanisms.get(snap_type).sync_paths(paths)


def _set_ttl(configs_iter: Iterable[configs.Config], path_suffix: str, ttl_str: str):
def _set_ttl(
configs_iter: Iterable[configs.Config], path_suffix: str, ttl_str: str
):
for config in configs_iter:
snap = snap_operator.find_target(config, path_suffix)
if snap:
snap.set_ttl(ttl_str, now=datetime.datetime.now())


def _delete_snap(configs_iter: Iterable[configs.Config], path_suffix: str, sync: bool):
def _delete_snap(
configs_iter: Iterable[configs.Config], path_suffix: str, sync: bool
):
to_sync: list[configs.Config] = []
for config in configs_iter:
snap = snap_operator.find_target(config, path_suffix)
Expand Down Expand Up @@ -184,20 +211,27 @@ def _batch_delete_snaps(
filters = batch_deleter.get_filters(args_as_dict)

targets = list(
batch_deleter.apply_snapshot_filters(config_snaps_mapping_tuple, *filters)
batch_deleter.apply_snapshot_filters(
config_snaps_mapping_tuple, *filters
)
)
if sum(len(mapping.snaps) for mapping in targets) == 0:
os_utils.eprint("No snapshots matching the criteria were found.")
return

batch_deleter.show_snapshots_to_be_deleted(targets)

if batch_deleter.interactive_confirm():
snaps = itertools.chain.from_iterable(mapping.snaps for mapping in targets)
msg = "Are you sure you want to delete the above snapshots? [y/N] "
if os_utils.interactive_confirm(msg):
snaps = itertools.chain.from_iterable(
mapping.snaps for mapping in targets
)
batch_deleter.delete_snapshots(snaps)

if sync:
to_sync = batch_deleter.get_to_sync_list(mapping.config for mapping in targets)
to_sync = batch_deleter.get_to_sync_list(
mapping.config for mapping in targets
)
_sync(to_sync)


Expand Down Expand Up @@ -247,7 +281,9 @@ def main():
global_flags.FLAGS.dryrun = True
configs.USER_CONFIG_FILE = args.config_file

colored_logs.setup_logging(level=logging.INFO if args.verbose else logging.WARNING)
colored_logs.setup_logging(
level=logging.INFO if args.verbose else logging.WARNING
)

if configs.is_schedule_enabled() and not os_utils.timer_enabled():
os_utils.eprint(
Expand Down Expand Up @@ -282,10 +318,20 @@ def main():
args=args,
sync=args.sync,
)
elif command == "rollback":
rollbacker.rollback(
configs.iterate_configs(source=args.source),
args.target_suffix,
execute=True,
no_confirm=args.noconfirm,
)
elif command == "rollback-gen":
rollbacker.rollback(
configs.iterate_configs(source=args.source), args.target_suffix
configs.iterate_configs(source=args.source),
args.target_suffix,
execute=args.execute,
)
# Does `rollback-gen` subcommand require the optional parameter `--nocofirm` ?
else:
comment = getattr(args, "comment", "")
_config_operation(
Expand Down
15 changes: 13 additions & 2 deletions src/code/os_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import re
import subprocess
import sys

from typing import Any


Expand Down Expand Up @@ -49,7 +48,9 @@ def run_user_script(script_name: str, args: list[str]) -> bool:
logging.warning(f"User script {script_name=} does not exist.")
return False
except subprocess.CalledProcessError:
logging.warning(f"User script {script_name=} with {args=} resulted in error.")
logging.warning(
f"User script {script_name=} with {args=} resulted in error."
)
return False
return True

Expand Down Expand Up @@ -80,3 +81,13 @@ def timer_enabled() -> bool:
def eprint(*args: Any, **kwargs: Any) -> None:
"""Notifications meant for user, but not for redirection to any file."""
print(*args, file=sys.stderr, **kwargs)


def interactive_confirm(msg: str) -> bool:
user_choice = input(msg)
match user_choice:
case "y" | "Y" | "yes" | "Yes" | "YES":
return True
case _:
print("Aborted.")
return False
84 changes: 77 additions & 7 deletions src/code/rollbacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,93 @@
# limitations under the License.

import collections
import logging
import os
import subprocess
from typing import Iterable

from . import configs
from . import os_utils
from . import snap_operator
from .mechanisms import snap_mechanisms

from typing import Iterable

def _find_and_categorize_snaps(
configs_iter: Iterable[configs.Config], path_suffix: str
) -> dict[snap_mechanisms.SnapType, list[tuple[str, str]]]:
"""Find snapshots in the configuration file and categorize them based on their type."""
source_dests_by_snaptype: dict[
snap_mechanisms.SnapType, list[tuple[str, str]]
] = collections.defaultdict(list)

def rollback(configs_iter: Iterable[configs.Config], path_suffix: str):
source_dests_by_snaptype: dict[snap_mechanisms.SnapType, list[tuple[str, str]]] = (
collections.defaultdict(list)
)
for config in configs_iter:
snap = snap_operator.find_target(config, path_suffix)
if snap:
source_dests_by_snaptype[config.snap_type].append(
(snap.metadata.source, snap.target)
)
for snap_type, source_dests in sorted(source_dests_by_snaptype.items()):
print("\n".join(snap_mechanisms.get(snap_type).rollback_gen(source_dests)))

return source_dests_by_snaptype


_ROLLBACK_SCRIPT_FILEPATH = "/tmp/rollback.sh"


def rollback(
configs_iter: Iterable[configs.Config],
path_suffix: str,
*,
execute: bool = False,
no_confirm: bool = False,
) -> None:
source_dests_by_snaptype = _find_and_categorize_snaps(
configs_iter, path_suffix
)
contents = _show_and_return_rollback_gen(source_dests_by_snaptype)
_create_and_chmod_script(contents)

if execute and no_confirm:
subprocess.run([".", _ROLLBACK_SCRIPT_FILEPATH])
return

if execute is True:
msg = "Review the code and enter 'y' to confirm execution. [y/N] "
confirm = os_utils.interactive_confirm(msg)
if confirm:
subprocess.run([".", _ROLLBACK_SCRIPT_FILEPATH])


def _create_and_chmod_script(contents: list[str]) -> None:
with open(_ROLLBACK_SCRIPT_FILEPATH, mode="w", encoding="utf_8") as fp:
fp.writelines(contents)
logging.info(
f"The rollback script is saved in {_ROLLBACK_SCRIPT_FILEPATH} ."
)

os.chmod(_ROLLBACK_SCRIPT_FILEPATH, mode=0o700)
if os.access(_ROLLBACK_SCRIPT_FILEPATH, mode=os.X_OK) is True:
logging.info(
f"Execution permissions have been granted for {_ROLLBACK_SCRIPT_FILEPATH} ."
)
else:
logging.warning(
f"Manual execution permissions need to be added for {_ROLLBACK_SCRIPT_FILEPATH} ."
)


def _show_and_return_rollback_gen(
source_dests_by_snaptype: dict[
snap_mechanisms.SnapType, list[tuple[str, str]]
],
) -> list[str]:
contents = [
"\n".join(snap_mechanisms.get(snap_type).rollback_gen(source_dests))
for snap_type, source_dests in sorted(source_dests_by_snaptype.items())
]

print("=== THE FOLLOWING IS THE ROLLBACK CODE ===")
for content in contents:
print(content)
print("=== THE ABOVE IS THE ROLLBACK CODE ===")

return contents
22 changes: 22 additions & 0 deletions src/code/rollbacker_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import unittest

from . import rollbacker

# For testing, we can access private methods.
# pyright: reportPrivateUsage=false


class TestCreateAndChmodScript(unittest.TestCase):
def test_create_and_chmod(self):
contents = [
"This file is used for unit testing the yabsnap rollbacker module, generated from yabsnap."
]
rollbacker._create_and_chmod_script(contents)
self.assertTrue(os.path.exists(rollbacker._ROLLBACK_SCRIPT_FILEPATH))
self.assertTrue(
os.access(rollbacker._ROLLBACK_SCRIPT_FILEPATH, mode=os.X_OK)
)

def tearDown(self):
os.remove(rollbacker._ROLLBACK_SCRIPT_FILEPATH)
Loading