Skip to content

Commit

Permalink
Included support for one more model - densenet121
Browse files Browse the repository at this point in the history
  • Loading branch information
nirajpandkar committed Oct 8, 2018
1 parent e0c4623 commit f343b6d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
9 changes: 4 additions & 5 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ def load_model(model_checkpoint):
checkpoint = torch.load(model_checkpoint)

arch = checkpoint["arch"]
model = None
if arch == "vgg16":
model = models.vgg16(pretrained=True)
elif arch == "vgg13":
model = models.vgg13(pretrained=True)
my_local = dict()
exec("model = models.{}(pretrained=True)".format(arch), globals(), my_local)

model = my_local['model']

for param in model.parameters():
param.requires_grad = False
Expand Down
19 changes: 11 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,16 @@ def load_data_folder(data_folder="data"):
return train_dataloader, valid_dataloader, train_dataset.class_to_idx

def build_model(arch="vgg16", hidden_units=4096, class_idx_mapping=None):
model=None
if arch == "vgg16":
model = models.vgg16(pretrained=True)
else:
exec("model = models.{}(pretrained=True)".format(arch))
my_local = dict()
exec("model = models.{}(pretrained=True)".format(arch), globals(), my_local)

model = my_local['model']
last_child = list(model.children())[-1]

input_features = model.classifier[0].in_features
if type(last_child) == torch.nn.modules.linear.Linear:
input_features = last_child.in_features
elif type(last_child) == torch.nn.modules.container.Sequential:
input_features = last_child[0].in_features

for param in model.parameters():
param.requires_grad = False
Expand All @@ -183,7 +186,7 @@ def main():
ap.add_argument("data_dir", help="Directory containing the dataset.",
default="data", nargs="?")

VALID_ARCH_CHOICES = ("vgg16", "vgg13", "resnet18", "densenet121", "inception_v3")
VALID_ARCH_CHOICES = ("vgg16", "vgg13", "densenet121")
ap.add_argument("--arch", help="Model architecture from 'torchvision.models'. (default: vgg16)", choices=VALID_ARCH_CHOICES,
default=VALID_ARCH_CHOICES[0])

Expand All @@ -203,7 +206,7 @@ def main():
default="models")
args = vars(ap.parse_args())

os.system("mkdir -p models")
os.system("mkdir -p " + args["model_dir"])

(train_dataloader, valid_dataloader, class_idx_mapping) = load_data_folder(data_folder=args["data_dir"])

Expand Down

0 comments on commit f343b6d

Please sign in to comment.