Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 committed Oct 1, 2024
1 parent 64a5ebc commit 8b063e0
Showing 1 changed file with 50 additions and 91 deletions.
141 changes: 50 additions & 91 deletions python/sglang/srt/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"""PyTorch Mllama model."""
import math
from array import array
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union
from typing import Iterable, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand Down Expand Up @@ -45,20 +46,6 @@
MLLAMA_IMAGE_TOKEN = "<|image|>"


class MllamaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: """
"""(batch_size, max_num_image, max_num_chunk, num_channel, height, width)"""
aspect_ratio_ids: torch.Tensor
"""Shape: `(batch_size, max_num_image)`"""
aspect_ratio_mask: torch.Tensor
"""Shape: `(batch_size, max_num_image, max_num_tiles)`"""


# TODO: support LlamaImageEmbeddingInputs


def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
# move encoder_prompt to prompt
if llm_inputs.get("prompt") is None:
Expand Down Expand Up @@ -296,7 +283,6 @@ def forward(
return hidden_state


# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):

def __init__(self, config: config_mllama.MllamaVisionConfig):
Expand Down Expand Up @@ -710,13 +696,13 @@ def __init__(
self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.scaling = self.head_dim**-0.5

# TODO: impl cross attention
self.attn = RadixAttention(
self.num_local_heads,
self.head_dim,
self.scaling,
self.num_local_key_value_heads,
layer_id=layer_id,
is_cross_attention=True,
)

def forward(
Expand Down Expand Up @@ -946,6 +932,7 @@ def __init__(
config.pad_token_id if config.pad_token_id is not None else -1
)
self.image_size = config.vision_config.image_size
self.has_cross_attention = True

self.vision_model = MllamaVisionModel(config.vision_config)
self.language_model = MllamaForCausalLM(
Expand All @@ -967,86 +954,59 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
num_patches = self.vision_model.num_patches
image_len = num_concurrent_media * num_tiles * num_patches
image_inputs.num_image_tokens = image_len

pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))

return pad_ids[:image_len] + input_ids

def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by
# MultiModalInputs.batch, so pixel_values here can be:
# - List[List[torch.Tensor]]:
# with shape (num_tiles, 3, image_res, image_res)
# - List[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values: Optional[
Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]
] = kwargs.pop("pixel_values", None)
image_embeds: Optional[
Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]
] = kwargs.pop("image_embeds", None)
aspect_ratio_ids: Optional[
Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]
] = kwargs.pop("aspect_ratio_ids", None)
aspect_ratio_mask: Optional[
Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]
] = kwargs.pop("aspect_ratio_mask", None)

if pixel_values is None and image_embeds is None:
return None
def _handle_image_inputs(self, forward_batch: ForwardBatch):
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
image_inputs = forward_batch.image_inputs
image_lens = np.array(
[im.image_len if im is not None else 0 for im in image_inputs]
)
# TODO: need_vision handled here
print(image_lens)
return

if pixel_values is not None and image_embeds is not None:
raise ValueError("Both pixel values and image embeds are provided.")

if pixel_values is not None:
assert aspect_ratio_ids is not None
assert aspect_ratio_mask is not None
max_num_images = max([len(x[0]) for x in pixel_values])
if max_num_images == 0:
raise ValueError("No images provided.")
max_num_tiles = max(max([len(x) for x in y[0]]) for y in pixel_values)
device = self.multi_modal_projector.weight.device
bsz = len(pixel_values)
out_num_tiles = []
out_images = torch.zeros(
bsz,
max_num_images,
max_num_tiles,
3,
self.image_size,
self.image_size,
dtype=torch.float32,
device=device,
)
out_ar_ids = torch.ones(
bsz, max_num_images, dtype=torch.int64, device=device
)
out_ar_mask = torch.zeros(
bsz, max_num_images, max_num_tiles, dtype=torch.int64, device=device
)
for b in range(len(pixel_values)):
_num_tiles = []
for i in range(len(pixel_values[b][0])):
img = pixel_values[b][0][i]
out_images[b, i, : img.shape[0]] = img
out_ar_ids[b, i] = aspect_ratio_ids[b][0][i]
out_ar_mask[b, i] = aspect_ratio_mask[b][0][i]
_num_tiles.append(img.shape[0])
out_num_tiles.append(_num_tiles)

return MllamaImagePixelInputs(
type="pixel_values",
data=out_images,
aspect_ratio_ids=out_ar_ids,
aspect_ratio_mask=out_ar_mask,
)
pixel_values = image_inputs.pixel_values
aspect_ratio_ids = image_inputs.aspect_ratio_ids
aspect_ratio_mask = image_inputs.aspect_ratio_mask

if image_embeds is not None:
raise NotImplementedError
print(pixel_values.shape)
print(aspect_ratio_ids.shape)
print(aspect_ratio_mask.shape)

raise AssertionError("This line should be unreachable.")
if pixel_values is None:
return None

assert aspect_ratio_ids is not None
assert aspect_ratio_mask is not None
max_num_images = max([len(x[0]) for x in pixel_values])
max_num_tiles = max(max([len(x) for x in y[0]]) for y in pixel_values)
device = self.multi_modal_projector.weight.device
bsz = len(pixel_values)
out_images = torch.zeros(
bsz,
max_num_images,
max_num_tiles,
3,
self.image_size,
self.image_size,
dtype=torch.float32,
device=device,
)
out_ar_ids = torch.ones(bsz, max_num_images, dtype=torch.int64, device=device)
out_ar_mask = torch.zeros(
bsz, max_num_images, max_num_tiles, dtype=torch.int64, device=device
)
for b in range(len(pixel_values)):
for i in range(len(pixel_values[b][0])):
img = pixel_values[b][0][i]
out_images[b, i, : img.shape[0]] = img
out_ar_ids[b, i] = aspect_ratio_ids[b][0][i]
out_ar_mask[b, i] = aspect_ratio_mask[b][0][i]

def flat_encoder_result(
self, cross_attention_states: torch.Tensor, forward_batch: ForwardBatch
Expand Down Expand Up @@ -1095,9 +1055,8 @@ def forward(
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]:
image_inputs = self._parse_and_validate_image_input(**kwargs)
image_inputs = self._handle_image_inputs(forward_batch)
setattr(
forward_batch,
"encoder_seq_lens_tensor",
Expand Down

0 comments on commit 8b063e0

Please sign in to comment.