diff --git a/.changelog/4658.yml b/.changelog/4658.yml new file mode 100644 index 00000000000..aa735a1003b --- /dev/null +++ b/.changelog/4658.yml @@ -0,0 +1,4 @@ +changes: +- description: Added the **test-use-case** command to test use-case flows on cloud machines. + type: internal +pr_number: 4658 diff --git a/demisto_sdk/__main__.py b/demisto_sdk/__main__.py index 3ee58d873e2..13ace9aff2e 100644 --- a/demisto_sdk/__main__.py +++ b/demisto_sdk/__main__.py @@ -493,6 +493,18 @@ def register_commands(_args: list[str] = []): # noqa: C901 help="This command generates a test playbook from integration/script YAML arguments.", )(generate_test_playbook) + if command_name == "test-use-case" or register_all: + from demisto_sdk.commands.test_content.test_use_case.test_use_case_setup import ( + run_test_use_case, + ) + + app.command( + name="test-use-case", + hidden=True, + no_args_is_help=True, + help="Test Use Cases.", + )(run_test_use_case) + # Register relevant commands to Demisto-SDK app based on command-line arguments. args = sys.argv[1:] diff --git a/demisto_sdk/commands/common/clients/__init__.py b/demisto_sdk/commands/common/clients/__init__.py index 08266ff329b..e5a84d63764 100644 --- a/demisto_sdk/commands/common/clients/__init__.py +++ b/demisto_sdk/commands/common/clients/__init__.py @@ -2,6 +2,7 @@ from functools import lru_cache from typing import Optional +from _pytest.fixtures import SubRequest from urllib3.exceptions import MaxRetryError from demisto_sdk.commands.common.clients.configs import ( @@ -27,6 +28,7 @@ DEMISTO_PASSWORD, DEMISTO_USERNAME, DEMISTO_VERIFY_SSL, + PROJECT_ID, MarketplaceVersions, ) from demisto_sdk.commands.common.logger import logger @@ -132,6 +134,7 @@ def get_client_from_server_type( auth_id: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, + project_id: Optional[str] = None, verify_ssl: Optional[bool] = None, raise_if_server_not_healthy: bool = True, ) -> XsoarClient: @@ -144,6 +147,7 @@ def get_client_from_server_type( auth_id: the auth ID, if not provided will take from XSIAM_AUTH_ID env var username: the username to authenticate, relevant only for xsoar on prem password: the password to authenticate, relevant only for xsoar on prem + project_id: the project id of the current cloud machine. verify_ssl: whether in each request SSL should be verified, True if yes, False if not, if verify_ssl = None, will take the SSL verification from DEMISTO_VERIFY_SSL env var raise_if_server_not_healthy: whether to raise an exception if the server is not healthy @@ -156,6 +160,7 @@ def get_client_from_server_type( _auth_id = auth_id or os.getenv(AUTH_ID) _username = username or os.getenv(DEMISTO_USERNAME, "") _password = password or os.getenv(DEMISTO_PASSWORD, "") + _project_id = project_id or os.getenv(PROJECT_ID, "") _verify_ssl = ( verify_ssl if verify_ssl is not None @@ -188,6 +193,7 @@ def get_client_from_server_type( api_key=_api_key, auth_id=_auth_id, verify_ssl=_verify_ssl, + project_id=_project_id, ), should_validate_server_type=should_validate_server_type, raise_if_server_not_healthy=raise_if_server_not_healthy, @@ -202,6 +208,7 @@ def get_client_from_server_type( api_key=_api_key, auth_id=_auth_id, verify_ssl=_verify_ssl, + project_id=_project_id, ), should_validate_server_type=should_validate_server_type, raise_if_server_not_healthy=raise_if_server_not_healthy, @@ -232,3 +239,31 @@ def get_client_from_server_type( f"make sure the {DEMISTO_BASE_URL}, {DEMISTO_KEY}, {AUTH_ID} are defined properly" ) raise + + +# =================== Playbook Flow Tests ================= + + +def parse_str_to_dict(input_str: str): + """Internal function to convert a string representing a dictionary into an actual dictionary. + + Args: + input_str (str): A string in the format 'key1=value1,key2=value2'. + + Returns: + dict: A dictionary with the parsed key-value pairs. + """ + x = dict(pair.split("=") for pair in input_str.split(",") if "=" in pair) + logger.info(x.get("base_url", "no base url")) + return dict(pair.split("=") for pair in input_str.split(",") if "=" in pair) + + +def get_client_conf_from_pytest_request(request: SubRequest): + # Manually parse command-line argument + for arg in request.config.invocation_params.args: + if isinstance(arg, str) and arg.startswith("--client_conf="): + logger.debug("Parsing --client_conf argument") + client_conf = arg.replace("--client_conf=", "") + return parse_str_to_dict(client_conf) + # If a client data was not provided, we proceed to use default. + return None diff --git a/demisto_sdk/commands/common/clients/configs.py b/demisto_sdk/commands/common/clients/configs.py index 4efd8ffb118..b30cf94d5f3 100644 --- a/demisto_sdk/commands/common/clients/configs.py +++ b/demisto_sdk/commands/common/clients/configs.py @@ -10,6 +10,7 @@ DEMISTO_PASSWORD, DEMISTO_USERNAME, DEMISTO_VERIFY_SSL, + PROJECT_ID, XSIAM_COLLECTOR_TOKEN, XSIAM_TOKEN, ) @@ -72,6 +73,9 @@ def __eq__(self, other): class XsoarSaasClientConfig(XsoarClientConfig): auth_id: str = Field(default=os.getenv(AUTH_ID), description="XSOAR/XSIAM Auth ID") + project_id: str = Field( + default=os.getenv(PROJECT_ID), description="XSOAR/XSIAM Project ID" + ) @root_validator() def validate_auth_params(cls, values: Dict[str, Any]): diff --git a/demisto_sdk/commands/common/clients/xsiam/xsiam_api_client.py b/demisto_sdk/commands/common/clients/xsiam/xsiam_api_client.py index 66995ecd9f7..36c4f95be71 100644 --- a/demisto_sdk/commands/common/clients/xsiam/xsiam_api_client.py +++ b/demisto_sdk/commands/common/clients/xsiam/xsiam_api_client.py @@ -89,11 +89,9 @@ def push_to_dataset( endpoint = urljoin(self.server_config.base_api_url, "logs/v1/event") additional_headers = { "authorization": self.server_config.collector_token, - "content-type": ( - "application/json" - if data_format.casefold == "json" - else "text/plain" - ), + "content-type": "application/json" + if data_format.casefold == "json" + else "text/plain", "content-encoding": "gzip", } token_type = "collector_token" @@ -199,3 +197,92 @@ def get_ioc_rules(self): ) return response + + """ + ############################# + Alerts related methods + ############################# + """ + + def create_alert_from_json(self, json_content: dict) -> int: + alert_payload = {"request_data": {"alert": json_content}} + endpoint = urljoin( + self.server_config.base_api_url, "/public_api/v1/alerts/create_alert" + ) + res = self._xdr_client.post(endpoint, json=alert_payload) + alert_data = self._process_response(res, res.status_code, 200) + return alert_data["reply"] + + def get_internal_alert_id(self, alert_external_id: str) -> int: + data = self.search_alerts( + filters=[ + { + "field": "external_id_list", + "operator": "in", + "value": [alert_external_id], + } + ] + ) + return data["alerts"][0]["alert_id"] + + def update_alert(self, alert_id: Union[str, list[str]], updated_data: dict) -> dict: + """ + Args: + alert_id (str | list[str]): alert ids to edit. + updated_data (dict): The data to update the alerts with. https://cortex-panw.stoplight.io/docs/cortex-xsiam-1/rpt3p1ne2bwfe-update-alerts + """ + alert_payload = { + "request_data": {"update_data": updated_data, "alert_id_list": alert_id} + } + endpoint = urljoin( + self.server_config.base_api_url, "/public_api/v1/alerts/update_alerts" + ) + res = self._xdr_client.post(endpoint, json=alert_payload) + alert_data = self._process_response(res, res.status_code, 200) + return alert_data + + def search_alerts( + self, + filters: list = None, + search_from: int = None, + search_to: int = None, + sort: dict = None, + ) -> dict: + """ + filters should be a list of dicts contains field, operator, value. + For example: + [{field: alert_id_list, operator: in, value: [1,2,3,4]}] + Allowed values for fields - alert_id_list, alert_source, severity, creation_time + """ + body = { + "request_data": { + "filters": filters, + "search_from": search_from, + "search_to": search_to, + "sort": sort, + } + } + endpoint = urljoin( + self.server_config.base_api_url, "/public_api/v1/alerts/get_alerts/" + ) + res = self._xdr_client.post(endpoint, json=body) + return self._process_response(res, res.status_code, 200)["reply"] + + def search_alerts_by_uuid(self, alert_uuids: list = None, filters: list = None): + alert_uuids = alert_uuids or [] + alert_ids: list = [] + res = self.search_alerts(filters=filters) + alerts: list = res.get("alerts") # type: ignore + count: int = res.get("result_count") # type: ignore + + while len(alerts) > 0 and len(alert_uuids) > len(alert_ids): + for alert in alerts: + for uuid in alert_uuids: + if alert.get("description").endswith(uuid): + alert_ids.append(alert.get("alert_id")) + + res = self.search_alerts(filters=filters, search_from=count) + alerts = res.get("alerts") # type: ignore + count = res.get("result_count") # type: ignore + + return alert_ids diff --git a/demisto_sdk/commands/common/clients/xsoar/xsoar_api_client.py b/demisto_sdk/commands/common/clients/xsoar/xsoar_api_client.py index fe746775b8b..c7293ac1aa3 100644 --- a/demisto_sdk/commands/common/clients/xsoar/xsoar_api_client.py +++ b/demisto_sdk/commands/common/clients/xsoar/xsoar_api_client.py @@ -206,6 +206,28 @@ def external_base_url(self) -> str: # url that its purpose is to expose apis of integrations outside from xsoar/xsiam return self.server_config.config.base_api_url + """ + ############################# + Helper methods + ############################# + """ + + def _process_response(self, response, status_code, expected_status=200): + """Process the HTTP response coming from the XSOAR client.""" + if status_code == expected_status: + if response: + try: + return response.json() + except json.JSONDecodeError: + error = response.text + err_msg = f"Failed to parse json response - with status code {response.status_code}" + err_msg += f"\n{error}" if error else "" + logger.error(err_msg) + response.raise_for_status() + else: + error_message = f"Expected status {expected_status}, but got {status_code}. Response: {response}" + raise Exception(error_message) + """ ############################# marketplace related methods @@ -1306,3 +1328,18 @@ def poll_playbook_state( else None ), ) + + def get_playbook_data(self, playbook_id: int) -> dict: + playbook_endpoint = f"/playbook/{playbook_id}" + + response, status_code, _ = self._xsoar_client.generic_request( + playbook_endpoint, method="GET", accept="application/json" + ) + return self._process_response(response, status_code, 200) + + def update_playbook_input(self, playbook_id: str, new_inputs: dict): + saving_inputs_path = f"/playbook/inputs/{playbook_id}" + response, status_code, _ = self._xsoar_client.generic_request( + saving_inputs_path, method="POST", body={"inputs": new_inputs} + ) + return self._process_response(response, status_code, 200) diff --git a/demisto_sdk/commands/common/clients/xsoar_saas/xsoar_saas_api_client.py b/demisto_sdk/commands/common/clients/xsoar_saas/xsoar_saas_api_client.py index e0860891544..01e8d0a830e 100644 --- a/demisto_sdk/commands/common/clients/xsoar_saas/xsoar_saas_api_client.py +++ b/demisto_sdk/commands/common/clients/xsoar_saas/xsoar_saas_api_client.py @@ -42,6 +42,7 @@ def __init__( "Content-Type": "application/json", } ) + self.project_id = config.project_id super().__init__( config, client=client, diff --git a/demisto_sdk/commands/common/constants.py b/demisto_sdk/commands/common/constants.py index f80db3a87c9..ef1a053ab3c 100644 --- a/demisto_sdk/commands/common/constants.py +++ b/demisto_sdk/commands/common/constants.py @@ -52,6 +52,7 @@ XSIAM_TOKEN = "XSIAM_TOKEN" XSIAM_COLLECTOR_TOKEN = "XSIAM_COLLECTOR_TOKEN" DEMISTO_VERIFY_SSL = "DEMISTO_VERIFY_SSL" +PROJECT_ID = "PROJECT_ID" # Logging DEMISTO_SDK_LOG_FILE_PATH = "DEMISTO_SDK_LOG_FILE_PATH" @@ -2218,6 +2219,7 @@ class PlaybookTaskType(StrEnum): # Test types: TEST_PLAYBOOKS = "TestPlaybooks" TEST_MODELING_RULES = "TestModelingRules" +TEST_USE_CASES = "TestUseCases" PB_RELEASE_NOTES_FORMAT = { "This playbook addresses the following alerts:": 5, diff --git a/demisto_sdk/commands/common/git_util.py b/demisto_sdk/commands/common/git_util.py index ffb2c6e45aa..0b212ed3360 100644 --- a/demisto_sdk/commands/common/git_util.py +++ b/demisto_sdk/commands/common/git_util.py @@ -325,7 +325,7 @@ def modified_files( # if remote does not exist we are checking against the commit sha1 else: committed = { - Path(os.path.join(item.a_path)) + Path(os.path.join(item.a_path)) # type: ignore for item in self.repo.commit(rev=branch) .diff(current_branch_or_hash) .iter_change_type("M") @@ -352,7 +352,7 @@ def modified_files( # get all the files that are staged on the branch and identified as modified. staged = { - Path(os.path.join(item.a_path)) + Path(os.path.join(item.a_path)) # type: ignore for item in self.repo.head.commit.diff().iter_change_type("M") }.union(untracked).union(untrue_rename_staged) @@ -373,7 +373,7 @@ def modified_files( # if remote does not exist we are checking against the commit sha1 else: committed_added = { - Path(os.path.join(item.a_path)) + Path(os.path.join(item.a_path)) # type: ignore for item in self.repo.commit(rev=branch) .diff(current_branch_or_hash) .iter_change_type("A") @@ -446,7 +446,7 @@ def added_files( # if remote does not exist we are checking against the commit sha1 else: committed = { - Path(os.path.join(item.a_path)) + Path(os.path.join(item.a_path)) # type: ignore for item in self.repo.commit(rev=branch) .diff(current_branch_or_hash) .iter_change_type("A") @@ -477,7 +477,7 @@ def added_files( # get all the files that are staged on the branch and identified as added. staged = { - Path(os.path.join(item.a_path)) + Path(os.path.join(item.a_path)) # type: ignore for item in self.repo.head.commit.diff().iter_change_type("A") }.union(untrue_rename_staged) @@ -487,7 +487,7 @@ def added_files( # so will added it from the staged added files. # same goes to untracked files - can be identified as modified but are actually added against prev_ver committed_added_locally_modified = { - Path(os.path.join(item.a_path)) + Path(os.path.join(item.a_path)) # type: ignore for item in self.repo.head.commit.diff().iter_change_type("M") }.intersection(committed) untracked = untracked_added.union(untracked_modified.intersection(committed)) @@ -550,7 +550,7 @@ def deleted_files( # if remote does not exist we are checking against the commit sha1 else: committed = { - Path(os.path.join(item.a_path)) + Path(os.path.join(item.a_path)) # type: ignore for item in self.repo.commit(rev=branch) .diff(current_branch_or_hash) .iter_change_type("D") @@ -571,7 +571,7 @@ def deleted_files( # get all the files that are staged on the branch and identified as added. staged = { - Path(os.path.join(item.a_path)) + Path(os.path.join(item.a_path)) # type: ignore for item in self.repo.head.commit.diff().iter_change_type("D") }.union(untracked) @@ -631,7 +631,7 @@ def renamed_files( # if remote does not exist we are checking against the commit sha1 else: committed = { - (Path(item.a_path), Path(item.b_path)) + (Path(item.a_path), Path(item.b_path)) # type: ignore for item in self.repo.commit(rev=branch) .diff(current_branch_or_hash) .iter_change_type("R") @@ -667,7 +667,7 @@ def renamed_files( # get all the files that are staged on the branch and identified as renamed and are with 100% score. staged = { - (Path(item.a_path), Path(item.b_path)) + (Path(item.a_path), Path(item.b_path)) # type: ignore for item in self.repo.head.commit.diff().iter_change_type("R") if item.score == 100 }.union(untracked) @@ -793,7 +793,7 @@ def _only_last_commit( try: if requested_status != "R": return { - Path(os.path.join(item.a_path)) + Path(os.path.join(item.a_path)) # type: ignore for item in self.repo.commit("HEAD~1") .diff() .iter_change_type(requested_status) @@ -801,7 +801,7 @@ def _only_last_commit( } else: return { - (Path(item.a_path), Path(item.b_path)) + (Path(item.a_path), Path(item.b_path)) # type: ignore for item in self.repo.commit("HEAD~1") .diff() .iter_change_type(requested_status) @@ -940,9 +940,9 @@ def handle_wrong_renamed_status( if staged_only: return { - Path(item.b_path) + Path(item.b_path) # type: ignore for item in self.repo.head.commit.diff().iter_change_type("R") - if item.score < 100 + if item.score < 100 # type: ignore and self._check_file_status( file_path=str(item.b_path), remote=remote, branch=branch ) @@ -965,11 +965,11 @@ def handle_wrong_renamed_status( # if remote does not exist we are checking against the commit sha1 return { - Path(item.b_path) + Path(item.b_path) # type: ignore for item in self.repo.commit(rev=branch) .diff(current_branch_or_hash) .iter_change_type("R") - if item.score < 100 + if item.score < 100 # type: ignore and self._check_file_status( file_path=str(item.b_path), remote=remote, branch=branch ) diff --git a/demisto_sdk/commands/test_content/test_modeling_rule/test_modeling_rule.py b/demisto_sdk/commands/test_content/test_modeling_rule/test_modeling_rule.py index d82fe5b540a..a6be2ebd2f3 100644 --- a/demisto_sdk/commands/test_content/test_modeling_rule/test_modeling_rule.py +++ b/demisto_sdk/commands/test_content/test_modeling_rule/test_modeling_rule.py @@ -19,10 +19,6 @@ from tabulate import tabulate from tenacity import ( Retrying, - before_sleep_log, - retry_if_exception_type, - stop_after_attempt, - wait_fixed, ) from typer.main import get_command_from_info @@ -34,7 +30,6 @@ ModelingRule, SingleModelingRule, ) -from demisto_sdk.commands.common.content_constant_paths import CONTENT_PATH from demisto_sdk.commands.common.handlers import DEFAULT_JSON_HANDLER as json from demisto_sdk.commands.common.logger import ( handle_deprecated_args, @@ -45,7 +40,6 @@ get_file, get_json_file, is_epoch_datetime, - parse_int_or_default, string_to_bool, ) from demisto_sdk.commands.test_content.ParallelLoggingManager import ( @@ -59,7 +53,20 @@ TIME_ZONE_WARNING, XQL_QUERY_ERROR_EXPLANATION, ) -from demisto_sdk.commands.test_content.tools import get_ui_url +from demisto_sdk.commands.test_content.tools import ( + XSIAM_CLIENT_RETRY_ATTEMPTS, + XSIAM_CLIENT_SLEEP_INTERVAL, + create_retrying_caller, + day_suffix, + duration_since_start_time, + get_relative_path_to_content, + get_type_pretty_name, + get_ui_url, + get_utc_now, + logs_token_cb, + tenant_config_cb, + xsiam_get_installed_packs, +) from demisto_sdk.commands.test_content.xsiam_tools.test_data import ( TestData, Validations, @@ -72,27 +79,9 @@ from demisto_sdk.utils.utils import get_containing_pack CI_PIPELINE_ID = os.environ.get("CI_PIPELINE_ID") -XSIAM_CLIENT_SLEEP_INTERVAL = 60 -XSIAM_CLIENT_RETRY_ATTEMPTS = 5 - -app = typer.Typer() - - -def get_utc_now() -> datetime: - """Get the current time in UTC, with timezone aware.""" - return datetime.now(tz=pytz.UTC) - -def duration_since_start_time(start_time: datetime) -> float: - """Get the duration since the given start time, in seconds. - Args: - start_time (datetime): Start time. - - Returns: - float: Duration since the given start time, in seconds. - """ - return (get_utc_now() - start_time).total_seconds() +app = typer.Typer() def create_table(expected: Dict[str, Any], received: Dict[str, Any]) -> str: @@ -114,39 +103,6 @@ def create_table(expected: Dict[str, Any], received: Dict[str, Any]) -> str: ) -def day_suffix(day: int) -> str: - """ - Returns a suffix string base on the day of the month. - for 1, 21, 31 => st - for 2, 22 => nd - for 3, 23 => rd - for to all the others => th - - see here for more details: https://en.wikipedia.org/wiki/English_numerals#Ordinal_numbers - - Args: - day: The day of the month represented by a number. - - Returns: - suffix string (st, nd, rd, th). - """ - return "th" if 11 <= day <= 13 else {1: "st", 2: "nd", 3: "rd"}.get(day % 10, "th") - - -def get_relative_path_to_content(path: Path) -> str: - """Get the relative path to the content directory. - - Args: - path: The path to the content item. - - Returns: - Path: The relative path to the content directory. - """ - if path.is_absolute() and path.as_posix().startswith(CONTENT_PATH.as_posix()): - return path.as_posix().replace(f"{CONTENT_PATH.as_posix()}{os.path.sep}", "") - return path.as_posix() - - def convert_epoch_time_to_string_time( epoch_time: int, with_ms: bool = False, tenant_timezone: str = "UTC" ) -> str: @@ -170,30 +126,6 @@ def convert_epoch_time_to_string_time( return datetime_object.strftime(time_format) -def get_type_pretty_name(obj: Any) -> str: - """Get the pretty name of the type of the given object. - - Args: - obj (Any): The object to get the type name for. - - Returns: - str: The pretty name of the type of the given object. - """ - return { - type(None): "null", - list: "list", - dict: "dict", - tuple: "tuple", - set: "set", - UUID: "UUID", - str: "string", - int: "int", - float: "float", - bool: "boolean", - datetime: "datetime", - }.get(type(obj), str(type(obj))) - - def sanitize_received_value_by_expected_type( received_value: Any, expected_type: str ) -> Tuple[str, Any]: @@ -218,20 +150,6 @@ def sanitize_received_value_by_expected_type( return received_value_type, received_value -def create_retrying_caller(retry_attempts: int, sleep_interval: int) -> Retrying: - """Create a Retrying object with the given retry_attempts and sleep_interval.""" - sleep_interval = parse_int_or_default(sleep_interval, XSIAM_CLIENT_SLEEP_INTERVAL) - retry_attempts = parse_int_or_default(retry_attempts, XSIAM_CLIENT_RETRY_ATTEMPTS) - retry_params: Dict[str, Any] = { - "reraise": True, - "before_sleep": before_sleep_log(logging.getLogger(), logging.DEBUG), - "retry": retry_if_exception_type(requests.exceptions.RequestException), - "stop": stop_after_attempt(retry_attempts), - "wait": wait_fixed(sleep_interval), - } - return Retrying(**retry_params) - - def xsiam_execute_query(xsiam_client: XsiamApiClient, query: str) -> List[dict]: """Execute an XQL query and return the results. Wrapper for XsiamApiClient.execute_query() with retry logic. @@ -249,13 +167,6 @@ def xsiam_push_to_dataset( return xsiam_client.push_to_dataset(events_test_data, rule.vendor, rule.product) -def xsiam_get_installed_packs(xsiam_client: XsiamApiClient) -> List[Dict[str, Any]]: - """Get the list of installed packs from the XSIAM tenant. - Wrapper for XsiamApiClient.get_installed_packs() with retry logic. - """ - return xsiam_client.installed_packs - - def verify_results( modeling_rule: ModelingRule, tested_dataset: str, @@ -1357,37 +1268,6 @@ def add_result_to_test_case( # ====================== test-modeling-rule ====================== # -def tenant_config_cb( - ctx: typer.Context, param: typer.CallbackParam, value: Optional[str] -): - if ctx.resilient_parsing: - return - # Only check the params if the machine_assignment is not set. - if param.value_is_missing(value) and not ctx.params.get("machine_assignment"): - err_str = ( - f"{param.name} must be set either via the environment variable " - f'"{param.envvar}" or passed explicitly when running the command' - ) - raise typer.BadParameter(err_str) - return value - - -def logs_token_cb(ctx: typer.Context, param: typer.CallbackParam, value: Optional[str]): - if ctx.resilient_parsing: - return - # Only check the params if the machine_assignment is not set. - if param.value_is_missing(value) and not ctx.params.get("machine_assignment"): - parameter_to_check = "xsiam_token" - other_token = ctx.params.get(parameter_to_check) - if not other_token: - err_str = ( - f"One of {param.name} or {parameter_to_check} must be set either via it's associated" - " environment variable or passed explicitly when running the command" - ) - raise typer.BadParameter(err_str) - return value - - class TestResults: def __init__( self, diff --git a/demisto_sdk/commands/test_content/test_modeling_rule/tests/test_modeling_rule_test.py b/demisto_sdk/commands/test_content/test_modeling_rule/tests/test_modeling_rule_test.py index 780aba20fac..726d1d46aab 100644 --- a/demisto_sdk/commands/test_content/test_modeling_rule/tests/test_modeling_rule_test.py +++ b/demisto_sdk/commands/test_content/test_modeling_rule/tests/test_modeling_rule_test.py @@ -165,47 +165,6 @@ def test_convert_epoch_time_to_string_time(epoc_time, with_ms, human_readable_ti ) -@pytest.mark.parametrize( - "day, suffix", - [ - (1, "st"), - (2, "nd"), - (3, "rd"), - (4, "th"), - (10, "th"), - (11, "th"), - (12, "th"), - (21, "st"), - (31, "st"), - ], -) -def test_day_suffix(day, suffix): - """ - Given: - - A day of a month. - case-1: 1 => st. - case-2: 2 => nd. - case-3: 3 => rd. - case-4: 4 => th. - case-5: 10 => th. - case-6: 11 => th. - case-7: 12 => th. - case-8: 21 => st. - case-9: 31 => st. - - When: - - The day_suffix function is running. - - Then: - - Verify we get the expected results. - """ - from demisto_sdk.commands.test_content.test_modeling_rule.test_modeling_rule import ( - day_suffix, - ) - - assert day_suffix(day) == suffix - - @pytest.mark.parametrize( "mr_text, expected_result", [ diff --git a/demisto_sdk/commands/test_content/test_use_case/template_file.py b/demisto_sdk/commands/test_content/test_use_case/template_file.py new file mode 100644 index 00000000000..a6e78a54390 --- /dev/null +++ b/demisto_sdk/commands/test_content/test_use_case/template_file.py @@ -0,0 +1,72 @@ +""" +{ + "additional_needed_packs": { + "PackOne": "instance_name1", + "PackTwo": "" + } +} +""" + +import pytest + +from demisto_sdk.commands.common.clients import ( + XsiamClient, + get_client_conf_from_pytest_request, + get_client_from_server_type, +) + +# Any additional imports your tests require + + +@pytest.fixture(scope="class") +def client_conf(request): + # Manually parse command-line arguments + return get_client_conf_from_pytest_request(request) + + +@pytest.fixture(scope="class") +def api_client(client_conf: dict): + if client_conf: # Running from external pipeline + client_obj = get_client_from_server_type(**client_conf) + + else: # Running manually using pytest. + client_obj = get_client_from_server_type() + return client_obj + + +class TestExample: + @classmethod + def setup_class(self): + """Run once for the class before *all* tests""" + pass + + def some_helper_function(self, method): + pass + + @classmethod + def teardown_class(self): + """Run once for the class after all tests""" + pass + + # PLAYBOOK X CHECKING VALID alert + def test_feature_one_manual_true(self, api_client: XsiamClient): + """Test feature one""" + a = api_client.list_indicators() + + assert a is not None, "list_indicators should not be None" + + def test_feature_two(self, api_client: XsiamClient): + """ + Given: Describe the given inputs or the given situation prior the use case. + When: Describe the use case + Then: Describe the desired outcome of the use case. + """ + # Test another aspect of your application + api_client.run_cli_command( + investigation_id="INCIDENT-1", command="!Set key=test value=A" + ) + assert False # replace with actual assertions for your application + + +if __name__ == "__main__": + pytest.main() diff --git a/demisto_sdk/commands/test_content/test_use_case/test_use_case.py b/demisto_sdk/commands/test_content/test_use_case/test_use_case.py new file mode 100644 index 00000000000..f7b2b2dd922 --- /dev/null +++ b/demisto_sdk/commands/test_content/test_use_case/test_use_case.py @@ -0,0 +1,539 @@ +import logging # noqa: TID251 # specific case, passed as argument to 3rd party +import os +import re +import shutil +import subprocess +from pathlib import Path +from threading import Thread +from typing import Any, List, Optional, Tuple, Union + +import demisto_client +import pytest +import typer +from google.cloud import storage # type: ignore[attr-defined] +from junitparser import JUnitXml, TestCase, TestSuite +from junitparser.junitparser import Failure, Skipped + +from demisto_sdk.commands.common.clients import get_client_from_server_type +from demisto_sdk.commands.common.clients.xsoar.xsoar_api_client import XsoarClient +from demisto_sdk.commands.common.constants import ( + TEST_USE_CASES, + XSIAM_SERVER_TYPE, +) +from demisto_sdk.commands.common.content_constant_paths import CONTENT_PATH +from demisto_sdk.commands.common.logger import ( + handle_deprecated_args, + logger, + logging_setup, +) +from demisto_sdk.commands.common.tools import ( + get_json_file, + get_pack_name, + string_to_bool, +) +from demisto_sdk.commands.test_content.ParallelLoggingManager import ( + ParallelLoggingManager, +) +from demisto_sdk.commands.test_content.tools import ( + duration_since_start_time, + get_relative_path_to_content, + get_ui_url, + get_utc_now, +) + +CI_PIPELINE_ID = os.environ.get("CI_PIPELINE_ID") + +app = typer.Typer() + + +def copy_conftest(test_dir): + """ + copy content's conftest.py file into the use case directory in order to be able to pass new custom + pytest argument (client_conf) + """ + source_conftest = Path(CONTENT_PATH) / "Tests/scripts/dev_envs/pytest/conftest.py" + dest_conftest = test_dir / "conftest.py" + + shutil.copy(source_conftest, dest_conftest) + + +def run_command(command): + """Run a shell command and capture the output.""" + try: + result = subprocess.run( + command, + shell=True, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + return result.stdout.decode("utf-8").strip() + except subprocess.CalledProcessError as e: + logging.error( + f"Error executing command: {e}\nCommand: {command}\nOutput: {e.output.decode('utf-8')}\nError: {e.stderr.decode('utf-8')}" + ) + return None + + +# ============================================== Classes ============================================ # +class TestResultCapture: + """ + This class is used to store the pytest results in test suite + """ + + def __init__(self, junit_testsuite): + self.junit_testsuite = junit_testsuite + + def pytest_runtest_logreport(self, report): + if report.when == "call": # Only capture results of test calls + test_case = TestCase(report.nodeid) + test_case.classname = report.location[0] # Test's module or class + test_case.time = report.duration # Add the test duration + if report.outcome == "passed": + self.junit_testsuite.add_testcase(test_case) + elif report.outcome == "failed": + error_text = self._sanitize_sensitive_data( + report.longreprtext if report.longrepr else "Test failed" + ) + failure = Failure(error_text) + test_case.result = failure + self.junit_testsuite.add_testcase(test_case) + elif report.outcome == "skipped": + skipped = Skipped("Test skipped") + test_case.result = skipped + self.junit_testsuite.add_testcase(test_case) + + def _sanitize_sensitive_data(self, text): + """ + Remove or redact sensitive data from the given text. + + Args: + text (str): The text to sanitize. + + Returns: + str: The sanitized text with sensitive data removed or redacted. + """ + + pattern = r"('Authorization':\s*')([^']+)(')" + # Replace the sensitive part with '[REDACTED]' + sanitized_text = re.sub(pattern, r"\1[REDACTED]\3", text) + + return sanitized_text + + +class TestResults: + def __init__( + self, + service_account: str = None, + artifacts_bucket: str = None, + ): + self.test_results_xml_file = JUnitXml() + self.errors = False + self.service_account = service_account + self.artifacts_bucket = artifacts_bucket + + def upload_result_json_to_bucket( + self, + repository_name: str, + file_name, + original_file_path: Path, + logging_module: Union[Any, ParallelLoggingManager] = logging, + ): + """Uploads a JSON object to a specified path in the GCP bucket. + + Args: + original_file_path: The path to the JSON file to upload. + repository_name: The name of the repository within the bucket. + file_name: The desired filename for the uploaded JSON data. + logging_module: Logging module to use for upload_result_json_to_bucket. + """ + logging_module.debug("Start uploading test use case results file to bucket") + + storage_client = storage.Client.from_service_account_json(self.service_account) + storage_bucket = storage_client.bucket(self.artifacts_bucket) + + blob = storage_bucket.blob( + f"content-test-use-case/{repository_name}/{file_name}" + ) + blob.upload_from_filename( + original_file_path.as_posix(), + content_type="application/xml", + ) + + logging_module.debug("Finished uploading test use case results file to bucket") + + +class BuildContext: + def __init__( + self, + nightly: bool, + build_number: Optional[str], + logging_module: ParallelLoggingManager, + cloud_servers_path: str, + cloud_servers_api_keys: str, + service_account: Optional[str], + artifacts_bucket: Optional[str], + cloud_url: Optional[str], + api_key: Optional[str], + auth_id: Optional[str], + inputs: Optional[List[Path]], + machine_assignment: str, + project_id: str, + ctx: typer.Context, + ): + self.logging_module: ParallelLoggingManager = logging_module + self.ctx = ctx + + # --------------------------- overall build configuration ------------------------------- + self.is_nightly = nightly + self.build_number = build_number + self.project_id = project_id + + # -------------------------- Manual run on a single instance -------------------------- + self.cloud_url = cloud_url + self.api_key = api_key + self.auth_id = auth_id + self.inputs = inputs + + # --------------------------- Machine preparation ------------------------------- + + self.cloud_servers_path_json = get_json_file(cloud_servers_path) + self.cloud_servers_api_keys_json = get_json_file(cloud_servers_api_keys) + self.machine_assignment_json = get_json_file(machine_assignment) + + # --------------------------- Testing preparation ------------------------------- + + self.tests_data_keeper = TestResults( + service_account, + artifacts_bucket, + ) + + # --------------------------- Machine preparation logic ------------------------------- + + self.servers = self.create_servers() + + def create_servers(self): + """ + Create servers object based on build type. + """ + # If cloud_url is provided we assume it's a run on a single server. + if self.cloud_url: + return [ + CloudServerContext( + self, + base_url=self.cloud_url, + api_key=self.api_key, # type: ignore[arg-type] + auth_id=self.auth_id, # type: ignore[arg-type] + ui_url=get_ui_url(self.cloud_url), + tests=[Path(test) for test in self.inputs] if self.inputs else [], + ) + ] + servers_list = [] + for machine, assignment in self.machine_assignment_json.items(): + tests = [ + Path(test) + for test in assignment.get("tests", {}).get(TEST_USE_CASES, []) + ] + if not tests: + logger.info(f"No test use cases found for machine {machine}") + continue + servers_list.append( + CloudServerContext( + self, + base_url=self.cloud_servers_path_json.get(machine, {}).get( + "base_url", "" + ), + ui_url=self.cloud_servers_path_json.get(machine, {}).get( + "ui_url", "" + ), + tests=tests, + api_key=self.cloud_servers_api_keys_json.get(machine, {}).get( + "api-key" + ), + auth_id=self.cloud_servers_api_keys_json.get(machine, {}).get( + "x-xdr-auth-id" + ), + ) + ) + return servers_list + + +class CloudServerContext: + def __init__( + self, + build_context: BuildContext, + base_url: str, + api_key: str, + auth_id: str, + ui_url: str, + tests: List[Path], + ): + self.build_context = build_context + self.client = None + self.base_url = base_url + self.api_key = api_key + self.auth_id = auth_id + os.environ.pop( + "DEMISTO_USERNAME", None + ) # we use client without demisto username + self.configure_new_client() + self.ui_url = ui_url + self.tests = tests + + def configure_new_client(self): + if self.client: + self.client.api_client.pool.close() + self.client.api_client.pool.terminate() + del self.client + self.client = demisto_client.configure( + base_url=self.base_url, + api_key=self.api_key, + auth_id=self.auth_id, + verify_ssl=False, + ) + + def execute_tests(self): + try: + self.build_context.logging_module.info( + f"Starts tests with server url - {get_ui_url(self.ui_url)}", + real_time=True, + ) + start_time = get_utc_now() + self.build_context.logging_module.info( + f"Running the following tests: {self.tests}", + real_time=True, + ) + + cloud_client = get_client_from_server_type( + base_url=self.base_url, api_key=self.api_key, auth_id=self.auth_id + ) + + for i, test_use_case_directory in enumerate(self.tests, start=1): + logger.info( + f"[{i}/{len(self.tests)}] test use cases: {get_relative_path_to_content(test_use_case_directory)}", + ) + + success, test_use_case_test_suite = run_test_use_case_pytest( + test_use_case_directory, + cloud_client=cloud_client, + project_id=self.build_context.project_id, + ) + + if success: + logger.info( + f"Test use case {get_relative_path_to_content(test_use_case_directory)} passed", + ) + else: + self.build_context.tests_data_keeper.errors = True + logger.error( + f"Test use case {get_relative_path_to_content(test_use_case_directory)} failed", + ) + if test_use_case_test_suite: + test_use_case_test_suite.add_property( + "start_time", + start_time, # type:ignore[arg-type] + ) + self.build_context.tests_data_keeper.test_results_xml_file.add_testsuite( + test_use_case_test_suite + ) + + self.build_context.logging_module.info( + f"Finished tests with server url - " f"{self.ui_url}", + real_time=True, + ) + duration = duration_since_start_time(start_time) + self.build_context.logging_module.info( + f"Finished tests with server url - {self.ui_url}, Took: {duration} seconds", + real_time=True, + ) + except Exception: + self.build_context.logging_module.exception("~~ Thread failed ~~") + self.build_context.tests_data_keeper.errors = True + finally: + self.build_context.logging_module.execute_logs() + + +# ============================================== Command logic ============================================ # + + +def run_test_use_case_pytest( + test_use_case_directory: Path, + cloud_client: XsoarClient, + durations: int = 5, + project_id: str = None, +) -> Tuple[bool, Union[TestSuite, None]]: + """Runs a test use case + + Args: + test_use_case_directory (Path): Path to the test use case directory. + durations (int): Number of slow tests to show durations for. + cloud_client (XsoarClient): The XSIAM client used to do API calls to the tenant. + """ + # Creating an instance of your results collector + test_use_case_suite = TestSuite("Test Use Case") + containing_pack = get_pack_name(test_use_case_directory) + + test_use_case_suite.add_property("file_name", str(test_use_case_directory.name)) + test_use_case_suite.add_property("pack_id", containing_pack) + if CI_PIPELINE_ID: + test_use_case_suite.add_property("ci_pipeline_id", CI_PIPELINE_ID) + + test_dir = test_use_case_directory.parent + copy_conftest(test_dir) + + logger.debug(f"before sending pytest {str(cloud_client.base_url)}") + pytest_args = [ + f"--client_conf=base_url={str(cloud_client.server_config.base_api_url)}," + f"api_key={cloud_client.server_config.api_key.get_secret_value()}," + f"auth_id={cloud_client.server_config.auth_id}," + f"project_id={project_id}", + str(test_use_case_directory), + f"--durations={str(durations)}", + "--log-cli-level=CRITICAL", + ] + + logger.info(f"Running pytest for file {test_use_case_directory}") + + # Running pytest + result_capture = TestResultCapture(test_use_case_suite) + status_code = pytest.main(pytest_args, plugins=[result_capture]) + + if status_code == pytest.ExitCode.OK: + logger.info( + f"Pytest run tests in {test_use_case_directory} successfully" + ) + return True, test_use_case_suite + elif status_code == pytest.ExitCode.TESTS_FAILED: + logger.error( + f"Pytest failed with statsu {status_code}", + ) + return False, test_use_case_suite + else: + raise Exception(f"Pytest failed with {status_code=}") + + +def run_test_use_case( + ctx: typer.Context, + inputs: List[Path], + xsiam_url: Optional[str], + api_key: Optional[str], + auth_id: Optional[str], + output_junit_file: Optional[Path], + service_account: Optional[str], + cloud_servers_path: str, + cloud_servers_api_keys: str, + machine_assignment: str, + build_number: str, + nightly: str, + artifacts_bucket: str, + project_id: str, + console_log_threshold: str, + file_log_threshold: str, + log_file_path: Optional[str], + **kwargs, +): + """ + Test a test use case against an XSIAM tenant + """ + logging_setup( + console_threshold=console_log_threshold, # type: ignore[arg-type] + file_threshold=file_log_threshold, # type: ignore[arg-type] + path=log_file_path, + calling_function=__name__, + ) + handle_deprecated_args(ctx.args) + + logging_module = ParallelLoggingManager( + "test_use_case.log", real_time_logs_only=not nightly + ) + + if machine_assignment: + if inputs: + logger.error( + "You cannot pass both machine_assignment and inputs arguments." + ) + raise typer.Exit(1) + if xsiam_url: + logger.error( + "You cannot pass both machine_assignment and xsiam_url arguments." + ) + raise typer.Exit(1) + + start_time = get_utc_now() + is_nightly = string_to_bool(nightly) + build_context = BuildContext( + nightly=is_nightly, + build_number=build_number, + logging_module=logging_module, + cloud_servers_path=cloud_servers_path, + cloud_servers_api_keys=cloud_servers_api_keys, + service_account=service_account, + artifacts_bucket=artifacts_bucket, + machine_assignment=machine_assignment, + ctx=ctx, + cloud_url=xsiam_url, + api_key=api_key, + auth_id=auth_id, + inputs=inputs, + project_id=project_id, + ) + + logging_module.debug( + "test use cases to test:", + ) + + for build_context_server in build_context.servers: + for test_use_case_directory in build_context_server.tests: + logging_module.info( + f"\tmachine:{build_context_server.base_url} - " + f"{get_relative_path_to_content(test_use_case_directory)}" + ) + + threads_list = [] + for index, server in enumerate(build_context.servers, start=1): + thread_name = f"Thread-{index} (execute_tests)" + threads_list.append(Thread(target=server.execute_tests, name=thread_name)) + + logging_module.info("Finished creating configurations, starting to run tests.") + for thread in threads_list: + thread.start() + + for t in threads_list: + t.join() + + logging_module.info("Finished running tests.") + + if output_junit_file: + logger.info( + f"Writing JUnit XML to {get_relative_path_to_content(output_junit_file)}", + ) + build_context.tests_data_keeper.test_results_xml_file.write( + output_junit_file.as_posix(), pretty=True + ) + if nightly: + if service_account and artifacts_bucket: + build_context.tests_data_keeper.upload_result_json_to_bucket( + XSIAM_SERVER_TYPE, + f"test_use_case_{build_number}.xml", + output_junit_file, + logging_module, + ) + else: + logger.warning( + "Service account or artifacts bucket not provided, skipping uploading JUnit XML to bucket", + ) + else: + logger.info( + "No JUnit XML file path was passed - skipping writing JUnit XML", + ) + + duration = duration_since_start_time(start_time) + if build_context.tests_data_keeper.errors: + logger.error( + f"Test use case: Failed, took:{duration} seconds", + ) + raise typer.Exit(1) + + logger.success( + f"Test use case: Passed, took:{duration} seconds", + ) diff --git a/demisto_sdk/commands/test_content/test_use_case/test_use_case_setup.py b/demisto_sdk/commands/test_content/test_use_case/test_use_case_setup.py new file mode 100644 index 00000000000..87466fb8fee --- /dev/null +++ b/demisto_sdk/commands/test_content/test_use_case/test_use_case_setup.py @@ -0,0 +1,128 @@ +from pathlib import Path +from typing import List, Optional + +import typer + +from demisto_sdk.commands.common.logger import logging_setup_decorator +from demisto_sdk.commands.test_content.tools import tenant_config_cb + + +@logging_setup_decorator +def run_test_use_case( + ctx: typer.Context, + inputs: List[Path] = typer.Argument( + None, + exists=True, + dir_okay=True, + resolve_path=True, + show_default=False, + help="The path to a directory of a test use cases. May pass multiple paths to test multiple test use cases.", + ), + xsiam_url: Optional[str] = typer.Option( + None, + envvar="DEMISTO_BASE_URL", + help="The base url to the cloud tenant.", + rich_help_panel="Cloud Tenant Configuration", + show_default=False, + callback=tenant_config_cb, + ), + api_key: Optional[str] = typer.Option( + None, + envvar="DEMISTO_API_KEY", + help="The api key for the cloud tenant.", + rich_help_panel="XSIAM Tenant Configuration", + show_default=False, + callback=tenant_config_cb, + ), + auth_id: Optional[str] = typer.Option( + None, + envvar="XSIAM_AUTH_ID", + help="The auth id associated with the cloud api key being used.", + rich_help_panel="XSIAM Tenant Configuration", + show_default=False, + callback=tenant_config_cb, + ), + output_junit_file: Optional[Path] = typer.Option( + None, "-jp", "--junit-path", help="Path to the output JUnit XML file." + ), + service_account: Optional[str] = typer.Option( + None, + "-sa", + "--service_account", + envvar="GCP_SERVICE_ACCOUNT", + help="GCP service account.", + show_default=False, + ), + cloud_servers_path: str = typer.Option( + "", + "-csp", + "--cloud_servers_path", + help="Path to secret cloud server metadata file.", + show_default=False, + ), + cloud_servers_api_keys: str = typer.Option( + "", + "-csak", + "--cloud_servers_api_keys", + help="Path to file with cloud Servers api keys.", + show_default=False, + ), + machine_assignment: str = typer.Option( + "", + "-ma", + "--machine_assignment", + help="the path to the machine assignment file.", + show_default=False, + ), + build_number: str = typer.Option( + "", + "-bn", + "--build_number", + help="The build number.", + show_default=True, + ), + nightly: str = typer.Option( + "false", + "--nightly", + "-n", + help="Whether the command is being run in nightly mode.", + ), + artifacts_bucket: str = typer.Option( + None, + "-ab", + "--artifacts_bucket", + help="The artifacts bucket name to upload the results to", + show_default=False, + ), + project_id: str = typer.Option( + None, + "-pi", + "--project_id", + help="The machine project ID", + show_default=False, + ), + console_log_threshold: str = typer.Option( + "INFO", + "-clt", + "--console-log-threshold", + help="Minimum logging threshold for the console logger.", + ), + file_log_threshold: str = typer.Option( + "DEBUG", + "-flt", + "--file-log-threshold", + help="Minimum logging threshold for the file logger.", + ), + log_file_path: Optional[str] = typer.Option( + None, + "-lp", + "--log-file-path", + help="Path to save log files onto.", + ), +): + from demisto_sdk.commands.test_content.test_use_case.test_use_case import ( + run_test_use_case, + ) + + kwargs = locals() + run_test_use_case(**kwargs) diff --git a/demisto_sdk/commands/test_content/test_use_case/tests/test_use_case_test.py b/demisto_sdk/commands/test_content/test_use_case/tests/test_use_case_test.py new file mode 100644 index 00000000000..1a2e23d4068 --- /dev/null +++ b/demisto_sdk/commands/test_content/test_use_case/tests/test_use_case_test.py @@ -0,0 +1,151 @@ +import pytest +from junitparser import TestSuite +from pytest import ExitCode + +import demisto_sdk.commands.test_content.test_use_case.test_use_case as test_use_case +from demisto_sdk.commands.test_content.test_use_case.test_use_case import ( + run_test_use_case_pytest, +) + + +# Mock the dependencies +@pytest.fixture +def mocker_cloud_client(mocker): + # Mock the XsoarClient + cloud_client = mocker.Mock() + cloud_client.server_config.base_api_url = "https://example.com" + cloud_client.server_config.api_key.get_secret_value.return_value = "API_KEY" + cloud_client.server_config.auth_id = "AUTH_ID" + return cloud_client + + +@pytest.fixture +def mocker_test_use_case_directory(mocker): + # Mock the test_use_case_directory + return mocker.Mock() + + +def test_run_test_use_case_pytest( + mocker, mocker_cloud_client, mocker_test_use_case_directory +): + """ + Given: parameters for running the tests. + When: running the test_use_case command. + Then: validate the correct params are used when running the pytest method. + """ + test_result_mocker = mocker.Mock() + mocker.patch.object(test_use_case, "get_pack_name", return_value="/path/to/pack") + mocker.patch.object(test_use_case, "copy_conftest") + mocker.patch.object(test_use_case, "logger") + mocker.patch.object( + test_use_case, "TestResultCapture", return_value=test_result_mocker + ) + mocker.patch("pytest.main", return_value=ExitCode.OK) + + # Call the function to be tested + result, test_use_case_suite = run_test_use_case_pytest( + mocker_test_use_case_directory, mocker_cloud_client, durations=5 + ) + + # Verify the expected behavior and assertions + assert result is True + assert isinstance(test_use_case_suite, TestSuite) + + # Additional assertions for the mocked dependencies + pytest.main.assert_called_once_with( + [ + "--client_conf=base_url=https://example.com," + "api_key=API_KEY," + "auth_id=AUTH_ID," + "project_id=None", + str(mocker_test_use_case_directory), + "--durations=5", + "--log-cli-level=CRITICAL", + ], + plugins=[test_result_mocker], + ) + mocker_cloud_client.server_config.api_key.get_secret_value.assert_called_once() + + +def test_pytest_runtest_logreport_passed(mocker): + """ + When: pytest_runtest_logreport is called with a passing test, + Given: a TestResultCapture instance and a passing report, + Then: Validate the correct testcase is appended the test suite. + + + """ + junit_testsuite = TestSuite("Test Suite") + test_result_capture = test_use_case.TestResultCapture(junit_testsuite) + + report = mocker.Mock() + report.when = "call" + report.nodeid = "test_module.test_function" + report.location = ("test_module",) + report.duration = 0.5 + report.outcome = "passed" + + test_result_capture.pytest_runtest_logreport(report) + + assert len(junit_testsuite) == 1 + + for testcase in junit_testsuite: + assert testcase.name == "test_module.test_function" + assert testcase.classname == "test_module" + assert testcase.time == 0.5 + assert len(testcase.result) == 0 + + +def test_pytest_runtest_logreport_failed(mocker): + """ + When: pytest_runtest_logreport is called with a failing test, + Given: a TestResultCapture instance and a failing report, + Then: Validate the correct testcase is appended the test suite. + """ + junit_testsuite = TestSuite("Test Suite") + test_result_capture = test_use_case.TestResultCapture(junit_testsuite) + + report = mocker.Mock() + report.when = "call" + report.nodeid = "test_module.test_function" + report.location = ("test_module",) + report.duration = 0.5 + report.outcome = "failed" + report.longreprtext = "AssertionError: Expected 1, but got 2" + + test_result_capture.pytest_runtest_logreport(report) + + assert len(junit_testsuite) == 1 + + for testcase in junit_testsuite: + assert testcase.name == "test_module.test_function" + assert testcase.classname == "test_module" + assert testcase.time == 0.5 + assert testcase.result[0].message == "AssertionError: Expected 1, but got 2" + + +def test_pytest_runtest_logreport_skipped(mocker): + """ + When: pytest_runtest_logreport is called with a skipped test, + Given: a TestResultCapture instance and a skipped report, + Then: Validate the correct testcase is appended the test suite. + """ + junit_testsuite = TestSuite("Test Suite") + test_result_capture = test_use_case.TestResultCapture(junit_testsuite) + + report = mocker.Mock() + report.when = "call" + report.nodeid = "test_module.test_function" + report.location = ("test_module",) + report.duration = 0.5 + report.outcome = "skipped" + + test_result_capture.pytest_runtest_logreport(report) + + assert len(junit_testsuite) == 1 + + for testcase in junit_testsuite: + assert testcase.name == "test_module.test_function" + assert testcase.classname == "test_module" + assert testcase.time == 0.5 + assert testcase.result[0].message == "Test skipped" diff --git a/demisto_sdk/commands/test_content/tests/test_tools.py b/demisto_sdk/commands/test_content/tests/test_tools.py index 6108c7546a6..b54fe3a7352 100644 --- a/demisto_sdk/commands/test_content/tests/test_tools.py +++ b/demisto_sdk/commands/test_content/tests/test_tools.py @@ -1,5 +1,7 @@ from subprocess import CalledProcessError +import pytest + from demisto_sdk.commands.test_content.constants import SSH_USER from demisto_sdk.commands.test_content.tools import is_redhat_instance @@ -20,3 +22,44 @@ def test_is_redhat_instance_positive(mocker): def test_is_redhat_instance_negative(mocker): mocker.patch("subprocess.check_output", side_effect=raise_exception) assert not is_redhat_instance("instance_ip") + + +@pytest.mark.parametrize( + "day, suffix", + [ + (1, "st"), + (2, "nd"), + (3, "rd"), + (4, "th"), + (10, "th"), + (11, "th"), + (12, "th"), + (21, "st"), + (31, "st"), + ], +) +def test_day_suffix(day, suffix): + """ + Given: + - A day of a month. + case-1: 1 => st. + case-2: 2 => nd. + case-3: 3 => rd. + case-4: 4 => th. + case-5: 10 => th. + case-6: 11 => th. + case-7: 12 => th. + case-8: 21 => st. + case-9: 31 => st. + + When: + - The day_suffix function is running. + + Then: + - Verify we get the expected results. + """ + from demisto_sdk.commands.test_content.tools import ( + day_suffix, + ) + + assert day_suffix(day) == suffix diff --git a/demisto_sdk/commands/test_content/tools.py b/demisto_sdk/commands/test_content/tools.py index ad9266bfc1c..9d536bd52ec 100644 --- a/demisto_sdk/commands/test_content/tools.py +++ b/demisto_sdk/commands/test_content/tools.py @@ -1,13 +1,34 @@ import ast +import logging # noqa: TID251 # specific case, passed as argument to 3rd party +import os from copy import deepcopy +from datetime import datetime +from pathlib import Path from pprint import pformat from subprocess import STDOUT, CalledProcessError, check_output -from typing import Dict, Optional, Set +from typing import Any, Dict, List, Optional, Set +from uuid import UUID import demisto_client - +import pytz +import requests +import typer +from tenacity import ( + Retrying, + before_sleep_log, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, +) + +from demisto_sdk.commands.common.content_constant_paths import CONTENT_PATH from demisto_sdk.commands.common.logger import logger +from demisto_sdk.commands.common.tools import parse_int_or_default from demisto_sdk.commands.test_content.constants import SSH_USER +from demisto_sdk.commands.test_content.xsiam_tools.xsiam_client import XsiamApiClient + +XSIAM_CLIENT_SLEEP_INTERVAL = 60 +XSIAM_CLIENT_RETRY_ATTEMPTS = 5 def update_server_configuration( @@ -101,3 +122,132 @@ def get_ui_url(client_host): """ return client_host.replace("https://api-", "https://") + + +# ================= Methods and Classes used in modeling rules and playbook flow commands ================= # + + +def get_utc_now() -> datetime: + """Get the current time in UTC, with timezone aware.""" + return datetime.now(tz=pytz.UTC) + + +def duration_since_start_time(start_time: datetime) -> float: + """Get the duration since the given start time, in seconds. + + Args: + start_time (datetime): Start time. + + Returns: + float: Duration since the given start time, in seconds. + """ + return (get_utc_now() - start_time).total_seconds() + + +def day_suffix(day: int) -> str: + """ + Returns a suffix string base on the day of the month. + for 1, 21, 31 => st + for 2, 22 => nd + for 3, 23 => rd + for to all the others => th + + see here for more details: https://en.wikipedia.org/wiki/English_numerals#Ordinal_numbers + + Args: + day: The day of the month represented by a number. + + Returns: + suffix string (st, nd, rd, th). + """ + return "th" if 11 <= day <= 13 else {1: "st", 2: "nd", 3: "rd"}.get(day % 10, "th") + + +def get_relative_path_to_content(path: Path) -> str: + """Get the relative path to the content directory. + + Args: + path: The path to the content item. + + Returns: + Path: The relative path to the content directory. + """ + if path.is_absolute() and path.as_posix().startswith(CONTENT_PATH.as_posix()): + return path.as_posix().replace(f"{CONTENT_PATH.as_posix()}{os.path.sep}", "") + return path.as_posix() + + +def get_type_pretty_name(obj: Any) -> str: + """Get the pretty name of the type of the given object. + + Args: + obj (Any): The object to get the type name for. + + Returns: + str: The pretty name of the type of the given object. + """ + return { + type(None): "null", + list: "list", + dict: "dict", + tuple: "tuple", + set: "set", + UUID: "UUID", + str: "string", + int: "int", + float: "float", + bool: "boolean", + datetime: "datetime", + }.get(type(obj), str(type(obj))) + + +def create_retrying_caller(retry_attempts: int, sleep_interval: int) -> Retrying: + """Create a Retrying object with the given retry_attempts and sleep_interval.""" + sleep_interval = parse_int_or_default(sleep_interval, XSIAM_CLIENT_SLEEP_INTERVAL) + retry_attempts = parse_int_or_default(retry_attempts, XSIAM_CLIENT_RETRY_ATTEMPTS) + retry_params: Dict[str, Any] = { + "reraise": True, + "before_sleep": before_sleep_log(logging.getLogger(), logging.DEBUG), + "retry": retry_if_exception_type(requests.exceptions.RequestException), + "stop": stop_after_attempt(retry_attempts), + "wait": wait_fixed(sleep_interval), + } + return Retrying(**retry_params) + + +def xsiam_get_installed_packs(xsiam_client: XsiamApiClient) -> List[Dict[str, Any]]: + """Get the list of installed packs from the XSIAM tenant. + Wrapper for XsiamApiClient.get_installed_packs() with retry logic. + """ + return xsiam_client.installed_packs + + +def tenant_config_cb( + ctx: typer.Context, param: typer.CallbackParam, value: Optional[str] +): + if ctx.resilient_parsing: + return + # Only check the params if the machine_assignment is not set. + if param.value_is_missing(value) and not ctx.params.get("machine_assignment"): + err_str = ( + f"{param.name} must be set either via the environment variable " + f'"{param.envvar}" or passed explicitly when running the command' + ) + raise typer.BadParameter(err_str) + return value + + +def logs_token_cb(ctx: typer.Context, param: typer.CallbackParam, value: Optional[str]): + if ctx.resilient_parsing: + return + # Only check the params if the machine_assignment is not set. + if param.value_is_missing(value) and not ctx.params.get("machine_assignment"): + parameter_to_check = "xsiam_token" + other_token = ctx.params.get(parameter_to_check) + if not other_token: + err_str = ( + f"One of {param.name} or {parameter_to_check} must be set either via it's associated" + " environment variable or passed explicitly when running the command" + ) + raise typer.BadParameter(err_str) + return value