From 55b974f96f14dc359111bb34a47655eec006655f Mon Sep 17 00:00:00 2001
From: Liangsheng Yin <hnyls2002@gmail.com>
Date: Sun, 29 Sep 2024 18:52:43 -0700
Subject: [PATCH] Process image in parallel (#1539)

---
 python/sglang/srt/managers/image_processor.py | 187 ++++++++++++++++++
 .../sglang/srt/managers/tokenizer_manager.py  | 164 ++-------------
 2 files changed, 204 insertions(+), 147 deletions(-)
 create mode 100644 python/sglang/srt/managers/image_processor.py

diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py
new file mode 100644
index 00000000000..e1e54af7fcc
--- /dev/null
+++ b/python/sglang/srt/managers/image_processor.py
@@ -0,0 +1,187 @@
+# TODO: also move pad_input_ids into this module
+import asyncio
+import concurrent.futures
+import logging
+import multiprocessing as mp
+import os
+from abc import ABC, abstractmethod
+from typing import List, Optional, Union
+
+import numpy as np
+import transformers
+
+from sglang.srt.hf_transformers_utils import get_processor
+from sglang.srt.mm_utils import expand2square, process_anyres_image
+from sglang.srt.server_args import ServerArgs
+from sglang.srt.utils import load_image
+from sglang.utils import get_exception_traceback
+
+logger = logging.getLogger(__name__)
+
+global global_processor
+
+
+def init_global_processor(server_args: ServerArgs):
+    """Init the global processor for multi modal models."""
+    global global_processor
+    transformers.logging.set_verbosity_error()
+    global_processor = get_processor(
+        server_args.tokenizer_path,
+        tokenizer_mode=server_args.tokenizer_mode,
+        trust_remote_code=server_args.trust_remote_code,
+    )
+
+
+class BaseImageProcessor(ABC):
+    @abstractmethod
+    async def process_images_async(self, image_data, **kwargs):
+        pass
+
+
+class DummyImageProcessor(BaseImageProcessor):
+    async def process_images_async(self, *args, **kwargs):
+        return None
+
+
+class LlavaImageProcessor(BaseImageProcessor):
+    def __init__(self, hf_config, server_args, _image_processor):
+        self.hf_config = hf_config
+        self._image_processor = _image_processor
+        self.executor = concurrent.futures.ProcessPoolExecutor(
+            initializer=init_global_processor,
+            mp_context=mp.get_context("fork"),
+            initargs=(server_args,),
+            max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
+        )
+
+    @staticmethod
+    def _process_single_image_task(
+        image_data: Union[str, bytes],
+        image_aspect_ratio: Optional[str] = None,
+        image_grid_pinpoints: Optional[str] = None,
+        image_processor=None,
+    ):
+        image_processor = image_processor or global_processor.image_processor
+
+        try:
+            image, image_size = load_image(image_data)
+            if image_size is not None:
+                # It is a video with multiple images
+                image_hash = hash(image_data)
+                pixel_values = image_processor(image)["pixel_values"]
+                for _ in range(len(pixel_values)):
+                    pixel_values[_] = pixel_values[_].astype(np.float16)
+                pixel_values = np.stack(pixel_values, axis=0)
+                return pixel_values, image_hash, image_size
+            else:
+                # It is an image
+                image_hash = hash(image_data)
+                if image_aspect_ratio == "pad":
+                    image = expand2square(
+                        image,
+                        tuple(int(x * 255) for x in image_processor.image_mean),
+                    )
+                    pixel_values = image_processor(image.convert("RGB"))[
+                        "pixel_values"
+                    ][0]
+                elif image_aspect_ratio == "anyres" or (
+                    image_aspect_ratio is not None
+                    and "anyres_max" in image_aspect_ratio
+                ):
+                    pixel_values = process_anyres_image(
+                        image, image_processor, image_grid_pinpoints
+                    )
+                else:
+                    pixel_values = image_processor(image)["pixel_values"][0]
+
+                if isinstance(pixel_values, np.ndarray):
+                    pixel_values = pixel_values.astype(np.float16)
+
+                return pixel_values, image_hash, image.size
+        except Exception:
+            logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
+
+    async def _process_single_image(
+        self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
+    ):
+        if self.executor is not None:
+            loop = asyncio.get_event_loop()
+            return await loop.run_in_executor(
+                self.executor,
+                LlavaImageProcessor._process_single_image_task,
+                image_data,
+                aspect_ratio,
+                grid_pinpoints,
+            )
+        else:
+            return self._process_single_image_task(
+                image_data, aspect_ratio, grid_pinpoints
+            )
+
+    async def process_images_async(
+        self, image_data: List[Union[str, bytes]], request_obj
+    ):
+        if not image_data:
+            return None
+
+        aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
+        grid_pinpoints = (
+            self.hf_config.image_grid_pinpoints
+            if hasattr(self.hf_config, "image_grid_pinpoints")
+            and "anyres" in aspect_ratio
+            else None
+        )
+
+        if isinstance(image_data, list) and len(image_data) > 0:
+            # Multiple images
+            if len(image_data) > 1:
+                aspect_ratio = "pad"  # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
+                pixel_values, image_hashes, image_sizes = [], [], []
+                res = []
+                for img_data in image_data:
+                    res.append(
+                        self._process_single_image(
+                            img_data, aspect_ratio, grid_pinpoints
+                        )
+                    )
+                res = await asyncio.gather(*res)
+                for pixel_v, image_h, image_s in res:
+                    pixel_values.append(pixel_v)
+                    image_hashes.append(image_h)
+                    image_sizes.append(image_s)
+
+                if isinstance(pixel_values[0], np.ndarray):
+                    pixel_values = np.stack(pixel_values, axis=0)
+            else:
+                # A single image
+                pixel_values, image_hash, image_size = await self._process_single_image(
+                    image_data[0], aspect_ratio, grid_pinpoints
+                )
+                image_hashes = [image_hash]
+                image_sizes = [image_size]
+        elif isinstance(image_data, str):
+            # A single image
+            pixel_values, image_hash, image_size = await self._process_single_image(
+                image_data, aspect_ratio, grid_pinpoints
+            )
+            image_hashes = [image_hash]
+            image_sizes = [image_size]
+        else:
+            raise ValueError(f"Invalid image data: {image_data}")
+
+        return {
+            "pixel_values": pixel_values,
+            "image_hashes": image_hashes,
+            "image_sizes": image_sizes,
+            "modalities": request_obj.modalities,
+        }
+
+
+def get_image_processor(
+    hf_config, server_args: ServerArgs, _image_processor
+) -> BaseImageProcessor:
+    return LlavaImageProcessor(hf_config, server_args, _image_processor)
+
+
+def get_dummy_image_processor():
+    return DummyImageProcessor()
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 78ea0d1682f..3103faec830 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -16,17 +16,13 @@
 """TokenizerManager is a process that tokenizes the text."""
 
 import asyncio
-import concurrent.futures
 import dataclasses
 import json
 import logging
-import multiprocessing as mp
 import os
 from typing import Dict, List, Optional, Tuple, Union
 
 import fastapi
-import numpy as np
-import transformers
 import uvloop
 import zmq
 import zmq.asyncio
@@ -38,6 +34,10 @@
     get_processor,
     get_tokenizer,
 )
