-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow batch eval/inference flexibility #80
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -26,24 +26,35 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: | |||||
""" | ||||||
eval_cfg = Config(cfg) | ||||||
|
||||||
if "checkpoints" in cfg.keys(): | ||||||
# update with parameters for batch train job | ||||||
if "batch_config" in cfg.keys(): | ||||||
try: | ||||||
index = int(os.environ["POD_INDEX"]) | ||||||
# For testing without deploying a job on runai | ||||||
except KeyError: | ||||||
index = input("Pod Index Not found! Please choose a pod index: ") | ||||||
|
||||||
logger.info(f"Pod Index: {index}") | ||||||
|
||||||
checkpoints = pd.read_csv(cfg.checkpoints) | ||||||
checkpoint = checkpoints.iloc[index] | ||||||
except KeyError as e: | ||||||
index = int( | ||||||
input(f"{e}. Assuming single run!\nPlease input task index to run:") | ||||||
) | ||||||
|
||||||
hparams_df = pd.read_csv(cfg.batch_config) | ||||||
hparams = hparams_df.iloc[index].to_dict() | ||||||
_ = hparams.pop("Unnamed: 0", None) | ||||||
|
||||||
if eval_cfg.set_hparams(hparams): | ||||||
logger.info("Updated the following hparams to the following values") | ||||||
logger.info(hparams) | ||||||
else: | ||||||
checkpoint = eval_cfg.cfg.ckpt_path | ||||||
hparams = {} | ||||||
|
||||||
checkpoint = eval_cfg.cfg.ckpt_path | ||||||
|
||||||
logger.info(f"Testing model saved at {checkpoint}") | ||||||
model = GTRRunner.load_from_checkpoint(checkpoint) | ||||||
|
||||||
model.tracker_cfg = eval_cfg.cfg.tracker | ||||||
model.tracker = Tracker(**model.tracker_cfg) | ||||||
|
||||||
logger.info(f"Using the following tracker:") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove extraneous f-string prefix. The f-string used in the logging statement does not contain any placeholders, making the Remove the extraneous - logger.info(f"Using the following tracker:")
+ logger.info("Using the following tracker:") Committable suggestion
Suggested change
ToolsRuff
|
||||||
|
||||||
print(model.tracker) | ||||||
model.metrics["test"] = eval_cfg.cfg.runner.metrics.test | ||||||
model.persistent_tracking["test"] = eval_cfg.cfg.tracker.get( | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -96,25 +96,35 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: | |||||||||
""" | ||||||||||
pred_cfg = Config(cfg) | ||||||||||
|
||||||||||
if "checkpoints" in cfg.keys(): | ||||||||||
# update with parameters for batch train job | ||||||||||
if "batch_config" in cfg.keys(): | ||||||||||
Comment on lines
+99
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify dictionary key check. The check for Apply this diff to simplify the check: -if "batch_config" in cfg.keys():
+if "batch_config" in cfg: Committable suggestion
Suggested change
ToolsRuff
|
||||||||||
try: | ||||||||||
index = int(os.environ["POD_INDEX"]) | ||||||||||
# For testing without deploying a job on runai | ||||||||||
except KeyError: | ||||||||||
index = input("Pod Index Not found! Please choose a pod index: ") | ||||||||||
|
||||||||||
logger.info(f"Pod Index: {index}") | ||||||||||
|
||||||||||
checkpoints = pd.read_csv(cfg.checkpoints) | ||||||||||
checkpoint = checkpoints.iloc[index] | ||||||||||
except KeyError as e: | ||||||||||
index = int( | ||||||||||
input(f"{e}. Assuming single run!\nPlease input task index to run:") | ||||||||||
) | ||||||||||
Comment on lines
+103
to
+106
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle missing environment variable more robustly. The handling of a missing Consider setting a default index or handling the error in a way that does not require user interaction, which might not be feasible in batch processes. |
||||||||||
|
||||||||||
hparams_df = pd.read_csv(cfg.batch_config) | ||||||||||
hparams = hparams_df.iloc[index].to_dict() | ||||||||||
_ = hparams.pop("Unnamed: 0", None) | ||||||||||
|
||||||||||
if pred_cfg.set_hparams(hparams): | ||||||||||
logger.info("Updated the following hparams to the following values") | ||||||||||
logger.info(hparams) | ||||||||||
Comment on lines
+108
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate and log hyperparameter updates. The code reads hyperparameters from a CSV and updates the configuration. Ensure that the hyperparameters are validated before applying them to avoid runtime errors. Add validation for the hyperparameters read from the CSV to ensure they meet expected formats and constraints. |
||||||||||
else: | ||||||||||
checkpoint = pred_cfg.cfg.ckpt_path | ||||||||||
hparams = {} | ||||||||||
|
||||||||||
checkpoint = pred_cfg.cfg.ckpt_path | ||||||||||
|
||||||||||
logger.info(f"Running inference with model from {checkpoint}") | ||||||||||
model = GTRRunner.load_from_checkpoint(checkpoint) | ||||||||||
|
||||||||||
tracker_cfg = pred_cfg.get_tracker_cfg() | ||||||||||
logger.info("Updating tracker hparams") | ||||||||||
|
||||||||||
model.tracker_cfg = tracker_cfg | ||||||||||
model.tracker = Tracker(**model.tracker_cfg) | ||||||||||
|
||||||||||
logger.info(f"Using the following tracker:") | ||||||||||
logger.info(model.tracker) | ||||||||||
|
||||||||||
|
@@ -124,12 +134,14 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: | |||||||||
os.makedirs(outdir, exist_ok=True) | ||||||||||
|
||||||||||
for label_file, vid_file in zip(labels_files, vid_files): | ||||||||||
logger.info(f"Tracking {label_file} - {vid_file}...") | ||||||||||
dataset = pred_cfg.get_dataset( | ||||||||||
label_files=[label_file], vid_files=[vid_file], mode="test" | ||||||||||
) | ||||||||||
dataloader = pred_cfg.get_dataloader(dataset, mode="test") | ||||||||||
preds = track(model, trainer, dataloader) | ||||||||||
outpath = os.path.join(outdir, f"{Path(label_file).stem}.dreem_inference.slp") | ||||||||||
logger.info(f"Saving results to {outpath}...") | ||||||||||
preds.save(outpath) | ||||||||||
|
||||||||||
return preds | ||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -323,6 +323,7 @@ def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor: | |||||
""" | ||||||
temp_lookup = self.lookup | ||||||
N = times.shape[0] | ||||||
times = times / times.max() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add safeguard against division by zero in normalization step. The normalization step Consider modifying the normalization step to include a small epsilon value: - times = times / times.max()
+ times = times / (times.max() + 1e-6) Committable suggestion
Suggested change
|
||||||
|
||||||
left_ind, right_ind, left_weight, right_weight = self._compute_weights(times) | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -301,7 +301,7 @@ def on_test_epoch_end(self): | |
avg_result = results_df[key].mean() | ||
results_file.attrs.create(key, avg_result) | ||
for i, (metrics, frames) in enumerate(zip(metrics_dict, preds)): | ||
vid_name = frames[0].vid_name.split("/")[-1].split(".")[0] | ||
vid_name = frames[0].vid_name.split("/")[-1] | ||
vid_group = results_file.require_group(vid_name) | ||
clip_group = vid_group.require_group(f"clip_{i}") | ||
for key, val in metrics.items(): | ||
|
@@ -310,11 +310,18 @@ def on_test_epoch_end(self): | |
if metrics.get("num_switches", 0) > 0: | ||
_ = frame.to_h5( | ||
clip_group, | ||
frame.get_gt_track_ids().cpu().numpy(), | ||
[ | ||
instance.gt_track_id.item() | ||
for instance in frame.instances | ||
], | ||
save={"crop": True, "features": True, "embeddings": True}, | ||
) | ||
else: | ||
_ = frame.to_h5( | ||
clip_group, frame.get_gt_track_ids().cpu().numpy() | ||
clip_group, | ||
[ | ||
instance.gt_track_id.item() | ||
for instance in frame.instances | ||
], | ||
Comment on lines
+313
to
+325
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Review changes to track ID handling. The new implementation uses a list comprehension to extract Verify that the new list structure of track IDs is compatible with all downstream processes that consume this data. |
||
) | ||
self.test_results = {"metrics": [], "preds": [], "save_path": fname} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor dictionary key check and improve error handling.
The handling of the
POD_INDEX
environment variable and user input is robust, enhancing the user experience by providing clear error messages. However, the check for"batch_config"
can be simplified by removing.keys()
for a more Pythonic approach.Apply this diff to refactor the dictionary key check:
The changes are approved.
Committable suggestion
Tools
Ruff