diff --git a/inference/benchmarks/sam_h/README.md b/inference/benchmarks/sam_h/README.md new file mode 100644 index 000000000..02d6cf352 --- /dev/null +++ b/inference/benchmarks/sam_h/README.md @@ -0,0 +1,62 @@ +### 1. 推理数据集 + +* 采用单张图及其在原始huggingface doc中SamModel(cpu)上运行的结果作为推理数据集与评估指标 +* 数据集下载 + * 在dataloader.py中自动下载 +* groundTruth制作 + * 运行hugging face文档中使用SamModel推理的样例,将其计算出的3个image_size个bool mask矩阵按照mask从小到大的顺序依次存储为sam_gt_0.pt~sam_gt_2.pt,并放置在data_dir下 + +### 2. 模型与权重 + +* 模型实现 + * pytorch:transformers.SamModel +* 权重下载 + * pytorch:from_pretrained("facebook/sam_vit_huge")(hugging face) + +### 3. 软硬件配置与运行信息参考 + +#### 3.1 Nvidia A100 + +- ##### 硬件环境 + + - 机器、加速卡型号: NVIDIA_A100-SXM4-40GB + - 多机网络类型、带宽: InfiniBand,200Gb/s + +- ##### 软件环境 + + - OS版本:Ubuntu 20.04 + - OS kernel版本: 5.4.0-113-generic + - 加速卡驱动版本:470.129.06 + - Docker 版本:20.10.16 + - 训练框架版本:pytorch-2.1.0a0+4136153 + - 依赖软件版本: + - cuda: 12.1 + +- 推理工具包 + + - TensorRT 8.6.1 + +### 3. 运行情况 + +* 指标列表 + +| 指标名称 | 指标值索引 | 特殊说明 | +| ------------------ | ---------------- | -------------------------------------------- | +| 数据精度 | precision | 可选fp32/fp16 | +| 批尺寸 | bs | | +| 硬件存储使用 | mem | 通常称为“显存”,单位为GiB | +| 端到端时间 | e2e_time | 总时间+Perf初始化等时间 | +| 验证总吞吐量 | p_val_whole | 实际验证图片数除以总验证时间 | +| 验证计算吞吐量 | p_val_core | 不包含IO部分耗时 | +| 推理总吞吐量 | p_infer_whole | 实际推理图片数除以总推理时间 | +| **推理计算吞吐量** | **\*p_infer_core** | 不包含IO部分耗时 | +| **计算卡使用率** | **\*MFU** | model flops utilization | +| 推理结果 | acc(推理/验证) | 单位为top1分类准确率(acc1) | + +* 指标值 + +| 推理工具 | precision | bs | e2e_time | p_val_whole | p_val_core | p_infer_whole | \*p_infer_core | \*MFU | acc | mem | +| ----------- | --------- | ---- | ---- | -------- | ----------- | ---------- | ------------- | ------------ | ----------- | ----------- | +| tensorrt | fp16 | 4 |1895.1 | 9.3 | 10.7 | 7.9 | 11.8 | 11.8% | 0.89/1.0 | 23.7/40.0 | +| tensorrt | fp32 | 2 | 1895.1 | 6.8 | 7.5 | 5.5 | 7.0 | 13.9% | 1.0/1.0 | 18.1/40.0 | + diff --git a/inference/benchmarks/sam_h/pytorch/__init__.py b/inference/benchmarks/sam_h/pytorch/__init__.py new file mode 100644 index 000000000..1f6cdf49b --- /dev/null +++ b/inference/benchmarks/sam_h/pytorch/__init__.py @@ -0,0 +1,5 @@ +from .dataloader import build_dataloader +from .model import create_model +from .export import export_model +from .evaluator import evaluator +from .forward import model_forward, engine_forward diff --git a/inference/benchmarks/sam_h/pytorch/dataloader.py b/inference/benchmarks/sam_h/pytorch/dataloader.py new file mode 100644 index 000000000..57d64b67d --- /dev/null +++ b/inference/benchmarks/sam_h/pytorch/dataloader.py @@ -0,0 +1,52 @@ +import torchvision as tv +from torch.utils.data import DataLoader as dl +from torch.utils.data import Dataset +import torch +from PIL import Image +import requests +from transformers import SamProcessor +import tqdm + + +class SamInferDataset(Dataset): + + def __init__(self, config): + processor = SamProcessor.from_pretrained(config.data_dir + "/" + + config.weights) + + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, + stream=True).raw).convert("RGB") + input_points = [[[450, 600]]] + + inputs = processor(raw_image, + input_points=input_points, + return_tensors="pt") + self.img = inputs["pixel_values"][0] + self.points = inputs["input_points"][0] + self.osize = inputs["original_sizes"][0] + self.dsize = inputs["reshaped_input_sizes"][0] + self.length = config.datasize + + def __len__(self): + return self.length + + def __getitem__(self, idx): + return self.img.clone().float(), self.points.clone().float( + ), self.osize.clone(), self.dsize.clone() + + +def build_dataset(config): + dataset = SamInferDataset(config) + return dataset + + +def build_dataloader(config): + dataset = build_dataset(config) + loader = dl(dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=True, + num_workers=config.num_workers) + + return loader diff --git a/inference/benchmarks/sam_h/pytorch/evaluator.py b/inference/benchmarks/sam_h/pytorch/evaluator.py new file mode 100644 index 000000000..e475ac1ff --- /dev/null +++ b/inference/benchmarks/sam_h/pytorch/evaluator.py @@ -0,0 +1,22 @@ +import torch +import numpy as np + + +def evaluator(pred, gt_path): + batch_size = len(pred) + + result = [] + for i in range(batch_size): + mask = pred[i][0] + + num_mask = len(mask) + + for j in range(num_mask): + mask_img = mask[j] + + gt = torch.load(gt_path + '_' + str(j) + ".pt") + gt = gt.to(mask_img.device) + iou = torch.eq(mask_img, gt).sum().item() / mask_img.numel() + result.append(float(iou)) + + return np.mean(result) diff --git a/inference/benchmarks/sam_h/pytorch/export.py b/inference/benchmarks/sam_h/pytorch/export.py new file mode 100644 index 000000000..9b16ece00 --- /dev/null +++ b/inference/benchmarks/sam_h/pytorch/export.py @@ -0,0 +1,33 @@ +import torch +import os + + +def export_model(model, config): + if config.exist_onnx_path is not None: + return config.exist_onnx_path + + filename = config.case + "_bs" + str(config.batch_size) + filename = filename + "_" + str(config.framework) + filename = filename + "_fp16" + str(config.fp16) + filename = "onnxs/" + filename + ".onnx" + onnx_path = config.perf_dir + "/" + filename + + img = torch.randn(config.batch_size, 3, 1024, 1024).cuda() + points = torch.ones(config.batch_size, 1, 1, 2).cuda() + + if config.fp16: + img = img.half() + dummy_input = (img, points) + + dir_onnx_path = os.path.dirname(onnx_path) + os.makedirs(dir_onnx_path, exist_ok=True) + + with torch.no_grad(): + torch.onnx.export(model, + dummy_input, + onnx_path, + verbose=False, + training=torch.onnx.TrainingMode.EVAL, + do_constant_folding=True) + + return onnx_path diff --git a/inference/benchmarks/sam_h/pytorch/forward.py b/inference/benchmarks/sam_h/pytorch/forward.py new file mode 100644 index 000000000..df61177fa --- /dev/null +++ b/inference/benchmarks/sam_h/pytorch/forward.py @@ -0,0 +1,121 @@ +from loguru import logger +import torch +import numpy as np +import time +from tools import torch_sync +from transformers import SamProcessor + + +def cal_perf(config, dataloader_len, duration, core_time, str_prefix): + model_forward_perf = config.repeat * dataloader_len * config.batch_size / duration + logger.info(str_prefix + "(" + config.framework + ") Perf: " + + str(model_forward_perf) + " ips") + model_forward_core_perf = config.repeat * dataloader_len * config.batch_size / core_time + logger.info(str_prefix + "(" + config.framework + ") core Perf: " + + str(model_forward_core_perf) + " ips") + return round(model_forward_perf, 3), round(model_forward_core_perf, 3) + + +def model_forward(model, dataloader, evaluator, config): + if config.no_validation: + return None, None, None + processor = SamProcessor.from_pretrained(config.data_dir + "/" + + config.weights) + gt_path = config.data_dir + "/" + config.ground_truth + start = time.time() + core_time = 0.0 + scores = [] + + for times in range(config.repeat): + + logger.debug("Repeat: " + str(times + 1)) + + local_scores = [] + for step, (x, y, osize, dsize) in enumerate(dataloader): + if config.fp16: + x = x.to(torch.float16) + y = y.to(torch.float16) + torch_sync(config) + core_time_start = time.time() + + if step % config.log_freq == 0: + logger.debug("Step: " + str(step) + " / " + + str(len(dataloader))) + + with torch.no_grad(): + + x = x.cuda() + y = y.cuda() + + pred = model(x, y)[1] + torch_sync(config) + core_time += time.time() - core_time_start + + pred = processor.post_process_masks(pred, osize, dsize) + score = evaluator(pred, gt_path) + local_scores.append(score) + + scores.append(np.mean(local_scores)) + + logger.info("Top1 Acc: " + str(scores)) + + duration = time.time() - start + model_forward_perf, model_forward_core_perf = cal_perf( + config, len(dataloader), duration, core_time, "Validation") + + return model_forward_perf, model_forward_core_perf, round( + float(np.mean(scores)), 3) + + +def engine_forward(model, dataloader, evaluator, config): + processor = SamProcessor.from_pretrained(config.data_dir + "/" + + config.weights) + gt_path = config.data_dir + "/" + config.ground_truth + start = time.time() + core_time = 0.0 + foo_time = 0.0 + scores = [] + + for times in range(config.repeat): + + logger.debug("Repeat: " + str(times + 1)) + + local_scores = [] + for step, (x, y, osize, dsize) in enumerate(dataloader): + if config.fp16: + x = x.to(torch.float16) + y = y.to(torch.float16) + torch_sync(config) + core_time_start = time.time() + + if step % config.log_freq == 0: + logger.debug("Step: " + str(step) + " / " + + str(len(dataloader))) + + with torch.no_grad(): + + outputs = model([x, y]) + pred = outputs[0] + foo_time += outputs[1] + + torch_sync(config) + core_time += time.time() - core_time_start + + pred = pred[0] + pred = pred.reshape(config.batch_size, 1, 3, 256, 256).float() + pred = pred.cpu() + + pred = processor.post_process_masks(pred, osize, dsize) + score = evaluator(pred, gt_path) + local_scores.append(score) + + scores.append(np.mean(local_scores)) + + logger.info("Top1 Acc: " + str(scores)) + + duration = time.time() - start - foo_time + model_forward_perf, model_forward_core_perf = cal_perf( + config, len(dataloader), duration, core_time - foo_time, "Inference") + + return model_forward_perf, model_forward_core_perf, round( + float(np.mean(scores)), 3) diff --git a/inference/benchmarks/sam_h/pytorch/model.py b/inference/benchmarks/sam_h/pytorch/model.py new file mode 100644 index 000000000..57f2fd5e4 --- /dev/null +++ b/inference/benchmarks/sam_h/pytorch/model.py @@ -0,0 +1,14 @@ +from .model_utils.sam import SamModel + + +def create_model(config): + if config.no_validation: + assert config.exist_onnx_path is not None + return None + model = SamModel.from_pretrained(config.data_dir + "/" + config.weights) + model.cuda() + model.eval() + if config.fp16: + model.half() + + return model diff --git a/inference/benchmarks/sam_h/pytorch/model_utils/sam.py b/inference/benchmarks/sam_h/pytorch/model_utils/sam.py new file mode 100755 index 000000000..5c7494d7e --- /dev/null +++ b/inference/benchmarks/sam_h/pytorch/model_utils/sam.py @@ -0,0 +1,1507 @@ +# coding=utf-8 +# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch SAM model.""" + +import collections +import math +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import Tensor, nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.models.sam.configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SamConfig" +_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge" + +SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/sam-vit-huge", + "facebook/sam-vit-large", + "facebook/sam-vit-base", + # See all SAM models at https://huggingface.co/models?filter=sam +] + + +@dataclass +class SamVisionEncoderOutput(ModelOutput): + """ + Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SamImageSegmentationOutput(ModelOutput): + """ + Base class for Segment-Anything model's output + + Args: + iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`): + The predicted low resolutions masks. Needs to be post-processed by the processor + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + iou_scores: torch.FloatTensor = None + pred_masks: torch.FloatTensor = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor]] = None + mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SamPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance( + image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance( + patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // + patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, + hidden_size, + kernel_size=patch_size, + stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class SamMLPBlock(nn.Module): + + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam +class SamLayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, + normalized_shape, + eps=1e-6, + data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + self.normalized_shape = (normalized_shape, ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + x = x.float() + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = x.to(dtype=input_dtype) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class SamAttention(nn.Module): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + self.hidden_size = config.hidden_size + + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def _separate_heads(self, hidden_states: Tensor, + num_attention_heads: int) -> Tensor: + batch, point_batch_size, n_tokens, channel = hidden_states.shape + c_per_head = channel // num_attention_heads + hidden_states = hidden_states.reshape(batch * point_batch_size, + n_tokens, num_attention_heads, + c_per_head) + return hidden_states.transpose(1, 2) + + def _recombine_heads(self, hidden_states: Tensor, + point_batch_size: int) -> Tensor: + batch, n_heads, n_tokens, c_per_head = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2) + return hidden_states.reshape(batch // point_batch_size, + point_batch_size, n_tokens, + n_heads * c_per_head) + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_similarity: Tensor = None) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # SamAttention + _, _, _, c_per_head = query.shape + attn = query @ key.permute( + 0, 1, 3, 2 + ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(float(c_per_head)) + attn = torch.softmax(attn, dim=-1) + + if attention_similarity is not None: + attn = attn + attention_similarity + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ value + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + +class SamTwoWayAttentionBlock(nn.Module): + + def __init__(self, + config, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`SamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = SamAttention(config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(self.hidden_size, + eps=self.layer_norm_eps) + + self.cross_attn_token_to_image = SamAttention( + config, downsample_rate=attention_downsample_rate) + self.layer_norm2 = nn.LayerNorm(self.hidden_size, + eps=self.layer_norm_eps) + + self.mlp = SamMLPBlock(config) + self.layer_norm3 = nn.LayerNorm(self.hidden_size, + eps=self.layer_norm_eps) + + self.layer_norm4 = nn.LayerNorm(self.hidden_size, + eps=self.layer_norm_eps) + self.cross_attn_image_to_token = SamAttention( + config, downsample_rate=attention_downsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + output_attentions: bool = False, + ): + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_token_to_image( + query=query, + key=key, + value=keys, + attention_similarity=attention_similarity) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_image_to_token(query=key, + key=query, + value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + + outputs = (queries, keys) + + if output_attentions: + outputs = outputs + (attn_out, ) + else: + outputs = outputs + (None, ) + + return outputs + + +class SamTwoWayTransformer(nn.Module): + + def __init__(self, config: SamMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append( + SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = SamAttention(config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_attentions = () + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, + 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten( + 2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, attention_outputs = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attention_outputs, ) + + # Apply the final attenion layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out = self.final_attn_token_to_image(query=query, + key=key, + value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys, all_attentions + + +class SamFeedForward(nn.Module): + + def __init__(self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False): + super().__init__() + self.num_layers = num_layers + self.activation = nn.ReLU() + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList( + [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +class SamMaskDecoder(nn.Module): + + def __init__(self, config: SamMaskDecoderConfig): + super().__init__() + + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = SamTwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, + self.hidden_size // 4, + kernel_size=2, + stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, + self.hidden_size // 8, + kernel_size=2, + stride=2) + self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, + data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [ + SamFeedForward(self.hidden_size, self.hidden_size, + self.hidden_size // 8, 3) + ] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + + self.iou_prediction_head = SamFeedForward(self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth) + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + output_attentions: Optional[bool] = None, + attention_similarity: torch.Tensor = None, + target_embedding: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + the embeddings from the image encoder + image_positional_embedding (`torch.Tensor`): + positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes + dense_prompt_embeddings (`torch.Tensor`): + the embeddings of the mask inputs + multimask_output (bool): + Whether to return multiple masks or a single mask. + output_attentions (bool, *optional*): + Whether or not to return the attentions tensors of all attention layers. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, + 1) + + if sparse_prompt_embeddings.sum().item() != 0: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), + dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-point + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat(point_batch_size, 1, 1, 1) + image_positional_embeddings = image_positional_embeddings.repeat( + point_batch_size, 1, 1, 1) + + # Run the transformer, image_positional_embedding are consumed + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, + 1:(1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).reshape( + batch_size * point_batch_size, num_channels, height, width) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation( + self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation( + self.upscale_conv2(upscaled_embedding)) + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = torch.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.reshape( + batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).reshape( + batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + outputs = (masks, iou_pred) + + if output_attentions: + outputs = outputs + (attentions, ) + else: + outputs = outputs + (None, ) + + return outputs + + +class SamPositionalEmbedding(nn.Module): + + def __init__(self, config): + super().__init__() + self.scale = config.hidden_size // 2 + self.register_buffer( + "positional_embedding", + self.scale * torch.randn((2, config.num_pos_feats))) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), + torch.cos(coordinates)], + dim=-1) + + +class SamMaskEmbedding(nn.Module): + + def __init__(self, config: SamPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, + self.mask_input_channels, + kernel_size=2, + stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, + config.mask_input_channels, + kernel_size=2, + stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, + config.hidden_size, + kernel_size=1) + self.layer_norm1 = SamLayerNorm(self.mask_input_channels, + eps=config.layer_norm_eps, + data_format="channels_first") + self.layer_norm2 = SamLayerNorm(self.mask_input_channels * 4, + eps=config.layer_norm_eps, + data_format="channels_first") + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class SamPromptEncoder(nn.Module): + + def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding): + super().__init__() + self.shared_embedding = shared_patch_embedding + self.mask_embed = SamMaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_embedding_size, + config.image_embedding_size) + self.input_image_size = config.image_size + + self.point_embed = nn.ModuleList([ + nn.Embedding(1, config.hidden_size) + for i in range(config.num_point_embeddings) + ]) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, + pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (points.shape[0], points.shape[1], 1, + points.shape[-1]) + target_labels_shape = (points.shape[0], points.shape[1], 1) + padding_point = torch.zeros(target_point_shape, + device=points.device) + padding_label = -torch.ones(target_labels_shape, + device=labels.device) + points = torch.cat([points, padding_point], dim=2) + labels = torch.cat([labels, padding_label], dim=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, + self.not_a_point_embed.weight, + point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitely + # specificed as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.tensor(0.0, + dtype=point_embedding.dtype, + device=point_embedding.device), + ) + + point_embedding = torch.where( + (labels == 0)[:, :, :, None], + point_embedding + self.point_embed[0].weight[None, None, :, :], + point_embedding, + ) + + point_embedding = torch.where( + (labels == 1)[:, :, :, None], + point_embedding + self.point_embed[1].weight[None, None, :, :], + point_embedding, + ) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = boxes.shape[:2] + coords = boxes.reshape(batch_size, nb_boxes, 2, 2) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding[:, :, 0, :] += self.point_embed[2].weight + corner_embedding[:, :, 1, :] += self.point_embed[3].weight + return corner_embedding + + def forward( + self, + input_points: Optional[Tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + target_device = self.shared_embedding.positional_embedding.device + if input_points is not None: + batch_size, point_batch_size = input_points.shape[:2] + if input_labels is None: + raise ValueError( + "If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, + input_labels, + pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat( + [sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape( + 1, -1, 1, 1).expand(batch_size, -1, + self.image_embedding_size[0], + self.image_embedding_size[1]) + + if sparse_embeddings is None: + sparse_embeddings = torch.zeros( + (batch_size, 1, 1, self.hidden_size), device=target_device) + + return sparse_embeddings, dense_embeddings + + +class SamVisionAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size): + super().__init__() + input_size = ((config.image_size // config.patch_size, + config.image_size // + config.patch_size) if window_size == 0 else + (window_size, window_size)) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = nn.Linear(config.hidden_size, + config.hidden_size * 3, + bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError( + "Input size must be provided if using relative positional encoding." + ) + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter( + torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter( + torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, + rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * q_size - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, + max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] + k_coords = torch.arange(k_size)[None, :] + relative_coords = ( + q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def add_decomposed_rel_pos( + self, + attn: torch.Tensor, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`torch.Tensor`): + attention map. + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`torch.Tensor`): + attention map with added relative positional embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, + rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, + rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, + dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, + relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, + relative_position_width) + attn = attn.reshape(batch_size, query_height, query_width, key_height, + key_width) + attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + attn = attn.reshape(batch_size, query_height * query_width, + key_height * key_width) + return attn + + def forward(self, + hidden_states: torch.Tensor, + output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, + self.num_attention_heads, + -1).permute(2, 0, 3, 1, 4)) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, + batch_size * self.num_attention_heads, + height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + attn_weights = self.add_decomposed_rel_pos(attn_weights, query, + self.rel_pos_h, + self.rel_pos_w, + (height, width), + (height, width)) + + attn_weights = torch.nn.functional.softmax(attn_weights, + dtype=torch.float32, + dim=-1).to(query.dtype) + + attn_probs = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) + + attn_output = (attn_probs @ value).reshape(batch_size, + self.num_attention_heads, + height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, + 4).reshape(batch_size, height, width, + -1) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + +class SamVisionLayer(nn.Module): + + def __init__(self, config, window_size): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attn = SamVisionAttention(config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = SamMLPBlock(config) + self.window_size = window_size + + def window_partition( + self, hidden_states: torch.Tensor, + window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Args: + Partition into non-overlapping windows with padding if needed. + hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window + size. + + Returns: + windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. + (pad_height, pad_width): padded height and width before partition + """ + batch_size, height, width, channel = hidden_states.shape + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = hidden_states.reshape(batch_size, + pad_height // window_size, + window_size, + pad_width // window_size, + window_size, channel) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape( + -1, window_size, window_size, channel) + return windows, (pad_height, pad_width) + + def window_unpartition(self, windows: torch.Tensor, window_size: int, + padding_shape: Tuple[int, int], + original_shape: Tuple[int, int]) -> torch.Tensor: + """ + Args: + Window unpartition into original sequences and removing padding. + hidden_states (tensor): + input tokens with [batch_size * num_windows, window_size, window_size, channel]. + window_size (int): + window size. + padding_shape (Tuple): + padded height and width (pad_height, pad_width). + original_shape (Tuple): original height and width (height, width) before padding. + + Returns: + hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. + """ + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // + window_size // window_size) + hidden_states = windows.reshape(batch_size, pad_height // window_size, + pad_width // window_size, window_size, + window_size, -1) + hidden_states = (hidden_states.permute(0, 1, 3, 2, 4, + 5).contiguous().reshape( + batch_size, pad_height, + pad_width, -1)) + + hidden_states = hidden_states[:, :height, :width, :].contiguous() + return hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition( + hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + # Reverse window partition + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, + self.window_size, + padding_shape, + (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states, ) + if output_attentions: + outputs += (attn_weights, ) + + return outputs + + +class SamVisionNeck(nn.Module): + + def __init__(self, config: SamVisionConfig): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, + config.output_channels, + kernel_size=1, + bias=False) + self.layer_norm1 = SamLayerNorm(config.output_channels, + data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, + config.output_channels, + kernel_size=3, + padding=1, + bias=False) + self.layer_norm2 = SamLayerNorm(config.output_channels, + data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + + +class SamVisionEncoder(nn.Module): + + def __init__(self, config: SamVisionConfig): + super().__init__() + self.config = config + self.image_size = config.image_size + + self.patch_embed = SamPatchEmbeddings(config) + + self.pos_embed = None + if config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + )) + + self.layers = nn.ModuleList() + for i in range(config.num_hidden_layers): + layer = SamVisionLayer( + config, + window_size=config.window_size + if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = SamVisionNeck(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.patch_embed + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SamVisionEncoderOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + ) + else: + layer_outputs = layer_module( + hidden_states, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + ( + layer_outputs[1], ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states, ) + if output_hidden_states: + outputs = outputs + (all_hidden_states, ) + if output_attentions: + outputs = outputs + (all_self_attentions, ) + return outputs + + return SamVisionEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SamPreTrainedModel(PreTrainedModel): + config_class = SamConfig + base_model_prefix = "sam" + main_input_name = "pixel_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +SAM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SamConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SAM_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and botton right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", + " optional 2D location and bounding boxes.", + SAM_START_DOCSTRING, +) +class SamModel(SamPreTrainedModel): + _tied_weights_keys = [ + "prompt_encoder.shared_embedding.positional_embedding" + ] + + def __init__(self, config): + super().__init__(config) + self.shared_image_embedding = SamPositionalEmbedding( + config.vision_config) + + self.vision_encoder = SamVisionEncoder(config.vision_config) + self.prompt_encoder = SamPromptEncoder(config.prompt_encoder_config, + self.shared_image_embedding) + self.mask_decoder = SamMaskDecoder(config.mask_decoder_config) + + self.post_init() + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones((size, size), + device=target_device, + dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding( + torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze( + 0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + vision_output = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_output[0] + return image_embeddings + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict=None, + **kwargs, + ) -> List[Dict[str, torch.Tensor]]: + r""" + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("facebook/sam-vit-base") + >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get segmentation mask + >>> outputs = model(**inputs) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + image_positional_embeddings = self.get_image_wide_positional_embeddings( + ) + # repeat with batch size + batch_size = pixel_values.shape[0] + image_positional_embeddings = image_positional_embeddings.repeat( + batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_outputs[0] + + input_labels = torch.ones_like(input_points[:, :, :, 0], + dtype=torch.int, + device=input_points.device) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + return iou_predictions, low_res_masks diff --git a/inference/benchmarks/sam_h/pytorch/requirements.txt b/inference/benchmarks/sam_h/pytorch/requirements.txt new file mode 100644 index 000000000..976a2b1f3 --- /dev/null +++ b/inference/benchmarks/sam_h/pytorch/requirements.txt @@ -0,0 +1 @@ +transformers diff --git a/inference/configs/sam_h/configurations.yaml b/inference/configs/sam_h/configurations.yaml new file mode 100644 index 000000000..efd9c8eb3 --- /dev/null +++ b/inference/configs/sam_h/configurations.yaml @@ -0,0 +1,16 @@ +batch_size: 4 +# 1 item(like 1 sequence, 1 image) flops +# Attention! For transformer decoder like bert, 1 token cause 2*param flops, so we need 2*length*params like 2*512*0.33B here +# format: a_1*a*2*...*a_nea_0,like 2*512*0.33e9(bert) or 4.12e9(resnet50) +flops: 3.1e12 +fp16: true +compiler: tensorrt +num_workers: 8 +log_freq: 30 +repeat: 5 +# skip validation(will also skip create_model, export onnx). Assert exist_onnx_path != null +no_validation: false +# set a real onnx_path to use exist, or set it to anything but null to avoid export onnx manually(like torch-tensorrt) +exist_onnx_path: null +# set a exist path of engine file like resnet50.trt/resnet50.plan/resnet50.engine +exist_compiler_path: null \ No newline at end of file diff --git a/inference/configs/sam_h/parameters.yaml b/inference/configs/sam_h/parameters.yaml new file mode 100644 index 000000000..9e0c95b8d --- /dev/null +++ b/inference/configs/sam_h/parameters.yaml @@ -0,0 +1,3 @@ +weights: "sam_h_weights" +datasize: 1000 +ground_truth: "sam_gt" \ No newline at end of file diff --git a/inference/configs/sam_h/vendor_config/nvidia_configurations.yaml b/inference/configs/sam_h/vendor_config/nvidia_configurations.yaml new file mode 100644 index 000000000..9f0b7e394 --- /dev/null +++ b/inference/configs/sam_h/vendor_config/nvidia_configurations.yaml @@ -0,0 +1,3 @@ +trt_tmp_path: nvidia_tmp/sam.trt +has_dynamic_axis: false +torchtrt_full_compile: true \ No newline at end of file