Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Update video read/write process in demos #2192

Merged
merged 8 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 38 additions & 32 deletions demo/bottomup_demo.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mimetypes
import os
import tempfile
from argparse import ArgumentParser

import cv2
import json_tricks as json
import mmcv
import mmengine
import numpy as np
from mmengine.utils import track_iter_progress

from mmpose.apis import inference_bottomup, init_model
from mmpose.registry import VISUALIZERS
from mmpose.structures import split_instances


def process_one_image(args, img_path, pose_estimator, visualizer,
show_interval):
def process_one_image(args, img, pose_estimator, visualizer, show_interval):
"""Visualize predicted keypoints (and heatmaps) of one image."""

# inference a single image
batch_results = inference_bottomup(pose_estimator, img_path)
batch_results = inference_bottomup(pose_estimator, img)
results = batch_results[0]

# show the results
img = mmcv.imread(img_path, channel_order='rgb')

out_file = None
if args.output_root:
out_file = f'{args.output_root}/{os.path.basename(img_path)}'
if isinstance(img, str):
img = mmcv.imread(img, channel_order='rgb')
elif isinstance(img, np.ndarray):
img = mmcv.bgr2rgb(img)

visualizer.add_datasample(
'result',
Expand All @@ -38,8 +38,7 @@ def process_one_image(args, img_path, pose_estimator, visualizer,
show_kpt_idx=args.show_kpt_idx,
show=args.show,
wait_time=show_interval,
out_file=out_file,
kpt_score_thr=args.kpt_thr)
kpt_thr=args.kpt_thr)

return results.pred_instances

Expand Down Expand Up @@ -97,8 +96,11 @@ def main():
args = parse_args()
assert args.show or (args.output_root != '')
assert args.input != ''
output_file = None
if args.output_root:
mmengine.mkdir_or_exist(args.output_root)
output_file = os.path.join(args.output_root,
os.path.basename(args.input))
if args.save_predictions:
assert args.output_root != ''
args.pred_save_path = f'{args.output_root}/results_' \
Expand Down Expand Up @@ -128,36 +130,40 @@ def main():
args, args.input, model, visualizer, show_interval=0)
pred_instances_list = split_instances(pred_instances)

if output_file:
img_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)

elif input_type == 'video':
tmp_folder = tempfile.TemporaryDirectory()
video = mmcv.VideoReader(args.input)
progressbar = mmengine.ProgressBar(len(video))
video.cvt2frames(tmp_folder.name, show_progress=False)
output_root = args.output_root
args.output_root = tmp_folder.name
video_reader = mmcv.VideoReader(args.input)
video_writer = None

pred_instances_list = []

for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
for frame_id, frame in enumerate(track_iter_progress(video_reader)):
pred_instances = process_one_image(
args,
f'{tmp_folder.name}/{img_fname}',
model,
visualizer,
show_interval=1)
progressbar.update()
args, frame, model, visualizer, show_interval=0.001)

pred_instances_list.append(
dict(
frame_id=frame_id,
instances=split_instances(pred_instances)))

if output_root:
mmcv.frames2video(
tmp_folder.name,
f'{output_root}/{os.path.basename(args.input)}',
fps=video.fps,
fourcc='mp4v',
show_progress=False)
tmp_folder.cleanup()
if output_file:
frame_vis = visualizer.get_image()
if video_writer is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# the size of the image with visualization may vary
# depending on the presence of heatmaps
video_writer = cv2.VideoWriter(output_file, fourcc,
video_reader.fps,
(frame_vis.shape[1],
frame_vis.shape[0]))

video_writer.write(mmcv.rgb2bgr(frame_vis))

if video_writer:
video_writer.release()

else:
args.save_predictions = False
Expand Down
6 changes: 6 additions & 0 deletions demo/inferencer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def parse_args():
'--draw-bbox',
action='store_true',
help='Whether to draw the bounding boxes.')
parser.add_argument(
'--draw-heatmap',
action='store_true',
default=False,
help='Whether to draw the predicted heatmaps.')
parser.add_argument(
'--bbox-thr',
type=float,
Expand Down Expand Up @@ -104,6 +109,7 @@ def parse_args():
'det_weights', 'det_cat_ids'
]
init_args = {}
init_args['output_heatmaps'] = call_args.pop('draw_heatmap')
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)

