-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmodel.py
797 lines (721 loc) · 33.1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import asyncio
import base64
import gc
import json
import os
import queue
import threading
from io import BytesIO
from typing import Dict, List
import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from PIL import Image
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.lora.request import LoRARequest
from vllm.utils import random_uuid
from utils.metrics import VllmStatLogger
from utils.vllm_backend_utils import TritonSamplingParams
_VLLM_ENGINE_ARGS_FILENAME = "model.json"
_MULTI_LORA_ARGS_FILENAME = "multi_lora.json"
class TritonPythonModel:
@classmethod
def auto_complete_config(cls, auto_complete_model_config):
# Add inputs/outputs to the model config.
cls._auto_complete_inputs_and_outputs(auto_complete_model_config)
# We need to use decoupled transaction policy for saturating
# vLLM engine for max throughtput.
# TODO [DLIS:5233]: Allow asynchronous execution to lift this
# restriction for cases there is exactly a single response to
# a single request.
auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True))
# Disabling batching in Triton, let vLLM handle the batching on its own.
auto_complete_model_config.set_max_batch_size(0)
return auto_complete_model_config
@staticmethod
def _auto_complete_inputs_and_outputs(auto_complete_model_config):
# Inputs expected by the backend.
inputs = [
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
{
"name": "image",
"data_type": "TYPE_STRING",
"dims": [-1], # can be multiple images as separate elements
"optional": True,
},
{
"name": "stream",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
{
"name": "sampling_parameters",
"data_type": "TYPE_STRING",
"dims": [1],
"optional": True,
},
{
"name": "exclude_input_in_output",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
{
"name": "return_finish_reason",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
{
"name": "return_cumulative_logprob",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
{
"name": "return_logprobs",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
{
"name": "return_num_input_tokens",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
{
"name": "return_num_output_tokens",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
]
# Outputs expected by the backend.
outputs = [
{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]},
{"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]},
{"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]},
{"name": "logprobs", "data_type": "TYPE_STRING", "dims": [-1]},
{"name": "num_input_tokens", "data_type": "TYPE_UINT32", "dims": [1]},
{"name": "num_output_tokens", "data_type": "TYPE_UINT32", "dims": [-1]},
]
# Collect input and output names from the provided model config.
config = auto_complete_model_config.as_dict()
input_names = []
output_names = []
for input in config["input"]:
input_names.append(input["name"])
for output in config["output"]:
output_names.append(output["name"])
# Add missing inputs and outputs to the model config.
for input in inputs:
if input["name"] not in input_names:
auto_complete_model_config.add_input(input)
for output in outputs:
if output["name"] not in output_names:
auto_complete_model_config.add_output(output)
def initialize(self, args):
self.args = args
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
output_config = pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
# Setup vLLM engine health check
self._enable_health_check = self._get_bool_config_param(
"ENABLE_VLLM_HEALTH_CHECK"
)
self._is_healthy = True
# Initialize engine arguments
# TODO: Move this into _init_engine(), after moving check metrics enabled.
self._init_engine_args()
# Check if metrics are enabled. The ZMQ process cannot be used when metrics are
# enabled.
# TODO: Move the check into _setup_metrics().
self._enable_metrics = (
self._get_bool_config_param("REPORT_CUSTOM_METRICS")
and not self._aync_engine_args.disable_log_stats
)
# Starting the vLLM engine and its event thread running the AsyncIO event loop.
self._init_engine()
# Setup vLLM metrics
self._setup_metrics()
# Starting the response thread. It allows vLLM to keep making progress while
# response sender(s) are sending responses to server frontend.
self._response_queue = queue.Queue()
self._response_thread = threading.Thread(target=self._response_loop)
self._response_thread.start()
def _init_engine_args(self):
# Currently, Triton needs to use decoupled policy for asynchronously
# forwarding requests to vLLM engine, so assert it.
self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
self.model_config
)
assert (
self.using_decoupled
), "vLLM Triton backend must be configured to use decoupled model transaction policy"
engine_args_filepath = os.path.join(
pb_utils.get_model_dir(), _VLLM_ENGINE_ARGS_FILENAME
)
assert os.path.isfile(
engine_args_filepath
), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'"
with open(engine_args_filepath) as file:
self.vllm_engine_config = json.load(file)
# Validate device and multi-processing settings are currently set based on model/configs.
self._validate_device_config()
# Check for LoRA config and set it up if enabled
self._setup_lora()
# Create an AsyncEngineArgs from the config from JSON
self._aync_engine_args = AsyncEngineArgs(**self.vllm_engine_config)
def _init_engine(self):
# Run the engine in a separate thread running the AsyncIO event loop.
self._llm_engine = None
self._llm_engine_start_cv = threading.Condition()
self._llm_engine_shutdown_event = asyncio.Event()
self._event_thread = threading.Thread(
target=asyncio.run, args=(self._run_llm_engine(),)
)
self._event_thread.start()
with self._llm_engine_start_cv:
while self._llm_engine is None:
self._llm_engine_start_cv.wait()
# The 'threading.Thread()' will not raise the exception here should the engine
# failed to start, so the exception is passed back via the engine variable.
if isinstance(self._llm_engine, Exception):
e = self._llm_engine
self.logger.log_error(f"[vllm] Failed to start engine: {e}")
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
raise e
async def _run_llm_engine(self):
# Counter to keep track of ongoing request counts.
self._ongoing_request_count = 0
try:
# Start the vLLM engine. The engine lives for the scope of this with
# statement.
# TODO: Metrics should work with ZMQ enabled.
async with build_async_engine_client_from_engine_args(
engine_args=self._aync_engine_args,
disable_frontend_multiprocessing=self._enable_metrics,
) as engine:
# Capture the engine event loop and make it visible to other threads.
self._event_loop = asyncio.get_running_loop()
# Signal the engine is started and make it visible to other threads.
with self._llm_engine_start_cv:
self._llm_engine = engine
self._llm_engine_start_cv.notify_all()
# Wait for the engine shutdown signal.
await self._llm_engine_shutdown_event.wait()
# Wait for the ongoing requests to complete.
while self._ongoing_request_count > 0:
self.logger.log_info(
"[vllm] Awaiting remaining {} requests".format(
self._ongoing_request_count
)
)
await asyncio.sleep(1)
# Cancel all tasks in the event loop.
for task in asyncio.all_tasks(loop=self._event_loop):
if task is not asyncio.current_task():
task.cancel()
except Exception as e:
# Signal and pass the exception back via the engine variable if the engine
# failed to start. If the engine has started, re-raise the exception.
with self._llm_engine_start_cv:
if self._llm_engine is None:
self._llm_engine = e
self._llm_engine_start_cv.notify_all()
return
raise e
self._llm_engine = None
self.logger.log_info("[vllm] Shutdown complete")
def _validate_device_config(self):
triton_kind = self.args["model_instance_kind"]
triton_device_id = int(self.args["model_instance_device_id"])
triton_instance = f"{self.args['model_name']}_{triton_device_id}"
# Triton's current definition of KIND_GPU makes assumptions that
# models only use a single GPU. For multi-GPU models, the recommendation
# is to specify KIND_MODEL to acknowledge that the model will take control
# of the devices made available to it.
# NOTE: Consider other parameters that would indicate multi-GPU in the future.
tp_size = int(self.vllm_engine_config.get("tensor_parallel_size", 1))
if tp_size > 1 and triton_kind == "GPU":
raise ValueError(
"KIND_GPU is currently for single-GPU models, please specify KIND_MODEL "
"in the model's config.pbtxt for multi-GPU models"
)
# If KIND_GPU is specified, specify the device ID assigned by Triton to ensure that
# multiple model instances do not oversubscribe the same default device.
if triton_kind == "GPU" and triton_device_id >= 0:
self.logger.log_info(
f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}"
)
# vLLM doesn't currently (v0.4.2) expose device selection in the APIs
torch.cuda.set_device(triton_device_id)
def _setup_lora(self):
self.enable_lora = False
# Check if `enable_lora` field is in the `model.json`,
# and if it is, read its contents, which can be string or bool.
if (
"enable_lora" in self.vllm_engine_config.keys()
and str(self.vllm_engine_config["enable_lora"]).lower() == "true"
):
# create Triton LoRA weights repository
multi_lora_args_filepath = os.path.join(
pb_utils.get_model_dir(), _MULTI_LORA_ARGS_FILENAME
)
try:
with open(multi_lora_args_filepath) as lora_file:
lora_repository: Dict[str, str] = json.load(lora_file)
self.lora_repository = lora_repository
self.supported_loras: List[str] = list(self.lora_repository.keys())
self.supported_loras_len = len(self.supported_loras)
self.enable_lora = True
except FileNotFoundError:
raise FileNotFoundError(
f"Triton backend cannot find {multi_lora_args_filepath}."
)
def _setup_metrics(self):
self._vllm_metrics = None
# TODO: Do not read metrics directly from the vLLM engine, read from prometheus
# client to allow the use of ZMQ process when metrics are enabled. See
# https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/entrypoints/openai/api_server.py#L222-L245
if self._enable_metrics:
try:
labels = {
"model": self.args["model_name"],
"version": self.args["model_version"],
}
# Add vLLM custom metrics
engine_config = self._llm_engine.engine.model_config
self._vllm_metrics = VllmStatLogger(
labels, engine_config.max_model_len, self.logger
)
self._llm_engine.add_logger("triton", self._vllm_metrics)
except pb_utils.TritonModelException as e:
if "metrics not supported" in str(e):
# Metrics are disabled at the server
self.logger.log_info("[vllm] Metrics not supported")
else:
raise e
def _get_bool_config_param(self, param_name: str) -> bool:
return (param_name in self.model_config["parameters"]) and (
self.model_config["parameters"][param_name]["string_value"].lower()
== "true"
)
def _response_loop(self):
while True:
item = self._response_queue.get()
# To signal shutdown a None item will be added to the queue.
if item is None:
break
response_state, response, response_flag = item
response_sender = response_state["response_sender"]
try:
response_sender.send(response, response_flag)
# Stop checking for cancellation if the last response is generated.
if not response_state["last_response_generated"]:
response_state["is_cancelled"] = response_sender.is_cancelled()
except Exception as e:
self.logger.log_error(
f"An error occurred while sending a response: {e}"
)
finally:
if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
self._ongoing_request_count -= 1
def execute(self, requests):
if self._enable_health_check and not self._check_health(requests):
return None
for request in requests:
request = self._verify_loras(request)
if request is not None:
assert (
self._llm_engine_shutdown_event.is_set() is False
), "Cannot create tasks after shutdown has been requested"
coro = self._generate(request)
asyncio.run_coroutine_threadsafe(coro, self._event_loop)
return None
async def _generate(self, request):
response_sender = request.get_response_sender()
response_state = {
"response_sender": response_sender,
"is_cancelled": False,
"last_response_generated": False, # last response ready but not yet sent
}
self._ongoing_request_count += 1
decrement_ongoing_request_count = True
try:
request_id = random_uuid()
(
prompt,
stream,
prepend_input,
parameters,
additional_outputs,
) = self._get_input_tensors(request)
sampling_params = TritonSamplingParams.from_dict(parameters, self.logger)
lora_name = sampling_params.lora_name
lora_request = None
if lora_name is not None:
lora_id = str(self.supported_loras.index(lora_name) + 1)
lora_int_id = int(lora_id)
lora_local_path = self.lora_repository[lora_name]
lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path)
response_iterator = self._llm_engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
)
request_output_state = {}
async for request_output in response_iterator:
# Cancellation state will be checked by the response loop and written to
# the response state if streaming. If not streaming, cancellation state
# needs to be checked here.
is_cancelled = response_state["is_cancelled"]
if not stream:
is_cancelled = response_sender.is_cancelled()
if is_cancelled:
self.logger.log_info("[vllm] Cancelling the request")
await self._llm_engine.abort(request_id)
self.logger.log_info("[vllm] Successfully cancelled the request")
if stream:
# Add cancelled final response to response loop.
response_state["last_response_generated"] = True
response = pb_utils.InferenceResponse(
error=pb_utils.TritonError(
message="Request was cancelled",
code=pb_utils.TritonError.CANCELLED,
)
)
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
decrement_ongoing_request_count = False
self._response_queue.put_nowait(
(response_state, response, flags)
)
break
# Send each response if streaming.
if stream:
response = self._create_response(
request_output_state,
request_output,
prepend_input=False,
additional_outputs=additional_outputs,
)
flags = 0
if request_output.finished:
response_state["last_response_generated"] = True
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
decrement_ongoing_request_count = False
self._response_queue.put_nowait((response_state, response, flags))
# Send the last response which contains all the outputs if not streaming.
if not stream:
response_sender.send(
self._create_response(
request_output_state={},
request_output=request_output,
prepend_input=prepend_input,
additional_outputs=additional_outputs,
),
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
)
except Exception as e:
self.logger.log_error(f"[vllm] Error generating stream: {e}")
error = pb_utils.TritonError(f"Error generating stream: {e}")
text_output_tensor = pb_utils.Tensor(
"text_output", np.asarray(["N/A"], dtype=self.output_dtype)
)
response = pb_utils.InferenceResponse(
output_tensors=[text_output_tensor], error=error
)
response_sender.send(
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
raise e
finally:
if decrement_ongoing_request_count:
self._ongoing_request_count -= 1
def _get_input_tensors(self, request):
# prompt
prompt = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy()[0]
if isinstance(prompt, bytes):
prompt = prompt.decode("utf-8")
# image
images = pb_utils.get_input_tensor_by_name(request, "image")
if images:
images_vllm = []
for image_np in images.as_numpy():
image_b = base64.b64decode(image_np.decode("utf-8"))
image_rgb = Image.open(BytesIO(image_b)).convert("RGB")
images_vllm.append(image_rgb)
if len(images_vllm) > 0:
prompt = {
"prompt": prompt,
"multi_modal_data": {"image": images_vllm},
}
# stream
stream = pb_utils.get_input_tensor_by_name(request, "stream")
if stream:
stream = stream.as_numpy()[0]
else:
stream = False
# prepend_input / exclude_input_in_output
prepend_input = pb_utils.get_input_tensor_by_name(
request, "exclude_input_in_output"
)
if prepend_input:
# When `exclude_input_in_output` is False, we want to prepend input prompt
# to output, thus prepend_input should be True, and vice versa.
prepend_input = not prepend_input.as_numpy()[0]
elif prepend_input is None and stream:
prepend_input = False
else:
prepend_input = True
if prepend_input and stream:
raise ValueError(
"When streaming, `exclude_input_in_output` = False is not allowed."
)
# parameters / sampling_parameters
# An alternative mechanism to receive serialized parameters as an input
# tensor, because request parameters are not yet supported via BLS.
sampling_parameters = pb_utils.get_input_tensor_by_name(
request, "sampling_parameters"
)
if sampling_parameters:
parameters = sampling_parameters.as_numpy()[0].decode("utf-8")
else:
parameters = request.parameters()
# additional outputs
additional_outputs = {
"return_finish_reason": None,
"return_cumulative_logprob": None,
"return_logprobs": None,
"return_num_input_tokens": None,
"return_num_output_tokens": None,
}
for tensor_name in additional_outputs.keys():
tensor = pb_utils.get_input_tensor_by_name(request, tensor_name)
if tensor:
tensor = bool(tensor.as_numpy()[0])
else:
tensor = False
additional_outputs[tensor_name] = tensor
return prompt, stream, prepend_input, parameters, additional_outputs
def _create_response(
self, request_output_state, request_output, prepend_input, additional_outputs
):
output_tensors = []
# text_output
prepend_prompt = ""
if "prev_lens_text_output" not in request_output_state:
# this is the first response
if prepend_input:
prepend_prompt = request_output.prompt
request_output_state["prev_lens_text_output"] = [0] * len(
request_output.outputs
)
prev_lens = request_output_state["prev_lens_text_output"]
text_output = [
(prepend_prompt + output.text[prev_len:]).encode("utf-8")
for output, prev_len in zip(request_output.outputs, prev_lens)
]
request_output_state["prev_lens_text_output"] = [
len(output.text) for output in request_output.outputs
]
output_tensors.append(
pb_utils.Tensor(
"text_output", np.asarray(text_output, dtype=self.output_dtype)
)
)
# finish_reason
if additional_outputs["return_finish_reason"]:
finish_reason = [
str(output.finish_reason) for output in request_output.outputs
]
output_tensors.append(
pb_utils.Tensor(
"finish_reason", np.asarray(finish_reason, dtype=np.object_)
)
)
# cumulative_logprob
if additional_outputs["return_cumulative_logprob"]:
cumulative_logprob = [
output.cumulative_logprob for output in request_output.outputs
]
output_tensors.append(
pb_utils.Tensor(
"cumulative_logprob",
np.asarray(cumulative_logprob, dtype=np.float32),
)
)
# logprobs
# https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/sequence.py#L37-L58
if additional_outputs["return_logprobs"]:
if "prev_lens_logprobs" not in request_output_state:
request_output_state["prev_lens_logprobs"] = [0] * len(
request_output.outputs
)
logprobs = []
for i in range(len(request_output.outputs)):
output = request_output.outputs[i]
if output.logprobs is None:
logprobs.append("null".encode("utf-8"))
continue
prev_len = request_output_state["prev_lens_logprobs"][i]
request_output_state["prev_lens_logprobs"][i] = len(output.logprobs)
logprobs_py = []
for logprob_d_vllm in output.logprobs[prev_len:]:
logprob_d_py = {}
for token_id, logprob_vllm in logprob_d_vllm.items():
logprob_d_py[token_id] = {
"logprob": logprob_vllm.logprob,
"rank": logprob_vllm.rank,
"decoded_token": logprob_vllm.decoded_token,
}
logprobs_py.append(logprob_d_py)
logprobs.append(json.dumps(logprobs_py).encode("utf-8"))
output_tensors.append(
pb_utils.Tensor("logprobs", np.asarray(logprobs, dtype=np.object_))
)
# num_input_tokens
if additional_outputs["return_num_input_tokens"]:
num_input_tokens = len(request_output.prompt_token_ids)
output_tensors.append(
pb_utils.Tensor(
"num_input_tokens", np.asarray(num_input_tokens, dtype=np.uint32)
)
)
# num_output_tokens
if additional_outputs["return_num_output_tokens"]:
if "prev_lens_num_output_tokens" not in request_output_state:
request_output_state["prev_lens_num_output_tokens"] = [0] * len(
request_output.outputs
)
prev_lens = request_output_state["prev_lens_num_output_tokens"]
num_output_tokens = [
(len(output.token_ids) - prev_len)
for output, prev_len in zip(request_output.outputs, prev_lens)
]
request_output_state["prev_lens_num_output_tokens"] = [
len(output.token_ids) for output in request_output.outputs
]
output_tensors.append(
pb_utils.Tensor(
"num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32)
)
)
return pb_utils.InferenceResponse(output_tensors=output_tensors)
def _verify_loras(self, request):
# We will check if the requested lora exists here, if not we will send a
# response with `LoRA not found` information. In this way we may avoid
# further processing.
verified_request = None
lora_error = None
lora_name = None
parameters_input_tensor = pb_utils.get_input_tensor_by_name(
request, "sampling_parameters"
)
if parameters_input_tensor:
parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8")
else:
parameters = request.parameters()
lora_name = json.loads(parameters).pop("lora_name", None)
if lora_name is not None:
if not self.enable_lora:
lora_error = pb_utils.TritonError("LoRA feature is not enabled.")
self.logger.log_info(
"[vllm] LoRA is not enabled, please restart the backend with LoRA enabled."
)
elif lora_name not in self.supported_loras:
lora_error = pb_utils.TritonError(
f"LoRA {lora_name} is not supported, we currently support {self.supported_loras}"
)
self.logger.log_info(f"[vllm] LoRA {lora_name} not found.")
if lora_error is not None:
output_tensor = pb_utils.Tensor(
"text_output",
np.asarray(["[Error] Unsupported LoRA."], dtype=self.output_dtype),
)
response = pb_utils.InferenceResponse(
output_tensors=[output_tensor], error=lora_error
)
response_sender = request.get_response_sender()
response_sender.send(
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
else:
verified_request = request
return verified_request
def _check_health(self, requests):
coro = self._llm_engine.check_health()
future = asyncio.run_coroutine_threadsafe(coro, self._event_loop)
try:
future.result()
except Exception as e:
self.logger.log_error(
f"[vllm] Engine is not healthy and model will be unloaded: {e}"
)
pb_utils.unload_model(self.model_config["name"]) # non-blocking
self._is_healthy = False
if not self._is_healthy:
for request in requests:
request.get_response_sender().send(
pb_utils.InferenceResponse(
error=pb_utils.TritonError(
message="Model is unavailable due to unhealthy vLLM engine",
code=pb_utils.TritonError.UNAVAILABLE,
)
),
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
)
return self._is_healthy
def finalize(self):
self.logger.log_info("[vllm] Issuing finalize to vllm backend")
self._event_loop.call_soon_threadsafe(self._llm_engine_shutdown_event.set)
# Shutdown the event thread.
if self._event_thread is not None:
self._event_thread.join()
self._event_thread = None
# Shutdown the response thread.
self._response_queue.put(None)
if self._response_thread is not None:
self._response_thread.join()
self._response_thread = None
# Shutdown the metrics thread.
if self._vllm_metrics is not None:
self._vllm_metrics.finalize()
# When using parallel tensors, the stub process may not shutdown due to
# unreleased references, so manually run the garbage collector once.
self.logger.log_info("[vllm] Running Garbage Collector on finalize...")
gc.collect()
self.logger.log_info("[vllm] Garbage Collector on finalize... done")