Skip to content

Commit

Permalink
fix heatmap bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing committed Apr 18, 2022
1 parent fa32e9f commit 1448e84
Showing 1 changed file with 50 additions and 50 deletions.
100 changes: 50 additions & 50 deletions yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,56 @@ def detect_image(self, image, crop = False, count = False):
del draw

return image

def get_FPS(self, image, test_interval):
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = decode_outputs(outputs, self.input_shape)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)

t1 = time.time()
for _ in range(test_interval):
with torch.no_grad():
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = decode_outputs(outputs, self.input_shape)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)

t2 = time.time()
tact_time = (t2 - t1) / test_interval
return tact_time

def detect_heatmap(self, image, heatmap_save_path):
import cv2
Expand Down Expand Up @@ -265,56 +315,6 @@ def sigmoid(x):
plt.savefig(heatmap_save_path, dpi=200)
print("Save to the " + heatmap_save_path)
plt.cla()

def get_FPS(self, image, test_interval):
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = decode_outputs(outputs, self.input_shape)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)

t1 = time.time()
for _ in range(test_interval):
with torch.no_grad():
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = decode_outputs(outputs, self.input_shape)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)

t2 = time.time()
tact_time = (t2 - t1) / test_interval
return tact_time

def convert_to_onnx(self, simplify, model_path):
import onnx
Expand Down

0 comments on commit 1448e84

Please sign in to comment.