diff --git a/HeterogeneousCore/SonicTriton/README.md b/HeterogeneousCore/SonicTriton/README.md index 314b5d4d15986..7eed0f67989fa 100644 --- a/HeterogeneousCore/SonicTriton/README.md +++ b/HeterogeneousCore/SonicTriton/README.md @@ -124,14 +124,18 @@ In a SONIC Triton producer, the basic flow should follow this pattern: ## Services +### `cmsTriton` + A script [`cmsTriton`](./scripts/cmsTriton) is provided to launch and manage local servers. -The script has two operations (`start` and `stop`) and the following options: +The script has three operations (`start`, `stop`, `check`) and the following options: * `-c`: don't cleanup temporary dir (for debugging) +* `-C [dir]`: directory containing Nvidia compatibility drivers (checks CMSSW_BASE by default if available) * `-D`: dry run: print container commands rather than executing them * `-d`: use Docker instead of Apptainer * `-f`: force reuse of (possibly) existing container instance * `-g`: use GPU instead of CPU * `-i` [name]`: server image name (default: fastml/triton-torchgeo:22.07-py3-geometric) +* `-I [num]`: number of model instances (default: 0 -> means no local editing of config files) * `-M [dir]`: model repository (can be given more than once) * `-m [dir]`: specific model directory (can be given more than one) * `-n [name]`: name of container instance, also used for hidden temporary dir (default: triton_server_instance) @@ -148,6 +152,7 @@ Additional details and caveats: * The `start` and `stop` operations for a given container instance should always be executed in the same directory if a relative path is used for the hidden temporary directory (including the default from the container instance name), in order to ensure that everything is properly cleaned up. +* The `check` operation just checks if the server can run on the current system, based on driver compatibility. * A model repository is a folder that contains multiple model directories, while a model directory contains the files for a specific file. (In the example below, `$CMSSW_BASE/src/HeterogeneousCore/SonicTriton/data/models` is a model repository, while `$CMSSW_BASE/src/HeterogeneousCore/SonicTriton/data/models/resnet50_netdef` is a model directory.) @@ -155,6 +160,24 @@ If a model repository is provided, all of the models it contains will be provide * Older versions of Apptainer (Singularity) have a short timeout that may cause launching the server to fail the first time the command is executed. The `-r` (retry) flag exists to work around this issue. +### `cmsTritonConfigTool` + +The `config.pbtxt` files used for model configuration are written in the protobuf text format. +To ease modification of these files, a dedicated Python tool [`cmsTritonConfigTool`](./scripts/cmsTritonConfigTool) is provided. +The tool has several modes of operation (each with its own options, which can be viewed using `--help`): +* `schema`: displays all field names and types for the Triton ModelConfig message class. +* `view`: displays the field values from a provided `config.pbtxt` file. +* `edit`: allows changing any field value in a `config.pbtxt` file. Non-primitive types are specified using JSON format. +* `checksum`: checks and updates checksums for model files (to enforce versioning). +* `versioncheck`: checks and updates checksums for all `config.pbtxt` files in `$CMSSW_SEARCH_PATH`. +* `threadcontrol`: adds job- and ML framework-specific thread control settings. + +The `edit` mode is intended for generic modifications, and only supports overwriting existing values +(not modifying, removing, deleting, etc.). +Additional dedicated modes, like `checksum` and `threadcontrol`, can easily be added for more complicated tasks. + +### `TritonService` + A central `TritonService` is provided to keep track of all available servers and which models they can serve. The servers will automatically be assigned to clients at startup. If some models are not served by any server, the `TritonService` can launch a fallback server using the `cmsTriton` script described above. diff --git a/HeterogeneousCore/SonicTriton/interface/triton_utils.h b/HeterogeneousCore/SonicTriton/interface/triton_utils.h index 159da808edcab..d6c7612a5159c 100644 --- a/HeterogeneousCore/SonicTriton/interface/triton_utils.h +++ b/HeterogeneousCore/SonicTriton/interface/triton_utils.h @@ -83,6 +83,7 @@ extern template std::string triton_utils::printColl(const edm::Span& coll, const std::string& delim); extern template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); +extern template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); extern template std::string triton_utils::printColl(const std::unordered_set& coll, const std::string& delim); diff --git a/HeterogeneousCore/SonicTriton/scripts/cmsTriton b/HeterogeneousCore/SonicTriton/scripts/cmsTriton index addbfb2c247c7..9c84be2b62616 100755 --- a/HeterogeneousCore/SonicTriton/scripts/cmsTriton +++ b/HeterogeneousCore/SonicTriton/scripts/cmsTriton @@ -34,7 +34,7 @@ get_sandbox(){ usage() { ECHO="echo -e" - $ECHO "cmsTriton [options] [start|stop]" + $ECHO "cmsTriton [options] [start|stop|check]" $ECHO $ECHO "Options:" $ECHO "-c \t don't cleanup temporary dir (for debugging)" @@ -338,57 +338,6 @@ wait_server(){ echo "server is ready!" } -edit_model(){ - MODELNAME=$1 - NUMINSTANCES=$2 - - cp -r $MODELNAME $TMPDIR/$LOCALMODELREPO/ - COPY_EXIT=$? - if [ "$COPY_EXIT" -ne 0 ]; then - echo "Could not copy $MODELNAME into $TMPDIR/$LOCALMODELREPO/" - exit "$COPY_EXIT" - fi - IFS='/' read -ra ADDR <<< "$MODELNAME" - CONFIG=$TMPDIR/$LOCALMODELREPO/${ADDR[-1]}/config.pbtxt - - PLATFORM=$(grep -m 1 "^platform:" "$CONFIG") - - if [[ $PLATFORM == *"ensemble"* ]]; then - #recurse over submodels of ensemble model - MODELLOC=$(echo ""${ADDR[@]:0:${#ADDR[@]}-1} | sed "s/ /\//g") - SUBNAME=$(grep "model_name:" "$CONFIG" | sed 's/model_name://; s/"//g') - for SUBMODEL in ${SUBNAME}; do - SUBMODEL=${MODELLOC}/${SUBMODEL} - edit_model $SUBMODEL "$INSTANCES" - done - else - #This is not an ensemble model, so we should edit the config file - cat <> $CONFIG -instance_group [ - { - count: $NUMINSTANCES - kind: KIND_CPU - } -] - -EOF - if [[ $PLATFORM == *"onnx"* ]]; then - cat <> $CONFIG -parameters { key: "intra_op_thread_count" value: { string_value: "1" } } -parameters { key: "inter_op_thread_count" value: { string_value: "1" } } -EOF - elif [[ $PLATFORM == *"tensorflow"* ]]; then - cat <> $CONFIG -parameters { key: "TF_NUM_INTRA_THREADS" value: { string_value: "1" } } -parameters { key: "TF_NUM_INTER_THREADS" value: { string_value: "1" } } -parameters { key: "TF_USE_PER_SESSION_THREADS" value: { string_value: "1" } } -EOF - else - echo "Warning: thread (instance) control not implemented for $PLATFORM" - fi - fi -} - list_models(){ # make list of model repositories LOCALMODELREPO="local_model_repo" @@ -411,7 +360,12 @@ list_models(){ MODEL="$(dirname "$MODEL")" fi if [ "$INSTANCES" -gt 0 ]; then - edit_model $MODEL "$INSTANCES" + $DRYRUN cmsTritonConfigTool threadcontrol -c ${MODEL}/config.pbtxt --copy $TMPDIR/$LOCALMODELREPO --nThreads $INSTANCES + TOOL_EXIT=$? + if [ "$TOOL_EXIT" -ne 0 ]; then + echo "Could not apply threadcontrol to $MODEL" + exit "$TOOL_EXIT" + fi else REPOS+=("$(dirname "$MODEL")") fi diff --git a/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool new file mode 100755 index 0000000000000..9a1ef54c57da6 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 + +import os, sys, json, pathlib, shutil +from collections import OrderedDict +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, RawDescriptionHelpFormatter, Action, Namespace +from enum import Enum +from google.protobuf import text_format, json_format, message, descriptor +from google.protobuf.internal import type_checkers +from tritonclient import grpc + +# convenience definition +# (from ConfigArgParse) +class ArgumentDefaultsRawHelpFormatter( + ArgumentDefaultsHelpFormatter, + RawTextHelpFormatter, + RawDescriptionHelpFormatter): + """HelpFormatter that adds default values AND doesn't do line-wrapping""" +pass + +class DictAction(Action): + val_type = None + def __call__(self, parser, namespace, values, option_string=None): + if self.val_type is None: + self.val_type = self.type + result = {} + if len(values)%2!=0: + parser.error("{} args must come in pairs".format(self.dest)) + for i in range(0, len(values), 2): + result[values[i]] = self.val_type(values[i+1]) + setattr(namespace, self.dest, result) + +class TritonChecksumStatus(Enum): + CORRECT = 0 + MISSING = 1 + INCORRECT = 2 + +message_classes = {cls.__name__ : cls for cls in message.Message.__subclasses__()} + +_FieldDescriptor = descriptor.FieldDescriptor +cpp_to_python = { + _FieldDescriptor.CPPTYPE_INT32: int, + _FieldDescriptor.CPPTYPE_INT64: int, + _FieldDescriptor.CPPTYPE_UINT32: int, + _FieldDescriptor.CPPTYPE_UINT64: int, + _FieldDescriptor.CPPTYPE_DOUBLE: float, + _FieldDescriptor.CPPTYPE_FLOAT: float, + _FieldDescriptor.CPPTYPE_BOOL: bool, + _FieldDescriptor.CPPTYPE_STRING: str, +} +checker_to_type = {val.__class__:cpp_to_python[key] for key,val in type_checkers._VALUE_CHECKERS.items()} +# for some reason, this one is not in the map +checker_to_type[type_checkers.UnicodeValueChecker] = str + +kind_to_int = {v.name:v.number for v in grpc.model_config_pb2._MODELINSTANCEGROUP_KIND.values} +thread_control_parameters = { + "onnx": ["intra_op_thread_count", "inter_op_thread_count"], + "tensorflow": ["TF_NUM_INTRA_THREADS", "TF_NUM_INTER_THREADS", "TF_USE_PER_SESSION_THREADS"], +} + +def get_type(obj): + obj_type = obj.__class__.__name__ + entry_type = None + entry_class = None + if obj_type=="RepeatedCompositeFieldContainer" or obj_type=="MessageMap": + entry_type = obj._message_descriptor.name + entry_class = message_classes[entry_type] + elif obj_type=="RepeatedScalarFieldContainer": + entry_class = checker_to_type[obj._type_checker.__class__] + entry_type = entry_class.__name__ + elif obj_type=="ScalarMap": + entry_class = obj.GetEntryClass()().value.__class__ + entry_type = entry_class.__name__ + return { + "class": obj.__class__, + "type": obj_type+("<"+entry_type+">" if entry_type is not None else ""), + "entry_class": entry_class, + "entry_type": entry_type, + } + +def get_fields(obj, name, level=0, verbose=False): + prefix = ' '*level + obj_info = {"name": name, "fields": []} + obj_info.update(get_type(obj)) + if verbose: print(prefix+obj_info["type"],name) + field_obj = None + if hasattr(obj, "DESCRIPTOR"): + field_obj = obj + elif obj_info["entry_class"] is not None and hasattr(obj_info["entry_class"], "DESCRIPTOR"): + field_obj = obj_info["entry_class"]() + field_list = [] + if field_obj is not None: + field_list = [f.name for f in field_obj.DESCRIPTOR.fields] + for field in field_list: + obj_info["fields"].append(get_fields(getattr(field_obj,field),field,level+1,verbose)) + return obj_info + +def get_model_info(): + return get_fields(grpc.model_config_pb2.ModelConfig(), "ModelConfig") + +def msg_json(val, defaults=False): + return json_format.MessageToJson(val, preserving_proto_field_name=True, including_default_value_fields=defaults, indent=0).replace(",\n",", ").replace("\n","") + +def print_fields(obj, info, level=0, json=False, defaults=False): + def print_subfields(obj,level): + fields = obj.DESCRIPTOR.fields if defaults else [f[0] for f in obj.ListFields()] + for field in fields: + print_fields(getattr(obj,field.name), next(f for f in info["fields"] if f["name"]==field.name), level=level, json=json, defaults=defaults) + + prefix = ' ' + print(prefix*level+info["type"],info["name"]) + if hasattr(obj, "DESCRIPTOR"): + if json and level>0: + print(prefix*(level+1)+msg_json(obj, defaults)) + else: + print_subfields(obj,level+1) + elif info["type"].startswith("RepeatedCompositeFieldContainer"): + if json: + print(prefix*(level+1)+str([msg_json(val, defaults) for val in obj])) + else: + for ientry,entry in enumerate(obj): + print(prefix*(level+1)+"{}: ".format(ientry)) + print_subfields(entry,level+2) + elif info["type"].startswith("MessageMap"): + if json: + print(prefix*(level+1)+str({key:msg_json(val, defaults) for key,val in obj.items()})) + else: + for key,val in obj.items(): + print(prefix*(level+1)+"{}: ".format(key)) + print_subfields(val,level+2) + else: + print(prefix*(level+1)+str(obj)) + +def edit_builtin(model,dest,val): + setattr(model,dest,val) + +def edit_scalar_list(model,dest,val): + item = getattr(model,dest) + item.clear() + item.extend(val) + +def edit_scalar_map(model,dest,val): + item = getattr(model,dest) + item.clear() + item.update(val) + +def edit_msg(model,dest,val): + item = getattr(model,dest) + json_format.ParseDict(val,item) + +def edit_msg_list(model,dest,val): + item = getattr(model,dest) + item.clear() + for v in vals: + m = item.add() + json_format.ParseDict(v,m) + +def edit_msg_map(model,dest,val): + item = getattr(model,dest) + item.clear() + for k,v in vals.items(): + m = item.get_or_create(k) + json_format.ParseDict(v,m) + +def add_edit_args(parser, model_info): + group = parser.add_argument_group("fields", description="ModelConfig fields to edit") + dests = {} + for field in model_info["fields"]: + argname = "--{}".format(field["name"].replace("_","-")) + val_type = None + editor = None + if field["class"].__module__=="builtins": + kwargs = dict(type=field["class"]) + editor = edit_builtin + elif field["type"].startswith("RepeatedScalarFieldContainer"): + kwargs = dict(type=field["entry_class"], nargs='*') + editor = edit_scalar_list + elif field["type"].startswith("ScalarMap"): + kwargs = dict(type=str, nargs='*', metavar="key value", action=DictAction) + val_type = field["entry_class"] + editor = edit_scalar_map + elif field["type"].startswith("RepeatedCompositeFieldContainer"): + kwargs = dict(type=json.loads, nargs='*', + help="provide {} values in json format".format(field["entry_type"]) + ) + editor = edit_msg_list + elif field["type"].startswith("MessageMap"): + kwargs = dict(type=str, nargs='*', metavar="key value", action=DictAction, + help="provide {} values in json format".format(field["entry_type"]) + ) + editor = edit_msg_map + val_type = json.loads + else: + kwargs = dict(type=json.loads, + help="provide {} values in json format".format(field["type"]) + ) + edit = edit_msg + action = group.add_argument(argname, **kwargs) + if val_type is not None: action.val_type = val_type + dests[action.dest] = editor + return parser, dests + +def get_checksum(filename, chunksize=4096): + import hashlib + with open(filename, 'rb') as f: + file_hash = hashlib.md5() + while chunk := f.read(chunksize): + file_hash.update(chunk) + return file_hash.hexdigest() + +def get_checksum_update_cmd(force=False): + extra_args = ["--update"] + if force: extra_args.append("--force") + extra_args = [arg for arg in extra_args if arg not in sys.argv] + return "{} {}".format(" ".join(sys.argv), " ".join(extra_args)) + +def update_config(args): + # update config path to be output path (in case view is called) + if args.copy: + args.config = "config.pbtxt" + if isinstance(args.copy,str): + args.config = os.path.join(args.copy, args.config) + + with open(args.config,'w') as outfile: + text_format.PrintMessage(args.model, outfile, use_short_repeated_primitives=True) + +def cfg_common(args): + if not hasattr(args,'model_info'): + args.model_info = get_model_info() + args.model = grpc.model_config_pb2.ModelConfig() + if hasattr(args,'config'): + with open(args.config,'r') as infile: + text_format.Parse(infile.read(), args.model) + +def cfg_schema(args): + get_fields(args.model, "ModelConfig", verbose=True) + +def cfg_view(args): + print("Contents of {}".format(args.config)) + print_fields(args.model, args.model_info, json=args.json, defaults=args.defaults) + +def cfg_edit(args): + for dest,editor,val in [(dest,editor,getattr(args,dest)) for dest,editor in args.edit_dests.items() if getattr(args,dest) is not None]: + editor(args.model,dest,val) + + update_config(args) + + if args.view: + cfg_view(args) + +def cfg_checksum(args): + # internal parameter + if not hasattr(args, "should_return"): + args.should_return = False + + agents = args.model.model_repository_agents.agents + checksum_agent = next((agent for agent in agents if agent.name=="checksum"), None) + if checksum_agent is None: + checksum_agent = agents.add(name="checksum") + + incorrect = [] + missing = [] + + from glob import glob + config_dir = os.path.dirname(args.config) + for filename in glob(os.path.join(config_dir,"*/*")): + if os.path.islink(os.path.dirname(filename)): continue + checksum = get_checksum(filename) + # key = algorithm:[filename relative to config.pbtxt dir] + filename = os.path.relpath(filename, config_dir) + filekey = "MD5:{}".format(filename) + if filekey in checksum_agent.parameters and checksum!=checksum_agent.parameters[filekey]: + incorrect.append(filename) + if args.update and args.force: + checksum_agent.parameters[filekey] = checksum + elif filekey not in checksum_agent.parameters: + missing.append(filename) + if args.update: + checksum_agent.parameters[filekey] = checksum + else: + continue + + needs_update = len(missing)>0 + needs_force_update = len(incorrect)>0 + + if not args.quiet: + if needs_update: + print("\n".join(["Missing checksums:"]+missing)) + if needs_force_update: + print("\n".join(["Incorrect checksums:"]+incorrect)) + + if needs_force_update: + if not (args.update and args.force): + if args.should_return: + return TritonChecksumStatus.INCORRECT + else: + raise RuntimeError("\n".join([ + "Incorrect checksum(s) found, indicating existing model file(s) has been changed, which violates policy.", + "To override, run the following command (and provide a justification in your PR):", + get_checksum_update_cmd(force=True) + ])) + else: + update_config(args) + elif needs_update: + if not args.update: + if args.should_return: + return TritonChecksumStatus.MISSING + else: + raise RuntimeError("\n".join([ + "Missing checksum(s) found, indicating new model file(s).", + "To update, run the following command:", + get_checksum_update_cmd(force=False) + ])) + else: + update_config(args) + + if args.view: + cfg_view(args) + + if args.should_return: + return TritonChecksumStatus.CORRECT + +def cfg_versioncheck(args): + incorrect = [] + missing = [] + + for path in os.environ['CMSSW_SEARCH_PATH'].split(':'): + for dirpath, dirnames, filenames in os.walk(path): + for filename in filenames: + if filename=="config.pbtxt": + filepath = os.path.join(dirpath,filename) + checksum_args = Namespace( + config=filepath, should_return=True, + copy=False, json=False, defaults=False, view=False, + update=args.update, force=args.force, quiet=True + ) + cfg_common(checksum_args) + status = cfg_checksum(checksum_args) + if status==TritonChecksumStatus.INCORRECT: + incorrect.append(filepath) + elif status==TritonChecksumStatus.MISSING: + missing.append(filepath) + + msg = [] + instr = [] + if len(missing)>0: + msg.extend(["","The following files have missing checksum(s), indicating new model file(s):"]+missing) + instr.extend(["","To update missing checksums, run the following command:",get_checksum_update_cmd(force=False)]) + if len(incorrect)>0: + msg.extend(["","The following files have incorrect checksum(s), indicating existing model file(s) have been changed, which violates policy:"]+incorrect) + instr.extend(["","To override incorrect checksums, run the following command (and provide a justification in your PR):",get_checksum_update_cmd(force=True)]) + + if len(msg)>0: + raise RuntimeError("\n".join(msg+instr)) + +def cfg_threadcontrol(args): + # copy the entire model, not just config.pbtxt + config_dir = os.path.dirname(args.config) + copy_dir = args.copy + new_config_dir = os.path.join(copy_dir, pathlib.Path(config_dir).name) + shutil.copytree(config_dir, new_config_dir) + + platform = args.model.platform + if platform=="ensemble": + repo_dir = pathlib.Path(config_dir).parent + for step in args.model.ensemble_scheduling.step: + # update args and run recursively + args.config = os.path.join(repo_dir,step.model_name,"config.pbtxt") + args.copy = copy_dir + cfg_common(args) + cfg_threadcontrol(args) + return + + # is it correct to do this even if found_params is false at the end? + args.model.instance_group.add(count=args.nThreads, kind=kind_to_int['KIND_CPU']) + + found_params = False + for key,val in thread_control_parameters.items(): + if key in platform: # partial matching + for param in val: + item = args.model.parameters.get_or_create(key) + item.string_value = "1" + found_params = True + break + if not found_params: + print("Warning: thread (instance) control not implemented for {}".format(platform)) + + args.copy = new_config_dir + update_config(args) + + if args.view: + cfg_view(args) + +if __name__=="__main__": + # initial common operations + model_info = get_model_info() + edit_dests = None + + _parser_common = ArgumentParser(add_help=False) + _parser_common.add_argument("-c", "--config", type=str, default="", required=True, help="path to input config.pbtxt file") + + parser = ArgumentParser(formatter_class=ArgumentDefaultsRawHelpFormatter) + subparsers = parser.add_subparsers(dest="command") + + parser_schema = subparsers.add_parser("schema", help="view ModelConfig schema", + description="""Display all fields in the ModelConfig object, with type information. + (For collection types, the subfields of the entry type are shown.)""", + ) + parser_schema.set_defaults(func=cfg_schema) + + _parser_view_args = ArgumentParser(add_help=False) + _parser_view_args.add_argument("--json", default=False, action="store_true", help="display in json format") + _parser_view_args.add_argument("--defaults", default=False, action="store_true", help="show fields with default values") + + parser_view = subparsers.add_parser("view", parents=[_parser_common, _parser_view_args], help="view config.pbtxt contents") + parser_view.set_defaults(func=cfg_view) + + _parser_copy_view = ArgumentParser(add_help=False) + _parser_copy_view.add_argument("--view", default=False, action="store_true", help="view file after editing") + + _parser_copy = ArgumentParser(add_help=False, parents=[_parser_copy_view]) + _parser_copy.add_argument("--copy", metavar="dir", default=False, const=True, nargs='?', type=str, + help="make a copy of config.pbtxt instead of editing in place ([dir] = output path for copy; if omitted, current directory is used)" + ) + + parser_edit = subparsers.add_parser("edit", parents=[_parser_common, _parser_copy, _parser_view_args], help="edit config.pbtxt contents") + parser_edit, edit_dests = add_edit_args(parser_edit, model_info) + parser_edit.set_defaults(func=cfg_edit) + + _parser_checksum_update = ArgumentParser(add_help=False) + _parser_checksum_update.add_argument("--update", default=False, action="store_true", help="update missing checksums") + _parser_checksum_update.add_argument("--force", default=False, action="store_true", help="force update all checksums") + + parser_checksum = subparsers.add_parser("checksum", parents=[_parser_common, _parser_copy, _parser_view_args, _parser_checksum_update], help="handle model file checksums") + parser_checksum.add_argument("--quiet", default=False, action="store_true", help="suppress printouts") + parser_checksum.set_defaults(func=cfg_checksum) + + parser_versioncheck = subparsers.add_parser("versioncheck", parents=[_parser_checksum_update], help="check all model checksums") + parser_versioncheck.set_defaults(func=cfg_versioncheck) + + _parser_copy_req = ArgumentParser(add_help=False, parents=[_parser_copy_view]) + _parser_copy_req.add_argument("--copy", metavar="dir", type=str, required=True, + help="local model repository directory to copy model(s)" + ) + + parser_threadcontrol = subparsers.add_parser("threadcontrol", parents=[_parser_common, _parser_copy_req, _parser_view_args], help="enable thread controls") + parser_threadcontrol.add_argument("--nThreads", type=int, required=True, help="number of threads") + parser_threadcontrol.set_defaults(func=cfg_threadcontrol) + + args = parser.parse_args() + args.model_info = model_info + if edit_dests is not None: + args.edit_dests = edit_dests + + cfg_common(args) + + args.func(args) diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index c57a8355d07a1..201ad40d35a0e 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -9,11 +9,18 @@ #include "grpc_client.h" #include "grpc_service.pb.h" +#include "model_config.pb.h" -#include +#include "google/protobuf/text_format.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" + +#include #include #include +#include +#include #include +#include #include #include @@ -75,22 +82,61 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d //convert seconds to microseconds options_[0].client_timeout_ = params.getUntrackedParameter("timeout") * 1e6; - //config needed for batch size - inference::ModelConfigResponse modelConfigResponse; - TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_[0].model_name_, options_[0].model_version_), - "TritonClient(): unable to get model config"); - inference::ModelConfig modelConfig(modelConfigResponse.config()); + //get fixed parameters from local config + inference::ModelConfig localModelConfig; + { + const std::string& localModelConfigPath(params.getParameter("modelConfigPath").fullPath()); + int fileDescriptor = open(localModelConfigPath.c_str(), O_RDONLY); + if (fileDescriptor < 0) + throw TritonException("LocalFailure") + << "TritonClient(): unable to open local model config: " << localModelConfigPath; + google::protobuf::io::FileInputStream localModelConfigInput(fileDescriptor); + localModelConfigInput.SetCloseOnDelete(true); + if (!google::protobuf::TextFormat::Parse(&localModelConfigInput, &localModelConfig)) + throw TritonException("LocalFailure") + << "TritonClient(): unable to parse local model config: " << localModelConfigPath; + } //check batch size limitations (after i/o setup) //triton uses max batch size = 0 to denote a model that does not support native batching (using the outer dimension) //but for models that do support batching (native or otherwise), a given event may set batch size 0 to indicate no valid input is present //so set the local max to 1 and keep track of "no outer dim" case - maxOuterDim_ = modelConfig.max_batch_size(); + maxOuterDim_ = localModelConfig.max_batch_size(); noOuterDim_ = maxOuterDim_ == 0; maxOuterDim_ = std::max(1u, maxOuterDim_); //propagate batch size setBatchSize(1); + //compare model checksums to remote config to enforce versioning + inference::ModelConfigResponse modelConfigResponse; + TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_[0].model_name_, options_[0].model_version_), + "TritonClient(): unable to get model config"); + inference::ModelConfig remoteModelConfig(modelConfigResponse.config()); + + std::map> checksums; + size_t fileCounter = 0; + for (const auto& modelConfig : {localModelConfig, remoteModelConfig}) { + const auto& agents = modelConfig.model_repository_agents().agents(); + auto agent = std::find_if(agents.begin(), agents.end(), [](auto const& a) { return a.name() == "checksum"; }); + if (agent != agents.end()) { + const auto& params = agent->parameters(); + for (const auto& [key, val] : params) { + // only check the requested version + if (key.compare(0, options_[0].model_version_.size() + 1, options_[0].model_version_ + "/") == 0) + checksums[key][fileCounter] = val; + } + } + ++fileCounter; + } + std::vector incorrect; + for (const auto& [key, val] : checksums) { + if (checksums[key][0] != checksums[key][1]) + incorrect.push_back(key); + } + if (!incorrect.empty()) + throw TritonException("ModelVersioning") << "The following files have incorrect checksums on the remote server: " + << triton_utils::printColl(incorrect, ", "); + //get model info inference::ModelMetadataResponse modelMetadata; TRITON_THROW_IF_ERROR(client_->ModelMetadata(&modelMetadata, options_[0].model_name_, options_[0].model_version_), diff --git a/HeterogeneousCore/SonicTriton/src/triton_utils.cc b/HeterogeneousCore/SonicTriton/src/triton_utils.cc index 3dc872d6e1b42..a71190d951e46 100644 --- a/HeterogeneousCore/SonicTriton/src/triton_utils.cc +++ b/HeterogeneousCore/SonicTriton/src/triton_utils.cc @@ -21,4 +21,5 @@ template std::string triton_utils::printColl(const edm::Span& coll, const std::string& delim); template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); +template std::string triton_utils::printColl(const std::vector& coll, const std::string& delim); template std::string triton_utils::printColl(const std::unordered_set& coll, const std::string& delim); diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index 272fba3da2cc8..e4ff7a0bb56f3 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -1,5 +1,8 @@ - - + + + + +