diff --git a/lmnet/lmnet/networks/base.py b/lmnet/lmnet/networks/base.py index 4df7a6ba2..c057e78bc 100644 --- a/lmnet/lmnet/networks/base.py +++ b/lmnet/lmnet/networks/base.py @@ -58,7 +58,7 @@ def __init__( self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {'learning_rate': 0.01} self.learning_rate_func = learning_rate_func self.learning_rate_kwargs = learning_rate_kwargs if learning_rate_kwargs is not None else {} - self.classes = classes + self.classes = list(map(lambda _class: _class.replace(' ', '_'), classes)) self.num_classes = len(classes) self.image_size = image_size self.batch_size = batch_size diff --git a/lmnet/lmnet/networks/segmentation/base.py b/lmnet/lmnet/networks/segmentation/base.py index 3129ae526..51878ed57 100755 --- a/lmnet/lmnet/networks/segmentation/base.py +++ b/lmnet/lmnet/networks/segmentation/base.py @@ -115,8 +115,6 @@ def metrics(self, output, labels): pred = tf.equal(output_argmax, i) truth = tf.equal(labels, i) - class_name = class_name.replace(' ', '_') - true_positive, true_positive_update = tf.metrics.true_positives(truth, pred, name=class_name) false_positive, false_positive_update = tf.metrics.false_positives(truth, pred, name=class_name) false_negative, false_negative_update = tf.metrics.false_negatives(truth, pred, name=class_name) diff --git a/lmnet/tests/lmnet_tests/networks_tests/object_detection_tests/test_yolo_v2.py b/lmnet/tests/lmnet_tests/networks_tests/object_detection_tests/test_yolo_v2.py index c660981bb..b2582bee1 100644 --- a/lmnet/tests/lmnet_tests/networks_tests/object_detection_tests/test_yolo_v2.py +++ b/lmnet/tests/lmnet_tests/networks_tests/object_detection_tests/test_yolo_v2.py @@ -736,7 +736,7 @@ def test_yolov2_post_process(): image_size = [96, 64] batch_size = 2 - classes = range(5) + classes = Pascalvoc2007.classes anchors = [(0.1, 0.2), (1.2, 1.1)] data_format = "NHWC" score_threshold = 0.25 diff --git a/lmnet/tests/lmnet_tests/networks_tests/segmentation_tests/test_lm_bisenet.py b/lmnet/tests/lmnet_tests/networks_tests/segmentation_tests/test_lm_bisenet.py index 1ee6566bc..8b714e2f7 100644 --- a/lmnet/tests/lmnet_tests/networks_tests/segmentation_tests/test_lm_bisenet.py +++ b/lmnet/tests/lmnet_tests/networks_tests/segmentation_tests/test_lm_bisenet.py @@ -79,7 +79,7 @@ def test_lm_bisenet_post_process(): image_size = [96, 64] batch_size = 2 - classes = range(5) + classes = Camvid.classes data_format = "NHWC" model = LMBiSeNet(