-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
59 lines (39 loc) · 1.54 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2
import torch
from models.modelB4 import LDC
from preprocessing.img_processing import transform, pixel_adjust
from preprocessing.save_images import save_image_to_disk
device = torch.device('cpu' if torch.cuda.device_count() == 0 else 'cuda')
#Loading Model
model = LDC().to(device)
# Pretrained Weights Path
checkpoint_path = 'weights/BRIND/11/11_model.pth'
def predict_img(checkpoint_path, file_path, save_path='output/average'):
#Extract the Filename
file_names =[file_path.split('/')[-1]]
#Read Original Image
img = cv2.imread(file_path)
print(img.shape)
#Tensor Conversion of Shape for Pre-processing & Scaling of Image
image_shape = [torch.tensor([img.shape[0]]), torch.tensor([img.shape[1]])]
#Loading Model weights & Eval Call
model.load_state_dict(torch.load(checkpoint_path,map_location=device))
model.eval()
#Transorm Image (Channel First format) based on Prediction Model
images, mean_rgb = transform(img)
# Prediction Block
try:
preds = model(images)
print('Prediction Successfull')
except Exception as e:
# Handling of Pixel adjustment
print('Error:', e)
img = pixel_adjust(img)
print('Adjusting Pixel')
images, mean_rgb = transform(img)
preds = model(images)
#Save Image as well as return output image array
output = save_image_to_disk(preds, file_names, save_path, image_shape)
return output, img