-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify-single.py
executable file
·72 lines (63 loc) · 2.2 KB
/
classify-single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/opt/homebrew/Caskroom/miniforge/base/envs/pytorch/bin/python
# Script to classify a singular MRI as having a tumor or not having one
import getopt
from os import error
import sys
import time
from PIL import Image
import torch
from torchvision import transforms as T
import torch.cuda as cuda
def main(argv):
try:
opts, args = getopt.getopt(argv,'hm:i:')
except getopt.GetoptError:
print('Usage:classify_image.py -m <model> -i <image>')
sys.exit(2)
# Parse command-line arguments
for opt,arg in opts:
if opt == '-h':
print('Usage: classify_image.py -m <model> -i <image>')
sys.exit()
elif opt in ('-m', '--model'):
# Check for CUDA
if torch.cuda.is_available():
device = torch.device('cuda:0')
print('Running on ' + cuda.get_device_name(device) + '...')
model = torch.load(arg, map_location='cuda:0')
model.eval()
else:
device = torch.device('cpu')
print('Running on the CPU...')
model = torch.load(arg, map_location=torch.device('cpu'))
model.eval()
elif opt in ('-i', '--image'):
image = Image.open(arg)
if(len(argv) != 4):
print('Usage: classify_image.py -m <model> -i <image>')
sys.exit(2)
# Image processing
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor()])
image = image.convert('RGB')
image = transform(image)
image = image.unsqueeze(0)
# Check for CUDA
if torch.cuda.is_available():
image = image.cuda()
else:
image = image.clone().detach().requires_grad_(False)
# Set tensor to run on device
image = image.to(device)
# Feed image through network and get prediction
with torch.no_grad():
pred = model(image)
# Output the prediction
if(pred.numpy().argmax() == 0) : output = 'tumor not detected'
elif(pred.numpy().argmax() == 1) : output = 'tumor detected'
print('Predicted output: ' + output)
if __name__ == '__main__':
start_time = time.time()
main(sys.argv[1:])
print('Runtime:', time.time() - start_time, 'seconds')