Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow batch eval/inference flexibility #80

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions dreem/inference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,35 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""
eval_cfg = Config(cfg)

if "checkpoints" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg.keys():
try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")

logger.info(f"Pod Index: {index}")

checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)

hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)

if eval_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
Comment on lines +29 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor dictionary key check and improve error handling.

The handling of the POD_INDEX environment variable and user input is robust, enhancing the user experience by providing clear error messages. However, the check for "batch_config" can be simplified by removing .keys() for a more Pythonic approach.

Apply this diff to refactor the dictionary key check:

- if "batch_config" in cfg.keys():
+ if "batch_config" in cfg:

The changes are approved.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# update with parameters for batch train job
if "batch_config" in cfg.keys():
try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")
logger.info(f"Pod Index: {index}")
checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
if eval_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
# update with parameters for batch train job
if "batch_config" in cfg:
try:
index = int(os.environ["POD_INDEX"])
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
if eval_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
Tools
Ruff

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

else:
checkpoint = eval_cfg.cfg.ckpt_path
hparams = {}

checkpoint = eval_cfg.cfg.ckpt_path

logger.info(f"Testing model saved at {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)

model.tracker_cfg = eval_cfg.cfg.tracker
model.tracker = Tracker(**model.tracker_cfg)

logger.info(f"Using the following tracker:")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extraneous f-string prefix.

The f-string used in the logging statement does not contain any placeholders, making the f prefix unnecessary.

Remove the extraneous f prefix to clean up the code:

- logger.info(f"Using the following tracker:")
+ logger.info("Using the following tracker:")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.info(f"Using the following tracker:")
logger.info("Using the following tracker:")
Tools
Ruff

56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)


print(model.tracker)
model.metrics["test"] = eval_cfg.cfg.runner.metrics.test
model.persistent_tracking["test"] = eval_cfg.cfg.tracker.get(
Expand Down
34 changes: 23 additions & 11 deletions dreem/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,35 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""
pred_cfg = Config(cfg)

if "checkpoints" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg.keys():
Comment on lines +99 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify dictionary key check.

The check for "batch_config" in the configuration dictionary can be simplified by removing .keys().

Apply this diff to simplify the check:

-if "batch_config" in cfg.keys():
+if "batch_config" in cfg:
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# update with parameters for batch train job
if "batch_config" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg:
Tools
Ruff

100-100: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

try:
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
index = input("Pod Index Not found! Please choose a pod index: ")

logger.info(f"Pod Index: {index}")

checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
except KeyError as e:
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
Comment on lines +103 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle missing environment variable more robustly.

The handling of a missing POD_INDEX environment variable could be improved by providing a default value or a more robust error handling mechanism, rather than prompting the user.

Consider setting a default index or handling the error in a way that does not require user interaction, which might not be feasible in batch processes.


hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)

if pred_cfg.set_hparams(hparams):
logger.info("Updated the following hparams to the following values")
logger.info(hparams)
Comment on lines +108 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validate and log hyperparameter updates.

The code reads hyperparameters from a CSV and updates the configuration. Ensure that the hyperparameters are validated before applying them to avoid runtime errors.

Add validation for the hyperparameters read from the CSV to ensure they meet expected formats and constraints.

else:
checkpoint = pred_cfg.cfg.ckpt_path
hparams = {}

checkpoint = pred_cfg.cfg.ckpt_path

