Skip to content

Commit

Permalink
Update debug session ID and refactor weights directory structure
Browse files Browse the repository at this point in the history
  • Loading branch information
cxnt committed Feb 3, 2025
1 parent ed0c3c7 commit 690050b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
2 changes: 1 addition & 1 deletion train/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
app_session_id = sly.io.env.task_id()
root_model_checkpoint_dir = sly.app.get_synced_data_dir()
else:
app_session_id = 67078 # for debug
app_session_id = 215 # for debug
root_model_checkpoint_dir = os.path.join(app_root_directory, "runs")


Expand Down
53 changes: 41 additions & 12 deletions train/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,7 @@ def start_training():

if os.path.exists(local_artifacts_dir):
sly.fs.remove_dir(local_artifacts_dir)

# get number of images in selected datasets
dataset_infos = [
api.dataset.get_info_by_id(dataset_id) for dataset_id in dataset_ids
Expand Down Expand Up @@ -1611,6 +1612,7 @@ def freeze_callback(trainer):
# get additional training params
additional_params = train_settings_editor.get_text()
additional_params = yaml.safe_load(additional_params)

if task_type == "pose estimation":
additional_params["fliplr"] = 0.0
# set up epoch progress bar and grid plot
Expand Down Expand Up @@ -1948,14 +1950,14 @@ def train_model():
print(results)
best_epoch = results["fitness"].idxmax()
best_filename = f"best_{best_epoch}.pt"
current_best_filepath = os.path.join(local_artifacts_dir, "checkpoints", "best.pt")
new_best_filepath = os.path.join(local_artifacts_dir, "checkpoints", best_filename)
current_best_filepath = os.path.join(local_artifacts_dir, "weights", "best.pt")
new_best_filepath = os.path.join(local_artifacts_dir, "weights", best_filename)
os.rename(current_best_filepath, new_best_filepath)

# add geometry config to saved weights for pose estimation task
if task_type == "pose estimation":
weights_filepath = os.path.join(
local_artifacts_dir, "checkpoints", best_filename
local_artifacts_dir, "weights", best_filename
)
weights_dict = torch.load(weights_filepath)
if len(cls2config.keys()) == 1:
Expand All @@ -1974,8 +1976,8 @@ def add_sly_metadata_to_ckpt(ckpt_path):
loaded["sly_metadata"] = {"model_name": selected_model_name}
torch.save(loaded, ckpt_path)

best_path = os.path.join(local_artifacts_dir, "checkpoints", best_filename)
last_path = os.path.join(local_artifacts_dir, "checkpoints", "last.pt")
best_path = os.path.join(local_artifacts_dir, "weights", best_filename)
last_path = os.path.join(local_artifacts_dir, "weights", "last.pt")
if os.path.exists(best_path):
add_sly_metadata_to_ckpt(best_path)
if os.path.exists(last_path):
Expand Down Expand Up @@ -2028,6 +2030,10 @@ def upload_monitor(monitor, api: sly.Api, progress: sly.Progress):
progress.set_current_value(value, report=False)
artifacts_pbar.update(progress.current - artifacts_pbar.n)

# Experiments update
weights_dir = os.path.join(local_artifacts_dir, "weights")
os.rename(weights_dir, os.path.join(local_artifacts_dir, "checkpoints"))

local_files = sly.fs.list_files_recursively(local_artifacts_dir)
total_size = sum([sly.fs.get_file_size(file_path) for file_path in local_files])
progress = sly.Progress(
Expand Down Expand Up @@ -2325,6 +2331,8 @@ def get_image_infos_by_split(split: list):
model_meta_file_info = generate_model_meta(
local_artifacts_dir, remote_artifacts_dir
)
hyperparameters_file_info = generate_hyperparameters(local_artifacts_dir, remote_artifacts_dir, additional_params)

experiment_info = {
"experiment_name": f"{g.app_session_id}_{project_info.name}_{selected_model_name}",
"framework_name": YOLOv8.framework_name,
Expand All @@ -2346,8 +2354,8 @@ def get_image_infos_by_split(split: list):
"train_val_split": train_val_file_info.name,
"train_size": train_set_size,
"val_size": val_set_size,
"hyperparameters": None,
"hyperparameters_id": None,
"hyperparameters": hyperparameters_file_info.name,
"hyperparameters_id": hyperparameters_file_info.id,
"artifacts_dir": remote_artifacts_dir,
"datetime": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"evaluation_report_id": report_id,
Expand Down Expand Up @@ -3095,13 +3103,13 @@ def train_model():
print(results)
best_epoch = results["fitness"].idxmax()
best_filename = f"best_{best_epoch}.pt"
current_best_filepath = os.path.join(local_artifacts_dir, "checkpoints", "best.pt")
new_best_filepath = os.path.join(local_artifacts_dir, "checkpoints", best_filename)
current_best_filepath = os.path.join(local_artifacts_dir, "weights", "best.pt")
new_best_filepath = os.path.join(local_artifacts_dir, "weights", best_filename)
os.rename(current_best_filepath, new_best_filepath)

# add geometry config to saved weights for pose estimation task
if task_type == "pose estimation":
weights_filepath = os.path.join(local_artifacts_dir, "checkpoints", best_filename)
weights_filepath = os.path.join(local_artifacts_dir, "weights", best_filename)
weights_dict = torch.load(weights_filepath)
if len(cls2config.keys()) == 1:
geometry_config = list(cls2config.values())[0]
Expand All @@ -3119,8 +3127,8 @@ def add_sly_metadata_to_ckpt(ckpt_path):
loaded["sly_metadata"] = {"model_name": selected_model_name}
torch.save(loaded, ckpt_path)

best_path = os.path.join(local_artifacts_dir, "checkpoints", best_filename)
last_path = os.path.join(local_artifacts_dir, "checkpoints", "last.pt")
best_path = os.path.join(local_artifacts_dir, "weights", best_filename)
last_path = os.path.join(local_artifacts_dir, "weights", "last.pt")
if os.path.exists(best_path):
add_sly_metadata_to_ckpt(best_path)
if os.path.exists(last_path):
Expand Down Expand Up @@ -3173,6 +3181,11 @@ def upload_monitor(monitor, api: sly.Api, progress: sly.Progress):
progress.set_current_value(value, report=False)
artifacts_pbar.update(progress.current - artifacts_pbar.n)


# Experiments update
weights_dir = os.path.join(local_artifacts_dir, "weights")
os.rename(weights_dir, os.path.join(local_artifacts_dir, "checkpoints"))

local_files = sly.fs.list_files_recursively(local_artifacts_dir)
total_size = sum([sly.fs.get_file_size(file_path) for file_path in local_files])
progress = sly.Progress(
Expand Down Expand Up @@ -3654,3 +3667,19 @@ def generate_model_meta(local_dir, remote_dir):
remote_model_meta_file_path,
)
return file_info

def generate_hyperparameters(local_dir:str, remote_dir: str, hyperparameters: str):
hyperparameters_file = "hyperparameters.yaml"

local_hyperparameters_file_path = os.path.join(local_dir, hyperparameters_file)
remote_hyperparameters_file_path = os.path.join(remote_dir, hyperparameters_file)

with open(local_hyperparameters_file_path, "w") as f:
yaml.safe_dump(hyperparameters, f)
sly.logger.debug(f"Uploading '{local_hyperparameters_file_path}' to Supervisely")
file_info = api.file.upload(
team_id,
local_hyperparameters_file_path,
remote_hyperparameters_file_path,
)
return file_info

0 comments on commit 690050b

Please sign in to comment.