Skip to content

Commit

Permalink
Fix convert.py script.
Browse files Browse the repository at this point in the history
It now loads the first conv and BN layers.
  • Loading branch information
NeutrinoXY committed Aug 14, 2020
1 parent 67f60e4 commit 8273111
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions tensorflow2.0/deep-sort-yolov4/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def load_yolo(self):

num_anchors = len(self.anchors)
num_classes = len(self.class_names)
print(num_classes)

# Generate colors for drawing bounding boxes.
hsv_tuples = [(x / len(self.class_names), 1., 1.)
Expand Down Expand Up @@ -69,10 +70,16 @@ def load_yolo(self):
bns_to_load = []
for i in range(len(self.yolo4_model.layers)):
layer_name = self.yolo4_model.layers[i].name
if layer_name.startswith('conv2d_'):
convs_to_load.append((int(layer_name[7:]), i))
if layer_name.startswith('batch_normalization_'):
bns_to_load.append((int(layer_name[20:]), i))
if layer_name.startswith('conv2d'):
if layer_name == 'conv2d':
convs_to_load.append((0, i))
else:
convs_to_load.append((int(layer_name[7:]), i))
if layer_name.startswith('batch_normalization'):
if layer_name == 'batch_normalization':
bns_to_load.append((0, i))
else:
bns_to_load.append((int(layer_name[20:]), i))

convs_sorted = sorted(convs_to_load, key=itemgetter(0))
bns_sorted = sorted(bns_to_load, key=itemgetter(0))
Expand Down Expand Up @@ -124,7 +131,6 @@ def load_yolo(self):
bn_weights[2] # running var
]
self.yolo4_model.layers[bns_sorted[bn_index][1]].set_weights(bn_weight_list)

conv_weights = np.ndarray(
shape=darknet_w_shape,
dtype='float32',
Expand Down

0 comments on commit 8273111

Please sign in to comment.