Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAM stdcase #218

Merged
merged 1 commit into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions inference/benchmarks/sam_h/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
### 1. 推理数据集

* 采用单张图及其在原始huggingface doc中SamModel(cpu)上运行的结果作为推理数据集与评估指标
* 数据集下载
* 在dataloader.py中自动下载
* groundTruth制作
* 运行hugging face文档中使用SamModel推理的样例,将其计算出的3个image_size个bool mask矩阵按照mask从小到大的顺序依次存储为sam_gt_0.pt~sam_gt_2.pt,并放置在data_dir下

### 2. 模型与权重

* 模型实现
* pytorch:transformers.SamModel
* 权重下载
* pytorch:from_pretrained("facebook/sam_vit_huge")(hugging face)

### 3. 软硬件配置与运行信息参考

#### 3.1 Nvidia A100

- ##### 硬件环境

- 机器、加速卡型号: NVIDIA_A100-SXM4-40GB
- 多机网络类型、带宽: InfiniBand,200Gb/s

- ##### 软件环境

- OS版本:Ubuntu 20.04
- OS kernel版本: 5.4.0-113-generic
- 加速卡驱动版本:470.129.06
- Docker 版本:20.10.16
- 训练框架版本:pytorch-2.1.0a0+4136153
- 依赖软件版本:
- cuda: 12.1

- 推理工具包

- TensorRT 8.6.1

### 3. 运行情况

* 指标列表

| 指标名称 | 指标值索引 | 特殊说明 |
| ------------------ | ---------------- | -------------------------------------------- |
| 数据精度 | precision | 可选fp32/fp16 |
| 批尺寸 | bs | |
| 硬件存储使用 | mem | 通常称为“显存”,单位为GiB |
| 端到端时间 | e2e_time | 总时间+Perf初始化等时间 |
| 验证总吞吐量 | p_val_whole | 实际验证图片数除以总验证时间 |
| 验证计算吞吐量 | p_val_core | 不包含IO部分耗时 |
| 推理总吞吐量 | p_infer_whole | 实际推理图片数除以总推理时间 |
| **推理计算吞吐量** | **\*p_infer_core** | 不包含IO部分耗时 |
| **计算卡使用率** | **\*MFU** | model flops utilization |
| 推理结果 | acc(推理/验证) | 单位为top1分类准确率(acc1) |

* 指标值

| 推理工具 | precision | bs | e2e_time | p_val_whole | p_val_core | p_infer_whole | \*p_infer_core | \*MFU | acc | mem |
| ----------- | --------- | ---- | ---- | -------- | ----------- | ---------- | ------------- | ------------ | ----------- | ----------- |
| tensorrt | fp16 | 4 |1895.1 | 9.3 | 10.7 | 7.9 | 11.8 | 11.8% | 0.89/1.0 | 23.7/40.0 |
| tensorrt | fp32 | 2 | 1895.1 | 6.8 | 7.5 | 5.5 | 7.0 | 13.9% | 1.0/1.0 | 18.1/40.0 |

5 changes: 5 additions & 0 deletions inference/benchmarks/sam_h/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .dataloader import build_dataloader
from .model import create_model
from .export import export_model
from .evaluator import evaluator
from .forward import model_forward, engine_forward
52 changes: 52 additions & 0 deletions inference/benchmarks/sam_h/pytorch/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torchvision as tv
from torch.utils.data import DataLoader as dl
from torch.utils.data import Dataset
import torch
from PIL import Image
import requests
from transformers import SamProcessor
import tqdm


class SamInferDataset(Dataset):

def __init__(self, config):
processor = SamProcessor.from_pretrained(config.data_dir + "/" +
config.weights)

img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url,
stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]

inputs = processor(raw_image,
input_points=input_points,
return_tensors="pt")
self.img = inputs["pixel_values"][0]
self.points = inputs["input_points"][0]
self.osize = inputs["original_sizes"][0]
self.dsize = inputs["reshaped_input_sizes"][0]
self.length = config.datasize

def __len__(self):
return self.length

def __getitem__(self, idx):
return self.img.clone().float(), self.points.clone().float(
), self.osize.clone(), self.dsize.clone()


def build_dataset(config):
dataset = SamInferDataset(config)
return dataset


def build_dataloader(config):
dataset = build_dataset(config)
loader = dl(dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=True,
num_workers=config.num_workers)

return loader
22 changes: 22 additions & 0 deletions inference/benchmarks/sam_h/pytorch/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import numpy as np


def evaluator(pred, gt_path):
batch_size = len(pred)

result = []
for i in range(batch_size):
mask = pred[i][0]

num_mask = len(mask)

for j in range(num_mask):
mask_img = mask[j]

gt = torch.load(gt_path + '_' + str(j) + ".pt")
gt = gt.to(mask_img.device)
iou = torch.eq(mask_img, gt).sum().item() / mask_img.numel()
result.append(float(iou))

