Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensembling + eks #74

Merged
merged 23 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
eed4303
first pass on multi-seed training
themattinthehatt Apr 19, 2024
cca05b8
finsih multi-seed training
themattinthehatt Apr 22, 2024
0ad3518
update unit tests
themattinthehatt Apr 22, 2024
9018059
refactor status dict for training and inference
themattinthehatt Apr 22, 2024
ed4143b
run inference on ensemble
themattinthehatt Apr 22, 2024
fe4a8c1
dummy eks execution
themattinthehatt Apr 22, 2024
78fc7f3
output eks metrics and labeled video; launch_works method in LitPose …
themattinthehatt Apr 23, 2024
06fbf3b
run eks on short video clips
themattinthehatt Apr 23, 2024
8beaf74
get ensembling+eks working in app
themattinthehatt Apr 23, 2024
c3b1186
ensure user keypoint names do not contain spaces or dashes
themattinthehatt Apr 23, 2024
fedde66
streamline video uploading
themattinthehatt Apr 23, 2024
15f42d3
ensembling+eks docs
themattinthehatt Apr 23, 2024
a4fb20b
add eks code
themattinthehatt Apr 25, 2024
6d89634
test
themattinthehatt Apr 25, 2024
d42fd61
test 2
themattinthehatt Apr 25, 2024
eec003f
point eks to properly formatted video
themattinthehatt Apr 26, 2024
52e82e7
bug fix
themattinthehatt Apr 26, 2024
ab1f0bf
revert label studio startup to run method
themattinthehatt Apr 30, 2024
23d9f2e
import demo dataset updates
themattinthehatt Apr 30, 2024
aadc235
update LP version
themattinthehatt Apr 30, 2024
0320f48
speed up unit tests
themattinthehatt Apr 30, 2024
3508c2e
set min train frames
themattinthehatt Apr 30, 2024
3e2438e
dummy proofing
themattinthehatt Apr 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 75 additions & 35 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import json
from lightning.app import CloudCompute, LightningApp, LightningFlow
import logging
import numpy as np
Expand Down Expand Up @@ -94,54 +95,78 @@ def __init__(self):
database_dir=os.path.join(self.data_dir, LABELSTUDIO_DB_DIR),
)

# start label studio
self.label_studio.run(action="start_label_studio")
self.import_demo_count = 0

# import mirror-mouse-example dataset
if not os.environ.get("TESTING_LAI"):
self.import_demo_dataset(
src_dir=os.path.join(LIGHTNING_POSE_DIR, "data", "mirror-mouse-example"),
dst_dir=os.path.join(self.data_dir[1:], "mirror-mouse-example")
)
def import_demo_dataset(self, src_dir_abs, dst_dir_abs):

def import_demo_dataset(self, src_dir, dst_dir):
"""NOTE
This is an ugly solution. Previously this function was called from the app constructor,
which required label studio to be started inside the constructor as well. This led to
issues with ports. Therefore this import needs to happen in the app's run method.
However, this means that various parts of this function will execute several times before
it is finished. Furthermore, this function runs *every* time the app is called.
"""

src_dir_abs = os.path.join(os.path.dirname(__file__), src_dir)
proj_dir_abs = os.path.join(os.path.dirname(__file__), dst_dir)
if os.path.isdir(proj_dir_abs):
if self.import_demo_count > 0:
return

_logger.info("Importing demo dataset; this will only take a minute")
proj_dir_abs = dst_dir_abs
project_name = os.path.basename(dst_dir_abs)

project_name = os.path.basename(dst_dir)
# check to see if the demo dataset has already been imported
label_studio_exports = os.path.join(
os.path.dirname(dst_dir_abs), LABELSTUDIO_DB_DIR, "export",
)
projects = {}
if os.path.isdir(label_studio_exports):
files = os.listdir(label_studio_exports)
for f in files:
if f.endswith("info.json"):
try:
json_file = os.path.join(label_studio_exports, f)
d = json.load(open(json_file, "r"))
project_name_curr = d["project"]["title"]
n_labels_curr = d["project"]["task_number"]
projects[project_name_curr] = n_labels_curr
except Exception:
# sometimes there is a json read error, not sure why
continue

if project_name in projects.keys() and projects[project_name] >= 90:
self.import_demo_count += 1
return

