Skip to content

Commit

Permalink
add training docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ksikka committed Jan 18, 2025
1 parent 2b99f81 commit b1e2494
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 50 deletions.
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = 'sphinx_rtd_theme'
html_theme_options = {"logo": {"text": "Lightning Pose Docs - Home"}}
html_logo = "images/LightningPose_logo_light.png"
html_favicon = "images/favicon.ico"

Expand Down
121 changes: 72 additions & 49 deletions docs/source/user_guide/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,86 +4,108 @@
Training
########

Lightning Pose provides several tools for training models:
Lightning pose provides the ``litpose train`` command to train models.
It expects a valid config file, and outputs the newly trained model to a directory.

#. A set of high-level functions used for creating data loaders, models, trainers, etc. You can combine these to create your own custom training script. This is required if you used the :ref:`pip package <pip_package>` installation method.
#. An example training script provided in the :ref:`conda from source <conda_from_source>` installation method. This demonstrates how to combine the high-level functions for model training and evaluation.
Training on your own dataset
============================

.. note::
Create a valid config file
--------------------------

The steps below assume the :ref:`conda from source <conda_from_source>` installation method.
If you did not use this installation method, see the
`example training script <https://github.com/danbider/lightning-pose/blob/main/scripts/train_hydra.py>`_.
Copy the default config (`config_default.yaml`_)
to a local file, and modify the ``data`` section to point to your own dataset. Sections other than
``data`` have reasonable defaults for getting started. For example:

.. code-block:: yaml
Train with example data
=======================
data:
image_resize_dims:
height: 256
width: 256
data_dir: /home/user1/data/
video_dir: /home/user1/data/videos
csv_file: labeled_frames.csv
downsample_factor: 2
# total number of keypoints
num_keypoints: 3
keypoint_names:
- paw_left
- paw_right
- nose_tip
To train a model on the example dataset provided with the Lightning Pose package,
run the following command from inside the ``lightning-pose`` directory
(make sure you have activated your conda environment):
.. _config_default.yaml: https://github.com/paninski-lab/lightning-pose/blob/main/scripts/configs/config_default.yaml

.. code-block:: console
Train a model
-------------

python scripts/train_hydra.py
To train a model, just point ``litpose train`` at your config file:

Note there are no arguments - this tells the script to default to the example data.
.. code-block:: shell
Train with your data
====================
# Replace 'config_default.yaml' with the path to your config file.
litpose train config_default.yaml
To train a model on your own dataset, follow these steps:
The model will be saved in ``./outputs/{YYYY-MM-DD}/{HH:MM:SS}/``, creating the folder if it does not already exist.
To customize the output directory, use the ``--output_dir OUTPUT_DIR`` flag of the command.

#. Ensure your data is in the :ref:`proper data format <directory_structure>`.
#. Copy the file ``scripts/configs/config_default.yaml`` to another directory and rename it. You will then need to update the various fields to match your dataset (see :ref:`The configuration file <config_file>` section). See other config files in ``scripts/configs/`` for examples.
#. Train your model from the terminal and overwrite the config path and config name with your newly created file:
.. code-block:: shell
.. code-block:: console
# Save to 'outputs/lp_test_1'
litpose train config_default.yaml --output_dir outputs/lp_test_1
python scripts/train_hydra.py --config-path=<PATH/TO/YOUR/CONFIGS/DIR> --config-name=<CONFIG_NAME.yaml>
.. note::

You can find more information on the structure of the output model directory
:ref:`below <model_directory_structure>`.
If the command ``litpose`` is not found, ensure that you've activated the conda
environment with lightning-pose installed, and that you're using version >= 1.7.0
(verify this using ``pip show lightning-pose``).

Working with ``hydra``
======================
For the full listing of training options, run ``litpose train --help``.

All of the scripts in the ``scripts`` directory rely on the ``hydra`` package to manage
arguments in config files.
You have two options: directly edit the config file, or override it from the command line.
Config overrides
----------------

#. **Edit** the config file, and save it.
Then run the script without arguments:
If you want to override some config values before training, you can use the ``--overrides`` flag.
This uses hydra under the hood, so refer to the `hydra syntax for config overrides`_.

.. code-block:: console
.. _hydra syntax for config overrides: https://hydra.cc/docs/advanced/override_grammar/basic/

python scripts/train_hydra.py
.. code-block:: shell
#. **Override** the argument from the command line; for example, if you want to use a maximum of 11
epochs instead of the default number (not recommended):
# Train for only 5 epochs
litpose train config_default.yaml --overrides training.min_epochs=5 training.max_epochs=5
.. code-block:: console
# Train a supervised model
litpose train config_default.yaml --output_dir outputs/supervised --overrides \
model.losses_to_use=null
python scripts/train_hydra.py training.max_epochs=11
Post-training flags
-------------------

Or, for your own dataset,
After training, lightning pose can automatically predict on some videos
and save out videos labeled with its predictions. The config settings that control this behavior are:

.. code-block::
* ``eval.predict_vids_after_training``: if ``true``, automatically run inference after training on
all videos located in the directory given by ``eval.test_videos_directory``; results are saved
to the model directory
* ``eval.save_vids_after_training``: if ``true`` (as well as ``eval.predict_vids_after_training``)
the keypoints predicted during the inference step will be overlaid on the videos and saved with
inference outputs to the model directory

python scripts/train_hydra.py --config-path=<PATH/TO/YOUR/CONFIGS/DIR> --config-name=<CONFIG_NAME.yaml> training.max_epochs=11

We also recommend trying out training with resizing to smaller images first;
this allows for larger batch sizes/fewer Out Of Memory errors on the GPU:
Training on a sample dataset
============================

.. code-block:: console
To quickly try lightning-pose without your own dataset, the lightning-pose git repository provides a small
sample dataset. Clone the repository and run the train command pointed at our sample config:

python scripts/train_hydra.py --config-path=<PATH/TO/YOUR/CONFIGS/DIR> --config-name=<CONFIG_NAME.yaml> data.image_resize_dims.height=256 data.image_resize_dims.width=256
.. code-block:: shell
See more documentation on the config file fields :ref:`here <config_file>`. A couple of fields that
are specific to the provided training script, but important to consider:
# (Skip this if you've already cloned, i.e. to install from source.)
git clone https://github.com/paninski-lab/lightning-pose
* ``eval.predict_vids_after_training``: if ``true``, automatically run inference after training on all videos located in the directory given by ``eval.test_videos_directory``; results are saved to the model directory
* ``eval.save_vids_after_training``: if ``true`` (as well as ``eval.predict_vids_after_training``) the keypoints predicted during the inference step will be overlaid on the videos and saved with inference outputs to the model directory
# Run from a directory containing the lightning-pose repo.
litpose train lightning-pose/scripts/configs/config_mirror-mouse-example.yaml
Tensorboard
===========
Expand Down Expand Up @@ -186,3 +208,4 @@ We also compute all unsupervised losses, where applicable, and store them
* ``predictions_pca_multiview_error.csv``: pca multiview reprojection error between predictions and labeled keypoints

* ``predictions_pca_singleview_error.csv``: pca singleview reprojection error between predictions and labeled keypoints

0 comments on commit b1e2494

Please sign in to comment.