Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
htylab committed Jun 16, 2024
1 parent 421b027 commit 47c7c55
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 72 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __pycache__/
models/
output/
sample/
backup/
*.py[cod]
*$py.class
.ipynb_checkpoints/
Expand All @@ -13,6 +14,7 @@ sample/
.Python
ignore/
build/

develop-eggs/
dist/
downloads/
Expand Down
33 changes: 31 additions & 2 deletions tigerhx/guitool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import os
import numpy as np
from glob import glob
from os.path import basename, join
from os.path import basename, join, isfile
from scipy.io import savemat, loadmat
import nibabel as nib
import onnxruntime
from tkinter import simpledialog
import tkinter as tk

from tigerhx import lib_tool

nib.Nifti1Header.quaternion_threshold = -100

Expand Down Expand Up @@ -154,3 +154,32 @@ def run_program_gui_interaction(model_path, log_box, root):
slice_select.append(aha4_start)
return files, slice_select

def init_app(application_path):
model_path = join(application_path, 'models')
output_path = join(application_path, 'output')
sample_path = join(application_path, 'sample')
os.makedirs(model_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
os.makedirs(sample_path, exist_ok=True)

model_server = 'https://github.com/htylab/tigerhx/releases/download/modelhub/'

default_models = ['cine4d_v0001_xyz_mms12.onnx',
'cine4d_v0002_xyz_mms12acdc.onnx',
'cine4d_v0003_xy_mms12acdc.onnx']

for m0 in default_models:
model_file = join(model_path, m0)
if not isfile(model_file):
try:
print(f'Downloading model files....')
model_url = model_server + m0
print(model_url, model_file)
lib_tool.download(model_url, model_file)
download_ok = True
print('Download finished...')
except:
download_ok = False

if not download_ok:
raise ValueError('Server error. Please check the model name or internet connection.')
143 changes: 73 additions & 70 deletions tigerhx/tigercinegui.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,37 @@
import tkinter as tk
from tkinter import ttk, simpledialog
from tkinter import ttk
import os
import sys
import threading
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from guitool import *
from os.path import join, isfile
from os.path import join, isfile, basename
from skimage.transform import resize
from tigerhx import lib_tool
import numpy as np
import glob
from scipy.io import loadmat, savemat

# determine if application is a script file or frozen exe
# Determine if the application is a script file or frozen exe
if getattr(sys, 'frozen', False):
application_path = os.path.dirname(sys.executable)
elif __file__:
application_path = os.path.dirname(os.path.abspath(__file__))


print(application_path)

model_path = join(application_path, 'models')
output_path = join(application_path, 'output')
sample_path = join(application_path, 'sample')
os.makedirs(model_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
os.makedirs(sample_path, exist_ok=True)

model_server = 'https://github.com/htylab/tigerhx/releases/download/modelhub/'

default_models = ['cine4d_v0001_xyz_mms12.onnx',
'cine4d_v0002_xyz_mms12acdc.onnx',
'cine4d_v0003_xy_mms12acdc.onnx']



for m0 in default_models:

model_file = join(model_path, m0)

if not isfile(model_file):

try:
print(f'Downloading model files....')
model_url = model_server + m0
print(model_url, model_file)
lib_tool.download(model_url, model_file)
download_ok = True
print('Download finished...')
except:
download_ok = False

if not download_ok:
raise ValueError('Server error. Please check the model name or internet connection.')


init_app(application_path)

# Global variables
log_box = None
root = None
progress_bar = None
display_type_combo = None
data = None # Ensure data is globally accessible
fig, ax = None, None
canvas = None
im = None

def on_go():
global progress_bar
Expand Down Expand Up @@ -120,12 +92,15 @@ def process_files_multithreaded(files, slice_select, model_ff):
root.after(0, update_mat_listbox)


root.after(0, lambda: progress_bar.pack_forget())
#root.after(0, lambda: progress_bar.pack_forget())
log_message(log_box, f'All job finished.........')
progress_bar['value'] = 0
root.update_idletasks() # Ensure the GUI updates
root.after(0, update_mat_listbox)

global seg
def on_mat_select(event):
global seg
global seg, data
widget = event.widget
selection = widget.curselection()
if selection:
Expand All @@ -134,30 +109,52 @@ def on_mat_select(event):
mat_path = os.path.join(output_path, selected_mat)
try:
data = loadmat(mat_path)
if 'Seg' in data:
seg = data['Seg']
selected_type = display_type_combo.get()
if selected_type in data:
seg = data[selected_type]
log_message(log_box, f"Showing {selected_mat}")
log_message(log_box, f"Seg matrix size: {seg.shape}")
log_message(log_box, f"{selected_type} matrix size: {seg.shape}")
show_montage(seg)
update_time_slider(seg) # Adapt the range of the time points
else:
log_message(log_box, f"'Seg' not found in {selected_mat}")
log_message(log_box, f"'{selected_type}' not found in {selected_mat}")

if 'model' in data:
model_name = data['model'][0] #from .mat file , the string stored into a cell array
log_message(log_box, f"Predicted using {model_name}")
except Exception as e:
log_message(log_box, f"An error occurred: {e}")

def on_display_type_change(event):
global seg, data
if seg is not None and data is not None:
selected_type = display_type_combo.get()
if selected_type in data:
seg = data[selected_type]
show_montage(seg)
update_time_slider(seg) # Adapt the range of the time points

def show_montage(emp, time_frame=0):
global fig, ax, canvas, im
plt.close('all') # Close all previous figures

# Initialize the figure and axes if they do not exist
if fig is None or ax is None:
fig, ax = plt.subplots(figsize=(4, 6), dpi=100)
ax.axis('off') # Hide the axes
ax = fig.add_axes([0, 0, 1, 1]) # Remove all margins and padding
ax.set_facecolor('black')
fig.set_facecolor('black')

# Clear previous canvas content
for widget in canvas_frame.winfo_children():
widget.destroy()

# Number of slices in the z-dimension
num_slices = emp.shape[2]

if len(emp.shape) == 3: emp = emp[..., None]

# Determine grid size for the mosaic to match the aspect ratio 400:600
slice_shape = emp[:, :, 0, time_frame].shape
aspect_ratio = 600 / 400
Expand Down Expand Up @@ -191,12 +188,11 @@ def show_montage(emp, time_frame=0):
# Resize the mosaic to 400x600
mosaic_resized = resize(padded_mosaic, (600, 400), anti_aliasing=True)

fig = plt.figure(figsize=(4, 6), dpi=100)
ax = fig.add_axes([0, 0, 1, 1]) # Remove all margins and padding
ax.imshow(mosaic_resized)
ax.axis('off')
ax.set_facecolor('black')
fig.set_facecolor('black')
if im is None:
im = ax.imshow(mosaic_resized)
else:
im.set_data(mosaic_resized)
im.set_clim(vmin=mosaic_resized.min(), vmax=mosaic_resized.max())

canvas = FigureCanvasTkAgg(fig, master=canvas_frame)
canvas.get_tk_widget().pack(fill=tk.BOTH, expand=False)
Expand All @@ -207,6 +203,7 @@ def update_montage(time_frame):
show_montage(seg, time_frame)

def update_time_slider(emp):
if len(emp.shape) == 3: emp = emp[..., None]
max_time_frame = emp.shape[3] - 1 # Get the maximum time frame
time_slider.config(to=max_time_frame) # Update the slider's range
time_slider.set(0) # Set initial value to 0
Expand Down Expand Up @@ -238,9 +235,8 @@ def list_and_log_sample_files(sample_dir):
root.title("TigerHx GUI")

# Set default font size
default_font = ("Arial", 12)
default_font = ("Arial", 11)

# Apply default font to all widgets
root.option_add("*Font", default_font)

# Adjust the size of the main window
Expand Down Expand Up @@ -273,6 +269,7 @@ def list_and_log_sample_files(sample_dir):
# Create a "Go" button
go_button = tk.Button(combo_frame, text="Go", command=on_go)
go_button.pack(side=tk.LEFT, padx=5)

# Create a log box to display messages
log_frame = tk.Frame(frame)
log_frame.pack(padx=10, pady=10)
Expand All @@ -284,12 +281,33 @@ def list_and_log_sample_files(sample_dir):
log_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
log_box.config(yscrollcommand=log_scrollbar.set)

# Create a frame for the listbox and display type combo box
listbox_frame = tk.Frame(frame)
listbox_frame.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)

# Create a frame for the display type label and combo box
display_type_frame = tk.Frame(listbox_frame)
display_type_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)