_logger.info("Importing demo dataset; this will only take a minute")

# -------------------------------
# copy data
# -------------------------------
# copy full example data directory over
shutil.copytree(src_dir_abs, proj_dir_abs)
if not os.path.isdir(proj_dir_abs):
shutil.copytree(src_dir_abs, proj_dir_abs)

# copy config file
config_file_dst = os.path.join(proj_dir_abs, f"model_config_{project_name}.yaml")
shutil.copyfile(
os.path.join(LIGHTNING_POSE_DIR, "scripts", "configs", f"config_{project_name}.yaml"),
config_file_dst,
)
if not os.path.isfile(config_file_dst):
print(project_name)
shutil.copyfile(
os.path.join(
LIGHTNING_POSE_DIR, "scripts", "configs", f"config_{project_name}.yaml"
),
config_file_dst,
)

# make csv file for label studio
n_frames = len(os.listdir(os.path.join(proj_dir_abs, LABELED_DATA_DIR)))
idxs_selected = np.arange(1, n_frames - 2) # we've stored mock context frames
n_digits = 2
extension = "png"
frames_to_label = np.sort(np.array(
["img%s.%s" % (str(idx).zfill(n_digits), extension) for idx in idxs_selected]
))
np.savetxt(
os.path.join(proj_dir_abs, LABELED_DATA_DIR, SELECTED_FRAMES_FILENAME),
frames_to_label,
delimiter=",",
fmt="%s",
)
csv_file_ls = os.path.join(proj_dir_abs, LABELED_DATA_DIR, SELECTED_FRAMES_FILENAME)
if not os.path.isfile(csv_file_ls):
n_frames = len(os.listdir(os.path.join(proj_dir_abs, LABELED_DATA_DIR)))
idxs_selected = np.arange(1, n_frames - 2) # we've stored mock context frames
n_digits = 2
extension = "png"
frames_to_label = np.sort(np.array(
["img%s.%s" % (str(idx).zfill(n_digits), extension) for idx in idxs_selected]
))
np.savetxt(csv_file_ls, frames_to_label, delimiter=",", fmt="%s")

# make models dir
os.makedirs(os.path.join(proj_dir_abs, MODELS_DIR), exist_ok=True)
Expand Down Expand Up @@ -178,9 +203,12 @@ def import_demo_dataset(self, src_dir, dst_dir):

csv_file = os.path.join(proj_dir_abs, COLLECTED_DATA_FILENAME)
df = pd.read_csv(csv_file, index_col=0, header=[0, 1, 2])
df.drop('obs_top', axis=1, level=1, inplace=True)
df.drop('obsHigh_bot', axis=1, level=1, inplace=True)
df.drop('obsLow_bot', axis=1, level=1, inplace=True)
if 'obs_top' in df.columns.get_level_values(1):
df.drop('obs_top', axis=1, level=1, inplace=True)
if 'obsHigh_bot' in df.columns.get_level_values(1):
df.drop('obsHigh_bot', axis=1, level=1, inplace=True)
if 'obsLow_bot' in df.columns.get_level_values(1):
df.drop('obsLow_bot', axis=1, level=1, inplace=True)
df.to_csv(csv_file)

# -------------------------------
Expand Down Expand Up @@ -231,6 +259,8 @@ def import_demo_dataset(self, src_dir, dst_dir):

del project_ui_demo

self.import_demo_count += 1

def start_tensorboard(self, logdir):
"""run tensorboard"""
cmd = f"tensorboard --logdir {logdir} --host $host --port $port --reload_interval 30"
Expand Down Expand Up @@ -272,6 +302,7 @@ def run(self):
# -------------------------------------------------------------
# start background services (run only once)
# -------------------------------------------------------------
self.label_studio.run(action="start_label_studio")
self.start_fiftyone()
if self.project_ui.model_dir is not None:
# find previously trained models for project, expose to training and diagnostics UIs
Expand All @@ -282,6 +313,15 @@ def run(self):
self.streamlit_frame.run(action="initialize")
self.streamlit_video.run(action="initialize")

