Skip to content

Commit

Permalink
allow batch eval/inference flexibility rather than just different mod…
Browse files Browse the repository at this point in the history
…el checkpoints
  • Loading branch information
aaprasad committed Aug 16, 2024
1 parent 2af0dd5 commit 9eddead
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 34 deletions.
31 changes: 21 additions & 10 deletions dreem/inference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,35 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""
eval_cfg = Config(cfg)

if "checkpoints" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg.keys():
try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")

logger.info(f"Pod Index: {index}")

checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)

hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)

if eval_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
else:
checkpoint = eval_cfg.cfg.ckpt_path
hparams = {}

checkpoint = eval_cfg.cfg.ckpt_path

logger.info(f"Testing model saved at {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)

model.tracker_cfg = eval_cfg.cfg.tracker
model.tracker = Tracker(**model.tracker_cfg)

logger.info(f"Using the following tracker:")

print(model.tracker)
model.metrics["test"] = eval_cfg.cfg.runner.metrics.test
model.persistent_tracking["test"] = eval_cfg.cfg.tracker.get(
Expand Down
34 changes: 23 additions & 11 deletions dreem/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,35 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""
pred_cfg = Config(cfg)

if "checkpoints" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg.keys():
try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")

logger.info(f"Pod Index: {index}")

checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)

hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)

if pred_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
else:
checkpoint = pred_cfg.cfg.ckpt_path
hparams = {}

checkpoint = pred_cfg.cfg.ckpt_path

logger.info(f"Running inference with model from {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)

tracker_cfg = pred_cfg.get_tracker_cfg()
logger.info("Updating tracker hparams")

model.tracker_cfg = tracker_cfg
model.tracker = Tracker(**model.tracker_cfg)

logger.info(f"Using the following tracker:")
logger.info(model.tracker)

Expand All @@ -124,12 +134,14 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
os.makedirs(outdir, exist_ok=True)

for label_file, vid_file in zip(labels_files, vid_files):
logger.info(f"Tracking {label_file} - {vid_file}...")
dataset = pred_cfg.get_dataset(
label_files=[label_file], vid_files=[vid_file], mode="test"
)
dataloader = pred_cfg.get_dataloader(dataset, mode="test")
preds = track(model, trainer, dataloader)
outpath = os.path.join(outdir, f"{Path(label_file).stem}.dreem_inference.slp")
logger.info(f"Saving results to {outpath}...")
preds.save(outpath)

return preds
Expand Down
40 changes: 30 additions & 10 deletions dreem/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def get_data_paths(self, data_cfg: dict) -> tuple[list[str], list[str]]:
labels_path = f"{dir_cfg.path}/*{labels_suff}"
vid_path = f"{dir_cfg.path}/*{vid_suff}"
logger.debug(f"Searching for labels matching {labels_path}")
label_files = glob.glob(labels_path)
label_files = sorted(glob.glob(labels_path))
logger.debug(f"Searching for videos matching {vid_path}")
vid_files = glob.glob(vid_path)
vid_files = sorted(glob.glob(vid_path))
logger.debug(f"Found {len(label_files)} labels and {len(vid_files)} videos")

else:
Expand All @@ -197,7 +197,7 @@ def get_dataset(
mode: str,
label_files: list[str] | None = None,
vid_files: list[str | list[str]] = None,
) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset":
) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None:
"""Getter for datasets.
Args:
Expand Down Expand Up @@ -230,33 +230,37 @@ def get_dataset(
dataset_params.slp_files = label_files
if vid_files is not None:
dataset_params.video_files = vid_files
return SleapDataset(**dataset_params)
dataset = SleapDataset(**dataset_params)

elif "tracks" in dataset_params or "source" in dataset_params:
if label_files is not None:
dataset_params.tracks = label_files
if vid_files is not None:
dataset_params.videos = vid_files
return MicroscopyDataset(**dataset_params)
dataset = MicroscopyDataset(**dataset_params)

elif "raw_images" in dataset_params:
if label_files is not None:
dataset_params.gt_images = label_files
if vid_files is not None:
dataset_params.raw_images = vid_files
return CellTrackingDataset(**dataset_params)
dataset = CellTrackingDataset(**dataset_params)

else:
raise ValueError(
"Could not resolve dataset type from Config! Please include \
either `slp_files` or `tracks`/`source`"
)
if len(dataset) == 0:
logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None")
return None
return dataset

def get_dataloader(
self,
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset",
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None,
mode: str,
) -> torch.utils.data.DataLoader:
) -> torch.utils.data.DataLoader | None:
"""Getter for dataloader.
Args:
Expand All @@ -267,6 +271,15 @@ def get_dataloader(
Returns:
A torch dataloader for `dataset` with parameters configured as specified
"""
if dataset is None:
logger.warn(f"{mode} dataset passed was `None`! Returning `None`")
return None

elif len(dataset) == 0:
logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning `None`")
return None


if mode.lower() == "train":
dataloader_params = self.cfg.dataloader.train_dataloader
elif mode.lower() == "val":
Expand All @@ -284,14 +297,21 @@ def get_dataloader(
else:
pin_memory = False

return torch.utils.data.DataLoader(
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
collate_fn=dataset.no_batching_fn,
**dataloader_params,
)

if len(dataloader) == 0:
logger.warn(
f"Length of {mode} dataloader is {len(dataloader)}! Returning `None`"
)
return None
return dataloader

def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer:
"""Getter for optimizer.
Expand Down Expand Up @@ -396,7 +416,7 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}"
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-final-{{{metric}}}"
checkpointers.append(checkpointer)
return checkpointers

Expand Down
1 change: 1 addition & 0 deletions dreem/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""
temp_lookup = self.lookup
N = times.shape[0]
times = times / times.max()

left_ind, right_ind, left_weight, right_weight = self._compute_weights(times)

Expand Down
13 changes: 10 additions & 3 deletions dreem/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def on_test_epoch_end(self):
avg_result = results_df[key].mean()
results_file.attrs.create(key, avg_result)
for i, (metrics, frames) in enumerate(zip(metrics_dict, preds)):
vid_name = frames[0].vid_name.split("/")[-1].split(".")[0]
vid_name = frames[0].vid_name.split("/")[-1]
vid_group = results_file.require_group(vid_name)
clip_group = vid_group.require_group(f"clip_{i}")
for key, val in metrics.items():
Expand All @@ -310,11 +310,18 @@ def on_test_epoch_end(self):
if metrics.get("num_switches", 0) > 0:
_ = frame.to_h5(
clip_group,
frame.get_gt_track_ids().cpu().numpy(),
[
instance.gt_track_id.item()
for instance in frame.instances
],
save={"crop": True, "features": True, "embeddings": True},
)
else:
_ = frame.to_h5(
clip_group, frame.get_gt_track_ids().cpu().numpy()
clip_group,
[
instance.gt_track_id.item()
for instance in frame.instances
],
)
self.test_results = {"metrics": [], "preds": [], "save_path": fname}

0 comments on commit 9eddead

Please sign in to comment.