From f343b6d5e36678d49b3d06b358c08d58327b9ad7 Mon Sep 17 00:00:00 2001 From: Niraj Pandkar Date: Mon, 8 Oct 2018 19:02:25 +0530 Subject: [PATCH] Included support for one more model - densenet121 --- predict.py | 9 ++++----- train.py | 19 +++++++++++-------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/predict.py b/predict.py index d83420a..5d8ea14 100644 --- a/predict.py +++ b/predict.py @@ -54,11 +54,10 @@ def load_model(model_checkpoint): checkpoint = torch.load(model_checkpoint) arch = checkpoint["arch"] - model = None - if arch == "vgg16": - model = models.vgg16(pretrained=True) - elif arch == "vgg13": - model = models.vgg13(pretrained=True) + my_local = dict() + exec("model = models.{}(pretrained=True)".format(arch), globals(), my_local) + + model = my_local['model'] for param in model.parameters(): param.requires_grad = False diff --git a/train.py b/train.py index eb35353..0b2915f 100644 --- a/train.py +++ b/train.py @@ -154,13 +154,16 @@ def load_data_folder(data_folder="data"): return train_dataloader, valid_dataloader, train_dataset.class_to_idx def build_model(arch="vgg16", hidden_units=4096, class_idx_mapping=None): - model=None - if arch == "vgg16": - model = models.vgg16(pretrained=True) - else: - exec("model = models.{}(pretrained=True)".format(arch)) + my_local = dict() + exec("model = models.{}(pretrained=True)".format(arch), globals(), my_local) + + model = my_local['model'] + last_child = list(model.children())[-1] - input_features = model.classifier[0].in_features + if type(last_child) == torch.nn.modules.linear.Linear: + input_features = last_child.in_features + elif type(last_child) == torch.nn.modules.container.Sequential: + input_features = last_child[0].in_features for param in model.parameters(): param.requires_grad = False @@ -183,7 +186,7 @@ def main(): ap.add_argument("data_dir", help="Directory containing the dataset.", default="data", nargs="?") - VALID_ARCH_CHOICES = ("vgg16", "vgg13", "resnet18", "densenet121", "inception_v3") + VALID_ARCH_CHOICES = ("vgg16", "vgg13", "densenet121") ap.add_argument("--arch", help="Model architecture from 'torchvision.models'. (default: vgg16)", choices=VALID_ARCH_CHOICES, default=VALID_ARCH_CHOICES[0]) @@ -203,7 +206,7 @@ def main(): default="models") args = vars(ap.parse_args()) - os.system("mkdir -p models") + os.system("mkdir -p " + args["model_dir"]) (train_dataloader, valid_dataloader, class_idx_mapping) = load_data_folder(data_folder=args["data_dir"])