# import mirror-mouse-example dataset
if not os.environ.get("TESTING_LAI"):
self.import_demo_dataset(
src_dir_abs=os.path.join(
os.path.dirname(__file__), LIGHTNING_POSE_DIR, "data", "mirror-mouse-example"),
dst_dir_abs=os.path.join(
os.path.dirname(__file__), self.data_dir[1:], "mirror-mouse-example"),
)

# -------------------------------------------------------------
# update project data (user has clicked button in project UI)
# -------------------------------------------------------------
Expand Down
12 changes: 12 additions & 0 deletions docs/source/faqs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,15 @@ Model training
You can find the relevant parameters to adjust
`here <https://lightning-pose.readthedocs.io/en/latest/source/user_guide/config_file.html>`_
(this link takes you to another set of docs specifically for Lightning Pose).


Post-processing
---------------

.. _faq_post_processing:

.. dropdown:: Does the Lightning Pose app perform post-processing of the predictions?

We offer the `Ensemble Kalman Smoother (EKS) <https://github.com/paninski-lab/eks>`_
post-processor, which we have found superior to other forms of post-processing.
To run EKS, see the :ref:`Create an ensemble of models<tab_train_infer__ensemble>` section.
4 changes: 4 additions & 0 deletions docs/source/tabs/manage_project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ seen from two views ("top" and "bottom").
If you are using more than one view, we recommend listing all keypoints from one view first,
then all keypoints from the next view, etc.

.. note::

Keypoint names cannot contain spaces or dashes (underscores are ok).

.. image:: https://imgur.com/m0a6TRy.png

You will then be prompted to select a subset of keypoints for the Pose PCA loss.
Expand Down
114 changes: 95 additions & 19 deletions docs/source/tabs/train_infer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,34 @@ This tab is the interface for training models and running inference on new video

The left side-bar displays your current labeling progress, and contains a drop-down menu showing
all previously trained models.
The "Train Networks" and "Predict on New Videos" columns are for training and inference,
The main workflows are:

* :ref:`Train networks<tab_train_infer__train>`
* :ref:`Predict on new videos<tab_train_infer__infer>`
* :ref:`Create an ensemble of models<tab_train_infer__ensemble>`

"Train Networks" and "Predict on New Videos" columns are for training and inference,
and detailed below.

Train Networks
.. _tab_train_infer__train:

Train networks
==============

Training options
----------------

From the drop-down "Expand to adjust training parameters" menu,
optionally change the max training epochs,
the model seed (different seeds will lead to different model outputs that are useful for ensembling),
and the types of unsupervised losses used for the semi-supervised models.

.. .. image:: https://imgur.com/LiylXxc.png
:width: 400
optionally change the following:

The PCA Multiview option will only appear if your data have more than one view;
the Pose PCA option will only appear if you selected keypoints for the Pose PCA loss during
project creation.
* Max training epochs (default is 300)
* Model seed; different seeds will lead to different model outputs that are useful for ensembling.
Enter either a single integer (e.g. ``0``) to train one model of each type (see below), or a
list of comma-separated integers (e.g. ``0,1,2``) to train multiple models of each type.
* Losses used for the semi-supervised models.
The PCA Multiview option will only appear if your data have more than one view;
the Pose PCA option will only appear if you selected keypoints for the Pose PCA loss during
project creation.

Select models to train
----------------------
Expand Down Expand Up @@ -62,27 +70,25 @@ Once training is complete for all models you will see

.. _tab_train_infer__infer:

Predict on New Videos
Predict on new videos
=====================

First, select the model you would like to use for inference from the drop-down menu.

Select videos
-------------

You have three options for video selection:
You have two options for video selection:

* **Upload new**:
upload a new video to the app.
To do so, drag and drop the video file(s) using the provided interface.
You will see an upload progress bar.
If your video is larger than the 200MB default limit, see the :ref:`FAQs<faq_upload_limit>`.
* **Select video(s) previously uploaded to the TRAIN/INFER tab**:
any video previously uploaded to this tab will be available in the drop down menu; you may
select multiple videos.
* **Select video(s) previously uploaded to the EXTRACT FRAMES tab**:
any video previously uploaded in the EXTRACT FRAMES tab for labeling will be available in the
drop down menu; you may select multiple videos.
* **Select previously uploaded video(s)**:
any video previously uploaded to this tab (located in the ``videos_infer`` directory) or the
Extract Frames tab (located in the ``videos`` directory) will be available in the drop down menu;
you may select multiple videos.

