Skip to content

Commit

Permalink
Delete absolute mode path from train executor (blue-oil#431)
Browse files Browse the repository at this point in the history
* delete absolute mode class path

* add task type to test config
  • Loading branch information
yasumura-lm authored and ananno committed Sep 27, 2019
1 parent a292e04 commit 4772222
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 2 deletions.
5 changes: 3 additions & 2 deletions lmnet/executor/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lmnet.datasets.base import ObjectDetectionBase
from lmnet.datasets.dataset_iterator import DatasetIterator
from lmnet.datasets.tfds import TFDSClassification, TFDSObjectDetection
from lmnet.common import Tasks


def _save_checkpoint(saver, sess, global_step, step):
Expand Down Expand Up @@ -85,7 +86,7 @@ def start_training(config):

graph = tf.Graph()
with graph.as_default():
if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
if config.TASK == Tasks.OBJECT_DETECTION:
model = ModelClass(
classes=train_dataset.classes,
num_max_boxes=train_dataset.num_max_boxes,
Expand All @@ -105,7 +106,7 @@ def start_training(config):
images_placeholder, labels_placeholder = model.placeholders()

output = model.inference(images_placeholder, is_training_placeholder)
if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
if config.TASK == Tasks.OBJECT_DETECTION:
loss = model.loss(output, labels_placeholder, global_step)
else:
loss = model.loss(output, labels_placeholder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from lmnet.networks.classification.darknet import Darknet
from lmnet.utils.executor import prepare_dirs
from lmnet.pre_processor import Resize
from lmnet.common import Tasks
from executor.train import start_training


Expand Down Expand Up @@ -50,6 +51,7 @@ def test_training():
config.KEEP_CHECKPOINT_MAX = 5
config.SUMMARISE_STEPS = 1
config.IS_PRETRAIN = False
config.TASK = Tasks.CLASSIFICATION

# network model config
config.NETWORK = EasyDict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
binary_mean_scaling_quantizer,
linear_mid_tread_half_quantizer,
)
from lmnet.common import Tasks
from executor.train import start_training

# Apply reset_default_graph() in conftest.py to all tests in this file.
Expand Down Expand Up @@ -53,6 +54,7 @@ def test_training():
config.KEEP_CHECKPOINT_MAX = 5
config.SUMMARISE_STEPS = 1
config.IS_PRETRAIN = False
config.TASK = Tasks.CLASSIFICATION

# network model config
config.NETWORK = EasyDict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lmnet.networks.object_detection.yolo_v1 import YoloV1
from lmnet.utils.executor import prepare_dirs
from lmnet.pre_processor import ResizeWithGtBoxes
from lmnet.common import Tasks
from executor.train import start_training

# Apply reset_default_graph() in conftest.py to all tests in this file.
Expand Down Expand Up @@ -279,6 +280,7 @@ def test_training():
config.KEEP_CHECKPOINT_MAX = 5
config.SUMMARISE_STEPS = 1
config.IS_PRETRAIN = False
config.TASK = Tasks.OBJECT_DETECTION

# network model config
config.NETWORK = EasyDict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ExcludeLowScoreBox,
NMS,
)
from lmnet.common import Tasks

# Apply reset_default_graph() in conftest.py to all tests in this file.
# Set test environment
Expand Down Expand Up @@ -713,6 +714,7 @@ def test_training():
config.KEEP_CHECKPOINT_MAX = 5
config.SUMMARISE_STEPS = 1
config.IS_PRETRAIN = False
config.TASK = Tasks.OBJECT_DETECTION

# network model config
config.NETWORK = EasyDict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
binary_channel_wise_mean_scaling_quantizer,
linear_mid_tread_half_quantizer,
)
from lmnet.common import Tasks
from executor.train import start_training


Expand All @@ -51,6 +52,7 @@ def test_training():
config.KEEP_CHECKPOINT_MAX = 5
config.SUMMARISE_STEPS = 1
config.IS_PRETRAIN = False
config.TASK = Tasks.OBJECT_DETECTION

# network model config
config.NETWORK = EasyDict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from lmnet.pre_processor import Resize
from lmnet.utils.executor import prepare_dirs
from lmnet.common import Tasks


# Apply reset_default_graph() and set_test_environment() in conftest.py to all tests in this file.
Expand All @@ -56,6 +57,7 @@ def test_training():
config.KEEP_CHECKPOINT_MAX = 5
config.SUMMARISE_STEPS = 1
config.IS_PRETRAIN = False
config.TASK = Tasks.SEMANTIC_SEGMENTATION

# network model config
config.NETWORK = EasyDict()
Expand Down

0 comments on commit 4772222

Please sign in to comment.