Expand Down
65 changes: 37 additions & 28 deletions demo/topdown_demo_with_mmdet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mimetypes
import os
import tempfile
from argparse import ArgumentParser

import cv2
import json_tricks as json
import mmcv
import mmengine
import numpy as np
from mmengine.utils import track_iter_progress

from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
Expand All @@ -23,12 +24,12 @@
has_mmdet = False


def process_one_image(args, img_path, detector, pose_estimator, visualizer,
def process_one_image(args, img, detector, pose_estimator, visualizer,
show_interval):
"""Visualize predicted keypoints (and heatmaps) of one image."""

# predict bbox
det_result = inference_detector(detector, img_path)
det_result = inference_detector(detector, img)
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = np.concatenate(
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
Expand All @@ -37,15 +38,14 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer,
bboxes = bboxes[nms(bboxes, args.nms_thr), :4]

# predict keypoints
pose_results = inference_topdown(pose_estimator, img_path, bboxes)
pose_results = inference_topdown(pose_estimator, img, bboxes)
data_samples = merge_data_samples(pose_results)

# show the results
img = mmcv.imread(img_path, channel_order='rgb')

out_file = None
if args.output_root:
out_file = f'{args.output_root}/{os.path.basename(img_path)}'
if isinstance(img, str):
img = mmcv.imread(img, channel_order='rgb')
elif isinstance(img, np.ndarray):
img = mmcv.bgr2rgb(img)

visualizer.add_datasample(
'result',
Expand All @@ -58,7 +58,6 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer,
skeleton_style=args.skeleton_style,
show=args.show,
wait_time=show_interval,
out_file=out_file,
kpt_thr=args.kpt_thr)

# if there is no instance detected, return None
Expand Down Expand Up @@ -154,8 +153,11 @@ def main():
assert args.input != ''
assert args.det_config is not None
assert args.det_checkpoint is not None
output_file = None
if args.output_root:
mmengine.mkdir_or_exist(args.output_root)
output_file = os.path.join(args.output_root,
os.path.basename(args.input))
if args.save_predictions:
assert args.output_root != ''
args.pred_save_path = f'{args.output_root}/results_' \
Expand Down Expand Up @@ -196,38 +198,45 @@ def main():
show_interval=0)
pred_instances_list = split_instances(pred_instances)

if output_file:
img_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)

elif input_type == 'video':
tmp_folder = tempfile.TemporaryDirectory()
video = mmcv.VideoReader(args.input)
progressbar = mmengine.ProgressBar(len(video))
video.cvt2frames(tmp_folder.name, show_progress=False)
output_root = args.output_root
args.output_root = tmp_folder.name
video_reader = mmcv.VideoReader(args.input)
video_writer = None

pred_instances_list = []

for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
for frame_id, frame in enumerate(track_iter_progress(video_reader)):
pred_instances = process_one_image(
args,
f'{tmp_folder.name}/{img_fname}',
frame,
detector,
pose_estimator,
visualizer,
show_interval=1)
show_interval=0.001)

progressbar.update()
pred_instances_list.append(
dict(
frame_id=frame_id,
instances=split_instances(pred_instances)))

if output_root:
mmcv.frames2video(
tmp_folder.name,
f'{output_root}/{os.path.basename(args.input)}',
fps=video.fps,
fourcc='mp4v',
show_progress=False)
tmp_folder.cleanup()
if output_file:
frame_vis = visualizer.get_image()
if video_writer is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# the size of the image with visualization may vary
# depending on the presence of heatmaps
video_writer = cv2.VideoWriter(output_file, fourcc,
video_reader.fps,
(frame_vis.shape[1],
frame_vis.shape[0]))

