From 55b974f96f14dc359111bb34a47655eec006655f Mon Sep 17 00:00:00 2001 From: Liangsheng Yin <hnyls2002@gmail.com> Date: Sun, 29 Sep 2024 18:52:43 -0700 Subject: [PATCH] Process image in parallel (#1539) --- python/sglang/srt/managers/image_processor.py | 187 ++++++++++++++++++ .../sglang/srt/managers/tokenizer_manager.py | 164 ++------------- 2 files changed, 204 insertions(+), 147 deletions(-) create mode 100644 python/sglang/srt/managers/image_processor.py diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py new file mode 100644 index 00000000000..e1e54af7fcc --- /dev/null +++ b/python/sglang/srt/managers/image_processor.py @@ -0,0 +1,187 @@ +# TODO: also move pad_input_ids into this module +import asyncio +import concurrent.futures +import logging +import multiprocessing as mp +import os +from abc import ABC, abstractmethod +from typing import List, Optional, Union + +import numpy as np +import transformers + +from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.mm_utils import expand2square, process_anyres_image +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import load_image +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + +global global_processor + + +def init_global_processor(server_args: ServerArgs): + """Init the global processor for multi modal models.""" + global global_processor + transformers.logging.set_verbosity_error() + global_processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + + +class BaseImageProcessor(ABC): + @abstractmethod + async def process_images_async(self, image_data, **kwargs): + pass + + +class DummyImageProcessor(BaseImageProcessor): + async def process_images_async(self, *args, **kwargs): + return None + + +class LlavaImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _image_processor): + self.hf_config = hf_config + self._image_processor = _image_processor + self.executor = concurrent.futures.ProcessPoolExecutor( + initializer=init_global_processor, + mp_context=mp.get_context("fork"), + initargs=(server_args,), + max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()), + ) + + @staticmethod + def _process_single_image_task( + image_data: Union[str, bytes], + image_aspect_ratio: Optional[str] = None, + image_grid_pinpoints: Optional[str] = None, + image_processor=None, + ): + image_processor = image_processor or global_processor.image_processor + + try: + image, image_size = load_image(image_data) + if image_size is not None: + # It is a video with multiple images + image_hash = hash(image_data) + pixel_values = image_processor(image)["pixel_values"] + for _ in range(len(pixel_values)): + pixel_values[_] = pixel_values[_].astype(np.float16) + pixel_values = np.stack(pixel_values, axis=0) + return pixel_values, image_hash, image_size + else: + # It is an image + image_hash = hash(image_data) + if image_aspect_ratio == "pad": + image = expand2square( + image, + tuple(int(x * 255) for x in image_processor.image_mean), + ) + pixel_values = image_processor(image.convert("RGB"))[ + "pixel_values" + ][0] + elif image_aspect_ratio == "anyres" or ( + image_aspect_ratio is not None + and "anyres_max" in image_aspect_ratio + ): + pixel_values = process_anyres_image( + image, image_processor, image_grid_pinpoints + ) + else: + pixel_values = image_processor(image)["pixel_values"][0] + + if isinstance(pixel_values, np.ndarray): + pixel_values = pixel_values.astype(np.float16) + + return pixel_values, image_hash, image.size + except Exception: + logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) + + async def _process_single_image( + self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str + ): + if self.executor is not None: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, + LlavaImageProcessor._process_single_image_task, + image_data, + aspect_ratio, + grid_pinpoints, + ) + else: + return self._process_single_image_task( + image_data, aspect_ratio, grid_pinpoints + ) + + async def process_images_async( + self, image_data: List[Union[str, bytes]], request_obj + ): + if not image_data: + return None + + aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) + grid_pinpoints = ( + self.hf_config.image_grid_pinpoints + if hasattr(self.hf_config, "image_grid_pinpoints") + and "anyres" in aspect_ratio + else None + ) + + if isinstance(image_data, list) and len(image_data) > 0: + # Multiple images + if len(image_data) > 1: + aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres + pixel_values, image_hashes, image_sizes = [], [], [] + res = [] + for img_data in image_data: + res.append( + self._process_single_image( + img_data, aspect_ratio, grid_pinpoints + ) + ) + res = await asyncio.gather(*res) + for pixel_v, image_h, image_s in res: + pixel_values.append(pixel_v) + image_hashes.append(image_h) + image_sizes.append(image_s) + + if isinstance(pixel_values[0], np.ndarray): + pixel_values = np.stack(pixel_values, axis=0) + else: + # A single image + pixel_values, image_hash, image_size = await self._process_single_image( + image_data[0], aspect_ratio, grid_pinpoints + ) + image_hashes = [image_hash] + image_sizes = [image_size] + elif isinstance(image_data, str): + # A single image + pixel_values, image_hash, image_size = await self._process_single_image( + image_data, aspect_ratio, grid_pinpoints + ) + image_hashes = [image_hash] + image_sizes = [image_size] + else: + raise ValueError(f"Invalid image data: {image_data}") + + return { + "pixel_values": pixel_values, + "image_hashes": image_hashes, + "image_sizes": image_sizes, + "modalities": request_obj.modalities, + } + + +def get_image_processor( + hf_config, server_args: ServerArgs, _image_processor +) -> BaseImageProcessor: + return LlavaImageProcessor(hf_config, server_args, _image_processor) + + +def get_dummy_image_processor(): + return DummyImageProcessor() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 78ea0d1682f..3103faec830 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -16,17 +16,13 @@ """TokenizerManager is a process that tokenizes the text.""" import asyncio -import concurrent.futures import dataclasses import json import logging -import multiprocessing as mp import os from typing import Dict, List, Optional, Tuple, Union import fastapi -import numpy as np -import transformers import uvloop import zmq import zmq.asyncio @@ -38,6 +34,10 @@ get_processor, get_tokenizer, ) +from sglang.srt.managers.image_processor import ( + get_dummy_image_processor, + get_image_processor, +) from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -53,11 +53,9 @@ UpdateWeightReqInput, UpdateWeightReqOutput, ) -from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image -from sglang.utils import get_exception_traceback +from sglang.srt.utils import is_generation_model, is_multimodal_model asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -105,6 +103,8 @@ def __init__( self.context_len = server_args.context_length or get_context_length( self.hf_config ) + # Create image processor placeholder + self.image_processor = get_dummy_image_processor() # Create tokenizer if server_args.skip_tokenizer_init: @@ -119,13 +119,9 @@ def __init__( self.tokenizer = self.processor.tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" - # We want to parallelize the image pre-processing so we - # create an executor for it - self.executor = concurrent.futures.ProcessPoolExecutor( - initializer=init_global_processor, - mp_context=mp.get_context("fork"), - initargs=(server_args,), - max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()), + # We want to parallelize the image pre-processing so we create an executor for it + self.image_processor = get_image_processor( + self.hf_config, server_args, self.processor.image_processor ) else: self.tokenizer = get_tokenizer( @@ -194,8 +190,8 @@ async def _handle_single_request( ) if self.is_generation: - image_inputs = await self._get_image_inputs( - obj, obj.image_data if not_use_index else obj.image_data[index] + image_inputs = await self.image_processor.process_images_async( + obj.image_data if not_use_index else obj.image_data[index], obj ) return_logprob = ( obj.return_logprob if not_use_index else obj.return_logprob[index] @@ -247,7 +243,9 @@ async def _handle_single_request( sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params.max_new_tokens = 0 - image_inputs = await self._get_image_inputs(obj, obj.image_data[0]) + image_inputs = await self.image_processor.process_images_async( + obj.image_data[0], obj + ) return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] @@ -362,8 +360,8 @@ async def _handle_batch_request( sampling_params = self._get_sampling_params(obj.sampling_params[index]) if self.is_generation: - image_inputs = await self._get_image_inputs( - obj, obj.image_data[index] + image_inputs = await self.image_processor.process_images_async( + obj.image_data[index], obj ) tokenized_obj = TokenizedGenerateReqInput( @@ -686,131 +684,3 @@ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): token_top_logprobs, decode_to_text ) return top_logprobs - - async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]): - if not image_data: - return None - - # TODO: move this into a processor for each vision architecture - aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) - grid_pinpoints = ( - self.hf_config.image_grid_pinpoints - if hasattr(self.hf_config, "image_grid_pinpoints") - and "anyres" in aspect_ratio - else None - ) - - if isinstance(image_data, list) and len(image_data) > 0: - # Multiple images - if len(image_data) > 1: - aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres - pixel_values, image_hashes, image_sizes = [], [], [] - for img_data in image_data: - pixel_v, image_h, image_s = await self._process_single_image( - img_data, aspect_ratio, grid_pinpoints - ) - pixel_values.append(pixel_v) - image_hashes.append(image_h) - image_sizes.append(image_s) - - if isinstance(pixel_values[0], np.ndarray): - pixel_values = np.stack(pixel_values, axis=0) - else: - # A single image - pixel_values, image_hash, image_size = await self._process_single_image( - image_data[0], aspect_ratio, grid_pinpoints - ) - image_hashes = [image_hash] - image_sizes = [image_size] - elif isinstance(image_data, str): - # A single image - pixel_values, image_hash, image_size = await self._process_single_image( - image_data, aspect_ratio, grid_pinpoints - ) - image_hashes = [image_hash] - image_sizes = [image_size] - else: - raise ValueError(f"Invalid image data: {image_data}") - - return { - "pixel_values": pixel_values, - "image_hashes": image_hashes, - "image_sizes": image_sizes, - "modalities": obj.modalities, - } - - async def _process_single_image( - self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str - ): - if self.executor is not None: - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, - _process_single_image_task, - image_data, - aspect_ratio, - grid_pinpoints, - ) - else: - return _process_single_image_task( - image_data, aspect_ratio, grid_pinpoints, self.processor - ) - - -global global_processor - - -def init_global_processor(server_args: ServerArgs): - """Init the global processor for multi modal models.""" - global global_processor - transformers.logging.set_verbosity_error() - global_processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) - - -def _process_single_image_task( - image_data: Union[str, bytes], - image_aspect_ratio: Optional[str] = None, - image_grid_pinpoints: Optional[str] = None, - processor=None, -): - try: - processor = processor or global_processor - image, image_size = load_image(image_data) - if image_size is not None: - # It is a video with multiple images - image_hash = hash(image_data) - pixel_values = processor.image_processor(image)["pixel_values"] - for _ in range(len(pixel_values)): - pixel_values[_] = pixel_values[_].astype(np.float16) - pixel_values = np.stack(pixel_values, axis=0) - return pixel_values, image_hash, image_size - else: - # It is an image - image_hash = hash(image_data) - if image_aspect_ratio == "pad": - image = expand2square( - image, - tuple(int(x * 255) for x in processor.image_processor.image_mean), - ) - pixel_values = processor.image_processor(image.convert("RGB"))[ - "pixel_values" - ][0] - elif image_aspect_ratio == "anyres" or ( - image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio - ): - pixel_values = process_anyres_image( - image, processor.image_processor, image_grid_pinpoints - ) - else: - pixel_values = processor.image_processor(image)["pixel_values"][0] - - if isinstance(pixel_values, np.ndarray): - pixel_values = pixel_values.astype(np.float16) - - return pixel_values, image_hash, image.size - except Exception: - logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())