-
Notifications
You must be signed in to change notification settings - Fork 812
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
204 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
# TODO: also move pad_input_ids into this module | ||
import asyncio | ||
import concurrent.futures | ||
import logging | ||
import multiprocessing as mp | ||
import os | ||
from abc import ABC, abstractmethod | ||
from typing import List, Optional, Union | ||
|
||
import numpy as np | ||
import transformers | ||
|
||
from sglang.srt.hf_transformers_utils import get_processor | ||
from sglang.srt.mm_utils import expand2square, process_anyres_image | ||
from sglang.srt.server_args import ServerArgs | ||
from sglang.srt.utils import load_image | ||
from sglang.utils import get_exception_traceback | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
global global_processor | ||
|
||
|
||
def init_global_processor(server_args: ServerArgs): | ||
"""Init the global processor for multi modal models.""" | ||
global global_processor | ||
transformers.logging.set_verbosity_error() | ||
global_processor = get_processor( | ||
server_args.tokenizer_path, | ||
tokenizer_mode=server_args.tokenizer_mode, | ||
trust_remote_code=server_args.trust_remote_code, | ||
) | ||
|
||
|
||
class BaseImageProcessor(ABC): | ||
@abstractmethod | ||
async def process_images_async(self, image_data, **kwargs): | ||
pass | ||
|
||
|
||
class DummyImageProcessor(BaseImageProcessor): | ||
async def process_images_async(self, *args, **kwargs): | ||
return None | ||
|
||
|
||
class LlavaImageProcessor(BaseImageProcessor): | ||
def __init__(self, hf_config, server_args, _image_processor): | ||
self.hf_config = hf_config | ||
self._image_processor = _image_processor | ||
self.executor = concurrent.futures.ProcessPoolExecutor( | ||
initializer=init_global_processor, | ||
mp_context=mp.get_context("fork"), | ||
initargs=(server_args,), | ||
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()), | ||
) | ||
|
||
@staticmethod | ||
def _process_single_image_task( | ||
image_data: Union[str, bytes], | ||
image_aspect_ratio: Optional[str] = None, | ||
image_grid_pinpoints: Optional[str] = None, | ||
image_processor=None, | ||
): | ||
image_processor = image_processor or global_processor.image_processor | ||
|
||
try: | ||
image, image_size = load_image(image_data) | ||
if image_size is not None: | ||
# It is a video with multiple images | ||
image_hash = hash(image_data) | ||
pixel_values = image_processor(image)["pixel_values"] | ||
for _ in range(len(pixel_values)): | ||
pixel_values[_] = pixel_values[_].astype(np.float16) | ||
pixel_values = np.stack(pixel_values, axis=0) | ||
return pixel_values, image_hash, image_size | ||
else: | ||
# It is an image | ||
image_hash = hash(image_data) | ||
if image_aspect_ratio == "pad": | ||
image = expand2square( | ||
image, | ||
tuple(int(x * 255) for x in image_processor.image_mean), | ||
) | ||
pixel_values = image_processor(image.convert("RGB"))[ | ||
"pixel_values" | ||
][0] | ||
elif image_aspect_ratio == "anyres" or ( | ||
image_aspect_ratio is not None | ||
and "anyres_max" in image_aspect_ratio | ||
): | ||
pixel_values = process_anyres_image( | ||
image, image_processor, image_grid_pinpoints | ||
) | ||
else: | ||
pixel_values = image_processor(image)["pixel_values"][0] | ||
|
||
if isinstance(pixel_values, np.ndarray): | ||
pixel_values = pixel_values.astype(np.float16) | ||
|
||
return pixel_values, image_hash, image.size | ||
except Exception: | ||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) | ||
|
||
async def _process_single_image( | ||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str | ||
): | ||
if self.executor is not None: | ||
loop = asyncio.get_event_loop() | ||
return await loop.run_in_executor( | ||
self.executor, | ||
LlavaImageProcessor._process_single_image_task, | ||
image_data, | ||
aspect_ratio, | ||
grid_pinpoints, | ||
) | ||
else: | ||
return self._process_single_image_task( | ||
image_data, aspect_ratio, grid_pinpoints | ||
) | ||
|
||
async def process_images_async( | ||
self, image_data: List[Union[str, bytes]], request_obj | ||
): | ||
if not image_data: | ||
return None | ||
|
||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) | ||
grid_pinpoints = ( | ||
self.hf_config.image_grid_pinpoints | ||
if hasattr(self.hf_config, "image_grid_pinpoints") | ||
and "anyres" in aspect_ratio | ||
else None | ||
) | ||
|
||
if isinstance(image_data, list) and len(image_data) > 0: | ||
# Multiple images | ||
if len(image_data) > 1: | ||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres | ||
pixel_values, image_hashes, image_sizes = [], [], [] | ||
res = [] | ||
for img_data in image_data: | ||
res.append( | ||
self._process_single_image( | ||
img_data, aspect_ratio, grid_pinpoints | ||
) | ||
) | ||
res = await asyncio.gather(*res) | ||
for pixel_v, image_h, image_s in res: | ||
pixel_values.append(pixel_v) | ||
image_hashes.append(image_h) | ||
image_sizes.append(image_s) | ||
|
||
if isinstance(pixel_values[0], np.ndarray): | ||
pixel_values = np.stack(pixel_values, axis=0) | ||
else: | ||
# A single image | ||
pixel_values, image_hash, image_size = await self._process_single_image( | ||
image_data[0], aspect_ratio, grid_pinpoints | ||
) | ||
image_hashes = [image_hash] | ||
image_sizes = [image_size] | ||
elif isinstance(image_data, str): | ||
# A single image | ||
pixel_values, image_hash, image_size = await self._process_single_image( | ||
image_data, aspect_ratio, grid_pinpoints | ||
) | ||
image_hashes = [image_hash] | ||
image_sizes = [image_size] | ||
else: | ||
raise ValueError(f"Invalid image data: {image_data}") | ||
|
||
return { | ||
"pixel_values": pixel_values, | ||
"image_hashes": image_hashes, | ||
"image_sizes": image_sizes, | ||
"modalities": request_obj.modalities, | ||
} | ||
|
||
|
||
def get_image_processor( | ||
hf_config, server_args: ServerArgs, _image_processor | ||
) -> BaseImageProcessor: | ||
return LlavaImageProcessor(hf_config, server_args, _image_processor) | ||
|
||
|
||
def get_dummy_image_processor(): | ||
return DummyImageProcessor() |
Oops, something went wrong.