-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
1,439 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Experiment configuration file | ||
Extended from config file from original PANet Repository | ||
""" | ||
import glob | ||
import itertools | ||
import os | ||
import sacred | ||
from sacred import Experiment | ||
from sacred.observers import FileStorageObserver | ||
from sacred.utils import apply_backspaces_and_linefeeds | ||
from utils import * | ||
|
||
sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False | ||
sacred.SETTINGS.CAPTURE_MODE = 'no' | ||
|
||
ex = Experiment("QNet") | ||
ex.captured_out_filter = apply_backspaces_and_linefeeds | ||
|
||
###### Set up source folder ###### | ||
source_folders = ['.', './dataloaders', './models', './utils'] | ||
sources_to_save = list(itertools.chain.from_iterable( | ||
[glob.glob(f'{folder}/*.py') for folder in source_folders])) | ||
for source_file in sources_to_save: | ||
ex.add_source_file(source_file) | ||
|
||
|
||
@ex.config | ||
def cfg(): | ||
"""Default configurations""" | ||
seed = 2021 | ||
gpu_id = 0 | ||
num_workers = 0 # 0 for debugging. | ||
mode = 'train' | ||
|
||
## dataset | ||
dataset = 'CHAOST2' # i.e. abdominal MRI - 'CHAOST2'; cardiac MRI - CMR | ||
exclude_label = None # None, for not excluding test labels; | ||
# 1 for Liver, 2 for RK, 3 for LK, 4 for Spleen in 'CHAOST2' | ||
if dataset == 'CMR': | ||
n_sv = 1000 | ||
else: | ||
n_sv = 5000 | ||
min_size = 200 | ||
max_slices = 3 | ||
use_gt = False # True - use ground truth as training label, False - use supervoxel as training label | ||
eval_fold = 0 # (0-4) for 5-fold cross-validation | ||
test_label = [1, 4] # for evaluation | ||
supp_idx = 0 # choose which case as the support set for evaluation, (0-4) for 'CHAOST2', (0-7) for 'CMR' | ||
n_part = 3 # for evaluation, i.e. 3 chunks | ||
|
||
## training | ||
n_steps = 1000 | ||
batch_size = 1 | ||
n_shot = 1 | ||
n_way = 1 | ||
n_query = 1 | ||
lr_step_gamma = 0.95 | ||
bg_wt = 0.1 | ||
t_loss_scaler = 0.0 | ||
ignore_label = 255 | ||
print_interval = 100 | ||
save_snapshot_every = 1000 | ||
max_iters_per_load = 1000 # epoch size, interval for reloading the dataset | ||
|
||
# Network | ||
# reload_model_path = '/home/SQQ/fsmis/ADNet/runs/ADNet_train_CHAOST2_cv0/1/snapshots/1000.pth' | ||
reload_model_path = None | ||
|
||
optim_type = 'sgd' | ||
optim = { | ||
'lr': 1e-3, | ||
'momentum': 0.9, | ||
'weight_decay': 0.0005, | ||
} | ||
|
||
exp_str = '_'.join( | ||
[mode] | ||
+ [dataset, ] | ||
+ [f'cv{eval_fold}']) | ||
|
||
path = { | ||
'log_dir': './runs', | ||
'CHAOST2': {'data_dir': './data/CHAOST2'}, | ||
'SABS': {'data_dir': './data/SABS'}, | ||
'CMR': {'data_dir': './data/CMR'}, | ||
} | ||
|
||
|
||
@ex.config_hook | ||
def add_observer(config, command_name, logger): | ||
"""A hook fucntion to add observer""" | ||
exp_name = f'{ex.path}_{config["exp_str"]}' | ||
observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name)) | ||
ex.observers.append(observer) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
Dataset Specifics | ||
Extended from ADNet code by Hansen et al. | ||
""" | ||
|
||
import torch | ||
import random | ||
|
||
|
||
def get_label_names(dataset): | ||
label_names = {} | ||
if dataset == 'CMR': | ||
label_names[0] = 'BG' | ||
label_names[1] = 'LV-MYO' | ||
label_names[2] = 'LV-BP' | ||
label_names[3] = 'RV' | ||
|
||
elif dataset == 'CHAOST2': | ||
label_names[0] = 'BG' | ||
label_names[1] = 'LIVER' | ||
label_names[2] = 'RK' | ||
label_names[3] = 'LK' | ||
label_names[4] = 'SPLEEN' | ||
elif dataset == 'SABS': | ||
label_names[0] = 'BG' | ||
label_names[1] = 'SPLEEN' | ||
label_names[2] = 'RK' | ||
label_names[3] = 'LK' | ||
label_names[4] = 'GALLBLADDER' | ||
label_names[5] = 'ESOPHAGUS' | ||
label_names[6] = 'LIVER' | ||
label_names[7] = 'STOMACH' | ||
label_names[8] = 'AORTA' | ||
label_names[9] = 'IVC' # Inferior vena cava | ||
label_names[10] = 'PS_VEIN' # portal vein and splenic vein | ||
label_names[11] = 'PANCREAS' | ||
label_names[12] = 'AG_R' # right adrenal gland | ||
label_names[13] = 'AG_L' # left adrenal gland | ||
|
||
return label_names | ||
|
||
|
||
def get_folds(dataset): | ||
FOLD = {} | ||
if dataset == 'CMR': | ||
FOLD[0] = set(range(0, 8)) | ||
FOLD[1] = set(range(7, 15)) | ||
FOLD[2] = set(range(14, 22)) | ||
FOLD[3] = set(range(21, 29)) | ||
FOLD[4] = set(range(28, 35)) | ||
FOLD[4].update([0]) | ||
return FOLD | ||
|
||
elif dataset == 'CHAOST2': | ||
FOLD[0] = set(range(0, 5)) | ||
FOLD[1] = set(range(4, 9)) | ||
FOLD[2] = set(range(8, 13)) | ||
FOLD[3] = set(range(12, 17)) | ||
FOLD[4] = set(range(16, 20)) | ||
FOLD[4].update([0]) | ||
return FOLD | ||
elif dataset == 'SABS': | ||
FOLD[0] = set(range(0, 7)) | ||
FOLD[1] = set(range(6, 13)) | ||
FOLD[2] = set(range(12, 19)) | ||
FOLD[3] = set(range(18, 25)) | ||
FOLD[4] = set(range(24, 30)) | ||
FOLD[4].update([0]) | ||
return FOLD | ||
else: | ||
raise ValueError(f'Dataset: {dataset} not found') | ||
|
||
|
||
def sample_xy(spr, k=0, b=215): | ||
_, h, v = torch.where(spr) | ||
|
||
if len(h) == 0 or len(v) == 0: | ||
horizontal = 0 | ||
vertical = 0 | ||
else: | ||
|
||
h_min = min(h) | ||
h_max = max(h) | ||
if b > (h_max - h_min): | ||
kk = min(k, int((h_max - h_min) / 2)) | ||
horizontal = random.randint(max(h_max - b - kk, 0), min(h_min + kk, 256 - b - 1)) | ||
else: | ||
kk = min(k, int(b / 2)) | ||
horizontal = random.randint(max(h_min - kk, 0), min(h_max - b + kk, 256 - b - 1)) | ||
|
||
v_min = min(v) | ||
v_max = max(v) | ||
if b > (v_max - v_min): | ||
kk = min(k, int((v_max - v_min) / 2)) | ||
vertical = random.randint(max(v_max - b - kk, 0), min(v_min + kk, 256 - b - 1)) | ||
else: | ||
kk = min(k, int(b / 2)) | ||
vertical = random.randint(max(v_min - kk, 0), min(v_max - b + kk, 256 - b - 1)) | ||
|
||
return horizontal, vertical |
Oops, something went wrong.