Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
shenqq377 authored Aug 24, 2022
1 parent 0d64895 commit 2c17b37
Show file tree
Hide file tree
Showing 11 changed files with 1,439 additions and 0 deletions.
96 changes: 96 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Experiment configuration file
Extended from config file from original PANet Repository
"""
import glob
import itertools
import os
import sacred
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
from utils import *

sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False
sacred.SETTINGS.CAPTURE_MODE = 'no'

ex = Experiment("QNet")
ex.captured_out_filter = apply_backspaces_and_linefeeds

###### Set up source folder ######
source_folders = ['.', './dataloaders', './models', './utils']
sources_to_save = list(itertools.chain.from_iterable(
[glob.glob(f'{folder}/*.py') for folder in source_folders]))
for source_file in sources_to_save:
ex.add_source_file(source_file)


@ex.config
def cfg():
"""Default configurations"""
seed = 2021
gpu_id = 0
num_workers = 0 # 0 for debugging.
mode = 'train'

## dataset
dataset = 'CHAOST2' # i.e. abdominal MRI - 'CHAOST2'; cardiac MRI - CMR
exclude_label = None # None, for not excluding test labels;
# 1 for Liver, 2 for RK, 3 for LK, 4 for Spleen in 'CHAOST2'
if dataset == 'CMR':
n_sv = 1000
else:
n_sv = 5000
min_size = 200
max_slices = 3
use_gt = False # True - use ground truth as training label, False - use supervoxel as training label
eval_fold = 0 # (0-4) for 5-fold cross-validation
test_label = [1, 4] # for evaluation
supp_idx = 0 # choose which case as the support set for evaluation, (0-4) for 'CHAOST2', (0-7) for 'CMR'
n_part = 3 # for evaluation, i.e. 3 chunks

## training
n_steps = 1000
batch_size = 1
n_shot = 1
n_way = 1
n_query = 1
lr_step_gamma = 0.95
bg_wt = 0.1
t_loss_scaler = 0.0
ignore_label = 255
print_interval = 100
save_snapshot_every = 1000
max_iters_per_load = 1000 # epoch size, interval for reloading the dataset

# Network
# reload_model_path = '/home/SQQ/fsmis/ADNet/runs/ADNet_train_CHAOST2_cv0/1/snapshots/1000.pth'
reload_model_path = None

optim_type = 'sgd'
optim = {
'lr': 1e-3,
'momentum': 0.9,
'weight_decay': 0.0005,
}

exp_str = '_'.join(
[mode]
+ [dataset, ]
+ [f'cv{eval_fold}'])

path = {
'log_dir': './runs',
'CHAOST2': {'data_dir': './data/CHAOST2'},
'SABS': {'data_dir': './data/SABS'},
'CMR': {'data_dir': './data/CMR'},
}


@ex.config_hook
def add_observer(config, command_name, logger):
"""A hook fucntion to add observer"""
exp_name = f'{ex.path}_{config["exp_str"]}'
observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name))
ex.observers.append(observer)
return config
100 changes: 100 additions & 0 deletions dataloaders/dataset_specifics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Dataset Specifics
Extended from ADNet code by Hansen et al.
"""

import torch
import random


def get_label_names(dataset):
label_names = {}
if dataset == 'CMR':
label_names[0] = 'BG'
label_names[1] = 'LV-MYO'
label_names[2] = 'LV-BP'
label_names[3] = 'RV'

elif dataset == 'CHAOST2':
label_names[0] = 'BG'
label_names[1] = 'LIVER'
label_names[2] = 'RK'
label_names[3] = 'LK'
label_names[4] = 'SPLEEN'
elif dataset == 'SABS':
label_names[0] = 'BG'
label_names[1] = 'SPLEEN'
label_names[2] = 'RK'
label_names[3] = 'LK'
label_names[4] = 'GALLBLADDER'
label_names[5] = 'ESOPHAGUS'
label_names[6] = 'LIVER'
label_names[7] = 'STOMACH'
label_names[8] = 'AORTA'
label_names[9] = 'IVC' # Inferior vena cava
label_names[10] = 'PS_VEIN' # portal vein and splenic vein
label_names[11] = 'PANCREAS'
label_names[12] = 'AG_R' # right adrenal gland
label_names[13] = 'AG_L' # left adrenal gland

return label_names


def get_folds(dataset):
FOLD = {}
if dataset == 'CMR':
FOLD[0] = set(range(0, 8))
FOLD[1] = set(range(7, 15))
FOLD[2] = set(range(14, 22))
FOLD[3] = set(range(21, 29))
FOLD[4] = set(range(28, 35))
FOLD[4].update([0])
return FOLD

elif dataset == 'CHAOST2':
FOLD[0] = set(range(0, 5))
FOLD[1] = set(range(4, 9))
FOLD[2] = set(range(8, 13))
FOLD[3] = set(range(12, 17))
FOLD[4] = set(range(16, 20))
FOLD[4].update([0])
return FOLD
elif dataset == 'SABS':
FOLD[0] = set(range(0, 7))
FOLD[1] = set(range(6, 13))
FOLD[2] = set(range(12, 19))
FOLD[3] = set(range(18, 25))
FOLD[4] = set(range(24, 30))
FOLD[4].update([0])
return FOLD
else:
raise ValueError(f'Dataset: {dataset} not found')


def sample_xy(spr, k=0, b=215):
_, h, v = torch.where(spr)

if len(h) == 0 or len(v) == 0:
horizontal = 0
vertical = 0
else:

h_min = min(h)
h_max = max(h)
if b > (h_max - h_min):
kk = min(k, int((h_max - h_min) / 2))
horizontal = random.randint(max(h_max - b - kk, 0), min(h_min + kk, 256 - b - 1))
else:
kk = min(k, int(b / 2))
horizontal = random.randint(max(h_min - kk, 0), min(h_max - b + kk, 256 - b - 1))

v_min = min(v)
v_max = max(v)
if b > (v_max - v_min):
kk = min(k, int((v_max - v_min) / 2))
vertical = random.randint(max(v_max - b - kk, 0), min(v_min + kk, 256 - b - 1))
else:
kk = min(k, int(b / 2))
vertical = random.randint(max(v_min - kk, 0), min(v_max - b + kk, 256 - b - 1))

return horizontal, vertical
Loading

0 comments on commit 2c17b37

Please sign in to comment.