Skip to content

Commit

Permalink
Process image in parallel (#1539)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Sep 30, 2024
1 parent f86c1e6 commit 55b974f
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 147 deletions.
187 changes: 187 additions & 0 deletions python/sglang/srt/managers/image_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# TODO: also move pad_input_ids into this module
import asyncio
import concurrent.futures
import logging
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Union

import numpy as np
import transformers

from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import load_image
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

global global_processor


def init_global_processor(server_args: ServerArgs):
"""Init the global processor for multi modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)


class BaseImageProcessor(ABC):
@abstractmethod
async def process_images_async(self, image_data, **kwargs):
pass


class DummyImageProcessor(BaseImageProcessor):
async def process_images_async(self, *args, **kwargs):
return None


class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
self._image_processor = _image_processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)

@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
image_processor = image_processor or global_processor.image_processor

try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None
and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, image_processor, image_grid_pinpoints
)
else:
pixel_values = image_processor(image)["pixel_values"][0]

if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)

return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())

async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
LlavaImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints
)

async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj
):
if not image_data:
return None

aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)

if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(
self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
)
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)

if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")

return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities,
}


def get_image_processor(
hf_config, server_args: ServerArgs, _image_processor
) -> BaseImageProcessor:
return LlavaImageProcessor(hf_config, server_args, _image_processor)


def get_dummy_image_processor():
return DummyImageProcessor()
Loading

0 comments on commit 55b974f

Please sign in to comment.