-
Notifications
You must be signed in to change notification settings - Fork 61
/
feature_extraction.py
61 lines (49 loc) · 1.66 KB
/
feature_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import transforms, models
import torch.nn as nn
class Rescale(object):
"""Rescale a image to a given size.
Args:
output_size (tuple or tuple): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, *output_size):
self.output_size = output_size
def __call__(self, image):
"""
Args:
image (PIL.Image) : PIL.Image object to rescale
"""
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = image.resize((new_w, new_h), resample=Image.BILINEAR)
return img
class ResNetFeature(nn.Module):
def __init__(self, feature='resnet101'):
"""
Args:
feature (string): resnet101 or resnet152
"""
super(ResNetFeature, self).__init__()
if feature == 'resnet101':
resnet = models.resnet101(pretrained=True)
else:
resnet = models.resnet152(pretrained=True)
resnet.float()
resnet.cuda()
resnet.eval()
module_list = list(resnet.children())
self.conv5 = nn.Sequential(*module_list[:-2])
self.pool5 = module_list[-2]
def forward(self, x):
res5c = self.conv5(x)
pool5 = self.pool5(res5c)
pool5 = pool5.view(pool5.size(0), -1)
return res5c, pool5
resnet_transform = transforms.Compose([
Rescale(224, 224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])