video_writer.write(mmcv.rgb2bgr(frame_vis))

if video_writer:
video_writer.release()

else:
args.save_predictions = False
Expand Down
3 changes: 1 addition & 2 deletions demo/webcam_api_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
def parse_args():
parser = ArgumentParser('Webcam executor configs')
parser.add_argument(
'--config', type=str, default='demo/webcam_cfg/pose_estimation.py')

'--config', type=str, default='demo/webcam_cfg/human_pose.py')
parser.add_argument(
'--cfg-options',
nargs='+',
Expand Down
102 changes: 102 additions & 0 deletions demo/webcam_cfg/human_pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) OpenMMLab. All rights reserved.
executor_cfg = dict(
# Basic configurations of the executor
name='Pose Estimation',
camera_id=0,
# Define nodes.
# The configuration of a node usually includes:
# 1. 'type': Node class name
# 2. 'name': Node name
# 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the
# input and output buffer names. This may depend on the node class.
# 4. 'enable_key': assign a hot-key to toggle enable/disable this node.
# This may depend on the node class.
# 5. Other class-specific arguments
nodes=[
# 'DetectorNode':
# This node performs object detection from the frame image using an
# MMDetection model.
dict(
type='DetectorNode',
name='detector',
model_config='projects/rtmpose/rtmdet/person/'
'rtmdet_nano_320-8xb32_coco-person.py',
model_checkpoint='https://download.openmmlab.com/mmpose/v1/'
'projects/rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth', # noqa
input_buffer='_input_', # `_input_` is an executor-reserved buffer
output_buffer='det_result'),
# 'TopdownPoseEstimatorNode':
# This node performs keypoint detection from the frame image using an
# MMPose top-down model. Detection results is needed.
dict(
type='TopdownPoseEstimatorNode',
name='human pose estimator',
model_config='projects/rtmpose/rtmpose/body_2d_keypoint/'
'rtmpose-t_8xb256-420e_coco-256x192.py',
model_checkpoint='https://download.openmmlab.com/mmpose/v1/'
'projects/rtmpose/rtmpose-tiny_simcc-aic-coco_pt-aic-coco_420e-256x192-cfc8f33d_20230126.pth', # noqa
labels=['person'],
input_buffer='det_result',
output_buffer='human_pose'),
# 'ObjectAssignerNode':
# This node binds the latest model inference result with the current
# frame. (This means the frame image and inference result may be
# asynchronous).
dict(
type='ObjectAssignerNode',
name='object assigner',
frame_buffer='_frame_', # `_frame_` is an executor-reserved buffer
object_buffer='human_pose',
output_buffer='frame'),
# 'ObjectVisualizerNode':
# This node draw the pose visualization result in the frame image.
# Pose results is needed.
dict(
type='ObjectVisualizerNode',
name='object visualizer',
enable_key='v',
enable=True,
show_bbox=True,
must_have_keypoint=False,
show_keypoint=True,
input_buffer='frame',
output_buffer='vis'),
# 'NoticeBoardNode':
# This node show a notice board with given content, e.g. help
# information.
dict(
type='NoticeBoardNode',
name='instruction',
enable_key='h',
enable=True,
input_buffer='vis',
output_buffer='vis_notice',
content_lines=[
'This is a demo for pose visualization and simple image '
'effects. Have fun!', '', 'Hot-keys:',
'"v": Pose estimation result visualization',
'"h": Show help information',
'"m": Show diagnostic information', '"q": Exit'
],
),
# 'MonitorNode':
# This node show diagnostic information in the frame image. It can
# be used for debugging or monitoring system resource status.
dict(
type='MonitorNode',
name='monitor',
enable_key='m',
enable=False,
input_buffer='vis_notice',
output_buffer='display'),
# 'RecorderNode':
# This node save the output video into a file.
dict(
type='RecorderNode',
name='recorder',
out_video_file='webcam_api_demo.mp4',
input_buffer='display',
output_buffer='_display_'
# `_display_` is an executor-reserved buffer
)
])
Loading