+from sglang.srt.managers.image_processor import (
+    get_dummy_image_processor,
+    get_image_processor,
+)
 from sglang.srt.managers.io_struct import (
     AbortReq,
     BatchEmbeddingOut,
@@ -53,11 +53,9 @@
     UpdateWeightReqInput,
     UpdateWeightReqOutput,
 )
-from sglang.srt.mm_utils import expand2square, process_anyres_image
 from sglang.srt.sampling.sampling_params import SamplingParams
 from sglang.srt.server_args import PortArgs, ServerArgs
-from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
-from sglang.utils import get_exception_traceback
+from sglang.srt.utils import is_generation_model, is_multimodal_model
 
 asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
 
@@ -105,6 +103,8 @@ def __init__(
         self.context_len = server_args.context_length or get_context_length(
             self.hf_config
         )
+        # Create image processor placeholder
+        self.image_processor = get_dummy_image_processor()
 
         # Create tokenizer
         if server_args.skip_tokenizer_init:
@@ -119,13 +119,9 @@ def __init__(
                 self.tokenizer = self.processor.tokenizer
                 os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
-                # We want to parallelize the image pre-processing so we
-                # create an executor for it
-                self.executor = concurrent.futures.ProcessPoolExecutor(
-                    initializer=init_global_processor,
-                    mp_context=mp.get_context("fork"),
-                    initargs=(server_args,),
-                    max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
+                # We want to parallelize the image pre-processing so we create an executor for it
+                self.image_processor = get_image_processor(
+                    self.hf_config, server_args, self.processor.image_processor
                 )
             else:
                 self.tokenizer = get_tokenizer(
@@ -194,8 +190,8 @@ async def _handle_single_request(
             )
 
             if self.is_generation:
-                image_inputs = await self._get_image_inputs(
-                    obj, obj.image_data if not_use_index else obj.image_data[index]
+                image_inputs = await self.image_processor.process_images_async(
+                    obj.image_data if not_use_index else obj.image_data[index], obj
                 )
                 return_logprob = (
                     obj.return_logprob if not_use_index else obj.return_logprob[index]
@@ -247,7 +243,9 @@ async def _handle_single_request(
 
             sampling_params = SamplingParams(**obj.sampling_params[0])
             sampling_params.max_new_tokens = 0
-            image_inputs = await self._get_image_inputs(obj, obj.image_data[0])
+            image_inputs = await self.image_processor.process_images_async(
+                obj.image_data[0], obj
+            )
             return_logprob = obj.return_logprob[0]
             logprob_start_len = obj.logprob_start_len[0]
             top_logprobs_num = obj.top_logprobs_num[0]
@@ -362,8 +360,8 @@ async def _handle_batch_request(
                 sampling_params = self._get_sampling_params(obj.sampling_params[index])
 
                 if self.is_generation:
-                    image_inputs = await self._get_image_inputs(
-                        obj, obj.image_data[index]
+                    image_inputs = await self.image_processor.process_images_async(
+                        obj.image_data[index], obj
                     )
 
                     tokenized_obj = TokenizedGenerateReqInput(
@@ -686,131 +684,3 @@ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
                     token_top_logprobs, decode_to_text
                 )
         return top_logprobs
-
-    async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]):
-        if not image_data:
-            return None
-
-        # TODO: move this into a processor for each vision architecture
-        aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
-        grid_pinpoints = (
-            self.hf_config.image_grid_pinpoints
-            if hasattr(self.hf_config, "image_grid_pinpoints")
-            and "anyres" in aspect_ratio
-            else None
-        )
-
-        if isinstance(image_data, list) and len(image_data) > 0:
-            # Multiple images
-            if len(image_data) > 1:
-                aspect_ratio = "pad"  # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
-                pixel_values, image_hashes, image_sizes = [], [], []
-                for img_data in image_data:
-                    pixel_v, image_h, image_s = await self._process_single_image(
-                        img_data, aspect_ratio, grid_pinpoints
-                    )
-                    pixel_values.append(pixel_v)
-                    image_hashes.append(image_h)
-                    image_sizes.append(image_s)
-
-                if isinstance(pixel_values[0], np.ndarray):
-                    pixel_values = np.stack(pixel_values, axis=0)
-            else:
-                # A single image
-                pixel_values, image_hash, image_size = await self._process_single_image(
-                    image_data[0], aspect_ratio, grid_pinpoints
-                )
-                image_hashes = [image_hash]
-                image_sizes = [image_size]
-        elif isinstance(image_data, str):
-            # A single image
-            pixel_values, image_hash, image_size = await self._process_single_image(
-                image_data, aspect_ratio, grid_pinpoints
-            )
-            image_hashes = [image_hash]
-            image_sizes = [image_size]
-        else:
-            raise ValueError(f"Invalid image data: {image_data}")
-
-        return {
-            "pixel_values": pixel_values,
-            "image_hashes": image_hashes,
-            "image_sizes": image_sizes,
-            "modalities": obj.modalities,
-        }
-
-    async def _process_single_image(
-        self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
-    ):
-        if self.executor is not None:
-            loop = asyncio.get_event_loop()
-            return await loop.run_in_executor(
-                self.executor,
-                _process_single_image_task,
-                image_data,
-                aspect_ratio,
-                grid_pinpoints,
-            )
-        else:
-            return _process_single_image_task(
-                image_data, aspect_ratio, grid_pinpoints, self.processor
-            )
-
-
-global global_processor
-
-
-def init_global_processor(server_args: ServerArgs):
-    """Init the global processor for multi modal models."""
-    global global_processor
-    transformers.logging.set_verbosity_error()
-    global_processor = get_processor(
-        server_args.tokenizer_path,
-        tokenizer_mode=server_args.tokenizer_mode,
-        trust_remote_code=server_args.trust_remote_code,
-    )
-
-
-def _process_single_image_task(
-    image_data: Union[str, bytes],
-    image_aspect_ratio: Optional[str] = None,
-    image_grid_pinpoints: Optional[str] = None,
-    processor=None,
-):
-    try:
-        processor = processor or global_processor
-        image, image_size = load_image(image_data)
-        if image_size is not None:
-            # It is a video with multiple images
-            image_hash = hash(image_data)
-            pixel_values = processor.image_processor(image)["pixel_values"]
-            for _ in range(len(pixel_values)):
-                pixel_values[_] = pixel_values[_].astype(np.float16)
-            pixel_values = np.stack(pixel_values, axis=0)
-            return pixel_values, image_hash, image_size
-        else:
-            # It is an image
-            image_hash = hash(image_data)
-            if image_aspect_ratio == "pad":
-                image = expand2square(
-                    image,
-                    tuple(int(x * 255) for x in processor.image_processor.image_mean),
-                )
-                pixel_values = processor.image_processor(image.convert("RGB"))[
-                    "pixel_values"
-                ][0]
-            elif image_aspect_ratio == "anyres" or (
-                image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
-            ):
-                pixel_values = process_anyres_image(
-                    image, processor.image_processor, image_grid_pinpoints
-                )
-            else:
-                pixel_values = processor.image_processor(image)["pixel_values"][0]
-
-            if isinstance(pixel_values, np.ndarray):
-                pixel_values = pixel_values.astype(np.float16)
-
-            return pixel_values, image_hash, image.size
-    except Exception:
-        logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())