# Create a label for the display type combo box
display_type_label = tk.Label(display_type_frame, text="Figure type")
display_type_label.pack(side=tk.LEFT)

# Create a combo box for selecting display type
display_types = ['Seg', 'input', 'LV', 'LVM', 'RV']
display_type_combo = ttk.Combobox(display_type_frame, values=display_types, width=10)
display_type_combo.pack(side=tk.LEFT, padx=5)
display_type_combo.current(0) # Set default display type to 'Seg'
display_type_combo.bind("<<ComboboxSelected>>", on_display_type_change)

# Create a listbox for the .mat files
mat_listbox = tk.Listbox(listbox_frame, width=40, height=10)
mat_listbox.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
mat_listbox.bind('<<ListboxSelect>>', on_mat_select)

# Create a progress bar
progress_bar = ttk.Progressbar(frame, mode='determinate')
# Don't pack the progress bar yet, only when needed
progress_bar.pack(side=tk.TOP, fill=tk.BOTH)

# Create a frame for the canvas and slider
canvas_slider_frame = tk.Frame(root)
Expand All @@ -308,26 +326,11 @@ def list_and_log_sample_files(sample_dir):
command=lambda val: update_montage(int(val)))
time_slider.pack(side=tk.RIGHT, fill=tk.Y)

# Create a listbox for the .mat files
mat_listbox_frame = tk.Frame(frame)
mat_listbox_frame.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)

mat_listbox_label = tk.Label(mat_listbox_frame, text="Prediction Files")
mat_listbox_label.pack()

mat_listbox = tk.Listbox(mat_listbox_frame, width=40, height=10)
mat_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
mat_listbox.bind('<<ListboxSelect>>', on_mat_select)

mat_scrollbar = tk.Scrollbar(mat_listbox_frame, orient=tk.VERTICAL)
mat_scrollbar.config(command=mat_listbox.yview)
mat_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
mat_listbox.config(yscrollcommand=mat_scrollbar.set)

# Initial update of the .mat listbox
update_mat_listbox()

list_and_log_sample_files('./sample')

# Run the application
root.mainloop()

0 comments on commit 47c7c55

Please sign in to comment.