Video handling options
----------------------
Expand Down Expand Up @@ -111,3 +117,73 @@ Once inference is complete for all videos you will see the
"waiting for existing inference to finish" warning disappear.

See :ref:`Accessing your data <directory_structure>` for the location of inference results.


.. _tab_train_infer__ensemble:

Create an ensemble of models
============================

Ensembling is a classical machine learning technique that combines predictions from multiple
models to provide enhanced performance.
We offer the `Ensemble Kalman Smoother (EKS) <https://github.com/paninski-lab/eks>`_,
a Bayesian ensembling technique that combines model predictions with a latent smoothing model.

To use EKS, you must first create an ensemble of models.
Then, if you run inference using the ensemble, EKS will automatically be run on the ensemble
output.
The steps are outlined in more detail below.

Select models for ensembling
----------------------------
Select a set of previously trained models to create the ensemble.
We recommend an ensemble size of 4-5 models for a good trade-off between computational efficiency
and accuracy.
An ensemble can be composed in many ways;
one way would be to include models of the same type (supervised, semi-supervised, etc.) using
different random seeds;
another way would be to include models of different types (e.g. one supervised, one
semi-supervised, etc.); a combination of these approaches would work too!

Add ensemble name
-----------------
Give your ensemble a name. This text will be appended to the date and time to form the final
ensemble name (just like the other models), to prevent overwriting previous models/ensembles.

Create ensemble
---------------
Click the "Create ensemble" button; you will see a brief success message.
The newly-created ensemble directory will contain a text file that points to the model directories
of the individual ensemble members.

Running the Ensemble Kalman Smoother post-processor
---------------------------------------------------
Now that the ensemble has been created, you can run inference on videos.
Navigate back to the :ref:`Predict on new videos <tab_train_infer__infer>` part of this tab.
You should now see your ensemble in the drop-down menu of models.

.. note::

If your model is not in the drop-down menu, click on the three vertical dots in the top right
of the tab (next to the "Deploy" button) and click "Rerun".

You can now treat the ensemble as any other model: select one or more videos to run inference on,
select any video labeling options you like, and then click "Run inference".
Upon doing so you will see multiple progress bars appear, one for each model/video combination:

.. image:: https://imgur.com/dGktgCm.png
:width: 400

Inference and labeled video creation will be skipped for any ensemble member that has already
performed these tasks.

After inference and labeled video creation are completed for each ensemble member, a new progress
bar will appear for the EKS model.
You will see the progress of the EKS fitting process, as well as the labeled video creation if you
have selected one of those options.

The outputs of EKS will be stored just like the inference outputs of a single model.
This means that you may inspect the EKS traces in the
:ref:`Video Diagnostics tab<tab_video_diagnostics>`
and view the labeled video (if you have selected one of these options) in the
:ref:`Video Player tab<tab_video_player>`.
2 changes: 1 addition & 1 deletion lightning-pose
5 changes: 4 additions & 1 deletion lightning_pose_app/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Package constants"""

__version__ = "1.3.0"
__version__ = "1.4.0"


# dir where lightning pose package lives, relative to Pose-app root
Expand All @@ -27,3 +27,6 @@

# file name constants; relative to project_dir/<LABELED_DATA_DIR>/<video_name>
SELECTED_FRAMES_FILENAME = "selected_frames.csv"

# file name constatns; relative to project_dir/MODELS_DIR/<date>/<time>/
ENSEMBLE_MEMBER_FILENAME = "models_for_ensemble.txt"
3 changes: 2 additions & 1 deletion lightning_pose_app/backend/extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def read_nth_frames(
video_file: str,
n: int = 1,
resize_dims: int = 64,
progress_delta: float = 0.5, # for online progress updates
work: Optional[LightningWork] = None, # for online progress updates
) -> np.ndarray:

Expand All @@ -51,7 +52,7 @@ def read_nth_frames(
progress = frame_counter / frame_total * 100.0
# periodically update progress of worker if available
if work is not None:
if round(progress, 4) - work.progress >= work.progress_delta:
if round(progress, 4) - work.progress >= progress_delta:
if progress > 100:
work.progress = 100.0
else:
Expand Down
Loading