From aefe8b610e91723a708cddf14a40205f0be1977e Mon Sep 17 00:00:00 2001 From: Prashant Mital Date: Wed, 19 Feb 2020 18:14:11 -0800 Subject: [PATCH] Implement astrolabe - a framework for testing drivers against MongoDB Atlas --- .gitignore | 2 + astrolabe/cli.py | 454 +++++++++++++++++++++++++++++++++++++ astrolabe/commands.py | 138 +++++++++++ astrolabe/configuration.py | 47 ++++ astrolabe/exceptions.py | 10 + astrolabe/poller.py | 31 +++ astrolabe/spec_runner.py | 318 ++++++++++++++++++++++++++ astrolabe/utils.py | 132 +++++++++++ atlasclient/client.py | 4 +- atlasclient/utils.py | 12 +- setup.py | 1 + 11 files changed, 1145 insertions(+), 4 deletions(-) create mode 100644 .gitignore create mode 100644 astrolabe/cli.py create mode 100644 astrolabe/commands.py create mode 100644 astrolabe/configuration.py create mode 100644 astrolabe/exceptions.py create mode 100644 astrolabe/poller.py create mode 100644 astrolabe/spec_runner.py create mode 100644 astrolabe/utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..4fd492eb --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +**/__pycache__ +.idea/ diff --git a/astrolabe/cli.py b/astrolabe/cli.py new file mode 100644 index 00000000..402790db --- /dev/null +++ b/astrolabe/cli.py @@ -0,0 +1,454 @@ +from itertools import chain +import json +from pprint import pprint +import subprocess +import sys +from textwrap import dedent +from time import sleep + +import click +from tabulate import tabulate + +import astrolabe.commands +from atlasclient import AtlasClient, AtlasApiError +from atlasclient.configuration import CONFIG_DEFAULTS as CL_DEFAULTS +from astrolabe.spec_runner import MultiTestRunner +from astrolabe.configuration import ( + CLI_OPTION_NAMES as OPTNAMES, + CONFIG_DEFAULTS as DEFAULTS, + CONFIG_ENVVARS as ENVVARS, + TestCaseConfiguration) +from astrolabe.utils import Timer + + +# Define CLI options used in multiple commands for easy re-use. +DBUSERNAME_OPTION = click.option( + OPTNAMES.DB_USERNAME, type=click.STRING, default=DEFAULTS.DB_USERNAME, + help='Database username on the MongoDB instance.') + +DBPASSWORD_OPTION = click.option( + OPTNAMES.DB_PASSWORD, type=click.STRING, default=DEFAULTS.DB_PASSWORD, + help='Database password on the MongoDB instance.') + +ATLASORGANIZATIONNAME_OPTION = click.option( + OPTNAMES.ORGANIZATION_NAME, type=click.STRING, + default=DEFAULTS.ORGANIZATION_NAME, + required=True, help='Name of the Atlas Organization.') + +ATLASCLUSTERNAME_OPTION = click.option( + '--cluster-name', required=True, type=click.STRING, + help='Name of the Atlas Cluster.') + +ATLASGROUPNAME_OPTION = click.option( + OPTNAMES.PROJECT_NAME, required=True, type=click.STRING, + envvar=ENVVARS.PROJECT_NAME, help='Name of the Atlas Project.') + +POLLINGTIMEOUT_OPTION = click.option( + OPTNAMES.POLLING_TIMEOUT, type=click.FLOAT, + envvar=ENVVARS.POLLING_TIMEOUT, default=DEFAULTS.POLLING_TIMEOUT, + help="Maximum time (in s) to poll API endpoints.") + +POLLINGFREQUENCY_OPTION = click.option( + OPTNAMES.POLLING_FREQUENCY, type=click.FLOAT, + envvar=ENVVARS.POLLING_FREQUENCY, default=DEFAULTS.POLLING_FREQUENCY, + help='Frequency (in Hz) at which to poll API endpoints.') + + +@click.group() +@click.option(OPTNAMES.BASE_URL, envvar=ENVVARS.BASE_URL, + default=CL_DEFAULTS.BASE_URL, + type=click.STRING, help='Base URL of the Atlas API.') +@click.option('-u', '--atlas-api-username', required=True, + envvar=ENVVARS.API_USERNAME, type=click.STRING, + help='HTTP-Digest username (Atlas API public-key).') +@click.option('-p', '--atlas-api-password', required=True, + envvar=ENVVARS.API_PASSWORD, type=click.STRING, + help='HTTP-Digest password (Atlas API private-key).') +@click.option(OPTNAMES.HTTP_TIMEOUT, envvar=ENVVARS.HTTP_TIMEOUT, + default=CL_DEFAULTS.HTTP_TIMEOUT, type=click.FLOAT, + help='Time (in s) after which HTTP requests should timeout.') +@click.option('-v', '--verbose', count=True, default=False, + help="Set the logging level. Default: off.") +@click.version_option() +@click.pass_context +def cli(ctx, atlas_base_url, atlas_api_username, + atlas_api_password, http_timeout, verbose): + """ + Astrolabe is a command-line application for running automated driver + tests against a MongoDB Atlas cluster undergoing maintenance. + """ + client = AtlasClient( + base_url=atlas_base_url, + username=atlas_api_username, + password=atlas_api_password, + timeout=http_timeout, + verbose=verbose) + ctx.obj = client + + +@cli.command() +@click.pass_context +def check_connection(ctx): + """Command to verify validity of Atlas API credentials.""" + pprint(ctx.obj.root.get().data) + +# +# @cli.command() +# @click.option('-c', '--config', multiple=True, type=_JsonDotNotationType()) +# def cluster_config_option(config): +# pprint(_merge_dictionaries(config)) + + +@cli.group('organizations') +def atlas_organizations(): + """Commands related to Atlas Organizations.""" + pass + + +@atlas_organizations.command('list') +@click.pass_context +def list_all_organizations(ctx): + """List all Atlas Organizations (limited to first 100).""" + pprint(ctx.obj.orgs.get().data) + + +@atlas_organizations.command('get-one') +@ATLASORGANIZATIONNAME_OPTION +@click.pass_context +def get_one_organization_by_name(ctx, org_name): + """Get one Atlas Organization by name. Prints "None" if no organization + bearing the given name exists.""" + pprint(astrolabe.commands.get_one_organization_by_name(ctx.obj, org_name)) + + +@cli.group('projects') +def atlas_projects(): + """Commands related to Atlas Projects.""" + pass + + +@atlas_projects.command('create') +@ATLASORGANIZATIONNAME_OPTION +@ATLASGROUPNAME_OPTION +@click.pass_context +def create_project(ctx, org_name, group_name,): + """Create a new Atlas Project.""" + org = astrolabe.commands.get_one_organization_by_name(ctx.obj, org_name) + response = ctx.obj.groups.post(name=group_name, orgId=org.id) + pprint(response.data) + + +@atlas_projects.command('list') +@click.pass_context +def list_projects(ctx): + """List all Atlas Projects (limited to first 100).""" + pprint(ctx.obj.groups.get().data) + + +@atlas_projects.command('get-one') +@ATLASGROUPNAME_OPTION +@click.pass_context +def get_one_project_by_name(ctx, group_name): + """Get one Atlas Project.""" + pprint(ctx.obj.groups.byName[group_name].get().data) + + +@atlas_projects.command('enable-anywhere-access') +@ATLASGROUPNAME_OPTION +@click.pass_context +def enable_project_access_from_anywhere(ctx, group_name): + """Add 0.0.0.0/0 to the IP whitelist of the Atlas Project.""" + group = ctx.obj.groups.byName[group_name].get().data + astrolabe.commands.ensure_connect_from_anywhere(ctx.obj, group.id) + + +@cli.group('users') +def atlas_users(): + """Commands related to Atlas Users.""" + pass + + +@atlas_users.command('create-admin-user') +@DBUSERNAME_OPTION +@DBPASSWORD_OPTION +@ATLASGROUPNAME_OPTION +@click.pass_context +def create_user(ctx, db_username, db_password, group_name): + """Create an Atlas User with admin privileges. Modifies user + permissions, if the user already exists.""" + group = ctx.obj.groups.byName[group_name].get().data + user = astrolabe.commands.ensure_admin_user( + client=ctx.obj, group_id=group.id, username=db_username, + password=db_password) + pprint(user) + + +@atlas_users.command('list') +@ATLASGROUPNAME_OPTION +@click.pass_context +def list_users(ctx, group_name): + """List all Atlas Users.""" + project = ctx.obj.groups.byName[group_name].get().data + pprint(ctx.obj.groups[project.id].databaseUsers.get().data) + + +@cli.group('clusters') +def atlas_clusters(): + """Commands related to Atlas Clusters.""" + pass + + +@atlas_clusters.command('create-dedicated') +@ATLASGROUPNAME_OPTION +@ATLASCLUSTERNAME_OPTION +@click.option('-s', '--instance-size-name', required=True, + type=click.Choice(["M10", "M20"]), + help="Name of AWS Cluster Tier to provision.") +@click.pass_context +def create_cluster(ctx, group_name, cluster_name, instance_size_name): + """Create a new dedicated-tier Atlas Cluster.""" + project = ctx.obj.groups.byName[group_name].get().data + + cluster_config = { + 'name': cluster_name, + 'clusterType': 'REPLICASET', + 'providerSettings': { + 'providerName': 'AWS', + 'regionName': 'US_WEST_1', + 'instanceSizeName': instance_size_name}} + + cluster = ctx.obj.groups[project.id].clusters.post(**cluster_config) + pprint(cluster.data) + + +@atlas_clusters.command('get-one') +@ATLASCLUSTERNAME_OPTION +@ATLASGROUPNAME_OPTION +@click.pass_context +def get_one_cluster_by_name(ctx, cluster_name, group_name): + """Get one Atlas Cluster.""" + project = ctx.obj.groups.byName[group_name].get().data + cluster = ctx.obj.groups[project.id].clusters[cluster_name].get() + pprint(cluster.data) + + +@atlas_clusters.command('resize-dedicated') +@ATLASGROUPNAME_OPTION +@ATLASCLUSTERNAME_OPTION +@click.option('-s', '--instance-size-name', required=True, + type=click.Choice(["M10", "M20"]), + help="Target AWS Cluster Tier.") +@click.pass_context +def resize_cluster(ctx, group_name, cluster_name, instance_size_name): + """Resize an existing dedicated-tier Atlas Cluster.""" + project = ctx.obj.groups.byName[group_name].get().data + + new_cluster_config = { + 'clusterType': 'REPLICASET', + 'providerSettings': { + 'providerName': 'AWS', + 'regionName': 'US_WEST_1', + 'instanceSizeName': instance_size_name}} + + cluster = ctx.obj.groups[project.id].clusters[cluster_name].patch( + **new_cluster_config) + pprint(cluster.data) + + +@atlas_clusters.command('toggle-js') +@ATLASGROUPNAME_OPTION +@ATLASCLUSTERNAME_OPTION +@click.pass_context +def toggle_cluster_javascript(ctx, group_name, cluster_name): + """Enable/disable server-side javascript for an existing Atlas Cluster.""" + project = ctx.obj.groups.byName[group_name].get().data + + # Alias to reduce verbosity. + pargs = ctx.obj.groups[project.id].clusters[cluster_name].processArgs + + initial_process_args = pargs.get() + target_js_value = not initial_process_args.data.javascriptEnabled + + cluster = pargs.patch(javascriptEnabled=target_js_value) + pprint(cluster.data) + + +@atlas_clusters.command('list') +@ATLASGROUPNAME_OPTION +@click.pass_context +def list_clusters(ctx, group_name): + """List all Atlas Clusters.""" + project = ctx.obj.groups.byName[group_name].get().data + clusters = ctx.obj.groups[project.id].clusters.get() + pprint(clusters.data) + + +@atlas_clusters.command('isready') +@ATLASGROUPNAME_OPTION +@ATLASCLUSTERNAME_OPTION +@click.pass_context +def isready_cluster(ctx, group_name, cluster_name): + """Check if the Atlas Cluster is 'IDLE'.""" + project = ctx.obj.groups.byName[group_name].get().data + state = ctx.obj.groups[project.id].clusters[cluster_name].get().data.stateName + + if state == "IDLE": + print("True") + exit(0) + print("False") + exit(1) + + +@atlas_clusters.command('delete') +@ATLASGROUPNAME_OPTION +@ATLASCLUSTERNAME_OPTION +@click.pass_context +def delete_cluster(ctx, group_name, cluster_name): + """Delete the Atlas Cluster.""" + project = ctx.obj.groups.byName[group_name].get().data + ctx.obj.groups[project.id].clusters[cluster_name].delete().data + print("DONE!") + + +# @atlas_clusters.command('getlogs') + + + +# @cli.group('run-debug') +# @click.option('-f', '--test') +# def debug_test(): +# """Command group for running orchestrating tests.""" +# pass +# +# + +@cli.group('help') +def help_topics(): + """Help topics for astrolabe users.""" + pass + + +@help_topics.command('environment-variables') +def help_environment_variables(): + """Environment variables used to configure astrolabe.""" + helptext = dedent("""\ + Many of astrolabe's configuration options can be set at runtime using + environment variables. Specification of the HTTP Digest Authentication + credentials for MongoDB Atlas API access via the command-line is + HIGHLY DISCOURAGED due to security reasons. + {} + """) + tabledata = [] + for internal_id, envvar_name in ENVVARS.items(): + tabledata.append([internal_id, OPTNAMES[internal_id], envvar_name]) + headers = ["Internal ID", "CLI Option Name", + "Environment Variable"] + tabletext = tabulate(tabledata, headers=headers, tablefmt="fancy_grid") + click.echo_via_pager(helptext.format(tabletext)) + + +@help_topics.command('default-values') +def help_default_values(): + """Default values of configuration options.""" + helptext = dedent("""\ + Default values of configuration options are: + {} + """) + tabledata = [] + for internal_id, default_value in chain( + CL_DEFAULTS.items(), DEFAULTS.items()): + if internal_id in OPTNAMES: + tabledata.append( + [internal_id, OPTNAMES[internal_id], default_value]) + headers = ["Internal ID", "CLI Option Name", "Default Value"] + tabletext = tabulate(tabledata, headers=headers, tablefmt="fancy_grid") + click.echo_via_pager(helptext.format(tabletext)) + + +@cli.group('spec-tests') +def spec_tests(): + """Commands related to running APM spec-tests.""" + pass + + +@spec_tests.command('run-one') +@click.argument("spec_test_file", type=click.Path( + exists=True, file_okay=True, dir_okay=False, resolve_path=True)) +@click.option('-e', '--workload-executor', required=True, type=click.Path( + exists=True, file_okay=True, dir_okay=False, resolve_path=True), + help='Absolute or relative path to the workload-executor') +@click.option('--log-dir', required=True, default="logs", + type=click.Path(resolve_path=True)) +@DBUSERNAME_OPTION +@DBPASSWORD_OPTION +@click.pass_context +def run_one_test(ctx, spec_tests_directory, workload_executor, db_username, + db_password, polling_timeout, polling_frequency): + pass + + +@spec_tests.command('run') +@click.argument("spec_tests_directory", type=click.Path( + exists=True, file_okay=False, dir_okay=True, resolve_path=True)) +@click.option('-e', '--workload-executor', required=True, type=click.Path( + exists=True, file_okay=True, dir_okay=False, resolve_path=True), + help='Absolute or relative path to the workload-executor') +@DBUSERNAME_OPTION +@DBPASSWORD_OPTION +@ATLASORGANIZATIONNAME_OPTION +@ATLASGROUPNAME_OPTION +@click.option(OPTNAMES.CLUSTER_NAME_SALT, type=click.STRING, required=True, + envvar=ENVVARS.CLUSTER_NAME_SALT, + help='Salt for generating unique hashes.') +@POLLINGTIMEOUT_OPTION +@POLLINGFREQUENCY_OPTION +@click.option('--xunit-output', type=click.STRING, default="xunit-output", + help='Name of the folder in which to write the XUnit XML files.') +@click.pass_context +def run_headless(ctx, spec_tests_directory, workload_executor, db_username, + db_password, org_name, group_name, cluster_name_salt, + polling_timeout, polling_frequency, xunit_output): + """ + Main entry point for running APM tests in headless environments. + This command runs all tests found in the SPEC_TESTS_DIRECTORY + sequentially on an Atlas cluster. + """ + # Construct test configuration object. + config = TestCaseConfiguration( + organization_name=org_name, + group_name=group_name, + name_salt=cluster_name_salt, + polling_timeout=polling_timeout, + polling_frequency=polling_frequency, + database_username=db_username, + database_password=db_password, + workload_executor=workload_executor) + + # Step-0: print configuration. + table_data = [["Atlas organization name", config.organization_name], + ["Atlas group/project name", config.group_name], + ["Salt for cluster names", config.name_salt], + ["Polling frequency (Hz)", config.polling_frequency], + ["Polling timeout (s)", config.polling_timeout]] + click.echo(tabulate(table_data, headers=["Configuration option", "Value"], + tablefmt="fancy_grid")) + + # Step-1: create the Test-Runner. + runner = MultiTestRunner(client=ctx.obj, + spec_tests_directory=spec_tests_directory, + configuration=config, + xunit_output=xunit_output) + click.echo("---------------- Test Plan ---------------- ") + click.echo(runner.get_printable_test_plan()) + + # Step-2: run the tests. + failed = runner.run() + + if failed: + exit(1) + else: + exit(0) + + +if __name__ == '__main__': + cli() diff --git a/astrolabe/commands.py b/astrolabe/commands.py new file mode 100644 index 00000000..34e735d9 --- /dev/null +++ b/astrolabe/commands.py @@ -0,0 +1,138 @@ +from time import time, sleep + +from astrolabe.utils import assert_subset +from atlasclient import AtlasApiError + + +def get_one_organization_by_name(*, client, organization_name): + all_orgs = client.orgs.get().data + for org in all_orgs.results: + if org.name == organization_name: + return org + raise AtlasApiError('Resource not found.') + + +def ensure_project(*, client, group_name, organization_id): + try: + return client.groups.post( + name=group_name, orgId=organization_id).data + except AtlasApiError as exc: + if exc.error_code == 'GROUP_ALREADY_EXISTS': + return client.groups.byName[group_name].get().data + else: + raise + + +def ensure_admin_user(*, client, group_id, username, password): + user_details = { + "groupId": group_id, + "databaseName": "admin", + "roles": [{ + "databaseName": "admin", + "roleName": "atlasAdmin"}], + "username": username, + "password": password} + + try: + return client.groups[group_id].databaseUsers.post(**user_details).data + except AtlasApiError as exc: + if exc.error_code == "USER_ALREADY_EXISTS": + username = user_details.pop("username") + return client.groups[group_id].databaseUsers.admin[username].patch( + **user_details).data + else: + raise + + +def ensure_connect_from_anywhere(*, client, group_id,): + ip_details_list = [{"cidrBlock": "0.0.0.0/0"}] + client.groups[group_id].whitelist.post(json=ip_details_list) + + +def get_cluster_state(client, group_name, cluster_name): + project = client.groups.byName[group_name].get().data + cluster = client.groups[project.id].clusters[cluster_name].get().data + return cluster.stateName + + +def is_cluster_state(client, group_name, cluster_name, target_state): + project = client.groups.byName[group_name].get().data + cluster = client.groups[project.id].clusters[cluster_name].get().data + return cluster.stateName == target_state + + +def wait_until_cluster_state(client, group_name, cluster_name, target_state, + polling_frequency, polling_timeout): + if is_cluster_state(client, group_name, cluster_name, + target_state): + return True + + start_time = time() + sleep_interval = 1 / polling_frequency + while (time() - start_time) < polling_timeout: + if is_cluster_state(client, group_name, cluster_name, target_state): + return True + sleep(sleep_interval) + + return False + + +def select_callback(callback, args, kwargs, frequency, timeout): + start_time = time() + interval = 1 / frequency + while (time() - start_time) < timeout: + return_value = callback(*args, **kwargs) + if return_value is not None: + return return_value + print("Waiting {} seconds before retrying".format(interval)) + sleep(interval) + raise RuntimeError # TODO make new error type for polling timeout + + +def get_ready_test_plan(client, group_id, test_plans): + clusters = client.groups[group_id].clusters + for test_case in test_plans: + cluster_resource = clusters[test_case.cluster_name] + cluster = cluster_resource.get().data + if cluster.stateName == "IDLE": + # Verification + assert_subset(cluster, test_case.spec["maintenancePlan"]["initial"]["basicConfiguration"]) + processArgs = cluster_resource.processArgs.get().data + assert_subset(processArgs, test_case.spec["maintenancePlan"]["initial"]["processArgs"]) + print("Cluster {} is ready!".format(test_case.cluster_name)) + return test_case, cluster + else: + print("Cluster {} is not ready!".format(test_case.cluster_name)) + return None + + +def get_executor_args(test_case, username, password, plain_srv_address): + prefix, suffix = plain_srv_address.split("//") + + srv_address = prefix + "//" + username + ":" + password + "@" + suffix + "/?" + + uri_options = test_case.spec["maintenancePlan"]["uriOptions"] + + from urllib.parse import urlencode + srv_address = srv_address + urlencode(uri_options) + + return srv_address, test_case.spec["driverWorkload"] + + +def run_maintenance(client, test_case, group_id): + final_config = test_case.spec["maintenancePlan"]["final"] + + basic_conf = final_config["basicConfiguration"] + process_args = final_config["processArgs"] + + if not basic_conf and not process_args: + raise RuntimeError("invalid maintenance plan - both configs cannot be blank") + + cluster = client.groups[group_id].clusters[test_case.cluster_name] + if basic_conf: + cluster.patch(**basic_conf) + + if process_args: + cluster.processArgs.patch(**process_args) + + print("Maintenance has been started!") \ No newline at end of file diff --git a/astrolabe/configuration.py b/astrolabe/configuration.py new file mode 100644 index 00000000..9da6c09b --- /dev/null +++ b/astrolabe/configuration.py @@ -0,0 +1,47 @@ +from collections import namedtuple + +from atlasclient.utils import JSONObject + +CONFIG_DEFAULTS = JSONObject({ + "ORGANIZATION_NAME" : "MongoDB", + "DB_USERNAME" : "atlasuser", + "DB_PASSWORD" : "mypassword123", + "POLLING_TIMEOUT" : 600.0, + "POLLING_FREQUENCY" : 1.0, +}) + + +CONFIG_ENVVARS = JSONObject({ + "PROJECT_NAME" : "EVERGREEN_PROJECT_ID", # ${project} in EVG + "CLUSTER_NAME_SALT" : "EVERGREEN_BUILD_ID", # ${build_id} in EVG + "POLLING_TIMEOUT" : "ATLAS_POLLING_TIMEOUT", + "POLLING_FREQUENCY" : "ATLAS_POLLING_FREQUENCY", + "BASE_URL" : "ATLAS_API_BASE_URL", + "API_USERNAME" : "ATLAS_API_USERNAME", + "API_PASSWORD" : "ATLAS_API_PASSWORD", + "HTTP_TIMEOUT" : "ATLAS_HTTP_TIMEOUT", +}) + + +CLI_OPTION_NAMES = JSONObject({ + "PROJECT_NAME": "--group-name", + "CLUSTER_NAME_SALT": "--cluster-name-salt", + "POLLING_TIMEOUT": "--polling-timeout", + "POLLING_FREQUENCY": "--polling-frequency", + "BASE_URL": "--atlas-base-url", + "API_USERNAME": "--atlas-api-username", + "API_PASSWORD": "--atlas-api-password", + "HTTP_TIMEOUT": "--http-timeout", + "DB_USERNAME" : "--db-username", + "DB_PASSWORD" : "--db-password", + "ORGANIZATION_NAME": "--org-name", + "BASE_URL": "--atlas-base-url", + "HTTP_TIMEOUT": "--http-timeout"}) + + +# Convenience class for storing settings related to polling. +TestCaseConfiguration = namedtuple( + "AtlasPlannedMaintenanceTestConfiguration", + ["organization_name", "group_name", "name_salt", "polling_timeout", + "polling_frequency", "database_username", "database_password", + "workload_executor"]) diff --git a/astrolabe/exceptions.py b/astrolabe/exceptions.py new file mode 100644 index 00000000..7ecedfbc --- /dev/null +++ b/astrolabe/exceptions.py @@ -0,0 +1,10 @@ +class AstrolabeBaseError(Exception): + pass + + +class AstrolabeTestCaseError(AstrolabeBaseError): + pass + + +class PollingTimeoutError(AstrolabeBaseError): + pass \ No newline at end of file diff --git a/astrolabe/poller.py b/astrolabe/poller.py new file mode 100644 index 00000000..1211761c --- /dev/null +++ b/astrolabe/poller.py @@ -0,0 +1,31 @@ +from time import sleep + +from astrolabe.exceptions import PollingTimeoutError +from astrolabe.utils import Timer + + +class SelectBase: + def __init__(self, *, frequency, timeout): + self.interval = 1.0 / frequency + self.timeout = timeout + + @staticmethod + def poll(obj, attribute, args, kwargs): + raise NotImplementedError + + def select(self, objects, *, attribute, args, kwargs): + timer = Timer() + timer.start() + while timer.elapsed < self.timeout: + for obj in objects: + return_value = self.poll(obj, attribute, args, kwargs) + if return_value: + return obj + sleep(self.interval) + raise PollingTimeoutError + + +class BooleanCallableSelector(SelectBase): + @staticmethod + def poll(obj, attribute, args=(), kwargs={}): + return bool(getattr(obj, attribute)(*args, **kwargs)) diff --git a/astrolabe/spec_runner.py b/astrolabe/spec_runner.py new file mode 100644 index 00000000..aa602eb0 --- /dev/null +++ b/astrolabe/spec_runner.py @@ -0,0 +1,318 @@ +from hashlib import sha256 +import json +import os +import signal +import subprocess +import sys +from time import sleep +from urllib.parse import urlencode + +from pymongo import MongoClient +from tabulate import tabulate +import junitparser +import yaml + +from atlasclient import AtlasApiError, JSONObject +from astrolabe.commands import ( + get_one_organization_by_name, ensure_project, ensure_admin_user, + ensure_connect_from_anywhere) +from astrolabe.exceptions import AstrolabeTestCaseError +from astrolabe.poller import BooleanCallableSelector +from astrolabe.utils import ( + assert_subset, cached_property, encode_cdata, SingleTestXUnitLogger, + Timer) + + +class AtlasTestCase: + def __init__(self, *, client, test_name, cluster_name, specification, + configuration): + # Initialize. + self.client = client + self.id = test_name + self.cluster_name = cluster_name + self.spec = specification + self.config = configuration + self.failed = False + + # Account for platform-specific interrupt signals. + if sys.platform != 'win32': + self.sigint = signal.SIGINT + else: + self.sigint = signal.CTRL_C_EVENT + + # Validate organization and group. + self.get_organization() + self.get_group() + + @cached_property + def get_organization(self): + return get_one_organization_by_name( + client=self.client, + organization_name=self.config.organization_name) + + @cached_property + def get_group(self): + return ensure_project( + client=self.client, group_name=self.config.group_name, + organization_id=self.get_organization().id) + + @property + def cluster_url(self): + return self.client.groups[self.get_group().id].clusters[ + self.cluster_name] + + @cached_property + def get_connection_string(self): + cluster = self.cluster_url.get().data + prefix, suffix = cluster.srvAddress.split("//") + uri_options = self.spec.maintenancePlan.uriOptions.copy() + + # Boolean options must be converted to lowercase strings. + for key, value in uri_options.items(): + if isinstance(value, bool): + uri_options[key] = str(value).lower() + + connection_string = (prefix + "//" + self.config.database_username + + ":" + self.config.database_password + "@" + + suffix + "/?") + connection_string += urlencode(uri_options) + return connection_string + + def is_cluster_state(self, goal_state): + cluster_info = self.cluster_url.get().data + return cluster_info.stateName.lower() == goal_state.lower() + + def verify_cluster_configuration_matches(self, state): + """Verify that the cluster config is what we expect it to be (based on + maintenance status). Raises AssertionError.""" + state = state.lower() + if state not in ("initial", "final"): + raise AstrolabeTestCaseError( + "State must be either 'initial' or 'final'.") + cluster_config = self.cluster_url.get().data + assert_subset( + cluster_config, + self.spec.maintenancePlan[state].clusterConfiguration) + process_args = self.cluster_url.processArgs.get().data + assert_subset( + process_args, self.spec.maintenancePlan[state].processArgs) + + def initialize(self): + """ + Initialize a cluster with the configuration required by the test + specification. + """ + # Create a cluster of the desired name with the given config. + cluster_config = self.spec.maintenancePlan.initial.\ + clusterConfiguration.copy() + cluster_config["name"] = self.cluster_name + try: + self.client.groups[self.get_group().id].clusters.post( + **cluster_config) + except AtlasApiError as exc: + if exc.error_code == 'DUPLICATE_CLUSTER_NAME': + # Cluster already exists. Simply re-configure it. + # Cannot send cluster name when updating existing cluster. + cluster_config.pop("name") + self.client.groups[self.get_group().id].\ + clusters[self.cluster_name].patch(**cluster_config) + + # Apply processArgs if provided. + process_args = self.spec.maintenancePlan.initial.processArgs + if process_args: + self.client.groups[self.get_group().id].\ + clusters[self.cluster_name].processArgs.patch(**process_args) + + def run(self): + # Step-0: sanity-check the cluster configuration. + self.verify_cluster_configuration_matches("initial") + + print("Running test {} on cluster {}.".format(self.id, self.cluster_name)) + + # Start the test timer. + timer = Timer() + timer.start() + + # Step-1: load test data. + test_data = self.spec.driverWorkload.get('testData') + if test_data: + connection_string = self.get_connection_string() + client = MongoClient(connection_string, w="majority") + coll = client.get_database( + self.spec.driverWorkload.database).get_collection( + self.spec.driverWorkload.collection) + coll.drop() + coll.insert_many(test_data) + + # Step-2: run driver workload. + connection_string = self.get_connection_string() + driver_workload = json.dumps(self.spec.driverWorkload) + worker_subprocess = subprocess.Popen([ + sys.executable, self.config.workload_executor, connection_string, + driver_workload], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + print("Workload executor running. PID: {}".format(worker_subprocess.pid)) + + # Step-3: begin maintenance routine. + final_config = self.spec.maintenancePlan.final + cluster_config = final_config.clusterConfiguration + process_args = final_config.processArgs + + if not cluster_config and not process_args: + raise RuntimeError("invalid maintenance plan.") + + if cluster_config: + self.cluster_url.patch(**cluster_config) + + if process_args: + self.cluster_url.processArgs.patch(**process_args) + + # Sleep before polling to avoid "missing" cluster.stateName change. + sleep(3) + + print("Waiting for maintenance to complete.") + + # Step-4: wait until maintenance completes (cluster is IDLE). + selector = BooleanCallableSelector( + frequency=self.config.polling_frequency, + timeout=self.config.polling_timeout) + selector.select([self], attribute="is_cluster_state", args=("IDLE",), + kwargs={}) + self.verify_cluster_configuration_matches("final") + + # Step-5: interrupt driver workload and capture streams + os.kill(worker_subprocess.pid, self.sigint) + stdout, stderr = worker_subprocess.communicate() + + # Stop the timer + timer.stop() + + print("Maintenance is complete.") + print(stdout) + print("-------------------------") + print(stderr) + + # Step-6: compute xunit entry. + junit_test = junitparser.TestCase(self.id) + junit_test.time = timer.elapsed + if worker_subprocess.returncode != 0: + self.failed = True + errmsg = """ + Number of errors: {numErrors} + Number of failures: {numFailures} + """ + try: + err_info = json.loads(stderr) + junit_test.result = junitparser.Failure( + errmsg.format(**err_info)) + except json.JSONDecodeError: + junit_test.result = junitparser.Error(encode_cdata(stderr)) + junit_test.system_err = encode_cdata(stderr) + junit_test.system_out = encode_cdata(stdout) + + # TODO + # download logs and delete cluster asynchronously + # cleanup_queue.put(cluster_name) # cleanup queue downloads logs and deletes cluster + # atlas.Clusters(ctx.obj).delete(group.json()["id"], test_case.cluster_name) + # Step-8: make a zipfile with all logs + # import shutil + # shutil.make_archive(...) + + return junit_test + + +class MultiTestRunner: + def __init__(self, *, client, spec_tests_directory, configuration, xunit_output): + self.cases = [] + self.client = client + self.config = configuration + self.xunit_logger = SingleTestXUnitLogger(output_directory=xunit_output) + + # Scan directory and create tests. + for root, dirs, files in os.walk(spec_tests_directory): + for file in files: + full_path = os.path.join(root, file) + if (os.path.isfile(full_path) and + file.lower().endswith(('.yml', 'yaml'))): + # Step-1: load test specification. + with open(full_path, 'r') as spec_file: + test_spec = JSONObject( + yaml.load(spec_file, Loader=yaml.FullLoader)) + + # Step-2: generate test name. + _, filename = os.path.split(full_path) + test_name = os.path.splitext(filename)[0].replace('-', '_') + + # Step-3: generate unique cluster name. + name_hash = sha256(test_name.encode('utf-8')) + name_hash.update(self.config.name_salt.encode('utf-8')) + cluster_name = name_hash.hexdigest()[:10] + + self.cases.append( + AtlasTestCase(client=self.client, + test_name=test_name, + cluster_name=cluster_name, + specification=test_spec, + configuration=self.config)) + + # Set up Atlas for tests. + # Step-1: ensure validity of the organization. + # Note: organizations can only be created by via the web UI. + org = get_one_organization_by_name( + client=self.client, + organization_name=self.config.organization_name) + + # Step-2: check that the project exists or else create it. + group = ensure_project( + client=self.client, group_name=self.config.group_name, + organization_id=org.id) + + # Step-3: create a user under the project. + # Note: all test operations will be run as this user. + ensure_admin_user( + client=self.client, group_id=group.id, + username=self.config.database_username, + password=self.config.database_password) + + # Step-4: populate project IP whitelist to allow access from anywhere. + ensure_connect_from_anywhere( + client=self.client, group_id=group.id) + + def get_printable_test_plan(self): + table_data = [] + for test_case in self.cases: + table_data.append([test_case.id, test_case.cluster_name]) + return tabulate( + table_data, headers=["Test name", "Atlas cluster name"], + showindex="always", tablefmt="fancy_grid") + + def run(self): + # Step-0: sentinel flag to track failure/success. + failed = False + + # Step-1: initialize tests clusters + for case in self.cases: + case.initialize() + + # Step-2: run tests round-robin until all have been run. + remaining_test_cases = self.cases.copy() + while remaining_test_cases: + selector = BooleanCallableSelector( + frequency=self.config.polling_frequency, + timeout=self.config.polling_timeout) + # Select a case whose cluster is ready. + active_case = selector.select( + remaining_test_cases, attribute="is_cluster_state", + args=("IDLE",), kwargs={}) + # Run the case. + xunit_test = active_case.run() + # Write xunit entry for case. + self.xunit_logger.write_xml( + test_case=xunit_test, + filename=active_case.id) + # Remove completed case from list. + remaining_test_cases.remove(active_case) + # Update tracker. + failed = failed or active_case.failed + + return failed diff --git a/astrolabe/utils.py b/astrolabe/utils.py new file mode 100644 index 00000000..eb629897 --- /dev/null +++ b/astrolabe/utils.py @@ -0,0 +1,132 @@ +from collections import defaultdict +import errno +import json +import os +from textwrap import dedent +from time import monotonic + +import click +import junitparser + + +def assert_subset(dict1, dict2): + """Utility that asserts that `dict2` is a subset of `dict1`, while + accounting for nested fields.""" + for key, value in dict2.items(): + if key not in dict1: + raise AssertionError("not a subset") + if isinstance(value, dict): + assert_subset(dict1[key], value) + else: + assert dict1[key] == value + + +class Timer: + """Class to simplify timing operations.""" + def __init__(self): + self._start = None + self._end = None + + def reset(self): + self.__init__() + + def start(self): + self._start = monotonic() + self._end = None + + def stop(self): + self._end = monotonic() + + @property + def elapsed(self): + if self._end is None: + return monotonic() - self._start + return self._end - self._start + + +def cached_property(func): + """Decorator to memoize a class method that accepts no args/kwargs.""" + memo = None + + def memoized_function(self, *args, **kwargs): + if args or kwargs: + raise RuntimeError("cannot memoize methods that accept arguments") + nonlocal memo + if memo is None: + memo = func(self) + return memo + + return memoized_function + + +def encode_cdata(data): + """Encode `data` to XML-recognized CDATA.""" + return "".format(data=data) + + +class SingleTestXUnitLogger: + def __init__(self, *, output_directory): + self._output_directory = os.path.realpath(os.path.join( + os.getcwd(), output_directory)) + + # Ensure folder exists. + try: + os.mkdir(self._output_directory) + except FileExistsError: + pass + + def write_xml(self, test_case, filename): + filename += '.xml' + xml_path = os.path.join(self._output_directory, filename) + + # Remove existing file if applicable. + try: + os.unlink(xml_path) + except FileNotFoundError: + pass + + # use filename as suitename + suite = junitparser.TestSuite(filename) + suite.add_testcase(test_case) + + xml = junitparser.JUnitXml() + xml.add_testsuite(suite) + xml.write(xml_path) + + +def _nested_defaultdict(): + """An infinitely nested defaultdict type.""" + return defaultdict(_nested_defaultdict) + + +def _merge_dictionaries(dicts): + """Utility method to merge a list of dictionaries. + Last observed value prevails.""" + result = {} + for d in dicts: + result.update(d) + return result + + +class _JsonDotNotationType(click.ParamType): + """Custom Click-type for user-friendly JSON input.""" + def convert(self, value, param, ctx): + # Return None and target type without change. + if value is None or isinstance(value, dict): + return value + + # Parse the input (of type path.to.namespace=value). + ns, config_value = value.split("=") + ns_path = ns.split(".") + return_value = _nested_defaultdict() + + # Construct dictionary from parsed option. + pointer = return_value + for key in ns_path: + if key == ns_path[-1]: + pointer[key] = config_value + else: + pointer = pointer[key] + + # Convert nested defaultdict into vanilla dictionary. + return json.loads(json.dumps(return_value)) diff --git a/atlasclient/client.py b/atlasclient/client.py index c3c650a4..dea2cfb5 100644 --- a/atlasclient/client.py +++ b/atlasclient/client.py @@ -74,14 +74,14 @@ class _ApiResponse: def __init__(self, response, request_method, json_data): self.resource_url = response.url self.headers = response.headers - self.status = response.status_code + self.status_code = response.status_code self.request_method = request_method self.data = json_data def __repr__(self): return '<{}: {} {}, [HTTP status code: {}]>'.format( self.__class__.__name__, self.request_method, - self.resource_url, self.status) + self.resource_url, self.status_code) class AtlasClient: diff --git a/atlasclient/utils.py b/atlasclient/utils.py index 8b63931c..4363f0da 100644 --- a/atlasclient/utils.py +++ b/atlasclient/utils.py @@ -21,12 +21,20 @@ class JSONObject(dict): """Dictionary with dot-notation read access.""" + def __coerce_dict(self, value): + if isinstance(value, dict): + return type(self)(value) + return value + def __getattr__(self, name): if name in self: - return self[name] - raise AttributeError('{} has no property named {}}.'.format( + return self.__coerce_dict(self[name]) + raise AttributeError('{} has no property named {}.'.format( self.__class__.__name__, name)) + def __getitem__(self, item): + return self.__coerce_dict(super().__getitem__(item)) + def enable_http_logging(loglevel): """Enables logging of all HTTP requests.""" diff --git a/setup.py b/setup.py index 4b44d633..aed86271 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ 'pymongo>=3.10,<4', 'dnspython>=1.16,<2', 'pyyaml>=5,<6', + 'tabulate>=0.8,<0.9', 'junitparser>=1,<2'], entry_points={ 'console_scripts': ['astrolabe=astrolabe.cli:cli']},