Skip to content

Commit

Permalink
Merge branch 'master' into integrate_object_detection
Browse files Browse the repository at this point in the history
  • Loading branch information
Cartucho authored Jun 1, 2019
2 parents 5137f3c + e4a02ac commit 77c1ace
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 28 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
*.pyc
.idea/
venv/
main/input/people_walking_mp4
main/output
object_detection/models
*.pyc
object_detection/models
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "main/DaSiamRPN"]
path = main/DaSiamRPN
url = https://github.com/foolwood/DaSiamRPN/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Image labeling in multiple annotation formats:

## Latest Features

- May 2019: [ECCV2018] Distractor-aware Siamese Networks for Visual Object Tracking
- Jan 2019: easy and quick bounding-boxe's resizing!
- Jan 2019: video object tracking with OpenCV trackers!
- TODO: Label photos via Google drive to allow "team online labeling".
Expand Down
1 change: 1 addition & 0 deletions main/DaSiamRPN
Submodule DaSiamRPN added at 167d0d
102 changes: 102 additions & 0 deletions main/dasiamrpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
Author : Will Stone
Date : 190407
Desc : Wrapper class for the DaSiamRPN tracking method. This class has the
methods required to interface with the tracking class implemented
in main.py within the OpenLabeling package.
"""
import torch
import numpy as np
import sys
from os.path import realpath, dirname, join, exists
try:
from DaSiamRPN.code.run_SiamRPN import SiamRPN_init, SiamRPN_track
except ImportError:
# check if the user has downloaded the submodules
if not exists(join('DaSiamRPN', 'code', 'net.py')):
print('Error: DaSiamRPN files not found. Please run the following command:')
print('\tgit submodule update --init')
exit()
else:
# if python 3
if sys.version_info >= (3, 0):
sys.path.append(realpath(join('DaSiamRPN', 'code')))
else:
# check if __init__py files exist (otherwise create them)
path_temp = join('DaSiamRPN', 'code', '__init__.py')
if not exists(path_temp):
open(path_temp, 'w').close()
path_temp = join('DaSiamRPN', '__init__.py')
if not exists(path_temp):
open(path_temp, 'w').close()
# try to import again
from DaSiamRPN.code.run_SiamRPN import SiamRPN_init, SiamRPN_track
from DaSiamRPN.code.net import SiamRPNvot
from DaSiamRPN.code.utils import get_axis_aligned_bbox, cxy_wh_2_rect

class dasiamrpn(object):
"""
Wrapper class for incorporating DaSiamRPN into OpenLabeling
(https://github.com/foolwood/DaSiamRPN,
https://github.com/Cartucho/OpenLabeling)
"""
def __init__(self):
self.net = SiamRPNvot()
# check if SiamRPNVOT.model was already downloaded (otherwise download it now)
model_path = join(realpath(dirname(__file__)), 'DaSiamRPN', 'code', 'SiamRPNVOT.model')
print(model_path)
if not exists(model_path):
print('\nError: module not found. Please download the pre-trained model and copy it to the directory \'DaSiamRPN/code/\'\n')
print('\tdownload link: https://drive.google.com/file/d/1-vNVZxfbIplXHrqMHiJJYWXYWsOIvGsf/view')
exit()
self.net.load_state_dict(torch.load(model_path))
self.net.eval().cuda()

def init(self, init_frame, initial_bbox):
"""
Initialize DaSiamRPN tracker with inital frame and bounding box.
"""
target_pos, target_sz = self.bbox_to_pos(initial_bbox)
self.state = SiamRPN_init(
init_frame, target_pos, target_sz, self.net)

def update(self, next_image):
"""
Update bounding box position and size on next_image. Returns True
beacuse tracking is terminated based on number of frames predicted
in OpenLabeling, not based on feedback from tracking algorithm (unlike
the opencv tracking algorithms).
"""
self.state = SiamRPN_track(self.state, next_image)
target_pos = self.state["target_pos"]
target_sz = self.state["target_sz"]
bbox = self.pos_to_bbox(target_pos, target_sz)

return True, bbox

def bbox_to_pos(self, initial_bbox):
"""
Convert bounding box format from a tuple format containing
xmin, ymin, width, and height to a tuple of two arrays which contain
the x and y coordinates of the center of the box and its width and
height respectively.
"""
xmin, ymin, w, h = initial_bbox
cx = int(xmin + w/2)
cy = int(ymin + h/2)
target_pos = np.array([cx, cy])
target_sz = np.array([w, h])

return target_pos, target_sz

def pos_to_bbox(self, target_pos, target_sz):
"""
Invert the bounding box format produced in the above conversion
function.
"""
w = target_sz[0]
h = target_sz[1]
xmin = int(target_pos[0] - w/2)
ymin = int(target_pos[1] - h/2)

return xmin, ymin, w, h
68 changes: 42 additions & 26 deletions main/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lxml import etree
import xml.etree.cElementTree as ET

from dasiamrpn import dasiamrpn

DELAY = 20 # keyboard delay (in milliseconds)
WITH_QT = False
Expand All @@ -28,18 +29,22 @@
parser.add_argument('-i', '--input_dir', default='input', type=str, help='Path to input directory')
parser.add_argument('-o', '--output_dir', default='output', type=str, help='Path to output directory')
parser.add_argument('-t', '--thickness', default='1', type=int, help='Bounding box and cross line thickness')
parser.add_argument('--tracker', default='KCF', type=str, help='Type of tracker being used')
parser.add_argument('-n', '--n_frames', default='50', type=int, help='number of frames to track object for')
args = parser.parse_args()

class_index = 0
img_index = 0
img = None
img_objects = []

INPUT_DIR = args.input_dir
INPUT_DIR = args.input_dir
OUTPUT_DIR = args.output_dir
N_FRAMES = args.n_frames
TRACKER_TYPE = args.tracker

WINDOW_NAME = 'OpenLabeling'
TRACKBAR_IMG = 'Image'
WINDOW_NAME = 'OpenLabeling'
TRACKBAR_IMG = 'Image'
TRACKBAR_CLASS = 'Class'

annotation_formats = {'PASCAL_VOC' : '.xml', 'YOLO_darknet' : '.txt'}
Expand Down Expand Up @@ -569,7 +574,7 @@ def mouse_listener(event, x, y, flags, param):
point_1 = (x, y)
else:
# minimal size for bounding box to avoid errors
threshold = 10
threshold = 5
if abs(x - point_1[0]) > threshold or abs(y - point_1[1]) > threshold:
# second click
point_2 = (x, y)
Expand Down Expand Up @@ -790,7 +795,7 @@ class LabelTracker():
# TODO: press ESC to stop the tracking process

def __init__(self, tracker_type, init_frame, next_frame_path_list):
tracker_types = ['CSRT', 'KCF','MOSSE', 'MIL', 'BOOSTING', 'MEDIANFLOW', 'TLD', 'GOTURN']
tracker_types = ['CSRT', 'KCF','MOSSE', 'MIL', 'BOOSTING', 'MEDIANFLOW', 'TLD', 'GOTURN', 'DASIAMRPN']
''' Recomended tracker_type:
KCF -> KCF is usually very good (minimum OpenCV 3.1.0)
CSRT -> More accurate than KCF but slightly slower (minimum OpenCV 3.4.2)
Expand All @@ -809,27 +814,36 @@ def __init__(self, tracker_type, init_frame, next_frame_path_list):


def call_tracker_constructor(self, tracker_type):
# -- TODO: remove this if I assume OpenCV version > 3.4.0
if int(self.major_ver == 3) and int(self.minor_ver) < 3:
tracker = cv2.Tracker_create(tracker_type)
# --
if tracker_type == 'DASIAMRPN':
tracker = dasiamrpn()
else:
if tracker_type == 'CSRT':
tracker = cv2.TrackerCSRT_create()
elif tracker_type == 'KCF':
tracker = cv2.TrackerKCF_create()
elif tracker_type == 'MOSSE':
tracker = cv2.TrackerMOSSE_create()
elif tracker_type == 'MIL':
tracker = cv2.TrackerMIL_create()
elif tracker_type == 'BOOSTING':
tracker = cv2.TrackerBoosting_create()
elif tracker_type == 'MEDIANFLOW':
tracker = cv2.TrackerMedianFlow_create()
elif tracker_type == 'TLD':
tracker = cv2.TrackerTLD_create()
elif tracker_type == 'GOTURN':
tracker = cv2.TrackerGOTURN_create()
# -- TODO: remove this if I assume OpenCV version > 3.4.0
if int(self.major_ver == 3) and int(self.minor_ver) < 3:
#tracker = cv2.Tracker_create(tracker_type)
pass
# --
else:
try:
tracker = cv2.TrackerKCF_create()
except AttributeError as error:
print(error)
print('\nMake sure that OpenCV contribute is installed: opencv-contrib-python\n')
if tracker_type == 'CSRT':
tracker = cv2.TrackerCSRT_create()
elif tracker_type == 'KCF':
tracker = cv2.TrackerKCF_create()
elif tracker_type == 'MOSSE':
tracker = cv2.TrackerMOSSE_create()
elif tracker_type == 'MIL':
tracker = cv2.TrackerMIL_create()
elif tracker_type == 'BOOSTING':
tracker = cv2.TrackerBoosting_create()
elif tracker_type == 'MEDIANFLOW':
tracker = cv2.TrackerMedianFlow_create()
elif tracker_type == 'TLD':
tracker = cv2.TrackerTLD_create()
elif tracker_type == 'GOTURN':
tracker = cv2.TrackerGOTURN_create()
return tracker


Expand All @@ -848,6 +862,8 @@ def start_tracker(self, json_file_data, json_file_path, img_path, obj, color, an
next_image = cv2.imread(frame_path)
# get the new bbox prediction of the object
success, bbox = tracker.update(next_image.copy())
if pred_counter >= N_FRAMES:
success = False
if success:
pred_counter += 1
xmin, ymin, w, h = map(int, bbox)
Expand Down Expand Up @@ -1080,7 +1096,7 @@ def complement_bgr(color):
next_frame_path_list = get_next_frame_path_list(video_name, img_path)
# initial frame
init_frame = img.copy()
label_tracker = LabelTracker('KCF', init_frame, next_frame_path_list) # TODO: replace 'KCF' by 'CSRT'
label_tracker = LabelTracker(TRACKER_TYPE, init_frame, next_frame_path_list)
for obj in object_list:
class_index = obj[0]
color = class_rgb[class_index].tolist()
Expand Down

0 comments on commit 77c1ace

Please sign in to comment.