return np.mean(result)
33 changes: 33 additions & 0 deletions inference/benchmarks/sam_h/pytorch/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import os


def export_model(model, config):
if config.exist_onnx_path is not None:
return config.exist_onnx_path

filename = config.case + "_bs" + str(config.batch_size)
filename = filename + "_" + str(config.framework)
filename = filename + "_fp16" + str(config.fp16)
filename = "onnxs/" + filename + ".onnx"
onnx_path = config.perf_dir + "/" + filename

img = torch.randn(config.batch_size, 3, 1024, 1024).cuda()
points = torch.ones(config.batch_size, 1, 1, 2).cuda()

if config.fp16:
img = img.half()
dummy_input = (img, points)

dir_onnx_path = os.path.dirname(onnx_path)
os.makedirs(dir_onnx_path, exist_ok=True)

with torch.no_grad():
torch.onnx.export(model,
dummy_input,
onnx_path,
verbose=False,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True)

return onnx_path
121 changes: 121 additions & 0 deletions inference/benchmarks/sam_h/pytorch/forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from loguru import logger
import torch
import numpy as np
import time
from tools import torch_sync
from transformers import SamProcessor


def cal_perf(config, dataloader_len, duration, core_time, str_prefix):
model_forward_perf = config.repeat * dataloader_len * config.batch_size / duration
logger.info(str_prefix + "(" + config.framework + ") Perf: " +
str(model_forward_perf) + " ips")
model_forward_core_perf = config.repeat * dataloader_len * config.batch_size / core_time
logger.info(str_prefix + "(" + config.framework + ") core Perf: " +
str(model_forward_core_perf) + " ips")
return round(model_forward_perf, 3), round(model_forward_core_perf, 3)


def model_forward(model, dataloader, evaluator, config):
if config.no_validation:
return None, None, None
processor = SamProcessor.from_pretrained(config.data_dir + "/" +
config.weights)
gt_path = config.data_dir + "/" + config.ground_truth
start = time.time()
core_time = 0.0
scores = []

for times in range(config.repeat):

logger.debug("Repeat: " + str(times + 1))

local_scores = []
for step, (x, y, osize, dsize) in enumerate(dataloader):
if config.fp16:
x = x.to(torch.float16)
y = y.to(torch.float16)
torch_sync(config)
core_time_start = time.time()

if step % config.log_freq == 0:
logger.debug("Step: " + str(step) + " / " +
str(len(dataloader)))

with torch.no_grad():

x = x.cuda()
y = y.cuda()

pred = model(x, y)[1]
torch_sync(config)
core_time += time.time() - core_time_start

pred = processor.post_process_masks(pred, osize, dsize)
score = evaluator(pred, gt_path)
local_scores.append(score)

scores.append(np.mean(local_scores))

logger.info("Top1 Acc: " + str(scores))

duration = time.time() - start
model_forward_perf, model_forward_core_perf = cal_perf(
config, len(dataloader), duration, core_time, "Validation")

return model_forward_perf, model_forward_core_perf, round(
float(np.mean(scores)), 3)


def engine_forward(model, dataloader, evaluator, config):
processor = SamProcessor.from_pretrained(config.data_dir + "/" +
config.weights)
gt_path = config.data_dir + "/" + config.ground_truth
start = time.time()
core_time = 0.0
foo_time = 0.0
scores = []

for times in range(config.repeat):

logger.debug("Repeat: " + str(times + 1))

local_scores = []
for step, (x, y, osize, dsize) in enumerate(dataloader):
if config.fp16:
x = x.to(torch.float16)
y = y.to(torch.float16)
torch_sync(config)
core_time_start = time.time()

if step % config.log_freq == 0:
logger.debug("Step: " + str(step) + " / " +
str(len(dataloader)))

with torch.no_grad():

outputs = model([x, y])
pred = outputs[0]
foo_time += outputs[1]

torch_sync(config)
core_time += time.time() - core_time_start

pred = pred[0]
pred = pred.reshape(config.batch_size, 1, 3, 256, 256).float()
pred = pred.cpu()

pred = processor.post_process_masks(pred, osize, dsize)
score = evaluator(pred, gt_path)
local_scores.append(score)

scores.append(np.mean(local_scores))

logger.info("Top1 Acc: " + str(scores))

duration = time.time() - start - foo_time
model_forward_perf, model_forward_core_perf = cal_perf(
config, len(dataloader), duration, core_time - foo_time, "Inference")

return model_forward_perf, model_forward_core_perf, round(
float(np.mean(scores)), 3)
14 changes: 14 additions & 0 deletions inference/benchmarks/sam_h/pytorch/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .model_utils.sam import SamModel


def create_model(config):
if config.no_validation:
assert config.exist_onnx_path is not None
return None
model = SamModel.from_pretrained(config.data_dir + "/" + config.weights)
model.cuda()
model.eval()
if config.fp16:
model.half()

return model
Loading