diff --git a/tools/scripts/codegen/legacy_c2j_cpp_gen.py b/tools/scripts/codegen/legacy_c2j_cpp_gen.py index a21ef8f7012..50bfa9aa3ae 100644 --- a/tools/scripts/codegen/legacy_c2j_cpp_gen.py +++ b/tools/scripts/codegen/legacy_c2j_cpp_gen.py @@ -172,7 +172,7 @@ def _init_common_java_cli(self, run_command += ["--service", service_name] run_command += ["--outputfile", output_filename] - if service_name in SMITHY_SUPPORTED_CLIENTS: + if service_name in SMITHY_SUPPORTED_CLIENTS or model_files.use_smithy: run_command += ["--use-smithy-client"] for key, val in kwargs.items(): diff --git a/tools/scripts/codegen/model_utils.py b/tools/scripts/codegen/model_utils.py index ccc285f5c7b..e8cdf8e0771 100644 --- a/tools/scripts/codegen/model_utils.py +++ b/tools/scripts/codegen/model_utils.py @@ -7,6 +7,7 @@ A set of utils to go through c2j models and their corresponding endpoint rules """ import datetime +import json import os import re @@ -18,6 +19,29 @@ "transcribe-streaming": "transcribestreaming", "streams.dynamodb": "dynamodbstreams"} +SMITHY_EXCLUSION_CLIENTS = { + # multi auth + "eventbridge" + , "cloudfront-keyvaluestore" + , "cognito-identity" + , "cognito-idp" + # customization + , "machinelearning" + , "apigatewayv2" + , "apigateway" + , "eventbridge" + , "glacier" + , "lambda" + , "polly" + , "sqs" + # bearer token + # ,"codecatalyst" + # bidirectional streaming + , "lexv2-runtime" + , "qbusiness" + , "transcribestreaming" +} + # Regexp to parse C2J model filename to extract service name and date version SERVICE_MODEL_FILENAME_PATTERN = re.compile( "^" @@ -29,12 +53,13 @@ class ServiceModel(object): # A helper class to store C2j model info and metadata (endpoint rules and tests) - def __init__(self, service_id, c2j_model, endpoint_rule_set, endpoint_tests): + def __init__(self, service_id: str, c2j_model: str, endpoint_rule_set: str, endpoint_tests: str, use_smithy: bool): self.service_id = service_id # For debugging purposes, not used atm # only filenames, no filesystem path self.c2j_model = c2j_model self.endpoint_rule_set = endpoint_rule_set self.endpoint_tests = endpoint_tests + self.use_smithy = use_smithy class ModelUtils(object): @@ -113,7 +138,8 @@ def _collect_available_models(models_dir: str, endpoint_rules_dir: str) -> dict: # fetch endpoint-rules filename which is based on ServiceId in c2j models: try: - service_name_to_model_filename[key] = ModelUtils._build_service_model(endpoint_rules_dir, + service_name_to_model_filename[key] = ModelUtils._build_service_model(models_dir, + endpoint_rules_dir, model_file_date[0]) if key == "s3": @@ -125,7 +151,8 @@ def _collect_available_models(models_dir: str, endpoint_rules_dir: str) -> dict: service_name_to_model_filename[key] = ServiceModel(service_id=key, c2j_model=model_file_date[0], endpoint_rule_set=None, - endpoint_tests=None) + endpoint_tests=None, + use_smithy=False) if missing: # TODO: re-enable with endpoints introduction # print(f"Missing endpoints for services: {missing}") @@ -137,7 +164,25 @@ def _collect_available_models(models_dir: str, endpoint_rules_dir: str) -> dict: return service_name_to_model_filename @staticmethod - def _build_service_model(endpoint_rules_dir: str, c2j_model_filename) -> ServiceModel: + def is_smithy_enabled(service_id, models_dir, c2j_model_filename): + """Return true if given service id and c2j model file should enable smithy client generation path + + :param service_id: + :param models_dir: + :param c2j_model_filename: + :return: + """ + use_smithy = False + if service_id not in SMITHY_EXCLUSION_CLIENTS: + with open(models_dir + "/" + c2j_model_filename, 'r') as json_file: + model = json.load(json_file) + model_protocol = model.get("metadata", dict()).get("protocol", "UNKNOWN_PROTOCOL") + if model_protocol in {"json", "rest-json"}: + use_smithy = True + return use_smithy + + @staticmethod + def _build_service_model(models_dir: str, endpoint_rules_dir: str, c2j_model_filename) -> ServiceModel: """Return a ServiceModel containing paths to the Service models: C2J model and endpoints (rules and tests). :param models_dir (str): filepath (absolute or relative) to the dir with c2j models @@ -153,8 +198,11 @@ def _build_service_model(endpoint_rules_dir: str, c2j_model_filename) -> Service match = SERVICE_MODEL_FILENAME_PATTERN.match(c2j_model_filename) service_id = match.group("service") + use_smithy = ModelUtils._is_smithy_enabled(service_id, models_dir, c2j_model_filename) + if os.path.exists(endpoint_rules_filepath) and os.path.exists(endpoint_tests_filepath): return ServiceModel(service_id=service_id, c2j_model=c2j_model_filename, endpoint_rule_set=endpoint_rules_filename, - endpoint_tests=endpoint_tests_filename) + endpoint_tests=endpoint_tests_filename, + use_smithy=use_smithy) diff --git a/tools/scripts/codegen/protocol_tests_gen.py b/tools/scripts/codegen/protocol_tests_gen.py index 8ef5076a915..e915cfb351c 100644 --- a/tools/scripts/codegen/protocol_tests_gen.py +++ b/tools/scripts/codegen/protocol_tests_gen.py @@ -12,7 +12,7 @@ from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED, ALL_COMPLETED from codegen.legacy_c2j_cpp_gen import LegacyC2jCppGen -from codegen.model_utils import SERVICE_MODEL_FILENAME_PATTERN, ServiceModel +from codegen.model_utils import SERVICE_MODEL_FILENAME_PATTERN, ServiceModel, ModelUtils PROTOCOL_TESTS_BASE_DIR = "tools/code-generation/protocol-tests" PROTOCOL_TESTS_CLIENT_MODELS = PROTOCOL_TESTS_BASE_DIR + "/api-descriptions" @@ -112,8 +112,9 @@ def _collect_test_client_models(self) -> dict: if service_model_name in UNSUPPORTED_CLIENTS: continue + use_smithy = ModelUtils.is_smithy_enabled(service_model_name, self.client_models_dir, filename) service_models[service_model_name] = ServiceModel(service_model_name, filename, - PROTOCOL_TESTS_ENDPOINT_RULES, None) + PROTOCOL_TESTS_ENDPOINT_RULES, None, use_smithy) return service_models def _generate_tests(self):