diff --git a/nemo/deploy/nlp/__init__.py b/nemo/deploy/nlp/__init__.py index 21e2ca2751f8..52d5b3dbff3e 100644 --- a/nemo/deploy/nlp/__init__.py +++ b/nemo/deploy/nlp/__init__.py @@ -15,6 +15,8 @@ use_query_llm = True try: - from nemo.deploy.nlp.query_llm import NemoQueryLLM + from nemo.deploy.nlp.query_llm import NemoTritonQueryLLMTensorRT except Exception: use_query_llm = False + +from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable diff --git a/nemo/deploy/nlp/megatronllm_deployable.py b/nemo/deploy/nlp/megatronllm_deployable.py new file mode 100644 index 000000000000..c27bbbd0102b --- /dev/null +++ b/nemo/deploy/nlp/megatronllm_deployable.py @@ -0,0 +1,316 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from enum import IntEnum, auto +from pathlib import Path + +import numpy as np +import torch +import wrapt +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.text_generation_utils import ( + OutputType, + get_default_length_params, + get_default_sampling_params, +) +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.deploy import ITritonDeployable +from nemo.deploy.utils import cast_output, str_ndarray2list + + +@wrapt.decorator +def noop_decorator(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +use_pytriton = True +batch = noop_decorator +try: + from pytriton.decorators import batch + from pytriton.model_config import Tensor +except Exception: + use_pytriton = False + +LOGGER = logging.getLogger("NeMo") + + +def GetTensorShape(pyvalue): + """ + utility function to get Triton Tensor shape from a python value + assume that lists are shape -1 and all others are scalars with shape 1 + """ + return (-1 if type(pyvalue) == list else 1,) + + +def GetNumpyDtype(pyvalue): + """ + utility function to get numpy dtype of a python value + e.g. bool -> np.bool_ + """ + ''' + manually defining the mapping of python type -> numpy type for now + is there a better way to do it? tried np.array(pyvalue).dtype, but that doesn't seem to work + ''' + py_to_numpy_mapping = {str: bytes, bool: np.bool_, float: np.single, int: np.int_} + python_type = type(pyvalue) + # for lists, return the type of the internal elements + if python_type == list: + python_type = type(pyvalue[0]) + numpy_type = py_to_numpy_mapping[python_type] + return numpy_type + + +class ServerSync(IntEnum): + """Enum for synchronization messages using torch.distributed""" + + WAIT = auto() + SIGNAL = auto() + + def to_long_tensor(self): + return torch.tensor([self], dtype=torch.long, device='cuda') + + +class MegatronLLMDeployable(ITritonDeployable): + """Triton inference server compatible deploy class for a .nemo model file""" + + def __init__( + self, + nemo_checkpoint_filepath: str = None, + num_devices: int = 1, + num_nodes: int = 1, + existing_model: MegatronGPTModel = None, + ): + if nemo_checkpoint_filepath is None and existing_model is None: + raise ValueError( + "MegatronLLMDeployable requires either a .nemo checkpoint filepath or an existing MegatronGPTModel, but both provided were None" + ) + if num_devices > 1: + LOGGER.warning( + "Creating a MegatronLLMDeployable with num_devices>1 will assume running with a PyTorch Lightning DDP-variant strategy, which will run the main script once per device. Make sure any user code is compatible with multiple executions!" + ) + + # if both existing_model and nemo_checkpoint_filepath are provided, existing_model will take precedence + if existing_model is not None: + self.model = existing_model + else: + self._load_from_nemo_checkpoint(nemo_checkpoint_filepath, num_devices, num_nodes) + + self.model.eval() + # helper threads spawned by torch.multiprocessing should loop inside this helper function + self._helper_thread_evaluation_loop() + + def _load_from_nemo_checkpoint(self, nemo_checkpoint_filepath: str, num_devices: int, num_nodes: int): + if Path(nemo_checkpoint_filepath).exists(): + trainer = Trainer( + strategy=NLPDDPStrategy(), + devices=num_devices, + num_nodes=num_nodes, + ) + + custom_config = MegatronGPTModel.restore_from( + nemo_checkpoint_filepath, trainer=trainer, return_config=True + ) + # transformer_engine should always be true according to EricH, but GPT-2B model will fail if it is enabled + custom_config.transformer_engine = True + # using multi-gpu for tensor parallelism directly for now, could do pipeline parallel instead or a combination + custom_config.tensor_model_parallel_size = num_devices + # had to override these to make Nemotron3-22B work, see sample_sequence_batch() in text_generation_utils.py + custom_config.activations_checkpoint_granularity = None + custom_config.activations_checkpoint_method = None + + self.model = MegatronGPTModel.restore_from( + nemo_checkpoint_filepath, trainer=trainer, override_config_path=custom_config + ) + + def _helper_thread_evaluation_loop(self): + # only deploy the server on main thread, other threads enter this evaluation loop + if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: + while True: + wait_value = ServerSync.WAIT.to_long_tensor() + torch.distributed.broadcast(wait_value, 0) + if wait_value.item() == ServerSync.SIGNAL: + self.model.generate(inputs=[""], length_params=None) + + _INPUT_PARAMETER_FIELDS = { + "prompts": (-1, bytes, False), + } + + ''' + there is no get_default equivalent for OutputType like there is for SamplingParameters and LengthParameters + but we still want to generate output using a real OutputType TypedDict for static type checking + ''' + _BLANK_OUTPUTTYPE: OutputType = { + 'sentences': [""], + 'tokens': [[""]], + 'logprob': [[0.0]], + 'full_logprob': [[0.0]], + 'token_ids': [[0]], + 'offsets': [[0]], + } + + @property + def get_triton_input(self): + input_parameters = tuple( + Tensor(name=name, shape=(shape,), dtype=dtype, optional=optional) + for name, (shape, dtype, optional) in self._INPUT_PARAMETER_FIELDS.items() + ) + ''' + in theory, would like to use typedict2tensor() function to generate Tensors, but it purposely ignores 1D arrays + asked JakubK why on 2024-04-26, but he doesn't know who owns the code + sampling_parameters = typedict2tensor(SamplingParam) + length_parameters = typedict2tensor(LengthParam) + ''' + default_sampling_params: SamplingParam = get_default_sampling_params() + sampling_parameters = tuple( + Tensor( + name=parameter_name, + shape=GetTensorShape(parameter_value), + dtype=GetNumpyDtype(parameter_value), + optional=True, + ) + for parameter_name, parameter_value in default_sampling_params.items() + ) + default_length_params: LengthParam = get_default_length_params() + length_parameters = tuple( + Tensor( + name=parameter_name, + shape=GetTensorShape(parameter_value), + dtype=GetNumpyDtype(parameter_value), + optional=True, + ) + for parameter_name, parameter_value in default_length_params.items() + ) + + inputs = input_parameters + sampling_parameters + length_parameters + return inputs + + @property + def get_triton_output(self): + # outputs are defined by the fields of OutputType + outputs = [ + Tensor( + name=parameter_name, + shape=GetTensorShape(parameter_value), + dtype=GetNumpyDtype(parameter_value[0]), + ) + for parameter_name, parameter_value in MegatronLLMDeployable._BLANK_OUTPUTTYPE.items() + ] + return outputs + + @staticmethod + def _sampling_params_from_triton_inputs(**inputs: np.ndarray): + """Extract SamplingParam fields from triton input dict""" + sampling_params: SamplingParam = get_default_sampling_params() + for sampling_param_field in sampling_params.keys(): + if sampling_param_field in inputs: + sampling_params[sampling_param_field] = inputs.pop(sampling_param_field)[0][0] + return sampling_params + + @staticmethod + def _length_params_from_triton_inputs(**inputs: np.ndarray): + """Extract LengthParam fields from triton input dict""" + length_params: LengthParam = get_default_length_params() + for length_param_field in length_params.keys(): + if length_param_field in inputs: + length_params[length_param_field] = inputs.pop(length_param_field)[0][0] + return length_params + + @batch + def triton_infer_fn(self, **inputs: np.ndarray): + """Triton server inference function that actually runs the model""" + if torch.distributed.is_initialized(): + distributed_rank = torch.distributed.get_rank() + if distributed_rank != 0: + raise ValueError( + f"Triton inference function should not be called on a thread with torch.distributed rank != 0, but this thread is rank {distributed_rank}" + ) + signal_value = ServerSync.SIGNAL.to_long_tensor() + torch.distributed.broadcast(signal_value, 0) + + input_strings = str_ndarray2list(inputs.pop("prompts")) + sampling_params = self._sampling_params_from_triton_inputs(**inputs) + length_params = self._length_params_from_triton_inputs(**inputs) + + model_output = self.model.generate( + inputs=input_strings, length_params=length_params, sampling_params=sampling_params + ) + ''' + model_output['sentences'] will be a list of strings (one per prompt) + other fields will either be a list of lists (tokens, for example) + or a list of pytorch Tensor + ''' + + triton_output = {} + _OUTPUT_FILLER_VALUES = { + 'tokens': "", + 'logprob': 0.0, + 'full_logprob': 0.0, + 'token_ids': -1, + 'offsets': -1, + } + for model_output_field, value in model_output.items(): + + if model_output_field != 'sentences' and value is not None: + # find length of longest non-sentence output item + field_longest_output_item = 0 + for item in value: + field_longest_output_item = max(field_longest_output_item, len(item)) + # then pad shorter items to match this length + for index, item in enumerate(value): + num_pad_values = field_longest_output_item - len(item) + if num_pad_values > 0: + pad_value = _OUTPUT_FILLER_VALUES[model_output_field] + if isinstance(item, torch.Tensor): + pad_tensor = torch.full( + (num_pad_values, item.size(1)) if item.dim() > 1 else (num_pad_values,), + pad_value, + dtype=item.dtype, + device='cuda', + ) + padded_item = torch.cat((item, pad_tensor)) + value[index] = padded_item + else: + pad_list = [pad_value] * num_pad_values + padded_item = item + pad_list + value[index] = padded_item + + field_dtype = GetNumpyDtype(MegatronLLMDeployable._BLANK_OUTPUTTYPE[model_output_field][0]) + if value is None: + # triton does not allow for optional output parameters, so need to populate them if they don't exist + triton_output[model_output_field] = np.full( + # 'sentences' should always have a valid value, so use that for the output shape + np.shape(model_output['sentences']), + MegatronLLMDeployable._BLANK_OUTPUTTYPE[model_output_field][0], + dtype=field_dtype, + ) + elif field_dtype == bytes: + # strings are cast to bytes + triton_output[model_output_field] = cast_output(value, field_dtype) + elif isinstance(value[0], torch.Tensor): + if value[0].dtype == torch.bfloat16: + # numpy currently does not support bfloat16, so need to manually convert it + triton_output[model_output_field] = np.array([tensor.cpu().float().numpy() for tensor in value]) + else: + triton_output[model_output_field] = np.array([tensor.cpu().numpy() for tensor in value]) + else: + # non-strings are output as-is (in numpy format) + triton_output[model_output_field] = np.array(value) + return triton_output diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index 0f7866e57cda..835ff46dd5fe 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -19,9 +19,9 @@ from pathlib import Path from nemo.deploy import DeployPyTriton +from nemo.deploy.nlp import MegatronLLMDeployable from nemo.export import TensorRTLLM - LOGGER = logging.getLogger("NeMo") @@ -31,6 +31,13 @@ def get_args(argv): description=f"Deploy nemo models to Triton", ) parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source .nemo file") + parser.add_argument( + "-dsn", + "--direct_serve_nemo", + default=False, + action='store_true', + help="Serve the nemo model directly instead of exporting to TRTLLM first. Will ignore other TRTLLM-specific arguments.", + ) parser.add_argument( "-ptnc", "--ptuning_nemo_checkpoint", @@ -146,18 +153,7 @@ def get_args(argv): return args -def nemo_deploy(argv): - args = get_args(argv) - - if args.debug_mode: - loglevel = logging.DEBUG - else: - loglevel = logging.INFO - - LOGGER.setLevel(loglevel) - LOGGER.info("Logging level set to {}".format(loglevel)) - LOGGER.info(args) - +def get_trtllm_deployable(args): if args.triton_model_repository is None: trt_llm_path = "/tmp/trt_llm_model_dir/" LOGGER.info( @@ -170,28 +166,24 @@ def nemo_deploy(argv): trt_llm_path = args.triton_model_repository if args.nemo_checkpoint is None and args.triton_model_repository is None: - LOGGER.error( + raise ValueError( "The provided model repository is not a valid TensorRT-LLM model " "directory. Please provide a --nemo_checkpoint." ) - return if args.nemo_checkpoint is None and not os.path.isdir(args.triton_model_repository): - LOGGER.error( + raise ValueError( "The provided model repository is not a valid TensorRT-LLM model " "directory. Please provide a --nemo_checkpoint." ) - return if args.nemo_checkpoint is not None and args.model_type is None: - LOGGER.error("Model type is required to be defined if a nemo checkpoint is provided.") - return + raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") ptuning_tables_files = [] if not args.ptuning_nemo_checkpoint is None: if args.max_prompt_embedding_table_size is None: - LOGGER.error("max_prompt_embedding_table_size parameter is needed for the prompt tuning table(s).") - return + raise ValueError("max_prompt_embedding_table_size parameter is needed for the prompt tuning table(s).") for pt_checkpoint in args.ptuning_nemo_checkpoint: ptuning_nemo_checkpoint_path = Path(pt_checkpoint) @@ -199,19 +191,16 @@ def nemo_deploy(argv): if ptuning_nemo_checkpoint_path.is_file(): ptuning_tables_files.append(pt_checkpoint) else: - LOGGER.error("Could not read the prompt tuning tables from {0}".format(pt_checkpoint)) - return + raise IsADirectoryError("Could not read the prompt tuning tables from {0}".format(pt_checkpoint)) else: - LOGGER.error("File or directory {0} does not exist.".format(pt_checkpoint)) - return + raise FileNotFoundError("File or directory {0} does not exist.".format(pt_checkpoint)) if args.task_ids is not None: if len(ptuning_tables_files) != len(args.task_ids): - LOGGER.error( + raise RuntimeError( "Number of task ids and prompt embedding tables have to match. " "There are {0} tables and {1} task ids.".format(len(ptuning_tables_files), len(args.task_ids)) ) - return trt_llm_exporter = TensorRTLLM( model_dir=trt_llm_path, @@ -245,8 +234,7 @@ def nemo_deploy(argv): save_nemo_model_config=True, ) except Exception as error: - LOGGER.error("An error has occurred during the model export. Error message: " + str(error)) - return + raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) try: for i, prompt_embeddings_checkpoint_path in enumerate(ptuning_tables_files): @@ -265,12 +253,35 @@ def nemo_deploy(argv): prompt_embeddings_checkpoint_path=prompt_embeddings_checkpoint_path, ) except Exception as error: - LOGGER.error("An error has occurred during adding the prompt embedding table(s). Error message: " + str(error)) - return + raise RuntimeError( + "An error has occurred during adding the prompt embedding table(s). Error message: " + str(error) + ) + return trt_llm_exporter + + +def get_nemo_deployable(args): + if args.nemo_checkpoint is None: + raise ValueError("Direct serve requires a .nemo checkpoint") + return MegatronLLMDeployable(args.nemo_checkpoint, args.num_gpus) + + +def nemo_deploy(argv): + args = get_args(argv) + + if args.debug_mode: + loglevel = logging.DEBUG + else: + loglevel = logging.INFO + + LOGGER.setLevel(loglevel) + LOGGER.info("Logging level set to {}".format(loglevel)) + LOGGER.info(args) + + triton_deployable = get_nemo_deployable(args) if args.direct_serve_nemo else get_trtllm_deployable(args) try: nm = DeployPyTriton( - model=trt_llm_exporter, + model=triton_deployable, triton_model_name=args.triton_model_name, triton_model_version=args.triton_model_version, max_batch_size=args.max_batch_size, diff --git a/tests/deploy/pytriton_deploy.py b/tests/deploy/pytriton_deploy.py new file mode 100644 index 000000000000..3b722d2d7fec --- /dev/null +++ b/tests/deploy/pytriton_deploy.py @@ -0,0 +1,136 @@ +import argparse + +import numpy as np +from pytriton.client import ModelClient + +from nemo.deploy.deploy_pytriton import DeployPyTriton +from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +from nemo.deploy.nlp.query_llm import NemoTritonQueryLLMPyTorch + + +def test_triton_deployable(args): + megatron_deployable = MegatronLLMDeployable(args.nemo_checkpoint, args.num_gpus) + + prompts = ["What is the biggest planet in the solar system?", "What is the fastest steam locomotive in history?"] + url = "localhost:8000" + model_name = args.model_name + init_timeout = 600.0 + + nm = DeployPyTriton( + model=megatron_deployable, + triton_model_name=model_name, + triton_model_version=1, + max_batch_size=8, + port=8000, + address="0.0.0.0", + streaming=False, + ) + nm.deploy() + nm.run() + + # run once with NemoTritonQueryLLMPyTorch + nemo_triton_query = NemoTritonQueryLLMPyTorch(url, model_name) + + result_dict = nemo_triton_query.query_llm( + prompts, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + max_length=args.max_output_token, + init_timeout=init_timeout, + ) + print("NemoTritonQueryLLMPyTriton result:") + print(result_dict) + + # run once with ModelClient, the results should be identical + str_ndarray = np.array(prompts)[..., np.newaxis] + prompts = np.char.encode(str_ndarray, "utf-8") + max_output_token = np.full(prompts.shape, args.max_output_token, dtype=np.int_) + top_k = np.full(prompts.shape, args.top_k, dtype=np.int_) + top_p = np.full(prompts.shape, args.top_p, dtype=np.single) + temperature = np.full(prompts.shape, args.temperature, dtype=np.single) + + with ModelClient(url, model_name, init_timeout_s=init_timeout) as client: + result_dict = client.infer_batch( + prompts=prompts, + max_length=max_output_token, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + print("ModelClient result:") + print(result_dict) + + # test logprobs generation + # right now we don't support batches where output data is inconsistent in size, so submitting each prompt individually + all_probs = np.full(prompts.shape, True, dtype=np.bool_) + compute_logprob = np.full(prompts.shape, True, dtype=np.bool_) + with ModelClient(url, model_name, init_timeout_s=init_timeout) as client: + logprob_results = client.infer_batch( + prompts=prompts, + max_length=max_output_token, + top_k=top_k, + top_p=top_p, + temperature=temperature, + all_probs=all_probs, + compute_logprob=compute_logprob, + ) + print("Logprob results:") + print(logprob_results) + + nm.stop() + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=f"Deploy nemo models to Triton and benchmark the models", + ) + + parser.add_argument( + "--model_name", + type=str, + required=True, + ) + parser.add_argument( + "--num_gpus", + type=int, + default=1, + ) + parser.add_argument( + "--nemo_checkpoint", + type=str, + required=True, + ) + parser.add_argument( + "--max_batch_size", + type=int, + default=8, + ) + parser.add_argument( + "--max_output_token", + type=int, + default=128, + ) + parser.add_argument( + "--top_k", + type=int, + default=1, + ) + parser.add_argument( + "--top_p", + type=float, + default=0.0, + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + ) + + return parser.parse_args() + + +if __name__ == '__main__': + args = get_args() + test_triton_deployable(args)