-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
Copy pathmodel.py
548 lines (458 loc) · 19.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: model.py
import tensorflow as tf
from tensorpack.tfutils import get_current_tower_context
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import (
Conv2D, FullyConnected, GlobalAvgPooling, layer_register, Deconv2D)
from utils.box_ops import pairwise_iou
import config
@under_name_scope()
def clip_boxes(boxes, window, name=None):
"""
Args:
boxes: nx4, xyxy
window: [h, w]
"""
boxes = tf.maximum(boxes, 0.0)
m = tf.tile(tf.reverse(window, [0]), [2]) # (4,)
boxes = tf.minimum(boxes, tf.to_float(m), name=name)
return boxes
@layer_register(log_shape=True)
def rpn_head(featuremap, channel, num_anchors):
"""
Returns:
label_logits: fHxfWxNA
box_logits: fHxfWxNAx4
"""
with argscope(Conv2D, data_format='NCHW',
W_init=tf.random_normal_initializer(stddev=0.01)):
hidden = Conv2D('conv0', featuremap, channel, 3, nl=tf.nn.relu)
label_logits = Conv2D('class', hidden, num_anchors, 1)
box_logits = Conv2D('box', hidden, 4 * num_anchors, 1)
# 1, NA(*4), im/16, im/16 (NCHW)
label_logits = tf.transpose(label_logits, [0, 2, 3, 1]) # 1xfHxfWxNA
label_logits = tf.squeeze(label_logits, 0) # fHxfWxNA
shp = tf.shape(box_logits) # 1x(NAx4)xfHxfW
box_logits = tf.transpose(box_logits, [0, 2, 3, 1]) # 1xfHxfWx(NAx4)
box_logits = tf.reshape(box_logits, tf.stack([shp[2], shp[3], num_anchors, 4])) # fHxfWxNAx4
return label_logits, box_logits
@under_name_scope()
def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
"""
Args:
anchor_labels: fHxfWxNA
anchor_boxes: fHxfWxNAx4, encoded
label_logits: fHxfWxNA
box_logits: fHxfWxNAx4
Returns:
label_loss, box_loss
"""
with tf.device('/cpu:0'):
valid_mask = tf.stop_gradient(tf.not_equal(anchor_labels, -1))
pos_mask = tf.stop_gradient(tf.equal(anchor_labels, 1))
nr_valid = tf.stop_gradient(tf.count_nonzero(valid_mask, dtype=tf.int32), name='num_valid_anchor')
nr_pos = tf.count_nonzero(pos_mask, dtype=tf.int32, name='num_pos_anchor')
valid_anchor_labels = tf.boolean_mask(anchor_labels, valid_mask)
valid_label_logits = tf.boolean_mask(label_logits, valid_mask)
with tf.name_scope('label_metrics'):
valid_label_prob = tf.nn.sigmoid(valid_label_logits)
summaries = []
with tf.device('/cpu:0'):
for th in [0.5, 0.2, 0.1]:
valid_prediction = tf.cast(valid_label_prob > th, tf.int32)
nr_pos_prediction = tf.reduce_sum(valid_prediction, name='num_pos_prediction')
pos_prediction_corr = tf.count_nonzero(
tf.logical_and(
valid_label_prob > th,
tf.equal(valid_prediction, valid_anchor_labels)),
dtype=tf.int32)
summaries.append(tf.truediv(
pos_prediction_corr,
nr_pos, name='recall_th{}'.format(th)))
precision = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos_prediction))
precision = tf.where(tf.equal(nr_pos_prediction, 0), 0.0, precision, name='precision_th{}'.format(th))
summaries.append(precision)
add_moving_summary(*summaries)
label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits)
label_loss = tf.reduce_mean(label_loss, name='label_loss')
pos_anchor_boxes = tf.boolean_mask(anchor_boxes, pos_mask)
pos_box_logits = tf.boolean_mask(box_logits, pos_mask)
delta = 1.0 / 9
box_loss = tf.losses.huber_loss(
pos_anchor_boxes, pos_box_logits, delta=delta,
reduction=tf.losses.Reduction.SUM) / delta
box_loss = tf.div(
box_loss,
tf.cast(nr_valid, tf.float32), name='box_loss')
add_moving_summary(label_loss, box_loss, nr_valid, nr_pos)
return label_loss, box_loss
@under_name_scope()
def decode_bbox_target(box_predictions, anchors):
"""
Args:
box_predictions: (..., 4), logits
anchors: (..., 4), floatbox. Must have the same shape
Returns:
box_decoded: (..., 4), float32. With the same shape.
"""
orig_shape = tf.shape(anchors)
box_pred_txtytwth = tf.reshape(box_predictions, (-1, 2, 2))
box_pred_txty, box_pred_twth = tf.split(box_pred_txtytwth, 2, axis=1)
# each is (...)x1x2
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
waha = anchors_x2y2 - anchors_x1y1
xaya = (anchors_x2y2 + anchors_x1y1) * 0.5
wbhb = tf.exp(tf.minimum(
box_pred_twth, config.BBOX_DECODE_CLIP)) * waha
xbyb = box_pred_txty * waha + xaya
x1y1 = xbyb - wbhb * 0.5
x2y2 = xbyb + wbhb * 0.5 # (...)x1x2
out = tf.concat([x1y1, x2y2], axis=-2)
return tf.reshape(out, orig_shape)
@under_name_scope()
def encode_bbox_target(boxes, anchors):
"""
Args:
boxes: (..., 4), float32
anchors: (..., 4), float32
Returns:
box_encoded: (..., 4), float32 with the same shape.
"""
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
waha = anchors_x2y2 - anchors_x1y1
xaya = (anchors_x2y2 + anchors_x1y1) * 0.5
boxes_x1y1x2y2 = tf.reshape(boxes, (-1, 2, 2))
boxes_x1y1, boxes_x2y2 = tf.split(boxes_x1y1x2y2, 2, axis=1)
wbhb = boxes_x2y2 - boxes_x1y1
xbyb = (boxes_x2y2 + boxes_x1y1) * 0.5
# Note that here not all boxes are valid. Some may be zero
txty = (xbyb - xaya) / waha
twth = tf.log(wbhb / waha) # may contain -inf for invalid boxes
encoded = tf.concat([txty, twth], axis=1) # (-1x2x2)
return tf.reshape(encoded, tf.shape(boxes))
@under_name_scope()
def generate_rpn_proposals(boxes, scores, img_shape):
"""
Args:
boxes: nx4 float dtype, decoded to floatbox already
scores: n float, the logits
img_shape: [h, w]
Returns:
boxes: kx4 float
scores: k logits
"""
assert boxes.shape.ndims == 2, boxes.shape
if get_current_tower_context().is_training:
PRE_NMS_TOPK = config.TRAIN_PRE_NMS_TOPK
POST_NMS_TOPK = config.TRAIN_POST_NMS_TOPK
else:
PRE_NMS_TOPK = config.TEST_PRE_NMS_TOPK
POST_NMS_TOPK = config.TEST_POST_NMS_TOPK
topk = tf.minimum(PRE_NMS_TOPK, tf.size(scores))
topk_scores, topk_indices = tf.nn.top_k(scores, k=topk, sorted=False)
topk_boxes = tf.gather(boxes, topk_indices)
topk_boxes = clip_boxes(topk_boxes, img_shape)
topk_boxes_x1y1x2y2 = tf.reshape(topk_boxes, (-1, 2, 2))
topk_boxes_x1y1, topk_boxes_x2y2 = tf.split(topk_boxes_x1y1x2y2, 2, axis=1)
# nx1x2 each
wbhb = tf.squeeze(topk_boxes_x2y2 - topk_boxes_x1y1, axis=1)
valid = tf.reduce_all(wbhb > config.RPN_MIN_SIZE, axis=1) # n,
topk_valid_boxes_x1y1x2y2 = tf.boolean_mask(topk_boxes_x1y1x2y2, valid)
topk_valid_scores = tf.boolean_mask(topk_scores, valid)
topk_valid_boxes_y1x1y2x2 = tf.reshape(
tf.reverse(topk_valid_boxes_x1y1x2y2, axis=[2]),
(-1, 4), name='nms_input_boxes')
nms_indices = tf.image.non_max_suppression(
topk_valid_boxes_y1x1y2x2,
topk_valid_scores,
max_output_size=POST_NMS_TOPK,
iou_threshold=config.RPN_PROPOSAL_NMS_THRESH)
topk_valid_boxes = tf.reshape(topk_valid_boxes_x1y1x2y2, (-1, 4))
final_boxes = tf.gather(
topk_valid_boxes,
nms_indices, name='boxes')
final_scores = tf.gather(topk_valid_scores, nms_indices, name='scores')
tf.sigmoid(final_scores, name='probs') # for visualization
return final_boxes, final_scores
@under_name_scope()
def proposal_metrics(iou):
"""
Add summaries for RPN proposals.
Args:
iou: nxm, #proposal x #gt
"""
# find best roi for each gt, for summary only
best_iou = tf.reduce_max(iou, axis=0)
mean_best_iou = tf.reduce_mean(best_iou, name='best_iou_per_gt')
summaries = [mean_best_iou]
with tf.device('/cpu:0'):
for th in [0.3, 0.5]:
recall = tf.truediv(
tf.count_nonzero(best_iou >= th),
tf.size(best_iou, out_type=tf.int64),
name='recall_iou{}'.format(th))
summaries.append(recall)
add_moving_summary(*summaries)
@under_name_scope()
def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
"""
Sample some ROIs from all proposals for training.
Args:
boxes: nx4 region proposals, floatbox
gt_boxes: mx4, floatbox
gt_labels: m, int32
Returns:
sampled_boxes: tx4 floatbox, the rois
sampled_labels: t labels, in [0, #class-1]. Positive means foreground.
fg_inds_wrt_gt: #fg indices, each in range [0, m-1].
It contains the matching GT of each foreground roi.
"""
iou = pairwise_iou(boxes, gt_boxes) # nxm
proposal_metrics(iou)
# add ground truth as proposals as well
boxes = tf.concat([boxes, gt_boxes], axis=0) # (n+m) x 4
iou = tf.concat([iou, tf.eye(tf.shape(gt_boxes)[0])], axis=0) # (n+m) x m
# #proposal=n+m from now on
def sample_fg_bg(iou):
fg_mask = tf.reduce_max(iou, axis=1) >= config.FASTRCNN_FG_THRESH
fg_inds = tf.reshape(tf.where(fg_mask), [-1])
num_fg = tf.minimum(int(
config.FASTRCNN_BATCH_PER_IM * config.FASTRCNN_FG_RATIO),
tf.size(fg_inds), name='num_fg')
fg_inds = tf.random_shuffle(fg_inds)[:num_fg]
bg_inds = tf.reshape(tf.where(tf.logical_not(fg_mask)), [-1])
num_bg = tf.minimum(
config.FASTRCNN_BATCH_PER_IM - num_fg,
tf.size(bg_inds), name='num_bg')
bg_inds = tf.random_shuffle(bg_inds)[:num_bg]
add_moving_summary(num_fg, num_bg)
return fg_inds, bg_inds
fg_inds, bg_inds = sample_fg_bg(iou)
# fg,bg indices w.r.t proposals
best_iou_ind = tf.argmax(iou, axis=1) # #proposal, each in 0~m-1
fg_inds_wrt_gt = tf.gather(best_iou_ind, fg_inds) # num_fg
all_indices = tf.concat([fg_inds, bg_inds], axis=0) # indices w.r.t all n+m proposal boxes
ret_boxes = tf.gather(boxes, all_indices, name='sampled_proposal_boxes')
ret_labels = tf.concat(
[tf.gather(gt_labels, fg_inds_wrt_gt),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels')
# stop the gradient -- they are meant to be ground-truth
return tf.stop_gradient(ret_boxes), tf.stop_gradient(ret_labels), fg_inds_wrt_gt
@under_name_scope()
def crop_and_resize(image, boxes, box_ind, crop_size):
"""
Better-aligned version of tf.image.crop_and_resize, following our definition of floating point boxes.
Args:
image: NCHW
boxes: nx4, x1y1x2y2
box_ind: (n,)
crop_size (int):
Returns:
n,C,size,size
"""
assert isinstance(crop_size, int), crop_size
@under_name_scope()
def transform_fpcoor_for_tf(boxes, image_shape, crop_shape):
"""
The way tf.image.crop_and_resize works (with normalized box):
Initial point (the value of output[0]): x0_box * (W_img - 1)
Spacing: w_box * (W_img - 1) / (W_crop - 1)
Use the above grid to bilinear sample.
However, what we want is (with fpcoor box):
Spacing: w_box / W_crop
Initial point: x0_box + spacing/2 - 0.5
(-0.5 because bilinear sample assumes floating point coordinate (0.0, 0.0) is the same as pixel value (0, 0))
This function transform fpcoor boxes to a format to be used by tf.image.crop_and_resize
Returns:
y1x1y2x2
"""
x0, y0, x1, y1 = tf.split(boxes, 4, axis=1)
spacing_w = (x1 - x0) / tf.to_float(crop_shape[1])
spacing_h = (y1 - y0) / tf.to_float(crop_shape[0])
nx0 = (x0 + spacing_w / 2 - 0.5) / tf.to_float(image_shape[1] - 1)
ny0 = (y0 + spacing_h / 2 - 0.5) / tf.to_float(image_shape[0] - 1)
nw = spacing_w * tf.to_float(crop_shape[1] - 1) / tf.to_float(image_shape[1] - 1)
nh = spacing_h * tf.to_float(crop_shape[0] - 1) / tf.to_float(image_shape[0] - 1)
return tf.concat([ny0, nx0, ny0 + nh, nx0 + nw], axis=1)
image_shape = tf.shape(image)[2:]
boxes = transform_fpcoor_for_tf(boxes, image_shape, [crop_size, crop_size])
image = tf.transpose(image, [0, 2, 3, 1]) # 1hwc
ret = tf.image.crop_and_resize(
image, boxes, box_ind,
crop_size=[crop_size, crop_size])
ret = tf.transpose(ret, [0, 3, 1, 2]) # ncss
return ret
@under_name_scope()
def roi_align(featuremap, boxes, output_shape):
"""
Args:
featuremap: 1xCxHxW
boxes: Nx4 floatbox
output_shape: int
Returns:
NxCxoHxoW
"""
boxes = tf.stop_gradient(boxes) # TODO
# sample 4 locations per roi bin
ret = crop_and_resize(
featuremap, boxes,
tf.zeros([tf.shape(boxes)[0]], dtype=tf.int32),
output_shape * 2)
ret = tf.nn.avg_pool(ret, [1, 1, 2, 2], [1, 1, 2, 2], padding='SAME', data_format='NCHW')
return ret
@layer_register(log_shape=True)
def fastrcnn_head(feature, num_classes):
"""
Args:
feature (NxCx7x7):
num_classes(int): num_category + 1
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
feature = GlobalAvgPooling('gap', feature, data_format='NCHW')
classification = FullyConnected(
'class', feature, num_classes,
W_init=tf.random_normal_initializer(stddev=0.01))
box_regression = FullyConnected(
'box', feature, (num_classes - 1) * 4,
W_init=tf.random_normal_initializer(stddev=0.001))
box_regression = tf.reshape(box_regression, (-1, num_classes - 1, 4))
return classification, box_regression
@under_name_scope()
def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
"""
Args:
labels: n,
label_logits: nxC
fg_boxes: nfgx4, encoded
fg_box_logits: nfgx(C-1)x4
"""
label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=label_logits)
label_loss = tf.reduce_mean(label_loss, name='label_loss')
fg_inds = tf.where(labels > 0)[:, 0]
fg_labels = tf.gather(labels, fg_inds)
num_fg = tf.size(fg_inds)
indices = tf.stack(
[tf.range(num_fg),
tf.to_int32(fg_labels) - 1], axis=1) # #fgx2
fg_box_logits = tf.gather_nd(fg_box_logits, indices)
with tf.name_scope('label_metrics'), tf.device('/cpu:0'):
prediction = tf.argmax(label_logits, axis=1, name='label_prediction')
correct = tf.to_float(tf.equal(prediction, labels)) # boolean/integer gather is unavailable on GPU
accuracy = tf.reduce_mean(correct, name='accuracy')
fg_label_pred = tf.argmax(tf.gather(label_logits, fg_inds), axis=1)
num_zero = tf.reduce_sum(tf.to_int32(tf.equal(fg_label_pred, 0)), name='num_zero')
false_negative = tf.truediv(num_zero, num_fg, name='false_negative')
fg_accuracy = tf.reduce_mean(
tf.gather(correct, fg_inds), name='fg_accuracy')
box_loss = tf.losses.huber_loss(
fg_boxes, fg_box_logits, reduction=tf.losses.Reduction.SUM)
box_loss = tf.truediv(
box_loss, tf.to_float(tf.shape(labels)[0]), name='box_loss')
add_moving_summary(label_loss, box_loss, accuracy, fg_accuracy, false_negative)
return label_loss, box_loss
@under_name_scope()
def fastrcnn_predictions(boxes, probs):
"""
Generate final results from predictions of all proposals.
Args:
boxes: n#catx4 floatbox in float32
probs: nx#class
"""
assert boxes.shape[1] == config.NUM_CLASS - 1
assert probs.shape[1] == config.NUM_CLASS
boxes = tf.transpose(boxes, [1, 0, 2]) # #catxnx4
probs = tf.transpose(probs[:, 1:], [1, 0]) # #catxn
def f(X):
"""
prob: n probabilities
box: nx4 boxes
Returns: n boolean, the selection
"""
prob, box = X
output_shape = tf.shape(prob)
# filter by score threshold
ids = tf.reshape(tf.where(prob > config.RESULT_SCORE_THRESH), [-1])
prob = tf.gather(prob, ids)
box = tf.gather(box, ids)
# NMS within each class
selection = tf.image.non_max_suppression(
box, prob, config.RESULTS_PER_IM, config.FASTRCNN_NMS_THRESH)
selection = tf.to_int32(tf.gather(ids, selection))
# sort available in TF>1.4.0
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0]
mask = tf.sparse_to_dense(
sparse_indices=sorted_selection,
output_shape=output_shape,
sparse_values=True,
default_value=False)
return mask
masks = tf.map_fn(f, (probs, boxes), dtype=tf.bool,
parallel_iterations=10) # #cat x N
selected_indices = tf.where(masks) # #selection x 2, each is (cat_id, box_id)
probs = tf.boolean_mask(probs, masks)
# filter again by sorting scores
topk_probs, topk_indices = tf.nn.top_k(
probs,
tf.minimum(config.RESULTS_PER_IM, tf.size(probs)),
sorted=False)
filtered_selection = tf.gather(selected_indices, topk_indices)
filtered_selection = tf.reverse(filtered_selection, axis=[1], name='filtered_indices')
return filtered_selection, topk_probs
@layer_register(log_shape=True)
def maskrcnn_head(feature, num_class):
"""
Args:
feature (NxCx7x7):
num_classes(int): num_category + 1
Returns:
mask_logits (N x num_category x 14 x 14):
"""
with argscope([Conv2D, Deconv2D], data_format='NCHW',
W_init=tf.variance_scaling_initializer(
scale=2.0, mode='fan_in', distribution='normal')):
l = Deconv2D('deconv', feature, 256, 2, stride=2, nl=tf.nn.relu)
l = Conv2D('conv', l, num_class - 1, 1)
return l
@under_name_scope()
def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
"""
Args:
mask_logits: #fg x #category x14x14
fg_labels: #fg, in 1~#class
fg_target_masks: #fgx14x14, int
"""
num_fg = tf.size(fg_labels)
indices = tf.stack([tf.range(num_fg), tf.to_int32(fg_labels) - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fgx14x14
mask_probs = tf.sigmoid(mask_logits)
# add some training visualizations to tensorboard
with tf.name_scope('mask_viz'):
viz = tf.concat([fg_target_masks, mask_probs], axis=1)
viz = tf.expand_dims(viz, 3)
viz = tf.cast(viz * 255, tf.uint8, name='viz')
tf.summary.image('mask_truth|pred', viz, max_outputs=10)
loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=fg_target_masks, logits=mask_logits)
loss = tf.reduce_mean(loss, name='maskrcnn_loss')
pred_label = mask_probs > 0.5
truth_label = fg_target_masks > 0.5
accuracy = tf.reduce_mean(
tf.to_float(tf.equal(pred_label, truth_label)),
name='accuracy')
pos_accuracy = tf.logical_and(
tf.equal(pred_label, truth_label),
tf.equal(truth_label, True))
pos_accuracy = tf.reduce_mean(tf.to_float(pos_accuracy), name='pos_accuracy')
fg_pixel_ratio = tf.reduce_mean(tf.to_float(truth_label), name='fg_pixel_ratio')
add_moving_summary(loss, accuracy, fg_pixel_ratio, pos_accuracy)
return loss