Skip to content

Commit

Permalink
handle windows path issue in pose_config.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
sidhulyalkar committed Jan 26, 2024
1 parent f53852d commit 9844747
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions element_deeplabcut/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import inspect
import importlib
import os
from pathlib import Path
from pathlib import Path, PureWindowsPath, PurePosixPath
from element_interface.utils import find_full_path, dict_to_uuid
from .readers import dlc_reader
import yaml
Expand Down Expand Up @@ -243,6 +243,7 @@ class ModelTraining(dj.Computed):

def make(self, key):
import deeplabcut

try:
from deeplabcut.utils.auxiliaryfunctions import (
get_model_folder,
Expand Down Expand Up @@ -288,26 +289,49 @@ def make(self, key):
)
model_train_folder = project_path / model_folder / "train"

# update init_weight
# update init_weights, dataset, and metadataset paths in pose_config.yaml
with open(model_train_folder / "pose_cfg.yaml", "r") as f:
pose_cfg = yaml.safe_load(f)
init_weights_path = Path(pose_cfg["init_weights"])

if "pose_estimation_tensorflow/models/pretrained" in init_weights_path.as_posix():
# handle windows path passed in input
try:
init_weights_windows_path = PureWindowsPath(pose_cfg["init_weights"])
init_weights_path = PurePosixPath(*init_weights_windows_path.parts[2:])
dataset_path = PurePosixPath(PureWindowsPath(pose_cfg["dataset"]))
metadataset_path = PurePosixPath(PureWindowsPath(pose_cfg["metadataset"]))
except ValueError:
init_weights_path = Path(pose_cfg["init_weights"])
dataset_path = Path(pose_cfg["dataset"])
metadataset_path = Path(pose_cfg["metadataset"])

if (
"pose_estimation_tensorflow/models/pretrained"
in init_weights_path.as_posix()
):
# this is the res_net models, construct new path here
init_weights_path = Path(deeplabcut.__path__[0]) / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name
init_weights_path = (
Path(deeplabcut.__path__[0])
/ "pose_estimation_tensorflow/models/pretrained"
/ init_weights_path.name
)
else:
# this is existing snapshot weights, update path here
init_weights_path = model_train_folder / init_weights_path.name

edit_config(
model_train_folder / "pose_cfg.yaml",
{"project_path": project_path.as_posix(),
"init_weights": init_weights_path.as_posix()},
{
"project_path": project_path.as_posix(),
"init_weights": init_weights_path.as_posix(),
"dataset": dataset_path.as_posix(),
"metadataset": metadataset_path.as_posix(),
},
)

# ---- Trigger DLC model training job ----
train_network_input_args = list(inspect.signature(deeplabcut.train_network).parameters)
train_network_input_args = list(
inspect.signature(deeplabcut.train_network).parameters
)
train_network_kwargs = {
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
for k, v in dlc_config.items()
Expand All @@ -334,7 +358,7 @@ def make(self, key):

# update snapshotindex in the config
snapshotindex = snapshots.index(latest_snapshot_file)

dlc_config["snapshotindex"] = snapshotindex
edit_config(
dlc_cfg_filepath,
Expand All @@ -346,3 +370,9 @@ def make(self, key):
)


def _is_windows_path(path):
try:
PureWindowsPath(path)
return True
except ValueError:
return False

0 comments on commit 9844747

Please sign in to comment.