Skip to content

Commit

Permalink
update to jax-based eks (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt authored Jul 3, 2024
1 parent 807d3e3 commit 443fb2d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 17 deletions.
2 changes: 1 addition & 1 deletion lightning_pose_app/ui/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,6 @@ def _render_streamlit_fn(state: AppState):
"for your new project, continue work on an ongoing lightning pose project, or remove "
"a project from your projects."
)
st.text(f"Available projects: {state.initialized_projects}")

if st_mode == LOAD_STR:
st_project_name = st.selectbox(
Expand All @@ -638,6 +637,7 @@ def _render_streamlit_fn(state: AppState):
value="" if (not state.st_project_loaded or state.st_reset_project_name)
else state.st_project_name
)
st.text(f"Available projects: {state.initialized_projects}")
state.st_delete_project = False # extra insurance, keep this!

# ----------------------------------------------------
Expand Down
46 changes: 33 additions & 13 deletions lightning_pose_app/ui/train_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import streamlit as st
import torch
import yaml
from eks.singleview_smoother import ensemble_kalman_smoother_single_view
from eks.singlecam_smoother import ensemble_kalman_smoother_singlecam
from eks.utils import convert_lp_dlc, make_output_dataframe, populate_output_dataframe
from lightning.app import CloudCompute, LightningFlow, LightningWork
from lightning.app.structures import Dict
Expand Down Expand Up @@ -305,29 +305,46 @@ def _run_eks(
# -----------------------------------------
# run eks
# -----------------------------------------
self.status_ = "running eks"
self.progress = 0.0
self.status_ = "finding optimal smoothing parameter"
self.progress = 50.0

# make empty dataframe for eks outputs
df_eks = make_output_dataframe(preds_df) # make from unformatted dataframe

# loop over keypoints; apply eks to each individually
for k, keypoint_name in enumerate(keypoints_to_smooth):
# run eks
keypoint_df_dict, s_final, nll_values = ensemble_kalman_smoother_single_view(
markers_list=dfs,
keypoint_ensemble=keypoint_name,
smooth_param=smooth_param, # default (None) is to compute automatically
)
keypoint_df = keypoint_df_dict[keypoint_name + '_df'] # make cleaner 2
# Convert list of DataFrames to a 3D NumPy array
data_arrays = [df.to_numpy() for df in dfs]
markers_3d_array = np.stack(data_arrays, axis=0)

# Map keypoint names to keys in dfs and crop markers_3d_array
keypoint_is = {}
keys = []
for i, col in enumerate(dfs[0].columns):
keypoint_is[col] = i
for part in keypoints_to_smooth:
keys.append(keypoint_is[part + '_x'])
keys.append(keypoint_is[part + '_y'])
keys.append(keypoint_is[part + '_likelihood'])
key_cols = np.array(keys)
markers_3d_array = markers_3d_array[:, :, key_cols]

# Call the smoother function
df_dicts, s_finals = ensemble_kalman_smoother_singlecam(
markers_3d_array=markers_3d_array,
bodypart_list=keypoints_to_smooth,
smooth_param=smooth_param, # default (None) is to compute automatically
s_frames=[(0, min(10000, markers_3d_array.shape[0]))], # optimize on first 10k frames
)

self.status_ = "running eks"
self.progress = 0.0
for k, keypoint_name in enumerate(keypoints_to_smooth):
keypoint_df = df_dicts[k][keypoint_name + '_df']
# put results into new dataframe
df_eks = populate_output_dataframe(
keypoint_df,
keypoint_name,
df_eks,
)

self.progress = (k + 1.0) / len(keypoints_to_smooth) * 100.0

# -----------------------------------------
Expand Down Expand Up @@ -1210,4 +1227,7 @@ def create_ensemble_directory(ensemble_dir: str, model_dirs: list):
with open(text_file_path, 'w') as file:
file.writelines(f"{path}\n" for path in model_dirs)

# save empty predictions.csv file so that `find_models` function will find this
os.mknod(os.path.join(ensemble_dir, "predictions.csv"))

return text_file_path
3 changes: 2 additions & 1 deletion requirements_litpose.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ tensorboard==2.16.2
torchtyping==0.1.4
torchvision==0.18.1
typeguard==4.3.0
typing==3.7.4.3
typing==3.7.4.3
optax==0.1.7
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_version(rel_path):

install_requires = [
"lightning[app]==2.2.5",
"ensemble-kalman-smoother==1.1.0",
"ensemble-kalman-smoother==2.0.1",
"numpy",
"opencv-python-headless",
"pandas",
Expand Down
2 changes: 1 addition & 1 deletion tests/ui/test_ui_train_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_train_infer_ui(root_dir, tmp_proj_dir, video_file):
flow.work_is_done_inference = False
flow.work_is_done_eks = False
# NOTE: this will launch EKS as well
flow.run(action="run_inference", video_files=[video_file], smooth_param=1.0, testing=True)
flow.run(action="run_inference", video_files=[video_file], smooth_param=None, testing=True)

# check flow state
assert len(flow.st_infer_status) == 3 # one for eks, one for each of the two ensemble members
Expand Down

0 comments on commit 443fb2d

Please sign in to comment.