logger.info(f"Running inference with model from {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)

tracker_cfg = pred_cfg.get_tracker_cfg()
logger.info("Updating tracker hparams")

model.tracker_cfg = tracker_cfg
model.tracker = Tracker(**model.tracker_cfg)

logger.info(f"Using the following tracker:")
logger.info(model.tracker)

Expand All @@ -124,12 +134,14 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
os.makedirs(outdir, exist_ok=True)

for label_file, vid_file in zip(labels_files, vid_files):
logger.info(f"Tracking {label_file} - {vid_file}...")
dataset = pred_cfg.get_dataset(
label_files=[label_file], vid_files=[vid_file], mode="test"
)
dataloader = pred_cfg.get_dataloader(dataset, mode="test")
preds = track(model, trainer, dataloader)
outpath = os.path.join(outdir, f"{Path(label_file).stem}.dreem_inference.slp")
logger.info(f"Saving results to {outpath}...")
preds.save(outpath)

return preds
Expand Down
40 changes: 30 additions & 10 deletions dreem/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def get_data_paths(self, data_cfg: dict) -> tuple[list[str], list[str]]:
labels_path = f"{dir_cfg.path}/*{labels_suff}"
vid_path = f"{dir_cfg.path}/*{vid_suff}"
logger.debug(f"Searching for labels matching {labels_path}")
label_files = glob.glob(labels_path)
label_files = sorted(glob.glob(labels_path))
logger.debug(f"Searching for videos matching {vid_path}")
vid_files = glob.glob(vid_path)
vid_files = sorted(glob.glob(vid_path))
logger.debug(f"Found {len(label_files)} labels and {len(vid_files)} videos")

else:
Expand All @@ -197,7 +197,7 @@ def get_dataset(
mode: str,
label_files: list[str] | None = None,
vid_files: list[str | list[str]] = None,
) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset":
) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None:
"""Getter for datasets.

Args:
Expand Down Expand Up @@ -230,33 +230,37 @@ def get_dataset(
dataset_params.slp_files = label_files
if vid_files is not None:
dataset_params.video_files = vid_files
return SleapDataset(**dataset_params)
dataset = SleapDataset(**dataset_params)

elif "tracks" in dataset_params or "source" in dataset_params:
if label_files is not None:
dataset_params.tracks = label_files
if vid_files is not None:
dataset_params.videos = vid_files
return MicroscopyDataset(**dataset_params)
dataset = MicroscopyDataset(**dataset_params)

elif "raw_images" in dataset_params:
if label_files is not None:
dataset_params.gt_images = label_files
if vid_files is not None:
dataset_params.raw_images = vid_files
return CellTrackingDataset(**dataset_params)
dataset = CellTrackingDataset(**dataset_params)

else:
raise ValueError(
"Could not resolve dataset type from Config! Please include \
either `slp_files` or `tracks`/`source`"
)
if len(dataset) == 0:
logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None")
return None
return dataset

def get_dataloader(
self,
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset",
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None,
mode: str,
) -> torch.utils.data.DataLoader:
) -> torch.utils.data.DataLoader | None:
"""Getter for dataloader.

Args:
Expand All @@ -267,6 +271,15 @@ def get_dataloader(
Returns:
A torch dataloader for `dataset` with parameters configured as specified
"""
if dataset is None:
logger.warn(f"{mode} dataset passed was `None`! Returning `None`")
return None

elif len(dataset) == 0:
logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning `None`")
return None


if mode.lower() == "train":
dataloader_params = self.cfg.dataloader.train_dataloader
elif mode.lower() == "val":
Expand All @@ -284,14 +297,21 @@ def get_dataloader(
else:
pin_memory = False

return torch.utils.data.DataLoader(
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
collate_fn=dataset.no_batching_fn,
**dataloader_params,
)

if len(dataloader) == 0:
logger.warn(
f"Length of {mode} dataloader is {len(dataloader)}! Returning `None`"
)
return None
return dataloader

def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer:
"""Getter for optimizer.

Expand Down Expand Up @@ -396,7 +416,7 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}"
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-final-{{{metric}}}"
checkpointers.append(checkpointer)
return checkpointers

Expand Down
1 change: 1 addition & 0 deletions dreem/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""
temp_lookup = self.lookup
N = times.shape[0]
times = times / times.max()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add safeguard against division by zero in normalization step.

The normalization step times = times / times.max() could potentially lead to a division by zero error if times.max() is zero. It's crucial to add a small epsilon value to avoid this issue.

Consider modifying the normalization step to include a small epsilon value:

- times = times / times.max()
+ times = times / (times.max() + 1e-6)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
times = times / times.max()
times = times / (times.max() + 1e-6)


left_ind, right_ind, left_weight, right_weight = self._compute_weights(times)

Expand Down
13 changes: 10 additions & 3 deletions dreem/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def on_test_epoch_end(self):
avg_result = results_df[key].mean()
results_file.attrs.create(key, avg_result)
for i, (metrics, frames) in enumerate(zip(metrics_dict, preds)):
vid_name = frames[0].vid_name.split("/")[-1].split(".")[0]
vid_name = frames[0].vid_name.split("/")[-1]
vid_group = results_file.require_group(vid_name)
clip_group = vid_group.require_group(f"clip_{i}")
for key, val in metrics.items():
Expand All @@ -310,11 +310,18 @@ def on_test_epoch_end(self):
if metrics.get("num_switches", 0) > 0:
_ = frame.to_h5(
clip_group,
frame.get_gt_track_ids().cpu().numpy(),
[
instance.gt_track_id.item()
for instance in frame.instances
],
save={"crop": True, "features": True, "embeddings": True},
)
else:
_ = frame.to_h5(
clip_group, frame.get_gt_track_ids().cpu().numpy()
clip_group,
[
instance.gt_track_id.item()
for instance in frame.instances
],
Comment on lines +313 to +325
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review changes to track ID handling.

The new implementation uses a list comprehension to extract gt_track_id. Ensure that this change does not affect the expected data structure in downstream processes.

Verify that the new list structure of track IDs is compatible with all downstream processes that consume this data.

)
self.test_results = {"metrics": [], "preds": [], "save_path": fname}
Loading