-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgui.py
108 lines (83 loc) · 2.95 KB
/
gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from tkinter import *
from tkinter import filedialog
from PIL import ImageTk,Image
import torch
import numpy as np
import monai.transforms as mt
from monai.networks.nets import UNet
from monai.networks.layers import Norm, Act
from monai.visualize.utils import blend_images
root = Tk()
root.geometry("600x800")
root.title("Teeth segmentation")
device = "mps:0"
model = UNet(
spatial_dims=2,
in_channels=1,
out_channels=2,
channels=(32, 64, 128, 256, 512),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
act = Act.LEAKYRELU
).to(device)
checkpoint = torch.load("./models/checkpoint1/model.pt")
model.load_state_dict(checkpoint["model_state_dict"])
def mask_img(img):
img = np.asarray(img, dtype=np.float32)
msk_path = "./Segmentation/maxillomandibular/" + root.filename.split("/")[-1].split(".")[0] + ".jpg"
msk = np.asarray(Image.open(msk_path).convert("1"), dtype="float32")
return np.expand_dims(img*msk, axis=0)
def save_img(img_pil):
preds_dir = "./Predictions"
if os.path.exists(preds_dir) == False:
os.makedirs(preds_dir)
out_name = os.path.join(preds_dir, root.filename.split("/")[-1].split(".")[0] + "_pred.png")
img_pil.save(out_name)
save_label = Label(root, text=f"image saved in {out_name}").pack()
def predict(img_pil):
global overlay
transforms = mt.compose.Compose(
[
mt.NormalizeIntensity(
nonzero=True,
channel_wise=True
),
mt.Resize(
spatial_size=(256, 512),
mode="bilinear"
),
mt.ToTensor(
dtype=torch.float32
)
]
)
img = mask_img(img_pil)
img_preproc = transforms(img).to(device)
pred = model(img_preproc.unsqueeze(0))
pred = torch.nn.Softmax(dim=1)(pred)
pred = torch.argmax(pred, dim=1).cpu().numpy()
img_resized = img_pil.resize((512, 256))
img_resized = np.asarray(img_resized)/255
img_resized = np.expand_dims(img_resized, 0)
# print(img_resized.shape, pred.shape)
overlay = blend_images(img_resized, pred)
overlay = np.transpose(overlay, (1, 2, 0))*255
overlay_pil = Image.fromarray(overlay.astype(np.uint8))
overlay = ImageTk.PhotoImage(overlay_pil)
overlay_label = Label(image=overlay).pack()
save_button = Button(root, text="save", command=lambda: save_img(overlay_pil)).pack()
def open():
global image
root.filename = filedialog.askopenfilename(
initialdir="./Radiographs",
title="Radiographs directory",
)
#my_label = Label(root, text=root.filename).pack()
image_pil = Image.open(root.filename).convert("L")
image = ImageTk.PhotoImage(image_pil.resize((512, 256)))
image_label = Label(image=image).pack()
predict_button = Button(root, text="predict", command=lambda: predict(image_pil)).pack()
open_button = Button(root, text="open file", command=open).pack()
root.mainloop()