From 552ebb703c9450ed368702c4851e441a6e3acb81 Mon Sep 17 00:00:00 2001 From: Matt Whiteway Date: Fri, 23 Feb 2024 13:12:20 -0500 Subject: [PATCH] v1.1.0 updates (#53) * remove FileSystem backend * zipped frame uploads * unit tests * streamline fiftyone dataset creation * update lightning-pose version * minor bug fixes and documentation updates --------- Co-authored-by: Shmuel-columbia <155922206+Shmuel-columbia@users.noreply.github.com> --- app.py | 98 ++-- demo_app.py | 85 +-- docs/source/faqs.rst | 2 +- docs/source/tabs/extract_frames.rst | 91 ++- docs/source/tabs/fiftyone.rst | 4 +- docs/source/tabs/prepare_fiftyone.rst | 26 - docs/source/using_the_app.rst | 1 - labeling_app.py | 34 +- lightning-pose | 2 +- lightning_pose_app/__init__.py | 6 +- lightning_pose_app/bashwork.py | 38 +- lightning_pose_app/build_configs.py | 26 +- .../check_labeling_task_and_export.py | 5 +- lightning_pose_app/label_studio/component.py | 90 +-- .../label_studio/update_tasks.py | 2 + lightning_pose_app/label_studio/utils.py | 9 +- lightning_pose_app/ui/extract_frames.py | 518 ++++++++++++------ lightning_pose_app/ui/fifty_one.py | 298 ---------- lightning_pose_app/ui/project.py | 234 ++++---- lightning_pose_app/ui/streamlit.py | 19 +- lightning_pose_app/ui/train_infer.py | 361 ++++-------- lightning_pose_app/utilities.py | 250 ++++++--- setup.cfg | 16 + setup.py | 4 +- tests/conftest.py | 81 +++ tests/test_ui/test_extract_frames.py | 187 +++++++ tests/test_ui/test_project.py | 182 ++++++ tests/test_ui/test_train_infer.py | 230 ++++++++ tests/test_utilities.py | 172 ++++++ 29 files changed, 1818 insertions(+), 1253 deletions(-) delete mode 100644 docs/source/tabs/prepare_fiftyone.rst delete mode 100644 lightning_pose_app/ui/fifty_one.py create mode 100644 setup.cfg create mode 100644 tests/conftest.py create mode 100644 tests/test_ui/test_extract_frames.py create mode 100644 tests/test_ui/test_project.py create mode 100644 tests/test_ui/test_train_infer.py create mode 100644 tests/test_utilities.py diff --git a/app.py b/app.py index 013c5ed..d017bf5 100644 --- a/app.py +++ b/app.py @@ -7,38 +7,29 @@ from lightning.app import CloudCompute, LightningApp, LightningFlow from lightning.app.structures import Dict -from lightning.app.utilities.cloud import is_running_in_cloud import logging import os import sys import time import yaml -from lightning_pose_app import LABELSTUDIO_DB_DIR +from lightning_pose_app import LABELSTUDIO_DB_DIR, LIGHTNING_POSE_DIR from lightning_pose_app.bashwork import LitBashWork -from lightning_pose_app.ui.fifty_one import FiftyoneConfigUI from lightning_pose_app.label_studio.component import LitLabelStudio from lightning_pose_app.ui.extract_frames import ExtractFramesUI from lightning_pose_app.ui.project import ProjectUI from lightning_pose_app.ui.streamlit import StreamlitAppLightningPose -from lightning_pose_app.ui.train_infer import TrainUI, LitPose -from lightning_pose_app.build_configs import TensorboardBuildConfig, LitPoseBuildConfig -from lightning_pose_app.build_configs import lightning_pose_dir - +from lightning_pose_app.ui.train_infer import TrainUI logging.basicConfig(stream=sys.stdout, level=logging.INFO) _logger = logging.getLogger('APP') # TODO: HIGH PRIORITY -# - import previous projects -# * should this be done with a Work, and a frame upload status bar? -# * automatically create context datasets # - `abort` button next to training/inference progress bars so user doesn't have to kill app # - active learning # TODO: LOW PRIORITY -# - ProjectUI._put_to_drive_remove_local does NOT overwrite dirs already on FileSystem - ok? # - launch training in parallel (get this working with `extract_frames` standalone app first) # - update label studio xml and CollectedData.csv when user inputs new keypoint in project ui @@ -52,10 +43,10 @@ def __init__(self): # ----------------------------- # paths # ----------------------------- - self.data_dir = "/data" # relative to FileSystem root + self.data_dir = "/data" # # relative to Pose-app root # load default config and pass to project manager - config_dir = os.path.join(lightning_pose_dir, "scripts", "configs") + config_dir = os.path.join(LIGHTNING_POSE_DIR, "scripts", "configs") default_config_dict = yaml.safe_load(open(os.path.join(config_dir, "config_default.yaml"))) # ----------------------------- @@ -74,8 +65,10 @@ def __init__(self): # training tab (flow + work) self.train_ui = TrainUI() - # fiftyone tab (flow + work) - self.fiftyone_ui = FiftyoneConfigUI() + # fiftyone tab (work) + self.fiftyone = LitBashWork( + cloud_compute=CloudCompute("default"), + ) # streamlit tabs (flow + work) self.streamlit_frame = StreamlitAppLightningPose(app_type="frame") @@ -83,9 +76,7 @@ def __init__(self): # tensorboard tab (work) self.tensorboard = LitBashWork( - name="tensorboard", cloud_compute=CloudCompute("default"), - cloud_build_config=TensorboardBuildConfig(), ) # label studio (flow + work) @@ -96,27 +87,20 @@ def __init__(self): # works for inference self.inference = Dict() - # @property - # def ready(self) -> bool: - # """Return true once all works have an assigned url""" - # return all([ - # self.fiftyone_ui.work.url != "", - # self.streamlit_frame.work.url != "", - # self.streamlit_video.work.url != "", - # self.train_ui.work.url != "", - # self.label_studio.label_studio.url != "" - # ]) - def start_tensorboard(self, logdir): """run tensorboard""" cmd = f"tensorboard --logdir {logdir} --host $host --port $port --reload_interval 30" self.tensorboard.run(cmd, wait_for_exit=False, cwd=os.getcwd()) + def start_fiftyone(self): + """run fiftyone""" + cmd = "fiftyone app launch --address $host --port $port --remote --wait -1" + self.fiftyone.run(cmd, wait_for_exit=False, cwd=os.getcwd()) + def update_trained_models_list(self, timer): self.project_ui.run(action="update_trained_models_list", timer=timer) if self.project_ui.trained_models: self.train_ui.trained_models = self.project_ui.trained_models - self.fiftyone_ui.trained_models = self.project_ui.trained_models def run(self): @@ -127,12 +111,12 @@ def run(self): # don't interfere /w train; since all Works use the same filesystem when running locally, # one Work updating the filesystem which is also used by the trainer can corrupt data, etc. run_while_training = True - if not is_running_in_cloud() and self.train_ui.run_script_train: + if self.train_ui.run_script_train: run_while_training = False # don't interfere w/ inference run_while_inferring = True - if not is_running_in_cloud() and self.train_ui.run_script_infer: + if self.train_ui.run_script_infer: run_while_inferring = False # ------------------------------------------------------------- @@ -141,15 +125,11 @@ def run(self): # find previously initialized projects, expose to project UI self.project_ui.run(action="find_initialized_projects") - # find previously constructed fiftyone datasets - self.fiftyone_ui.run(action="find_fiftyone_datasets") - # ------------------------------------------------------------- # start background services (run only once) # ------------------------------------------------------------- - self.label_studio.run(action="import_database") self.label_studio.run(action="start_label_studio") - self.fiftyone_ui.run(action="start_fiftyone") + self.start_fiftyone() if self.project_ui.model_dir is not None: # find previously trained models for project, expose to training and diagnostics UIs # timer to force later runs @@ -169,8 +149,6 @@ def run(self): self.train_ui.proj_dir = self.project_ui.proj_dir self.streamlit_frame.proj_dir = self.project_ui.proj_dir self.streamlit_video.proj_dir = self.project_ui.proj_dir - self.fiftyone_ui.proj_dir = self.project_ui.proj_dir - self.fiftyone_ui.config_name = self.project_ui.config_name self.label_studio.run( action="update_paths", proj_dir=self.project_ui.proj_dir, proj_name=self.project_ui.st_project_name) @@ -224,9 +202,9 @@ def run(self): self.project_ui.run_script = False # ------------------------------------------------------------- - # extract frames for labeling from uploaded videos + # extract frames for labeling # ------------------------------------------------------------- - if self.extract_ui.proj_dir and self.extract_ui.run_script and run_while_training: + if self.extract_ui.proj_dir and self.extract_ui.run_script_video_random: self.extract_ui.run( action="extract_frames", video_files=self.extract_ui.st_video_files, # add arg for run caching purposes @@ -234,9 +212,23 @@ def run(self): # wait until frame extraction is complete, then update label studio tasks if self.extract_ui.work_is_done_extract_frames: self.project_ui.run(action="update_frame_shapes") - self.extract_ui.run_script = False # hack, app won't advance past ls run + # hack; for some reason the app won't advance past the ls run + self.extract_ui.run_script_video_random = False self.label_studio.run(action="update_tasks", videos=self.extract_ui.st_video_files) - self.extract_ui.run_script = False + self.extract_ui.run_script_video_random = False + + if self.extract_ui.proj_dir and self.extract_ui.run_script_zipped_frames: + self.extract_ui.run( + action="unzip_frames", + video_files=self.extract_ui.st_frame_files, # add arg for run caching purposes + ) + # wait until frame extraction is complete, then update label studio tasks + if self.extract_ui.work_is_done_extract_frames: + self.project_ui.run(action="update_frame_shapes") + # hack; for some reason the app won't advance past the ls run + self.extract_ui.run_script_zipped_frames = False + self.label_studio.run(action="update_tasks", videos=self.extract_ui.st_frame_files) + self.extract_ui.run_script_zipped_frames = False # ------------------------------------------------------------- # periodically check labeling task and export new labels @@ -262,17 +254,6 @@ def run(self): # ------------------------------------------------------------- if self.train_ui.run_script_train and run_while_inferring: self.train_ui.run(action="train", config_filename=self.project_ui.config_name) - inputs = [self.project_ui.model_dir] - # have tensorboard pull the new data - self.tensorboard.run( - "null command", - cwd=os.getcwd(), - input_output_only=True, # pull inputs from Drive, but do not run commands - inputs=inputs, - ) - # have streamlit pull the new data - self.streamlit_frame.run(action="pull_models", inputs=inputs) - self.streamlit_video.run(action="pull_models", inputs=inputs) self.project_ui.update_models = True self.train_ui.run_script_train = False @@ -291,13 +272,6 @@ def run(self): ) self.train_ui.run_script_infer = False - # ------------------------------------------------------------- - # build fiftyone dataset on button press from FiftyoneUI - # ------------------------------------------------------------- - if self.fiftyone_ui.run_script: - self.fiftyone_ui.run(action="build_fiftyone_dataset") - self.fiftyone_ui.run_script = False - def configure_layout(self): # init tabs @@ -312,8 +286,7 @@ def configure_layout(self): # diagnostics tabs st_frame_tab = {"name": "Labeled Diagnostics", "content": self.streamlit_frame.work} st_video_tab = {"name": "Video Diagnostics", "content": self.streamlit_video.work} - fo_prep_tab = {"name": "Prepare Fiftyone", "content": self.fiftyone_ui} - fo_tab = {"name": "Fiftyone", "content": self.fiftyone_ui.work} + fo_tab = {"name": "Fiftyone", "content": self.fiftyone} if self.extract_ui.proj_dir: return [ @@ -324,7 +297,6 @@ def configure_layout(self): train_status_tab, st_frame_tab, st_video_tab, - fo_prep_tab, fo_tab, ] else: diff --git a/demo_app.py b/demo_app.py index cbc647c..f2d4a71 100644 --- a/demo_app.py +++ b/demo_app.py @@ -6,22 +6,17 @@ """ from lightning.app import CloudCompute, LightningApp, LightningFlow -from lightning.app.structures import Dict -from lightning.app.utilities.cloud import is_running_in_cloud import logging import os import shutil import sys -import time import yaml from lightning_pose_app.bashwork import LitBashWork -from lightning_pose_app.ui.fifty_one import FiftyoneConfigUI from lightning_pose_app.ui.project import ProjectUI from lightning_pose_app.ui.streamlit import StreamlitAppLightningPose from lightning_pose_app.ui.train_infer import TrainUI -from lightning_pose_app.build_configs import TensorboardBuildConfig, lightning_pose_dir - +from lightning_pose_app import LIGHTNING_POSE_DIR logging.basicConfig(stream=sys.stdout, level=logging.INFO) _logger = logging.getLogger('APP') @@ -36,11 +31,11 @@ def __init__(self): # ----------------------------- # paths # ----------------------------- - self.data_dir = "/data" # relative to FileSystem root + self.data_dir = "/data" # relative to Pose-app root self.proj_name = "demo" # load default config and pass to project manager - config_dir = os.path.join(lightning_pose_dir, "scripts", "configs") + config_dir = os.path.join(LIGHTNING_POSE_DIR, "scripts", "configs") default_config_dict = yaml.safe_load(open(os.path.join(config_dir, "config_default.yaml"))) # ----------------------------- @@ -57,8 +52,10 @@ def __init__(self): self.train_ui.n_labeled_frames = 90 # hard-code these values for now self.train_ui.n_total_frames = 90 - # fiftyone tab (flow + work) - self.fiftyone_ui = FiftyoneConfigUI() + # fiftyone tab (work) + self.fiftyone = LitBashWork( + cloud_compute=CloudCompute("default"), + ) # streamlit tabs (flow + work) self.streamlit_frame = StreamlitAppLightningPose(app_type="frame") @@ -66,28 +63,26 @@ def __init__(self): # tensorboard tab (work) self.tensorboard = LitBashWork( - name="tensorboard", cloud_compute=CloudCompute("default"), - cloud_build_config=TensorboardBuildConfig(), ) # ----------------------------- # copy toy data to project # ----------------------------- - # here we copy the toy dataset config file, frames, and labels that come packaged with the - # lightning-pose repo and move it to a new directory that is consistent with the project + # here we copy the toy dataset config file, frames, and labels that come packaged with the + # lightning-pose repo and move it to a new directory that is consistent with the project # structure the app expects # later we will write that newly copied data to the FileSystem so other Works have access # copy config file toy_config_file_src = os.path.join( - lightning_pose_dir, "scripts/configs/config_mirror-mouse-example.yaml") + LIGHTNING_POSE_DIR, "scripts/configs/config_mirror-mouse-example.yaml") toy_config_file_dst = os.path.join( os.getcwd(), self.data_dir[1:], self.proj_name, "model_config_demo.yaml") self._copy_file(toy_config_file_src, toy_config_file_dst) # frames, videos, and labels - toy_data_src = os.path.join(lightning_pose_dir, "data/mirror-mouse-example") + toy_data_src = os.path.join(LIGHTNING_POSE_DIR, "data/mirror-mouse-example") toy_data_dst = os.path.join(os.getcwd(), self.data_dir[1:], self.proj_name) self._copy_dir(toy_data_src, toy_data_dst) @@ -129,26 +124,20 @@ def _copy_dir(self, src_path, dst_path): except IOError as e: _logger.warning(f"Unable to copy directory. {e}") - # @property - # def ready(self) -> bool: - # """Return true once all works have an assigned url""" - # return all([ - # self.fiftyone_ui.work.url != "", - # self.streamlit_frame.work.url != "", - # self.streamlit_video.work.url != "", - # self.train_ui.work.url != "", - # ]) - def start_tensorboard(self, logdir): """run tensorboard""" cmd = f"tensorboard --logdir {logdir} --host $host --port $port --reload_interval 30" self.tensorboard.run(cmd, wait_for_exit=False, cwd=os.getcwd()) + def start_fiftyone(self): + """run fiftyone""" + cmd = "fiftyone app launch --address $host --port $port --remote --wait -1" + self.fiftyone.run(cmd, wait_for_exit=False, cwd=os.getcwd()) + def update_trained_models_list(self, timer): self.project_ui.run(action="update_trained_models_list", timer=timer) if self.project_ui.trained_models: self.train_ui.trained_models = self.project_ui.trained_models - self.fiftyone_ui.trained_models = self.project_ui.trained_models def run(self): @@ -160,12 +149,12 @@ def run(self): # don't interfere w/ train; since all Works use the same filesystem when running locally, # one Work updating the filesystem which is also used by the trainer can corrupt data, etc. run_while_training = True - if not is_running_in_cloud() and self.train_ui.run_script_train: + if self.train_ui.run_script_train: run_while_training = False # don't interfere w/ inference run_while_inferring = True - if not is_running_in_cloud() and self.train_ui.run_script_infer: + if self.train_ui.run_script_infer: run_while_inferring = False # ------------------------------------------------------------- @@ -176,8 +165,6 @@ def run(self): self.train_ui.proj_dir = self.project_ui.proj_dir self.streamlit_frame.proj_dir = self.project_ui.proj_dir self.streamlit_video.proj_dir = self.project_ui.proj_dir - self.fiftyone_ui.proj_dir = self.project_ui.proj_dir - self.fiftyone_ui.config_name = self.project_ui.config_name # write demo data to the FileSystem so other Works have access (run once) if not self.demo_data_transferred: @@ -187,19 +174,12 @@ def run(self): # update config file self.project_ui.run( action="update_project_config", - new_vals_dict={"data": { # TODO: will this work on cloud? + new_vals_dict={"data": { "data_dir": os.path.join(os.getcwd(), self.project_ui.proj_dir)[1:]} }, ) # send params to train ui self.train_ui.config_dict = self.project_ui.config_dict - # put demo data onto FileSystem - self.project_ui.run( - action="put_file_to_drive", - file_or_dir=self.project_ui.proj_dir, - remove_local=False, - ) - _logger.info("Demo data transferred to FileSystem") self.demo_data_transferred = True # find previously trained models for project, expose to training and diagnostics UIs @@ -208,29 +188,15 @@ def run(self): # start background services (only run once) self.start_tensorboard(logdir=self.project_ui.model_dir[1:]) + self.start_fiftyone() self.streamlit_frame.run(action="initialize") self.streamlit_video.run(action="initialize") - self.fiftyone_ui.run(action="start_fiftyone") - - # find previously constructed fiftyone datasets - self.fiftyone_ui.run(action="find_fiftyone_datasets") # ------------------------------------------------------------- # train models on ui button press # ------------------------------------------------------------- if self.train_ui.run_script_train and run_while_inferring: self.train_ui.run(action="train", config_filename=self.project_ui.config_name) - inputs = [self.project_ui.model_dir] - # have tensorboard pull the new data - self.tensorboard.run( - "null command", - cwd=os.getcwd(), - input_output_only=True, # pull inputs from FileSystem, but do not run commands - inputs=inputs, - ) - # have streamlit pull the new data - self.streamlit_frame.run(action="pull_models", inputs=inputs) - self.streamlit_video.run(action="pull_models", inputs=inputs) self.project_ui.update_models = True self.train_ui.run_script_train = False @@ -249,13 +215,6 @@ def run(self): ) self.train_ui.run_script_infer = False - # ------------------------------------------------------------- - # build fiftyone dataset on button press from FiftyoneUI - # ------------------------------------------------------------- - if self.fiftyone_ui.run_script: - self.fiftyone_ui.run(action="build_fiftyone_dataset") - self.fiftyone_ui.run_script = False - def configure_layout(self): # training tabs @@ -265,15 +224,13 @@ def configure_layout(self): # diagnostics tabs st_frame_tab = {"name": "Labeled Diagnostics", "content": self.streamlit_frame.work} st_video_tab = {"name": "Video Diagnostics", "content": self.streamlit_video.work} - fo_prep_tab = {"name": "Prepare Fiftyone", "content": self.fiftyone_ui} - fo_tab = {"name": "Fiftyone", "content": self.fiftyone_ui.work} + fo_tab = {"name": "Fiftyone", "content": self.fiftyone} return [ train_tab, train_status_tab, st_frame_tab, st_video_tab, - fo_prep_tab, fo_tab, ] diff --git a/docs/source/faqs.rst b/docs/source/faqs.rst index c434bdd..393a085 100644 --- a/docs/source/faqs.rst +++ b/docs/source/faqs.rst @@ -43,7 +43,7 @@ Note that both semi-supervised and context models will increase memory usage (wi context models needing the most memory). If you encounter this error, reduce batch sizes during training or inference. This feature is currently not supported in the app, so you will need to manually open the config -file, located at `Pose-app/.shared/data//model_config_.yaml`, update bactch +file, located at `Pose-app/data//model_config_.yaml`, update bactch sizes, save the file, then close. We also recommend restarting the app after config updates. You can find the relevant parameters to adjust diff --git a/docs/source/tabs/extract_frames.rst b/docs/source/tabs/extract_frames.rst index e4d2352..b22f4bf 100644 --- a/docs/source/tabs/extract_frames.rst +++ b/docs/source/tabs/extract_frames.rst @@ -4,17 +4,27 @@ Extract Frames ############## -The Extract Frames tab allows you to upload raw videos and select a subset of frames for labeling. +The Extract Frames tab allows you to select frames for labeling using various methods. .. note:: This tab does not appear in the demo app. -.. image:: https://imgur.com/OTL8myA.png +* :ref:`Upload videos and automatically extract random frames ` +* :ref:`Upload zipped files of frames ` + +.. _upload_video_random: + +Upload videos and automatically extract random frames +===================================================== + +Select the appropriate option from the list. + +.. image:: https://imgur.com/szEHHtw.png Drag and drop video file(s) using the provided interface. You will see an upload progress bar. -.. image:: https://imgur.com/TChOcFH.png +.. image:: https://imgur.com/GjR2Jb4.png Choose number of frames to label per video - these frames will be automatically selected to maximize the diversity of poses from each video. @@ -30,3 +40,78 @@ Click "Extract frames" once the video upload is complete, and another progress b Once all frames have been extracted you will see "Proceed to the next tab to label frames" in green. .. image:: https://imgur.com/F9y1aPv.png + + +.. _upload_zipped_frames: + +Upload zipped files of frames +============================= + +Select the appropiate option from the list. + +.. image:: https://imgur.com/64ZjGu3.png + +Drag and drop zipped files(s) of frames using the provided interface. +You will see an upload progress bar. + +.. warning:: + + At the moment this feature of the app requires a strict file structure! + +As an example, let's say you have a video named ``subject023_session0.mp4`` and you have extracted +frames 143, 1156, and 34567, which you want to label. + +You will need to create a single zip file named ``subject023_session0.zip``. +The zip file must contain png files, and they must follow the naming convention ``img%08.png``, +for example ``img00000143.png`` +(such that there are 8 digits for the frame number, with leading zeros). + +If you would like to fit context models, you must also include context frames for each labeled +frame. Again using frame 143 as an example, you must include five files: + +* img00000141.png +* img00000142.png +* img00000143.png +* img00000144.png +* img00000145.png + +Including context frames is recommended, though not required. + +Finally, you must include a csv file named ``selected_frames.csv`` that is simply a list of the +file names of the frames you wish to *label* (not the context frames), +so that LabelStudio knows which frames to upload into its database. +For the example above, the csv file should look like: + +.. code-block:: + + img00000143.png + img00001156.png + img00034567.png + +Therefore, the final set of files that must be zipped into ``subject023_session0.zip`` for this +example is: + +* img00000141.png +* img00000142.png +* img00000143.png +* img00000144.png +* img00000145.png +* img00001154.png +* img00001155.png +* img00001156.png +* img00001157.png +* img00001158.png +* img00034565.png +* img00034566.png +* img00034567.png +* img00034568.png +* img00034569.png +* selected_frames.csv + +If you would like to upload frames for multiple videos, make one zip file per video. + +Click "Extract frames" once the zip file upload is complete. + +Once all frames have been extracted you will see "Proceed to the next tab to label frames" in green. + +.. image:: https://imgur.com/F9y1aPv.png diff --git a/docs/source/tabs/fiftyone.rst b/docs/source/tabs/fiftyone.rst index d8b962a..ee28843 100644 --- a/docs/source/tabs/fiftyone.rst +++ b/docs/source/tabs/fiftyone.rst @@ -5,14 +5,14 @@ FiftyOne ######## `FiftyOne `_ is a powerful diagnostic tool that allows you to plot -predictions from several models directly on images, with corresponding ground truth labels. +predictions from each model you train directly on images, with corresponding ground truth labels. The flexible user interface allows you to filter labels by keypoint, confidence range, etc., and is useful for qualitative inspection of model performance. .. image:: https://imgur.com/Vp9MEar.png :width: 600 -Select the dataset you just created from the drop-down menu in the top left corner. +Select the model name from the drop-down menu in the top left corner. .. image:: https://imgur.com/WB3jkzF.png :width: 400 diff --git a/docs/source/tabs/prepare_fiftyone.rst b/docs/source/tabs/prepare_fiftyone.rst deleted file mode 100644 index 5fdbb15..0000000 --- a/docs/source/tabs/prepare_fiftyone.rst +++ /dev/null @@ -1,26 +0,0 @@ -.. _tab_prepare_fiftyone: - -################ -Prepare FiftyOne -################ - -This tab allows you to select two models to package into a FiftyOne dataset, -which you will then be able to view in the following tab. - -First, select two models; FiftyOne will compare their predictions to ground-truth labels. -Give each model a display name. - -.. image:: https://imgur.com/kcPwpw0.png - :width: 600 - -Next, choose a unique dataset name for this combination of models; -existing names are shown above the text field (in this example, "data-1" and "data-2"). - -.. image:: https://imgur.com/sYZ0UCb.png - :width: 600 - -Click the "Prepare Fiftyone dataset" button or hit "Enter". -The green success message will indicate when it is time to proceed. - -.. image:: https://imgur.com/tFuAVt6.png - :width: 600 diff --git a/docs/source/using_the_app.rst b/docs/source/using_the_app.rst index 89a80f9..c6613a1 100644 --- a/docs/source/using_the_app.rst +++ b/docs/source/using_the_app.rst @@ -74,7 +74,6 @@ remember that ``demo_app.py`` and ``labeleing_app.py`` only utilize a subset of tabs/train_status tabs/labeled_diagnostics tabs/video_diagnostics - tabs/prepare_fiftyone tabs/fiftyone **Close the app** diff --git a/labeling_app.py b/labeling_app.py index 2fdffdc..952fe23 100644 --- a/labeling_app.py +++ b/labeling_app.py @@ -5,20 +5,17 @@ """ -from lightning.app import CloudCompute, LightningApp, LightningFlow -from lightning.app.structures import Dict +from lightning.app import LightningApp, LightningFlow import logging import os import sys import time import yaml -from lightning_pose_app import LABELSTUDIO_DB_DIR +from lightning_pose_app import LABELSTUDIO_DB_DIR, LIGHTNING_POSE_DIR from lightning_pose_app.label_studio.component import LitLabelStudio from lightning_pose_app.ui.extract_frames import ExtractFramesUI from lightning_pose_app.ui.project import ProjectUI -from lightning_pose_app.build_configs import lightning_pose_dir - logging.basicConfig(stream=sys.stdout, level=logging.INFO) _logger = logging.getLogger('APP') @@ -33,10 +30,10 @@ def __init__(self): # ----------------------------- # paths # ----------------------------- - self.data_dir = "/data" # relative to FileSystem root + self.data_dir = "/data" # relative to Pose-app root # load default config and pass to project manager - config_dir = os.path.join(lightning_pose_dir, "scripts", "configs") + config_dir = os.path.join(LIGHTNING_POSE_DIR, "scripts", "configs") default_config_dict = yaml.safe_load(open(os.path.join(config_dir, "config_default.yaml"))) # ----------------------------- @@ -72,7 +69,6 @@ def run(self): # ------------------------------------------------------------- # init label studio; this will only happen once # ------------------------------------------------------------- - self.label_studio.run(action="import_database") self.label_studio.run(action="start_label_studio") # ------------------------------------------------------------- @@ -127,9 +123,9 @@ def run(self): self.project_ui.run_script = False # ------------------------------------------------------------- - # extract frames for labeling from uploaded videos + # extract frames for labeling # ------------------------------------------------------------- - if self.extract_ui.proj_dir and self.extract_ui.run_script: + if self.extract_ui.proj_dir and self.extract_ui.run_script_video_random: self.extract_ui.run( action="extract_frames", video_files=self.extract_ui.st_video_files, # add arg for run caching purposes @@ -137,9 +133,23 @@ def run(self): # wait until frame extraction is complete, then update label studio tasks if self.extract_ui.work_is_done_extract_frames: self.project_ui.run(action="update_frame_shapes") - self.extract_ui.run_script = False # hack, app won't advance past ls run + # hack; for some reason the app won't advance past the ls run + self.extract_ui.run_script_video_random = False self.label_studio.run(action="update_tasks", videos=self.extract_ui.st_video_files) - self.extract_ui.run_script = False + self.extract_ui.run_script_video_random = False + + if self.extract_ui.proj_dir and self.extract_ui.run_script_zipped_frames: + self.extract_ui.run( + action="unzip_frames", + video_files=self.extract_ui.st_frame_files, # add arg for run caching purposes + ) + # wait until frame extraction is complete, then update label studio tasks + if self.extract_ui.work_is_done_extract_frames: + self.project_ui.run(action="update_frame_shapes") + # hack; for some reason the app won't advance past the ls run + self.extract_ui.run_script_zipped_frames = False + self.label_studio.run(action="update_tasks", videos=self.extract_ui.st_frame_files) + self.extract_ui.run_script_zipped_frames = False # ------------------------------------------------------------- # periodically check labeling task and export new labels diff --git a/lightning-pose b/lightning-pose index 4c730d3..163855a 160000 --- a/lightning-pose +++ b/lightning-pose @@ -1 +1 @@ -Subproject commit 4c730d34e4bb3fe6700ee6ce909190c119785bbf +Subproject commit 163855a8ebe57881eb6c32d58ccd89b9c7a5481f diff --git a/lightning_pose_app/__init__.py b/lightning_pose_app/__init__.py index d014d1c..aaa70e0 100644 --- a/lightning_pose_app/__init__.py +++ b/lightning_pose_app/__init__.py @@ -1,4 +1,7 @@ -# directory name constants; relative to FileSystem root +# dir where lightning pose package lives, relative to Pose-app root +LIGHTNING_POSE_DIR = "lightning-pose" + +# directory name constants; relative to Pose-app/data LABELSTUDIO_DB_DIR = "labelstudio_db" # directory name constants; relative to project_dir @@ -6,6 +9,7 @@ VIDEOS_DIR = "videos" VIDEOS_TMP_DIR = "videos_tmp" VIDEOS_INFER_DIR = "videos_infer" +ZIPPED_TMP_DIR = "frames_tmp" MODELS_DIR = "models" MODEL_VIDEO_PREDS_TRAIN_DIR = "video_preds" MODEL_VIDEO_PREDS_INFER_DIR = "video_preds_infer" diff --git a/lightning_pose_app/bashwork.py b/lightning_pose_app/bashwork.py index 270c033..5a1c7c3 100644 --- a/lightning_pose_app/bashwork.py +++ b/lightning_pose_app/bashwork.py @@ -12,7 +12,7 @@ import threading import time -from lightning_pose_app.utilities import args_to_dict, WorkWithFileSystem +from lightning_pose_app.utilities import args_to_dict _logger = logging.getLogger('APP.BASHWORK') @@ -25,7 +25,7 @@ def add_to_system_env(env_key='env', **kwargs) -> dict: env = kwargs[env_key] if isinstance(env, str): env = args_to_dict(env) - if not(env is None) and not(env == {}): + if not (env is None) and not (env == {}): new_env = os.environ.copy() new_env.update(env) return new_env @@ -46,9 +46,6 @@ def is_port_in_use(host: str, port: int) -> bool: s.close() return in_use - - -def work_calls_len(lwork: LightningWork): """get the number of call in state dict. state dict has current and past calls to work.""" # reduce by 1 to remove latest_call_hash entry return len(lwork.state["calls"]) - 1 @@ -58,9 +55,9 @@ def work_is_free(lwork: LightningWork): """work is free to accept new calls. this is expensive when a lot of calls accumulate over time work is when there is there is no pending and running calls at the moment - pending status is verified by examining each call history looking for anything call that is pending history - status.stage is not reliable indicator as there is delay registering new calls - status.stage shows SUCCEEDED even after 3 more calls are accepted in parallel mode + pending status is verified by examining each call history looking for anything call that + is pending history status.stage is not reliable indicator as there is delay registering + new calls status.stage shows SUCCEEDED even after 3 more calls are accepted in parallel mode """ status = lwork.status state = lwork.state @@ -69,9 +66,7 @@ def work_is_free(lwork: LightningWork): # multiple works are queued but # count run that are in pending state if ( - status.stage == "not_started" or - status.stage == "succeeded" or - status.stage == "failed" + status.stage == "not_started" or status.stage == "succeeded" or status.stage == "failed" ): # do not run if jobs are in pending state # not counting to reduce CPU load as looping thru all of the calls can get expensive @@ -86,12 +81,11 @@ def work_is_free(lwork: LightningWork): return False -class LitBashWork(WorkWithFileSystem): +class LitBashWork(LightningWork): def __init__( self, *args, - name="bashwork", wait_seconds_after_run=10, wait_seconds_after_kill=10, **kwargs @@ -100,15 +94,13 @@ def __init__( # required to to grab self.host and self.port in the cloud. # otherwise, the values flips from 127.0.0.1 to 0.0.0.0 causing two runs # host='0.0.0.0', - super().__init__(*args, name=name, **kwargs) + super().__init__(*args, **kwargs) self.wait_seconds_after_run = wait_seconds_after_run self.wait_seconds_after_kill = wait_seconds_after_kill self.pid = None self.exit_code = None self.stdout = None - self.inputs = None - self.outputs = None self.args = "" self._wait_proc = None @@ -137,9 +129,6 @@ def on_after_run(self): def work_is_free(self) -> bool: return work_is_free(self) - def work_calls_len(self) -> int: - return work_calls_len(self) - def popen_wait(self, cmd, save_stdout, exception_on_error, **kwargs): with subprocess.Popen( cmd, @@ -210,11 +199,11 @@ def subprocess_call( # should either wait, process is already done, or kill thread.join() # _logger.debug(self._wait_proc.returncode) - _logger.debug("wait completed", cmd) + _logger.debug(f"wait completed {cmd}") else: _logger.debug("no wait popen") self.popen_nowait(cmd, **kwargs) - _logger.debug("no wait completed", cmd) + _logger.debug(f"no wait completed {cmd}") def run( self, @@ -224,8 +213,6 @@ def run( wait_for_exit=True, input_output_only=False, kill_pid=False, - inputs=[], - outputs=[], run_after_run=[], timeout=0, timer=0, # added for uniqueness and caching @@ -236,7 +223,6 @@ def run( # pre processing self.on_before_run() - self.get_from_drive(inputs) self.args = args self.stdout = None @@ -247,7 +233,7 @@ def run( if self.pid and kill_pid: _logger.debug(f"***killing {self.pid}") os.kill(self.pid, signal.SIGTERM) - info = os.waitpid(self.pid, 0) + # info = os.waitpid(self.pid, 0) while is_port_in_use(self.host, self.port): _logger.debug(f"***killed. pid {self.pid} waiting to free port") time.sleep(self.wait_seconds_after_kill) @@ -261,8 +247,6 @@ def run( for cmd in run_after_run: self.popen_wait(cmd, save_stdout=True, exception_on_error=False, **kwargs) - # post processing - self.put_to_drive(outputs) # give time for REDIS to catch up and propagate self.stdout back to flow if save_stdout: _logger.debug(f"waiting work to flow message sleeping {self.wait_seconds_after_run}") diff --git a/lightning_pose_app/build_configs.py b/lightning_pose_app/build_configs.py index 2d112b5..df3c291 100644 --- a/lightning_pose_app/build_configs.py +++ b/lightning_pose_app/build_configs.py @@ -1,11 +1,7 @@ from lightning.app import BuildConfig from typing import List -# dir where lightning pose package lives -lightning_pose_dir = "lightning-pose" - -# dir where label studio python venv will be set up -label_studio_venv = None +from lightning_pose_app import LIGHTNING_POSE_DIR class LitPoseBuildConfig(BuildConfig): @@ -15,23 +11,5 @@ def build_commands() -> List[str]: return [ "sudo apt-get update", "sudo apt-get install -y ffmpeg libsm6 libxext6", - f"pip install -e {lightning_pose_dir}", - ] - - -class LabelStudioBuildConfig(BuildConfig): - - @staticmethod - def build_commands() -> List[str]: - return [ - "pip install -e .", # install lightning app to have access to packages - ] - - -class TensorboardBuildConfig(BuildConfig): - - @staticmethod - def build_commands() -> List[str]: - return [ - "pip install tensorboard", + f"pip install -e {LIGHTNING_POSE_DIR}", ] diff --git a/lightning_pose_app/label_studio/check_labeling_task_and_export.py b/lightning_pose_app/label_studio/check_labeling_task_and_export.py index 046598f..dd34d65 100644 --- a/lightning_pose_app/label_studio/check_labeling_task_and_export.py +++ b/lightning_pose_app/label_studio/check_labeling_task_and_export.py @@ -49,7 +49,10 @@ if len(exported_tasks) > 0: # save to pickle for resuming projects _logger.debug("Saving tasks to pickle file") - pickle.dump(exported_tasks, open(os.path.join(args.proj_dir, LABELSTUDIO_TASKS_FILENAME), "wb")) + pickle.dump( + exported_tasks, + open(os.path.join(args.proj_dir, LABELSTUDIO_TASKS_FILENAME), "wb"), + ) # save to csv for lightning pose models _logger.debug("Saving annotations to csv file") processor = LabelStudioJSONProcessor( diff --git a/lightning_pose_app/label_studio/component.py b/lightning_pose_app/label_studio/component.py index ab2a814..1a5a1ed 100644 --- a/lightning_pose_app/label_studio/component.py +++ b/lightning_pose_app/label_studio/component.py @@ -10,12 +10,14 @@ COLLECTED_DATA_FILENAME, ) from lightning_pose_app.bashwork import LitBashWork -from lightning_pose_app.build_configs import LabelStudioBuildConfig, label_studio_venv +from lightning_pose_app.utilities import abspath _logger = logging.getLogger('APP.LABELSTUDIO') log_level = "ERROR" # log level sent to label studio sdk +label_studio_venv = None + class LitLabelStudio(LightningFlow): @@ -24,12 +26,9 @@ def __init__(self, *args, database_dir="/data", proj_dir=None, **kwargs) -> None super().__init__(*args, **kwargs) self.label_studio = LitBashWork( - name="labelstudio", cloud_compute=CloudCompute("default"), - cloud_build_config=LabelStudioBuildConfig(), ) self.counts = { - "import_database": 0, "start_label_studio": 0, "create_new_project": 0, "import_existing_annotations": 0, @@ -59,32 +58,6 @@ def __init__(self, *args, database_dir="/data", proj_dir=None, **kwargs) -> None self.proj_name = None self.keypoints = None - @staticmethod - def abspath(path): - if path[0] == "/": - path_ = path[1:] - else: - path_ = path - return os.path.abspath(path_) - - def _import_database(self): - # pull database from FileSystem if it exists - # NOTE: db must be imported _after_ LabelStudio is started, otherwise some nginx error - if self.counts["import_database"] > 0: - return - - self.label_studio.run( - "null command", - venv_name=label_studio_venv, - cwd=os.getcwd(), - input_output_only=True, - inputs=[self.database_dir], - wait_for_exit=True, - env={"LOG_LEVEL": log_level}, - ) - - self.counts["import_database"] += 1 - def _start_label_studio(self): if self.counts["start_label_studio"] > 0: @@ -97,7 +70,7 @@ def _start_label_studio(self): # start label-studio self.label_studio.run( - f"label-studio start --no-browser --internal-host $host --port $port", + "label-studio start --no-browser --internal-host $host --port $port", venv_name=label_studio_venv, wait_for_exit=False, env={ @@ -108,7 +81,7 @@ def _start_label_studio(self): "LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED": "true", "LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT": os.path.abspath(os.getcwd()), "LABEL_STUDIO_DISABLE_SIGNUP_WITHOUT_LINK": "true", - "LABEL_STUDIO_BASE_DATA_DIR": self.abspath(self.database_dir), + "LABEL_STUDIO_BASE_DATA_DIR": abspath(self.database_dir), "LABEL_STUDIO_SESSION_COOKIE_SAMESITE": "Lax", "LABEL_STUDIO_CSRF_COOKIE_SAMESITE": "Lax", "LABEL_STUDIO_SESSION_COOKIE_SECURE": "1", @@ -148,10 +121,10 @@ def _create_new_project(self): # build script command script_path = os.path.join( os.getcwd(), "lightning_pose_app", "label_studio", "create_new_project.py") - label_studio_config_file = self.abspath(self.filenames["label_studio_config"]) + label_studio_config_file = abspath(self.filenames["label_studio_config"]) build_command = f"python {script_path} " \ f"--label_studio_url {self.label_studio_url} " \ - f"--proj_dir {self.abspath(self.proj_dir)} " \ + f"--proj_dir {abspath(self.proj_dir)} " \ f"--api_key {self.user_token} " \ f"--project_name {self.proj_name} " \ f"--label_config {label_studio_config_file} " @@ -165,13 +138,6 @@ def _create_new_project(self): venv_name=label_studio_venv, wait_for_exit=True, env={"LOG_LEVEL": log_level}, - inputs=[ - self.filenames["label_studio_config"], - self.filenames["labeled_data_dir"], - ], - outputs=[ - self.filenames["label_studio_metadata"], - ], ) def _update_tasks(self, videos=[]): @@ -182,7 +148,7 @@ def _update_tasks(self, videos=[]): os.getcwd(), "lightning_pose_app", "label_studio", "update_tasks.py") build_command = f"python {script_path} " \ f"--label_studio_url {self.label_studio_url} " \ - f"--proj_dir {self.abspath(self.proj_dir)} " \ + f"--proj_dir {abspath(self.proj_dir)} " \ f"--api_key {self.user_token} " # run command to update label studio tasks @@ -192,15 +158,10 @@ def _update_tasks(self, videos=[]): wait_for_exit=True, env={"LOG_LEVEL": log_level}, timer=videos, - inputs=[ - self.filenames["labeled_data_dir"], - self.filenames["label_studio_metadata"], - ], - outputs=[], ) def _check_labeling_task_and_export(self, timer): - """Check for new labels, export to lightning pose format, export database to FileSystem.""" + """Check for new labels, export to lightning pose format.""" script_path = os.path.join( os.getcwd(), "lightning_pose_app", "label_studio", "check_labeling_task_and_export.py") @@ -212,7 +173,7 @@ def _check_labeling_task_and_export(self, timer): keypoints_list = "/".join(self.keypoints) run_command = f"python {script_path} " \ f"--label_studio_url {self.label_studio_url} " \ - f"--proj_dir {self.abspath(self.proj_dir)} " \ + f"--proj_dir {abspath(self.proj_dir)} " \ f"--api_key {self.user_token} " \ f"--keypoints_list '{keypoints_list}' " @@ -223,16 +184,6 @@ def _check_labeling_task_and_export(self, timer): wait_for_exit=True, env={"LOG_LEVEL": log_level}, timer=timer, - inputs=[ - self.filenames["labeled_data_dir"], - self.filenames["label_studio_metadata"], - ], - outputs=[ - self.filenames["collected_data"], - self.filenames["label_studio_tasks"], - self.filenames["label_studio_metadata"], - self.database_dir, # sqlite database - ], ) self.check_labels = True @@ -247,7 +198,7 @@ def _create_labeling_config_xml(self, keypoints): script_path = os.path.join( os.getcwd(), "lightning_pose_app", "label_studio", "create_labeling_config.py") build_command = f"python {script_path} " \ - f"--proj_dir {self.abspath(self.proj_dir)} " \ + f"--proj_dir {abspath(self.proj_dir)} " \ f"--filename {os.path.basename(self.filenames['label_studio_config'])} " \ f"--keypoints_list {keypoints_list} " @@ -260,8 +211,6 @@ def _create_labeling_config_xml(self, keypoints): wait_for_exit=True, env={"LOG_LEVEL": log_level}, timer=keypoints, - inputs=[], - outputs=[self.filenames["label_studio_config"]], ) def _import_existing_annotations(self, **kwargs): @@ -275,9 +224,9 @@ def _import_existing_annotations(self, **kwargs): os.getcwd(), "lightning_pose_app", "label_studio", "update_tasks.py") build_command = f"python {script_path} " \ f"--label_studio_url {self.label_studio_url} " \ - f"--proj_dir {self.abspath(self.proj_dir)} " \ + f"--proj_dir {abspath(self.proj_dir)} " \ f"--api_key {self.user_token} " \ - f"--config_file {self.abspath(self.filenames['config_file'])} " \ + f"--config_file {abspath(self.filenames['config_file'])} " \ f"--update_from_csv " self.label_studio.run( @@ -285,22 +234,13 @@ def _import_existing_annotations(self, **kwargs): venv_name=label_studio_venv, wait_for_exit=True, env={"LOG_LEVEL": log_level}, - inputs=[ - self.filenames["labeled_data_dir"], - self.filenames["label_studio_metadata"], - self.filenames["collected_data"], - self.filenames["config_file"], - ], - outputs=[] ) self.counts["import_existing_annotations"] += 1 def run(self, action=None, **kwargs): - if action == "import_database": - self._import_database() - elif action == "start_label_studio": + if action == "start_label_studio": self._start_label_studio() elif action == "create_labeling_config_xml": self._create_labeling_config_xml(**kwargs) @@ -316,6 +256,6 @@ def run(self, action=None, **kwargs): self._import_existing_annotations(**kwargs) def on_exit(self): - # final save to drive + # final save _logger.info("SAVING DATA ONE LAST TIME") self._check_labeling_task_and_export(timer=0.0) diff --git a/lightning_pose_app/label_studio/update_tasks.py b/lightning_pose_app/label_studio/update_tasks.py index 9f47826..02cbcad 100644 --- a/lightning_pose_app/label_studio/update_tasks.py +++ b/lightning_pose_app/label_studio/update_tasks.py @@ -103,6 +103,8 @@ def get_annotation( ls_img_path = os.path.join(label_studio_prefix, rel_img) if ls_img_path not in existing_imgs: image_list.append({"img": ls_img_path}) + # update existing tasks + existing_tasks.append(ls_img_path) label_studio_project.import_tasks(image_list) _logger.debug("%i Tasks imported." % len(image_list)) diff --git a/lightning_pose_app/label_studio/utils.py b/lightning_pose_app/label_studio/utils.py index deee4ab..56ef8b5 100644 --- a/lightning_pose_app/label_studio/utils.py +++ b/lightning_pose_app/label_studio/utils.py @@ -25,7 +25,7 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except: - _logging.debug("Could not execute {}, retrying in one second...".format(func.__name__)) + _logger.debug(f"Could not execute {func.__name__}, retrying in one second...") attempts += 1 time.sleep(1) if attempts > MAX_CONNECT_ATTEMPTS: @@ -128,8 +128,7 @@ def get_pixel_coordinates_per_image(result: Dict[str, Any]) -> Tuple[float, floa width, height = result['original_width'], result['original_height'] if all([key in value for key in ['x', 'y']]): # if both x and y are in the dict - return width * value['x'] / 100.0, \ - height * value['y'] / 100.0 + return width * value['x'] / 100.0, height * value['y'] / 100.0 def __call__(self) -> pd.DataFrame: """Build a dataframe with the keypoint names as columns and the image paths as the index""" @@ -156,6 +155,10 @@ def __call__(self) -> pd.DataFrame: def get_rel_image_paths_from_idx_files(basedir: str) -> List[str]: img_list = [] for root, dirs, files in os.walk(basedir): + if LABELED_DATA_DIR not in root: + # make sure we only look in the labeled data directory + # if we do not do this we risk uploading info from temp dirs too + continue for file in files: if file == SELECTED_FRAMES_FILENAME: abspath = os.path.join(root, file) diff --git a/lightning_pose_app/ui/extract_frames.py b/lightning_pose_app/ui/extract_frames.py index b5a1c7d..ecfd362 100644 --- a/lightning_pose_app/ui/extract_frames.py +++ b/lightning_pose_app/ui/extract_frames.py @@ -1,37 +1,41 @@ import cv2 -from lightning.app import CloudCompute, LightningFlow -from lightning.app.storage import FileSystem +from lightning.app import CloudCompute, LightningFlow, LightningWork from lightning.app.structures import Dict -from lightning.app.utilities.cloud import is_running_in_cloud from lightning.app.utilities.state import AppState import logging import numpy as np import os +import shutil from sklearn.decomposition import PCA from sklearn.cluster import KMeans import streamlit as st from streamlit_autorefresh import st_autorefresh +import zipfile -from lightning_pose_app import LABELED_DATA_DIR, VIDEOS_DIR, VIDEOS_TMP_DIR +from lightning_pose_app import LABELED_DATA_DIR, VIDEOS_DIR, VIDEOS_TMP_DIR, ZIPPED_TMP_DIR from lightning_pose_app import SELECTED_FRAMES_FILENAME -from lightning_pose_app.utilities import StreamlitFrontend, WorkWithFileSystem -from lightning_pose_app.utilities import reencode_video, check_codec_format, get_frames_from_idxs - +from lightning_pose_app.utilities import StreamlitFrontend, abspath +from lightning_pose_app.utilities import copy_and_reformat_video, get_frames_from_idxs _logger = logging.getLogger('APP.EXTRACT_FRAMES') -class ExtractFramesWork(WorkWithFileSystem): +class ExtractFramesWork(LightningWork): def __init__(self, *args, **kwargs): - super().__init__(*args, name="extract", **kwargs) + super().__init__(*args, **kwargs) self.progress = 0.0 self.progress_delta = 0.5 self.work_is_done_extract_frames = False - def _read_nth_frames(self, video_file, n=1, resize_dims=64): + def _read_nth_frames( + self, + video_file: str, + n: int = 1, + resize_dims: int = 64, + ) -> np.ndarray: from tqdm import tqdm @@ -74,12 +78,12 @@ def _read_nth_frames(self, video_file, n=1, resize_dims=64): def _select_frame_idxs( self, - video_file, - resize_dims=64, - n_clusters=20, - frame_skip=1, - frame_range=[0, 1], - ): + video_file: str, + resize_dims: int = 64, + n_clusters: int = 20, + frame_skip: int = 1, + frame_range: list = [0, 1], + ) -> np.ndarray: # check inputs if frame_skip != 1: @@ -139,7 +143,7 @@ def _export_frames( format: str = "png", n_digits: int = 8, context_frames: int = 0, - ): + ) -> None: """ Parameters @@ -158,7 +162,7 @@ def _export_frames( # expand frame_idxs to include context frames if context_frames > 0: context_vec = np.arange(-context_frames, context_frames + 1) - frame_idxs = (frame_idxs.squeeze()[None, :] + context_vec[:, None]).flatten() + frame_idxs = (frame_idxs[None, :] + context_vec[:, None]).flatten() frame_idxs.sort() frame_idxs = frame_idxs[frame_idxs >= 0] frame_idxs = frame_idxs[frame_idxs < int(cap.get(cv2.CAP_PROP_FRAME_COUNT))] @@ -175,24 +179,33 @@ def _export_frames( img=frame[0], ) - def _extract_frames(self, video_file, proj_dir, n_frames_per_video, frame_range=[0, 1]): + def _extract_frames( + self, + video_file: str, + proj_dir: str, + n_frames_per_video: int, + frame_range: list = [0, 1], + ) -> None: _logger.info(f"============== extracting frames from {video_file} ================") # set flag for parent app self.work_is_done_extract_frames = False - # pull video from FileSystem - self.get_from_drive([video_file]) - data_dir_rel = os.path.join(proj_dir, LABELED_DATA_DIR) - data_dir = self.abspath(data_dir_rel) + if not os.path.exists(data_dir_rel): + data_dir = abspath(data_dir_rel) + else: + data_dir = data_dir_rel n_digits = 8 extension = "png" context_frames = 2 # check: does file exist? - video_file_abs = self.abspath(video_file) + if not os.path.exists(video_file): + video_file_abs = abspath(video_file) + else: + video_file_abs = video_file video_file_exists = os.path.exists(video_file_abs) _logger.info(f"video file exists? {video_file_exists}") if not video_file_exists: @@ -235,67 +248,87 @@ def _extract_frames(self, video_file, proj_dir, n_frames_per_video, frame_range= video_file=video_file_abs, save_dir=save_dir, frame_idxs=idxs_selected, format=extension, n_digits=n_digits, context_frames=context_frames) - # push extracted frames to drive - self.put_to_drive([data_dir_rel]) - # set flag for parent app self.work_is_done_extract_frames = True - def _reformat_video(self, video_file, **kwargs): + def _unzip_frames( + self, + video_file: str, + proj_dir: str, + ) -> None: - # get new names (ensure mp4 file extension, no tmp directory) - ext = os.path.splitext(os.path.basename(video_file))[1] - video_file_mp4_ext = video_file.replace(f"{ext}", ".mp4") - video_file_new = video_file_mp4_ext.replace(VIDEOS_TMP_DIR, VIDEOS_DIR) - video_file_abs_new = self.abspath(video_file_new) + _logger.info(f"============== unzipping frames from {video_file} ================") - # check 0: do we even need to reformat? - if self._drive.isfile(video_file_new): - return video_file_new + # set flag for parent app + self.work_is_done_extract_frames = False - # pull videos from FileSystem - self.get_from_drive([video_file]) - video_file_abs = self.abspath(video_file) + data_dir_rel = os.path.join(proj_dir, LABELED_DATA_DIR) + if not os.path.exists(data_dir_rel): + data_dir = abspath(data_dir_rel) + else: + data_dir = data_dir_rel + # TODO + # n_digits = 8 + # extension = "png" - # check 1: does file exist? + # check: does file exist? + if not os.path.exists(video_file): + video_file_abs = abspath(video_file) + else: + video_file_abs = video_file video_file_exists = os.path.exists(video_file_abs) + _logger.info(f"zipped file exists? {video_file_exists}") if not video_file_exists: - _logger.info(f"{video_file_abs} does not exist! skipping") - return None - - # check 2: is file in the correct format for DALI? - video_file_correct_codec = check_codec_format(video_file_abs) - - # reencode/rename - if not video_file_correct_codec: - _logger.info("re-encoding video to be compatable with Lightning Pose video reader") - reencode_video(video_file_abs, video_file_abs_new) - # remove old video from local files - os.remove(video_file_abs) - else: - # make dir to write into - os.makedirs(os.path.dirname(video_file_abs_new), exist_ok=True) - # rename - os.rename(video_file_abs, video_file_abs_new) + _logger.info("skipping frame extraction") + return - # remove old video(s) from FileSystem - if self._drive.isfile(video_file): - self._drive.rm(video_file) - if self._drive.isfile(video_file_mp4_ext): - self._drive.rm(video_file_mp4_ext) + # create folder to save images + video_name = os.path.splitext(os.path.basename(video_file))[0] + save_dir = os.path.join(data_dir, video_name) + os.makedirs(save_dir, exist_ok=True) - # push possibly reformated, renamed videos to FileSystem - self.put_to_drive([video_file_new]) + # unzip file in tmp directory + with zipfile.ZipFile(video_file_abs) as z: + unzipped_dir = video_file_abs.replace(".zip", "") + z.extractall(path=unzipped_dir) + + # save all contents to data directory + # don't use copytree as the destination dir may already exist + files = os.listdir(unzipped_dir) + for file in files: + src = os.path.join(unzipped_dir, file) + dst = os.path.join(save_dir, file) + shutil.copyfile(src, dst) + + # TODO: + # - if SELECTED_FRAMES_FILENAME does not exist, assume all frames are for labeling and + # make this file + + # # save csv file inside same output directory + # frames_to_label = np.array([ + # "img%s.%s" % (str(idx).zfill(n_digits), extension) for idx in idxs_selected]) + # np.savetxt( + # os.path.join(save_dir, SELECTED_FRAMES_FILENAME), + # np.sort(frames_to_label), + # delimiter=",", + # fmt="%s" + # ) - return video_file_new + # set flag for parent app + self.work_is_done_extract_frames = True def run(self, action, **kwargs): - if action == "reformat_video": - self._reformat_video(**kwargs) - elif action == "extract_frames": - new_vid_file = self._reformat_video(**kwargs) - kwargs["video_file"] = new_vid_file + if action == "extract_frames": + new_vid_file = copy_and_reformat_video( + video_file=abspath(kwargs["video_file"]), + dst_dir=abspath(os.path.join(kwargs["proj_dir"], VIDEOS_DIR)), + ) + # save relative rather than absolute path + kwargs["video_file"] = '/'.join(new_vid_file.split('/')[-4:]) self._extract_frames(**kwargs) + elif action == "unzip_frames": + # TODO: maybe we need to reformat the file names? + self._unzip_frames(**kwargs) else: pass @@ -307,9 +340,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # shared storage system - self._drive = FileSystem() - # updated externally by parent app self.proj_dir = None @@ -318,11 +348,13 @@ def __init__(self, *args, **kwargs): self.work_is_done_extract_frames = False # flag; used internally and externally - self.run_script = False + self.run_script_video_random = False + self.run_script_zipped_frames = False # output from the UI self.st_extract_status = {} # 'initialized' | 'active' | 'complete' - self.st_video_files_ = [] + self.st_video_files_ = [] # list of uploaded video files + self.st_frame_files_ = [] # list of uploaded zipped frame files self.st_submits = 0 self.st_frame_range = [0, 1] # limits for frame selection self.st_n_frames_per_video = None @@ -331,22 +363,11 @@ def __init__(self, *args, **kwargs): def st_video_files(self): return np.unique(self.st_video_files_).tolist() - def _push_video(self, video_file): - if video_file[0] == "/": - src = os.path.join(os.getcwd(), video_file[1:]) - dst = video_file - else: - src = os.path.join(os.getcwd(), video_file) - dst = "/" + video_file - if not self._drive.isfile(dst) and os.path.exists(src): - # only put to FileSystem under two conditions: - # 1. file exists locally; if it doesn't, maybe it has already been deleted for a reason - # 2. file does not already exist on FileSystem; avoids excessive file transfers - _logger.debug(f"UI try put {dst}") - self._drive.put(src, dst) - _logger.debug(f"UI success put {dst}") - - def _extract_frames(self, video_files=None, n_frames_per_video=None): + @property + def st_frame_files(self): + return np.unique(self.st_frame_files_).tolist() + + def _extract_frames(self, video_files=None, n_frames_per_video=None, testing=False): self.work_is_done_extract_frames = False @@ -355,21 +376,16 @@ def _extract_frames(self, video_files=None, n_frames_per_video=None): if not n_frames_per_video: n_frames_per_video = self.st_n_frames_per_video - # launch works: - # - sequential if local - # - parallel if on cloud + # launch works (sequentially for now) for video_file in video_files: video_key = video_file.replace(".", "_") # keys cannot contain "." if video_key not in self.works_dict.keys(): self.works_dict[video_key] = ExtractFramesWork( cloud_compute=CloudCompute("default"), - parallel=is_running_in_cloud(), ) status = self.st_extract_status[video_file] if status == "initialized" or status == "active": self.st_extract_status[video_file] = "active" - # move video from ui machine to shared FileSystem - self._push_video(video_file=video_file) # extract frames for labeling (automatically reformats video for DALI) self.works_dict[video_key].run( action="extract_frames", @@ -380,6 +396,45 @@ def _extract_frames(self, video_files=None, n_frames_per_video=None): ) self.st_extract_status[video_file] = "complete" + # clean up works + while len(self.works_dict) > 0: + for video_key in list(self.works_dict): + if (video_key in self.works_dict.keys()) \ + and self.works_dict[video_key].work_is_done_extract_frames: + # kill work + _logger.info(f"killing work from video {video_key}") + if not testing: # cannot run stop() from tests for some reason + self.works_dict[video_key].stop() + del self.works_dict[video_key] + + # set flag for parent app + self.work_is_done_extract_frames = True + + def _unzip_frames(self, video_files=None): + + self.work_is_done_extract_frames = False + + if not video_files: + video_files = self.st_frame_files + + # launch works + for video_file in video_files: + video_key = video_file.replace(".", "_") # keys cannot contain "." + if video_key not in self.works_dict.keys(): + self.works_dict[video_key] = ExtractFramesWork( + cloud_compute=CloudCompute("default"), + ) + status = self.st_extract_status[video_file] + if status == "initialized" or status == "active": + self.st_extract_status[video_file] = "active" + # extract frames for labeling (automatically reformats video for DALI) + self.works_dict[video_key].run( + action="unzip_frames", + video_file="/" + video_file, + proj_dir=self.proj_dir, + ) + self.st_extract_status[video_file] = "complete" + # clean up works while len(self.works_dict) > 0: for video_key in list(self.works_dict): @@ -394,99 +449,212 @@ def _extract_frames(self, video_files=None, n_frames_per_video=None): self.work_is_done_extract_frames = True def run(self, action, **kwargs): - if action == "push_video": - self._push_video(**kwargs) - elif action == "extract_frames": + if action == "extract_frames": self._extract_frames(**kwargs) + elif action == "unzip_frames": + self._unzip_frames(**kwargs) def configure_layout(self): return StreamlitFrontend(render_fn=_render_streamlit_fn) def _render_streamlit_fn(state: AppState): - st.markdown( """ ## Extract frames for labeling """ ) - if state.run_script: + if state.run_script_video_random or state.run_script_zipped_frames: # don't autorefresh during large file uploads, only during processing st_autorefresh(interval=5000, key="refresh_extract_frames_ui") - # upload video files to temporary directory - video_dir = os.path.join(state.proj_dir[1:], VIDEOS_TMP_DIR) - os.makedirs(video_dir, exist_ok=True) - - # initialize the file uploader - uploaded_files = st.file_uploader("Select video files", accept_multiple_files=True) - - # for each of the uploaded files - st_videos = [] - for uploaded_file in uploaded_files: - # read it - bytes_data = uploaded_file.read() - # name it - filename = uploaded_file.name.replace(" ", "_") - filepath = os.path.join(video_dir, filename) - st_videos.append(filepath) - if not state.run_script: - # write the content of the file to the path, but not while processing - with open(filepath, "wb") as f: - f.write(bytes_data) - - col0, col1 = st.columns(2, gap="large") - with col0: - # select number of frames to label per video - n_frames_per_video = st.text_input("Frames to label per video", 20) - st_n_frames_per_video = int(n_frames_per_video) - with col1: - # select range of video to pull frames from - st_frame_range = st.slider( - "Portion of video used for frame selection", 0.0, 1.0, (0.0, 1.0)) - - st_submit_button = st.button( - "Extract frames", - disabled=(st_n_frames_per_video == 0) or len(st_videos) == 0 or state.run_script + VIDEO_RANDOM_STR = "Upload videos and automatically extract random frames" + ZIPPED_FRAMES_STR = "Upload zipped files of frames" + VIDEO_MODEL_STR = "Upload videos and automatically extract frames using a given model" + + st_mode = st.radio( + "Select data upload option", + options=[VIDEO_RANDOM_STR, ZIPPED_FRAMES_STR], + # disabled=state.st_project_loaded, + index=0, ) - if state.run_script: - keys = [k for k, _ in state.works_dict.items()] # cannot directly call keys()? - for vid, status in state.st_extract_status.items(): - if status == "initialized": - p = 0.0 - elif status == "active": - vid_ = vid.replace(".", "_") - if vid_ in keys: - try: - p = state.works_dict[vid_].progress - except: - p = 100.0 # if work is deleted while accessing + + if st_mode == VIDEO_RANDOM_STR: + + # upload video files to temporary directory + video_dir = os.path.join(state.proj_dir[1:], VIDEOS_TMP_DIR) + os.makedirs(video_dir, exist_ok=True) + + # initialize the file uploader + uploaded_files = st.file_uploader( + "Select video files", + type=["mp4", "avi"], + accept_multiple_files=True, + ) + + if len(uploaded_files) > 0: + col1, col2, col3 = st.columns(spec=3, gap="medium") + col1.markdown("**Video Name**") + col2.markdown("**Video Duration**") + col3.markdown("**Number of Frames**") + + # for each of the uploaded files + st_videos = [] + for uploaded_file in uploaded_files: + # read it + bytes_data = uploaded_file.read() + # name it + filename = uploaded_file.name.replace(" ", "_") + filepath = os.path.join(video_dir, filename) + st_videos.append(filepath) + if not state.run_script_video_random: + # write the content of the file to the path, but not while processing + with open(filepath, "wb") as f: + f.write(bytes_data) + + # calculate video duration and frame count + cap = cv2.VideoCapture(filepath) + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = float(frame_count) / float(fps) + + col1.write(uploaded_file.name) + col2.write(f"{duration:.2f} seconds") + col3.write(str(frame_count)) + + # relese video + cap.release() + + # insert an empty element to create empty space + st.markdown("###") + + col0, col1 = st.columns(2, gap="large") + with col0: + # select number of frames to label per video + n_frames_per_video = st.text_input( + "Frames to label per video", 20, + help="Specify the desired number of frames for labeling per video. " + "The app will select frames to maximize the diversity of animal poses " + "captured within each video." + ) + st_n_frames_per_video = int(n_frames_per_video) + with col1: + # select range of video to pull frames from + st_frame_range = st.slider( + "Portion of video used for frame selection", 0.0, 1.0, (0.0, 1.0), + help="Focus on selecting video sections where the animals are clearly visible and " + "performing the desired behaviors. " + "Skip any parts without the animals or with distracting elements like hands, " + "as these can confuse the model." + ) + + st_submit_button = st.button( + "Extract frames", + disabled=( + (st_n_frames_per_video == 0) + or len(st_videos) == 0 + or state.run_script_video_random + ) + ) + if state.run_script_video_random: + keys = [k for k, _ in state.works_dict.items()] # cannot directly call keys()? + for vid, status in state.st_extract_status.items(): + if status == "initialized": + p = 0.0 + elif status == "active": + vid_ = vid.replace(".", "_") + if vid_ in keys: + try: + p = state.works_dict[vid_].progress + except: + p = 100.0 # if work is deleted while accessing + else: + p = 100.0 # state.work.progress + elif status == "complete": + p = 100.0 else: - p = 100.0 # state.work.progress - elif status == "complete": - p = 100.0 - else: - st.text(status) - st.progress(p / 100.0, f"{vid} progress ({status}: {int(p)}\% complete)") - st.warning(f"waiting for existing extraction to finish") - - if state.st_submits > 0 and not st_submit_button and not state.run_script: - proceed_str = "Please proceed to the next tab to label frames." - proceed_fmt = "

%s

" - st.markdown(proceed_fmt % proceed_str, unsafe_allow_html=True) - - # Lightning way of returning the parameters - if st_submit_button: - - state.st_submits += 1 - - state.st_video_files_ = st_videos - state.st_extract_status = {s: 'initialized' for s in st_videos} - state.st_n_frames_per_video = st_n_frames_per_video - state.st_frame_range = st_frame_range - st.text("Request submitted!") - state.run_script = True # must the last to prevent race condition - - # force rerun to show "waiting for existing..." message - st_autorefresh(interval=2000, key="refresh_extract_frames_after_submit") + st.text(status) + st.progress(p / 100.0, f"{vid} progress ({status}: {int(p)}\% complete)") + st.warning("waiting for existing extraction to finish") + + if state.st_submits > 0 and not st_submit_button and not state.run_script_video_random: + proceed_str = "Please proceed to the next tab to label frames." + proceed_fmt = "

%s

" + st.markdown(proceed_fmt % proceed_str, unsafe_allow_html=True) + + # Lightning way of returning the parameters + if st_submit_button: + + state.st_submits += 1 + + state.st_video_files_ = st_videos + state.st_extract_status = {s: 'initialized' for s in st_videos} + state.st_n_frames_per_video = st_n_frames_per_video + state.st_frame_range = st_frame_range + st.text("Request submitted!") + state.run_script_video_random = True # must the last to prevent race condition + + # force rerun to show "waiting for existing..." message + st_autorefresh(interval=2000, key="refresh_extract_frames_after_submit") + + elif st_mode == ZIPPED_FRAMES_STR: + + # upload zipped files to temporary directory + frames_dir = os.path.join(state.proj_dir[1:], ZIPPED_TMP_DIR) + os.makedirs(frames_dir, exist_ok=True) + + # initialize the file uploader + uploaded_files = st.file_uploader( + "Select zipped folders", + type="zip", + accept_multiple_files=True, + help="Upload one zip file per video. The file name should be the name of the video. " + "The frames should be in the format 'img%08i.png', i.e. a png file with a name " + "that starts with 'img' and contains the frame number with leading zeros such " + "that there are 8 total digits (e.g. 'img00003453.png')." + ) + + # for each of the uploaded files + st_videos = [] + for uploaded_file in uploaded_files: + # read it + bytes_data = uploaded_file.read() + # name it + filename = uploaded_file.name.replace(" ", "_") + filepath = os.path.join(frames_dir, filename) + st_videos.append(filepath) + if not state.run_script_zipped_frames: + # write the content of the file to the path, but not while processing + with open(filepath, "wb") as f: + f.write(bytes_data) + # check files: TODO + # state.st_error_flag, state.st_error_msg = check_files_in_zipfile( + # filepath, project_type=st_prev_format) + + st_submit_button_frames = st.button( + "Extract frames", + disabled=len(st_videos) == 0 or state.run_script_zipped_frames, + ) + + if ( + state.st_submits > 0 + and not st_submit_button_frames + and not state.run_script_zipped_frames + ): + proceed_str = "Please proceed to the next tab to label frames." + proceed_fmt = "

%s

" + st.markdown(proceed_fmt % proceed_str, unsafe_allow_html=True) + + # Lightning way of returning the parameters + if st_submit_button_frames: + + state.st_submits += 1 + + state.st_frame_files_ = st_videos + state.st_extract_status = {s: 'initialized' for s in st_videos} + st.text("Request submitted!") + state.run_script_zipped_frames = True # must the last to prevent race condition + + # force rerun to show "waiting for existing..." message + st_autorefresh(interval=2000, key="refresh_extract_frames_after_submit") diff --git a/lightning_pose_app/ui/fifty_one.py b/lightning_pose_app/ui/fifty_one.py deleted file mode 100644 index b5f0e0d..0000000 --- a/lightning_pose_app/ui/fifty_one.py +++ /dev/null @@ -1,298 +0,0 @@ -from lightning.app import CloudCompute, LightningFlow -from lightning.app.storage import FileSystem -from lightning.app.utilities.state import AppState -import os -import streamlit as st -from streamlit_autorefresh import st_autorefresh -import yaml - -from lightning_pose_app import MODELS_DIR -from lightning_pose_app.build_configs import LitPoseBuildConfig, lightning_pose_dir -from lightning_pose_app.utilities import StreamlitFrontend, WorkWithFileSystem - - -class FiftyoneWork(WorkWithFileSystem): - - def __init__(self, *args, **kwargs): - - super().__init__(*args, name="fiftyone", **kwargs) - - self.fiftyone_launched = False - self.fiftyone_datasets = [] - - def start_fiftyone(self): - """run fiftyone""" - if not self.fiftyone_launched: - import fiftyone as fo - fo.launch_app( - dataset=None, - remote=True, - address=self.host, - port=self.port, - ) - self.fiftyone_launched = True - - def find_fiftyone_datasets(self): - """get existing fiftyone datasets""" - # NOTE: we could migrate the fiftyone database back and forth between the Drive but this - # seems lke overkill? the datasets are quick to make and users probably don't care so much - # about these datasets; can return to this later - import fiftyone as fo - out = fo.list_datasets() - datasets = [] - for x in out: - if x.endswith("No datasets found"): - continue - if x.startswith("Migrating database"): - continue - if x.endswith("python"): - continue - if x in datasets: - continue - datasets.append(x) - self.fiftyone_datasets = datasets - - def build_fiftyone_dataset( - self, config_file: str, dataset_name: str, model_dirs: list, model_names: list, - ): - - if dataset_name in self.fiftyone_datasets: - return - - from lightning_pose.utils.fiftyone import FiftyOneFactory, check_dataset - from omegaconf import DictConfig - - # pull models (relative path) - self.get_from_drive(model_dirs) - - # pull config (relative path) - self.get_from_drive([config_file]) - - # load config (absolute path) - cfg = DictConfig(yaml.safe_load(open(self.abspath(config_file), "r"))) - - # edit config (add fiftyone key before making DictConfig, otherwise error) - model_dirs_abs = [os.path.join(os.getcwd(), x[1:]) for x in model_dirs] - cfg.data.data_dir = os.path.join(os.getcwd(), cfg.data.data_dir) - cfg.eval.fiftyone.build_speed = "fast" - cfg.eval.fiftyone.dataset_name = dataset_name - cfg.eval.fiftyone.model_display_names = model_names - cfg.eval.hydra_paths = model_dirs_abs - - # build dataset - FiftyOneClass = FiftyOneFactory(dataset_to_create="images")() - fo_plotting_instance = FiftyOneClass(cfg=cfg) - dataset = fo_plotting_instance.create_dataset() - # create metadata and print if there are problems - check_dataset(dataset) - # print the name of the dataset - fo_plotting_instance.dataset_info_print() - - # add dataset name to list for user to see - self.fiftyone_datasets.append(dataset_name) - - def run(self, action, **kwargs): - - # these functions require fiftyone and/or lighting-pose to be installed - # each function imports the necessary functions directly - # if imports are at the top of this module errors will arise in the orchestrator when - # importing from this module since the proper packages are not yet installed (cloud only) - - if action == "start_fiftyone": - self.start_fiftyone(**kwargs) - elif action == "find_fiftyone_datasets": - self.find_fiftyone_datasets(**kwargs) - elif action == "build_fiftyone_dataset": - self.build_fiftyone_dataset(**kwargs) - - -class FiftyoneConfigUI(LightningFlow): - """UI to run Fiftyone and Streamlit apps.""" - - def __init__(self, *args, **kwargs): - - super().__init__(*args, **kwargs) - - self.work = FiftyoneWork( - cloud_compute=CloudCompute("default"), - cloud_build_config=LitPoseBuildConfig(), # get fiftyone - ) - - # control runners - # True = Run Jobs. False = Do not Run jobs - # UI sets to True to kickoff jobs - # Job Runner sets to False when done - self.run_script = False - - # params updated externally by top-level flow - self.fiftyone_datasets = [] - self.trained_models = [] - self.proj_dir = None - self.config_name = None - - # submit count - self.submit_count = 0 - self.submit_success = False - - # output from the UI - self.st_submit = False - self.st_dataset_name = None - self.st_model_dirs = [None for _ in range(2)] - self.st_model_display_names = [None for _ in range(2)] - - def run(self, action, **kwargs): - - if action == "start_fiftyone": - self.work.run(action=action, **kwargs) - - elif action == "find_fiftyone_datasets": - self.work.run(action=action, **kwargs) - self.fiftyone_datasets = self.work.fiftyone_datasets - - elif action == "build_fiftyone_dataset": - self.work.run( - action=action, - config_file=os.path.join(self.proj_dir, self.config_name), # relative paths - dataset_name=self.st_dataset_name, - model_dirs=self.st_model_dirs, # relative paths - model_names=self.st_model_display_names, - **kwargs, - ) - self.fiftyone_datasets = self.work.fiftyone_datasets - - def configure_layout(self): - return StreamlitFrontend(render_fn=_render_streamlit_fn) - - -def _render_streamlit_fn(state: AppState): - """Create Fiftyone Dataset""" - - # force rerun to update page - st_autorefresh(interval=2000, key="refresh_page") - - st.markdown( - """ - ## Prepare Fiftyone diagnostics - - Choose two models for evaluation. - - """ - ) - - st.markdown( - """ - #### Select models - """ - ) - - # hard-code two models for now - st_model_dirs = [None for _ in range(2)] - st_model_display_names = [None for _ in range(2)] - - # --------------------------------------------------------- - # collect input from users - # --------------------------------------------------------- - with st.form(key="fiftyone_form", clear_on_submit=True): - - col0, col1 = st.columns(2) - - with col0: - - # select first model (supervised) - options1 = sorted(state.trained_models, reverse=True) - tmp = st.selectbox("Select Model 1", options=options1, disabled=state.run_script) - st_model_dirs[0] = tmp - tmp = st.text_input( - "Display name for Model 1", value="model_1", disabled=state.run_script) - st_model_display_names[0] = tmp - - with col1: - - # select second model (semi-supervised) - options2 = sorted(state.trained_models, reverse=True) - if st_model_dirs[0]: - options2.remove(st_model_dirs[0]) - - tmp = st.selectbox("Select Model 2", options=options2, disabled=state.run_script) - st_model_dirs[1] = tmp - tmp = st.text_input( - "Display name for Model 2", value="model_2", disabled=state.run_script) - st_model_display_names[1] = tmp - - # make model dirs paths relative to FileSystem - for i in range(2): - if st_model_dirs[i] and not os.path.isabs(st_model_dirs[i]): - st_model_dirs[i] = os.path.join(state.proj_dir, MODELS_DIR, st_model_dirs[i]) - - # dataset names - existing_datasets = state.fiftyone_datasets - st.write(f"Existing Fifityone datasets:\n{', '.join(existing_datasets)}") - st_dataset_name = st.text_input( - "Choose dataset name other than the above existing names", disabled=state.run_script) - - # build dataset - st.markdown(""" - Diagnostics will be displayed in the following 'Fiftyone' tab. - """) - st_submit_button = st.form_submit_button("Prepare Fiftyone dataset", disabled=state.run_script) - - # --------------------------------------------------------- - # check user input - # --------------------------------------------------------- - if st_model_display_names[0] is None \ - or st_model_display_names[1] is None \ - or st_model_display_names[0] == st_model_display_names[1]: - st_submit_button = False - state.submit_success = False - st.warning(f"Must choose two unique model display names") - if st_model_dirs[0] is None or st_model_dirs[1] is None: - st_submit_button = False - state.submit_success = False - st.warning(f"Must choose two models to continue") - if st_model_dirs[0] == st_model_dirs[1]: - st_submit_button = False - state.submit_success = False - st.warning(f"Must choose two unique models to continue") - if st_submit_button and \ - (st_dataset_name in existing_datasets - or st_dataset_name is None - or st_dataset_name == ""): - st_submit_button = False - state.submit_success = False - st.warning(f"Enter a unique dataset name to continue") - if state.run_script: - st.warning(f"Waiting for existing dataset creation to finish " - f"(may take 30 seconds to update)") - if state.submit_count > 0 \ - and not state.run_script \ - and not st_submit_button \ - and state.submit_success: - proceed_str = "Diagnostics are ready to view in the following tab." - proceed_fmt = "

%s

" - st.markdown(proceed_fmt % proceed_str, unsafe_allow_html=True) - - # --------------------------------------------------------- - # build fiftyone dataset - # --------------------------------------------------------- - # this will only be run once when the user clicks the button; - # on the following pass the button click will be set to False again - if st_submit_button: - - state.submit_count += 1 - - # save streamlit options to flow object only on button click - state.st_dataset_name = st_dataset_name - state.st_model_dirs = st_model_dirs - state.st_model_display_names = st_model_display_names - - # reset form - st_dataset_name = None - st_model_dirs = [None for _ in range(2)] - st_model_display_names = [None for _ in range(2)] - - st.text("Request submitted!") - state.submit_success = True - state.run_script = True # must the last to prevent race condition - - # force rerun to update warnings - st_autorefresh(interval=2000, key="refresh_diagnostics_submitted") diff --git a/lightning_pose_app/ui/project.py b/lightning_pose_app/ui/project.py index 491cc3b..d08a7c7 100644 --- a/lightning_pose_app/ui/project.py +++ b/lightning_pose_app/ui/project.py @@ -1,7 +1,6 @@ import copy import glob -from lightning.app import LightningFlow, LightningWork -from lightning.app.storage import FileSystem +from lightning.app import LightningFlow from lightning.app.utilities.state import AppState import logging import math @@ -28,6 +27,7 @@ StreamlitFrontend, collect_dlc_labels, copy_and_reformat_video_directory, + abspath, ) @@ -41,16 +41,8 @@ def __init__(self, *args, data_dir, default_config_dict, debug=False, **kwargs): super().__init__(*args, **kwargs) - self._drive = FileSystem() - # initialize data_dir if it doesn't yet exist - if not self._drive.isdir(data_dir): - d = self.abspath(data_dir) - os.makedirs(d, exist_ok=True) - f = os.path.join(d, "tmp.txt") - with open(f, "w") as fs: - fs.write("tmp") - self._drive.put(f, os.path.join(data_dir, "tmp.txt")) + os.makedirs(abspath(data_dir), exist_ok=True) # save default config info for initializing new projects self.default_config_dict = default_config_dict @@ -112,59 +104,12 @@ def st_keypoints(self): @property def proj_dir_abs(self): - return self.abspath(self.proj_dir) - - def _get_from_drive_if_not_local(self, file_or_dir): - - if not os.path.exists(self.abspath(file_or_dir)): - try: - _logger.debug(f"drive try get {file_or_dir}") - src = file_or_dir # shared - dst = self.abspath(file_or_dir) # local - self._drive.get(src, dst, overwrite=True) - _logger.debug(f"drive success get {file_or_dir}") - except Exception as e: - _logger.debug(e) - _logger.debug(f"could not find {file_or_dir} in {self.data_dir}") - else: - _logger.debug(f"loading local version of {file_or_dir}") - - def _put_to_drive_remove_local(self, file_or_dir, remove_local=True): - _logger.debug(f"put to drive {file_or_dir}") - src = self.abspath(file_or_dir) # local - if os.path.isfile(src): - dst = file_or_dir # shared - self._drive.put(src, dst) - elif os.path.isdir(src): - files_local = os.listdir(src) - existing_files = self._drive.listdir(file_or_dir) - for file_or_dir_local in files_local: - rel_path = os.path.join(file_or_dir, file_or_dir_local) - if rel_path not in existing_files: - src = self.abspath(rel_path) - dst = rel_path - self._drive.put(src, dst) - else: - _logger.debug(f"{rel_path} already exists on FileSystem! not updating") - # clean up the local object - if remove_local: - if os.path.isfile(self.abspath(file_or_dir)): - os.remove(self.abspath(file_or_dir)) - else: - shutil.rmtree(self.abspath(file_or_dir)) - - @staticmethod - def abspath(path): - if path[0] == "/": - path_ = path[1:] - else: - path_ = path - return os.path.abspath(path_) + return abspath(self.proj_dir) def _find_initialized_projects(self): # find all directories inside the data_dir; these should be the projects # (except labelstudio database) - projects = self._drive.listdir(self.data_dir) + projects = os.listdir(abspath(self.data_dir)) # strip leading dirs to just get project names projects = [ os.path.basename(p) for p in projects @@ -175,7 +120,7 @@ def _find_initialized_projects(self): def _update_paths(self, project_name=None, **kwargs): if not project_name: project_name = self.st_project_name - # these will all be paths RELATIVE to the FileSystem root + # these will all be paths RELATIVE to the Pose-app directory if project_name: self.proj_dir = os.path.join(self.data_dir, project_name) self.config_name = f"model_config_{project_name}.yaml" @@ -188,12 +133,8 @@ def _update_project_config(self, new_vals_dict=None, **kwargs): if not new_vals_dict: new_vals_dict = self.st_new_vals - # check to see if config exists locally; if not, try pulling from drive - if self.config_file: - self._get_from_drive_if_not_local(self.config_file) - # check to see if config exists; copy default config if not - if (self.config_file is None) or (not os.path.exists(self.abspath(self.config_file))): + if (self.config_file is None) or (not os.path.exists(abspath(self.config_file))): _logger.debug(f"no config file at {self.config_file}") _logger.debug("loading default config") # copy default config @@ -211,7 +152,7 @@ def _update_project_config(self, new_vals_dict=None, **kwargs): else: _logger.debug("loading existing config") # load existing config - config_dict = yaml.safe_load(open(self.abspath(self.config_file))) + config_dict = yaml.safe_load(open(abspath(self.config_file))) # update config using new_vals_dict; assume this is a dict of dicts # new_vals_dict = { @@ -230,27 +171,22 @@ def _update_project_config(self, new_vals_dict=None, **kwargs): # save out updated config file locally if not os.path.exists(self.proj_dir_abs): os.makedirs(self.proj_dir_abs) - yaml.dump(config_dict, open(self.abspath(self.config_file), "w")) + yaml.dump(config_dict, open(abspath(self.config_file), "w")) # save current params self.config_dict = config_dict - # push data to drive and clean up local file - self._put_to_drive_remove_local(self.config_file, remove_local=False) - def _update_frame_shapes(self): from PIL import Image - # get labeled data from drive - labeled_data_dir = os.path.join(self.proj_dir, LABELED_DATA_DIR) - # check to see if config exists locally; if not, try pulling from drive - self._get_from_drive_if_not_local(labeled_data_dir) - # load single frame from labeled data imgs = glob.glob(os.path.join(self.proj_dir_abs, LABELED_DATA_DIR, "*", "*.png")) \ - + glob.glob(os.path.join(self.proj_dir_abs, LABELED_DATA_DIR, "*", "*.jpg")) \ - + glob.glob(os.path.join(self.proj_dir_abs, LABELED_DATA_DIR, "*", "*.jpeg")) + + glob.glob(os.path.join(self.proj_dir_abs, LABELED_DATA_DIR, "*", "*.jpg")) \ + + glob.glob(os.path.join(self.proj_dir_abs, LABELED_DATA_DIR, "*", "*.jpeg")) \ + + glob.glob(os.path.join(self.proj_dir_abs, LABELED_DATA_DIR, "*.png")) \ + + glob.glob(os.path.join(self.proj_dir_abs, LABELED_DATA_DIR, "*.jpg")) \ + + glob.glob(os.path.join(self.proj_dir_abs, LABELED_DATA_DIR, "*.jpeg")) if len(imgs) > 0: img = imgs[0] image = Image.open(img) @@ -268,15 +204,14 @@ def _update_frame_shapes(self): }) else: _logger.debug(glob.glob(os.path.join(self.proj_dir_abs, LABELED_DATA_DIR, "*"))) - _logger.debug("did not find labeled data directory in FileSystem") + _logger.debug("did not find labeled data directory") def _compute_labeled_frame_fraction(self, timer=0.0): metadata_file = os.path.join(self.proj_dir, LABELSTUDIO_METADATA_FILENAME) - self._get_from_drive_if_not_local(metadata_file) try: - proj_details = yaml.safe_load(open(self.abspath(metadata_file), "r")) + proj_details = yaml.safe_load(open(abspath(metadata_file), "r")) n_labeled_frames = proj_details["n_labeled_tasks"] n_total_frames = proj_details["n_total_tasks"] except FileNotFoundError: @@ -288,40 +223,36 @@ def _compute_labeled_frame_fraction(self, timer=0.0): n_labeled_frames = None n_total_frames = None - # remove local file so that Work is forced to load updated versions from Drive - if os.path.exists(self.abspath(metadata_file)): - os.remove(self.abspath(metadata_file)) - self.n_labeled_frames = n_labeled_frames self.n_total_frames = n_total_frames def _load_project_defaults(self, **kwargs): - # check to see if config exists locally; if not, try pulling from drive - if self.config_file: - self._get_from_drive_if_not_local(self.config_file) - # check to see if config exists - if self.config_file and os.path.exists(self.abspath(self.config_file)): + if self.config_file and os.path.exists(abspath(self.config_file)): # set values from config - config_dict = yaml.safe_load(open(self.abspath(self.config_file))) + config_dict = yaml.safe_load(open(abspath(self.config_file))) self.st_keypoints_ = config_dict["data"]["keypoints"] self.st_n_keypoints = config_dict["data"]["num_keypoints"] self.st_pcasv_columns = config_dict["data"]["columns_for_singleview_pca"] self.st_pcamv_columns = config_dict["data"]["mirrored_column_matches"] self.st_n_views = 1 if len(self.st_pcamv_columns) == 0 else len(self.st_pcamv_columns) + # save current params + self.config_dict = config_dict def _update_trained_models_list(self, **kwargs): - if self._drive.isdir(self.model_dir): + if os.path.isdir(abspath(self.model_dir)): trained_models = [] # this returns a list of model training days - dirs_day = self._drive.listdir(self.model_dir) + dirs_day = os.listdir(abspath(self.model_dir)) # loop over days and find HH-MM-SS for dir_day in dirs_day: - dirs_time = self._drive.listdir("/" + dir_day) + fullpath1 = os.path.join(abspath(self.model_dir), dir_day) + dirs_time = os.listdir(fullpath1) for dir_time in dirs_time: - trained_models.append('/'.join(dir_time.split('/')[-2:])) + fullpath2 = os.path.join(fullpath1, dir_time) + trained_models.append('/'.join(fullpath2.split('/')[-2:])) self.trained_models = trained_models def _upload_existing_project(self, **kwargs): @@ -372,7 +303,11 @@ def contains_videos(file_or_dir): else: shutil.copyfile(src, dst) - elif self.st_existing_project_format == "DLC": + # ------------------- + # test find models + # -------------------# ------------------- + # test find models + # ------------------- == "DLC": # copy files files_and_dirs = os.listdir(unzipped_dir) @@ -414,9 +349,6 @@ def contains_videos(file_or_dir): # update config file with frame shapes self._update_frame_shapes() - # push files to FileSystem - self._put_to_drive_remove_local(self.proj_dir) - # update counter self.count_upload_existing += 1 @@ -426,15 +358,7 @@ def _delete_project(self, **kwargs): if os.path.exists(self.proj_dir_abs): shutil.rmtree(self.proj_dir_abs) - # recursively delete project on FileSystem - def rmdir(directory, drive): - for item in drive.listdir(directory): - if drive.isdir("/" + item): - rmdir("/" + item, drive) - else: - drive.rm("/" + item) - drive.rm(directory) - rmdir(self.proj_dir, self._drive) + # TODO: how to delete from label studio db? # update project info self.st_project_name = "" @@ -457,8 +381,6 @@ def run(self, action, **kwargs): self._load_project_defaults(**kwargs) elif action == "update_trained_models_list": self._update_trained_models_list(**kwargs) - elif action == "put_file_to_drive": - self._put_to_drive_remove_local(**kwargs) elif action == "upload_existing_project": self._upload_existing_project(**kwargs) elif action == "delete_project": @@ -529,8 +451,31 @@ def _render_streamlit_fn(state: AppState): # ---------------------------------------------------- # landing # ---------------------------------------------------- + with st.sidebar: + st.title(""" Welcome to Lightning Pose App! """) + st.write( + "The first tab of the app is the project manager. Here you will be able to" + " create new projects and load or delete existing projects under your account." + ) + st.write("## To move forward, you will need to complete all the steps in this tab.") + st.write("##") + st.markdown("**Need further help? Check the:**") + st.markdown( + "App [documentation]" + "(https://pose-app.readthedocs.io/en/latest/source/tabs/manage_project.html#)", + unsafe_allow_html=True, + ) + st.markdown( + "Github [repository](https://github.com/Lightning-Universe/Pose-app.html#)", + unsafe_allow_html=True, + ) + st.markdown( + "Lightning Pose [documentation]" + "(https://lightning-pose.readthedocs.io/en/latest/.html#)", + unsafe_allow_html=True, + ) - st.markdown(""" ## Manage Lightning Pose project """) + st.header("Manage Lightning Pose projects") CREATE_STR = "Create new project" UPLOAD_STR = "Create new project from source (e.g. existing DLC project)" @@ -538,12 +483,14 @@ def _render_streamlit_fn(state: AppState): DELETE_STR = "Delete existing project" st_mode = st.radio( - "", + label="Check the box that applies:", options=[CREATE_STR, UPLOAD_STR, LOAD_STR, DELETE_STR], disabled=state.st_project_loaded, index=2 if (state.st_project_loaded and not state.st_create_new_project) else 0, + help="Create a new project from scratch, upload an existing DLC project as a foundation " + "for your new project, continue work on an ongoing lightning pose project, or remove " + "a project from your projects." ) - st.text(f"Available projects: {state.initialized_projects}") st_project_name = st.text_input( @@ -551,7 +498,6 @@ def _render_streamlit_fn(state: AppState): value="" if (not state.st_project_loaded or state.st_reset_project_name) else state.st_project_name ) - # ---------------------------------------------------- # determine project status - load existing, create new # ---------------------------------------------------- @@ -643,8 +589,10 @@ def _render_streamlit_fn(state: AppState): if st_project_name and st_mode == UPLOAD_STR: st_prev_format = st.radio( - "Uploaded project format", + "Select uploaded project format", options=["DLC", "Lightning Pose"], # TODO: SLEAP, MARS? + help="Select the file format that the project is stored at." + " If DLC selected make sure the zipped folder has meet all reqierments" ) state.st_existing_project_format = st_prev_format @@ -669,7 +617,9 @@ def _render_streamlit_fn(state: AppState): state.st_upload_existing_project_zippath = filepath enter_data = True st_mode = CREATE_STR - + st.caption("If your zip file is larger than the 200MB limit, see the [FAQ]" + "(https://pose-app.readthedocs.io/en/latest/source/faqs.html#faq-upload-limit)", + unsafe_allow_html=True) if state.st_error_flag: st.markdown(state.st_error_msg, unsafe_allow_html=True) enter_data = False @@ -704,13 +654,25 @@ def _render_streamlit_fn(state: AppState): # camera views if enter_data: st.markdown("") + st.divider() st.markdown("") - st.markdown("##### Camera views") + st.markdown( + "##### Camera views", + help="Support for multiple views is currently limited to either fusing the views " + "into single frames or utilizing a mirror to generate multiple views from a " + "single camera", + ) n_views = st.text_input( "Enter number of camera views:", disabled=not enter_data, value="" if not state.st_project_loaded else str(st_n_views), ) + st.caption( + "For a multiview option check the [documentation]" + "(https://lightning-pose.readthedocs.io/en/latest/source/" + "user_guide_advanced/multiview_fused.html#)", + unsafe_allow_html=True + ) if n_views: st_n_views = int(n_views) else: @@ -719,7 +681,9 @@ def _render_streamlit_fn(state: AppState): # keypoints if st_n_views > 0: + st.divider() st.markdown("##### Define keypoints") + e1 = st.expander("Expand to see an exemple") keypoint_instructions = """ **Instructions**: If your data has multiple views, make sure to create an entry for each bodypart @@ -733,10 +697,10 @@ def _render_streamlit_fn(state: AppState): r_ear_bottom corner1_top ``` - It is also possible to track keypoints that are only present in a subset of the + It is also possible to track keypoints that are only present in a subset of the views, such as the keypoint `corner1_top` above. """ - st.markdown(keypoint_instructions) + e1.markdown(keypoint_instructions) if state.st_upload_existing_project: value = "\n".join(st_keypoints) elif not state.st_project_loaded: @@ -757,13 +721,24 @@ def _render_streamlit_fn(state: AppState): # pca singleview if st_n_keypoints > 1: + st.divider() st.markdown("##### Select subset of keypoints for Pose PCA") - st.markdown(""" - **Instructions**: - The selected subset will be used for a Pose PCA loss on unlabeled videos. - The subset should be keypoints that are not usually occluded (such as a tongue) - and are not static (such as the corner of a box). + # st.markdown(""" + # **Instructions**: + # The selected subset will be used for a Pose PCA loss on unlabeled videos. + # The subset should be keypoints that are not usually occluded (such as a tongue) + # and are not static (such as the corner of a box). + # """) + e2 = st.expander("Expend for further instractions") + e2.markdown(""" + **When selecting keypoints for Pose PCA on unlabeled videos, focus on**: + * **Selecting points with consistent visibility**, avoiding those prone to + occlusion (e.g., tongue) during movement. + * **Selecting points that exhibit dynamic changes**, + excluding static elements (e.g., corner of a box) + offering minimal pose information. """) + e2.write("*The selected subset will be used for a Pose PCA loss on unlabeled videos") pcasv_selected = [False for _ in st_keypoints] for k, kp in enumerate(st_keypoints): pcasv_selected[k] = st.checkbox( @@ -779,8 +754,9 @@ def _render_streamlit_fn(state: AppState): if st_n_keypoints > 1 and st_n_views > 1: st.markdown("##### Select subset of body parts for Multiview PCA") - st.markdown(""" - **Instructions**: + e3 = st.expander("Expand for further instractions") + e3.markdown(""" + Select the same body part from different POV's. The selected subset will be used for a Multiview PCA loss on unlabeled videos. The subset should be keypoints that are usually visible in all camera views. """) @@ -811,7 +787,8 @@ def _render_streamlit_fn(state: AppState): # set bodypart dropdowns for c, col in enumerate(cols[1:]): kp = col.selectbox( - f"", st_keypoints, key=f"Bodypart {r} view {c}", + "", st_keypoints, + key=f"Bodypart {r} view {c}", index=c * st_n_bodyparts + r ) st_pcamv_columns[c, r] = np.where(np.array(st_keypoints) == kp)[0] @@ -877,11 +854,10 @@ def _render_streamlit_fn(state: AppState): if state.st_submits > 0: proceed_str = """ Proceed to the next tab to extract frames for labeling.

- LabelStudio login information:
+ Use this LabelStudio login information:
username: user@localhost
password: pw """ - proceed_fmt = "

%s

" st.markdown(proceed_fmt % proceed_str, unsafe_allow_html=True) diff --git a/lightning_pose_app/ui/streamlit.py b/lightning_pose_app/ui/streamlit.py index 5b9cf92..0f3a994 100644 --- a/lightning_pose_app/ui/streamlit.py +++ b/lightning_pose_app/ui/streamlit.py @@ -1,9 +1,8 @@ from lightning.app import CloudCompute, LightningFlow import os -from lightning_pose_app import MODELS_DIR +from lightning_pose_app import MODELS_DIR, LIGHTNING_POSE_DIR from lightning_pose_app.bashwork import LitBashWork -from lightning_pose_app.build_configs import LitPoseBuildConfig, lightning_pose_dir class StreamlitAppLightningPose(LightningFlow): @@ -14,9 +13,7 @@ def __init__(self, *args, app_type, **kwargs): super().__init__(*args, **kwargs) self.work = LitBashWork( - name=f"streamlit_{app_type}", cloud_compute=CloudCompute("default"), - cloud_build_config=LitPoseBuildConfig(), # this may not be necessary ) # choose labeled frame or video option @@ -52,21 +49,9 @@ def initialize(self, **kwargs): + " -- " \ + " " + model_dir_args - self.work.run(cmd, cwd=lightning_pose_dir, wait_for_exit=False) - - def pull_models(self, **kwargs): - inputs = kwargs.get("inputs", None) - if inputs: - self.work.run( - "null command", - cwd=os.getcwd(), - input_output_only=True, # pull inputs from Drive, but do not run commands - inputs=inputs, - ) + self.work.run(cmd, cwd=LIGHTNING_POSE_DIR, wait_for_exit=False) def run(self, action, **kwargs): if action == "initialize": self.initialize(**kwargs) - elif action == "pull_models": - self.pull_models(**kwargs) diff --git a/lightning_pose_app/ui/train_infer.py b/lightning_pose_app/ui/train_infer.py index 565a79f..f27fc96 100644 --- a/lightning_pose_app/ui/train_infer.py +++ b/lightning_pose_app/ui/train_infer.py @@ -1,8 +1,7 @@ """UI for training models.""" from datetime import datetime -from lightning.app import CloudCompute, LightningFlow -from lightning.app.storage import FileSystem +from lightning.app import CloudCompute, LightningFlow, LightningWork from lightning.app.structures import Dict from lightning.app.utilities.cloud import is_running_in_cloud from lightning.app.utilities.state import AppState @@ -10,23 +9,23 @@ from lightning.pytorch.utilities import rank_zero_only import lightning.pytorch as pl import logging -import numpy as np import os -import pandas as pd +import shutil import streamlit as st from streamlit_autorefresh import st_autorefresh -import shutil -import subprocess -import sys -import time import yaml -from lightning_pose_app import LABELED_DATA_DIR, VIDEOS_DIR, VIDEOS_TMP_DIR, VIDEOS_INFER_DIR -from lightning_pose_app import MODELS_DIR, COLLECTED_DATA_FILENAME, SELECTED_FRAMES_FILENAME +from lightning_pose_app import VIDEOS_DIR, VIDEOS_TMP_DIR, VIDEOS_INFER_DIR +from lightning_pose_app import LABELED_DATA_DIR, MODELS_DIR, SELECTED_FRAMES_FILENAME from lightning_pose_app import MODEL_VIDEO_PREDS_TRAIN_DIR, MODEL_VIDEO_PREDS_INFER_DIR from lightning_pose_app.build_configs import LitPoseBuildConfig -from lightning_pose_app.utilities import StreamlitFrontend, WorkWithFileSystem -from lightning_pose_app.utilities import reencode_video, check_codec_format +from lightning_pose_app.utilities import ( + StreamlitFrontend, + abspath, + copy_and_reformat_video, + is_context_dataset, + make_video_snippet, +) _logger = logging.getLogger('APP.TRAIN_INFER') @@ -73,11 +72,11 @@ def on_predict_batch_end( self._update_progress(progress) -class LitPose(WorkWithFileSystem): +class LitPose(LightningWork): def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, name="train_infer", **kwargs) + super().__init__(*args, **kwargs) self.pwd = os.getcwd() self.progress = 0.0 @@ -87,55 +86,7 @@ def __init__(self, *args, **kwargs) -> None: self.work_is_done_inference = False self.count = 0 - def _reformat_video(self, video_file, **kwargs): - - # get new names (ensure mp4 file extension, no tmp directory) - ext = os.path.splitext(os.path.basename(video_file))[1] - video_file_mp4_ext = video_file.replace(f"{ext}", ".mp4") - video_file_new = video_file_mp4_ext.replace(VIDEOS_TMP_DIR, VIDEOS_INFER_DIR) - video_file_abs_new = self.abspath(video_file_new) - - # check 0: do we even need to reformat? - if self._drive.isfile(video_file_new): - return video_file_new - - # pull videos from FileSystem - self.get_from_drive([video_file]) - video_file_abs = self.abspath(video_file) - - # check 1: does file exist? - video_file_exists = os.path.exists(video_file_abs) - if not video_file_exists: - _logger.info(f"{video_file_abs} does not exist! skipping") - return None - - # check 2: is file in the correct format for DALI? - video_file_correct_codec = check_codec_format(video_file_abs) - - # reencode/rename - if not video_file_correct_codec: - _logger.info("re-encoding video to be compatable with Lightning Pose video reader") - reencode_video(video_file_abs, video_file_abs_new) - # remove old video from local files - os.remove(video_file_abs) - else: - # make dir to write into - os.makedirs(os.path.dirname(video_file_abs_new), exist_ok=True) - # rename - os.rename(video_file_abs, video_file_abs_new) - - # remove old video(s) from FileSystem - if self._drive.isfile(video_file): - self._drive.rm(video_file) - if self._drive.isfile(video_file_mp4_ext): - self._drive.rm(video_file_mp4_ext) - - # push possibly reformated, renamed videos to FileSystem - self.put_to_drive([video_file_new]) - - return video_file_new - - def _train(self, inputs, outputs, cfg_overrides, results_dir): + def _train(self, config_file, config_overrides, results_dir): import gc from omegaconf import DictConfig, OmegaConf @@ -162,20 +113,14 @@ def _train(self, inputs, outputs, cfg_overrides, results_dir): self.work_is_done_training = False # ---------------------------------------------------------------------------------- - # Pull data from FileSystem + # Set up config # ---------------------------------------------------------------------------------- - # pull config, frames, labels, and videos (relative paths) - self.get_from_drive(inputs) - - # load config (absolute path) - for i in inputs: - if i.endswith(".yaml"): - config_file = i - cfg = DictConfig(yaml.safe_load(open(self.abspath(config_file), "r"))) + # load config + cfg = DictConfig(yaml.safe_load(open(abspath(config_file), "r"))) # update config with user-provided overrides - for key1, val1 in cfg_overrides.items(): + for key1, val1 in config_overrides.items(): for key2, val2 in val1.items(): cfg[key1][key2] = val2 @@ -299,7 +244,7 @@ def _train(self, inputs, outputs, cfg_overrides, results_dir): f"Found {len(filenames)} videos to predict on " f"(in cfg.eval.test_videos_directory)" ) - + for v, video_file in enumerate(filenames): assert os.path.isfile(video_file) pretty_print_str(f"Predicting video: {video_file}...") @@ -310,7 +255,8 @@ def _train(self, inputs, outputs, cfg_overrides, results_dir): # get save name labeled video csv if cfg.eval.save_vids_after_training: labeled_vid_dir = os.path.join(video_pred_dir, "labeled_videos") - labeled_mp4_file = os.path.join(labeled_vid_dir, video_pred_name + "_labeled.mp4") + labeled_mp4_file = os.path.join(labeled_vid_dir, + video_pred_name + "_labeled.mp4") else: labeled_mp4_file = None # predict on video @@ -337,15 +283,17 @@ def _train(self, inputs, outputs, cfg_overrides, results_dir): continue # ---------------------------------------------------------------------------------- - # Push results to FileSystem, clean up + # Clean up # ---------------------------------------------------------------------------------- # save config file cfg_file_local = os.path.join(results_dir, "config.yaml") with open(cfg_file_local, "w") as fp: OmegaConf.save(config=cfg, f=fp.name) + # remove lightning logs + shutil.rmtree(os.path.join(results_dir, "lightning_logs"), ignore_errors=True) + os.chdir(self.pwd) - self.put_to_drive(outputs) # IMPORTANT! must come after changing directories # clean up memory del imgaug_transform @@ -379,26 +327,24 @@ def _run_inference(self, model_dir, video_file): self.work_is_done_inference = False # ---------------------------------------------------------------------------------- - # Pull data from FileSystem + # Set up paths # ---------------------------------------------------------------------------------- - # pull video from FileSystem - self.get_from_drive([video_file]) - # check: does file exist? - video_file_abs = self.abspath(video_file) + # check: does file exist? + if not os.path.exists(video_file): + video_file_abs = abspath(video_file) + else: + video_file_abs = video_file video_file_exists = os.path.exists(video_file_abs) _logger.info(f"video file exists? {video_file_exists}") if not video_file_exists: _logger.info("skipping inference") return - # pull model from FileSystem - self.get_from_drive([model_dir]) - - # load config (absolute path) + # load config config_file = os.path.join(model_dir, "config.yaml") - cfg = DictConfig(yaml.safe_load(open(self.abspath(config_file), "r"))) + cfg = DictConfig(yaml.safe_load(open(abspath(config_file), "r"))) cfg.training.imgaug = "default" # don't do imgaug # define paths @@ -408,7 +354,7 @@ def _run_inference(self, model_dir, video_file): pred_dir = os.path.join(model_dir, MODEL_VIDEO_PREDS_INFER_DIR) preds_file = os.path.join( - self.abspath(pred_dir), os.path.basename(video_file_abs).replace(".mp4", ".csv")) + abspath(pred_dir), os.path.basename(video_file_abs).replace(".mp4", ".csv")) # ---------------------------------------------------------------------------------- # Set up data/model objects @@ -425,7 +371,7 @@ def _run_inference(self, model_dir, video_file): data_module.setup() ckpt_file = ckpt_path_from_base_path( - base_path=self.abspath(model_dir), model_name=cfg.model.model_name + base_path=abspath(model_dir), model_name=cfg.model.model_name ) # ---------------------------------------------------------------------------------- @@ -458,8 +404,7 @@ def _run_inference(self, model_dir, video_file): # make short labeled snippet for manual inspection self.progress = 0.0 # reset progress so it will again be updated during snippet inference self.status_ = "creating labeled video" - video_file_abs_short = self._make_video_snippet( - video_file=video_file_abs, preds_file=preds_file) + video_file_abs_short = make_video_snippet(video_file=video_file_abs, preds_file=preds_file) preds_file_short = preds_file.replace(".csv", ".short.csv") export_predictions_and_labeled_video( video_file=video_file_abs_short, @@ -471,67 +416,53 @@ def _run_inference(self, model_dir, video_file): data_module=data_module, ) - # ---------------------------------------------------------------------------------- - # Push results to FileSystem, clean up - # ---------------------------------------------------------------------------------- - self.put_to_drive([pred_dir]) - # set flag for parent app self.work_is_done_inference = True @staticmethod - def _make_video_snippet(video_file, preds_file, clip_length=30, likelihood_thresh=0.9): - - import cv2 + def _make_fiftyone_dataset(config_file, results_dir, config_overrides=None, **kwargs): - # save videos with csv file - save_dir = os.path.dirname(preds_file) + from lightning_pose.utils.fiftyone import FiftyOneImagePlotter, check_dataset + from omegaconf import DictConfig - df = pd.read_csv(preds_file, header=[0, 1, 2], index_col=0) + # load config (absolute path) + cfg = DictConfig(yaml.safe_load(open(abspath(config_file), "r"))) - # how large is the clip window? - video = cv2.VideoCapture(video_file) - fps = video.get(cv2.CAP_PROP_FPS) - win_len = int(fps * clip_length) + # update config with user-provided overrides (this is mostly for unit testing) + for key1, val1 in config_overrides.items(): + for key2, val2 in val1.items(): + cfg[key1][key2] = val2 - # make a `clip_length` second video clip that contains the highest keypoint motion energy - src = video_file - dst = os.path.join(save_dir, os.path.basename(video_file).replace(".mp4", ".short.mp4")) - if win_len >= df.shape[0]: - # short video, no need to shorten further. just copy existing video - shutil.copyfile(src, dst) - else: - # compute motion energy (averaged over keypoints) - kps_and_conf = df.to_numpy().reshape(df.shape[0], -1, 3) - kps = kps_and_conf[:, :, :2] - conf = kps_and_conf[:, :, -1] - conf2 = np.concatenate([conf[:, :, None], conf[:, :, None]], axis=2) - kps[conf2 < likelihood_thresh] = np.nan - me = np.nanmean(np.linalg.norm(kps[1:] - kps[:1], axis=2), axis=-1) - - # find window - df_me = pd.DataFrame({"me": np.concatenate([[0], me])}) - df_me_win = df_me.rolling(window=win_len, center=False).mean() - # rolling places results in right edge of window, need to subtract this - clip_start_idx = df_me_win.me.argmax() - win_len - # convert to seconds - clip_start_sec = int(clip_start_idx / fps) - # if all predictions are bad, make sure we still create a valid snippet video - if np.isnan(clip_start_sec) or clip_start_sec < 0: - clip_start_sec = 0 - - # make clip - ffmpeg_cmd = f"ffmpeg -ss {clip_start_sec} -i {src} -t {clip_length} {dst}" - subprocess.run(ffmpeg_cmd, shell=True) - - return dst + # edit config + cfg.data.data_dir = os.path.join(os.getcwd(), cfg.data.data_dir) + model_dir = "/".join(results_dir.split("/")[-2:]) + # get project name from config file, assuming first part is model_config_ + proj_name = os.path.basename(config_file).split(".")[0][13:] + cfg.eval.fiftyone.dataset_name = f"{proj_name}_{model_dir}" + cfg.eval.fiftyone.model_display_names = [model_dir.split("_")[-1]] + cfg.eval.hydra_paths = [results_dir] + + # build dataset + fo_plotting_instance = FiftyOneImagePlotter(cfg=cfg) + dataset = fo_plotting_instance.create_dataset() + # create metadata and print if there are problems + check_dataset(dataset) + # print the name of the dataset + fo_plotting_instance.dataset_info_print() def run(self, action=None, **kwargs): if action == "train": self._train(**kwargs) + self._make_fiftyone_dataset(**kwargs) elif action == "run_inference": - new_vid_file = self._reformat_video(**kwargs) - kwargs["video_file"] = new_vid_file + proj_dir = '/'.join(kwargs["model_dir"].split('/')[:3]) + new_vid_file = copy_and_reformat_video( + video_file=abspath(kwargs["video_file"]), + dst_dir=abspath(os.path.join(proj_dir, VIDEOS_INFER_DIR)), + remove_old=kwargs.pop("remove_old", True), + ) + # save relative rather than absolute path + kwargs["video_file"] = '/'.join(new_vid_file.split('/')[-4:]) self._run_inference(**kwargs) @@ -542,9 +473,6 @@ def __init__(self, *args, allow_context=True, max_epochs_default=300, **kwargs): super().__init__(*args, **kwargs) - # shared storage system - self._drive = FileSystem() - # updated externally by parent app self.trained_models = [] self.proj_dir = None @@ -593,35 +521,13 @@ def __init__(self, *args, allow_context=True, max_epochs_default=300, **kwargs): self.st_infer_label_opt = None # what to do with video evaluation self.st_inference_videos = [] - def _push_video(self, video_file): - if video_file[0] == "/": - src = os.path.join(os.getcwd(), video_file[1:]) - dst = video_file - else: - src = os.path.join(os.getcwd(), video_file) - dst = "/" + video_file - if not self._drive.isfile(dst) and os.path.exists(src): - # only put to FileSystem under two conditions: - # 1. file exists locally; if it doesn't, maybe it has already been deleted for a reason - # 2. file does not already exist on FileSystem; avoids excessive file transfers - _logger.debug(f"TRAIN_INFER UI try put {dst}") - self._drive.put(src, dst) - _logger.debug(f"TRAIN_INFER UI success put {dst}") - - def _train( - self, - config_filename=None, - video_dirname=VIDEOS_DIR, - labeled_data_dirname=LABELED_DATA_DIR, - csv_filename=COLLECTED_DATA_FILENAME, - ): + def _train(self, config_filename=None, video_dirname=VIDEOS_DIR): if config_filename is None: _logger.error("config_filename must be specified for training models") # set config overrides base_dir = os.path.join(os.getcwd(), self.proj_dir[1:]) - model_dir = os.path.join(self.proj_dir, MODELS_DIR) if self.st_train_label_opt == VIDEO_LABEL_NONE: predict_vids = False @@ -633,7 +539,7 @@ def _train( predict_vids = True save_vids = True - cfg_overrides = { + config_overrides = { "data": { "data_dir": base_dir, "video_dir": os.path.join(base_dir, video_dirname), @@ -645,7 +551,6 @@ def _train( }, "model": { # update these below if necessary "model_type": "heatmap", - "do_context": False, }, "training": { "imgaug": "dlc", @@ -653,26 +558,18 @@ def _train( } } - # list files needed from FileSystem - inputs = [ - os.path.join(self.proj_dir, config_filename), - os.path.join(self.proj_dir, labeled_data_dirname), - os.path.join(self.proj_dir, video_dirname), - os.path.join(self.proj_dir, csv_filename), - ] - # train models for m in ["super", "semisuper", "super ctx", "semisuper ctx"]: status = self.st_train_status[m] if status == "initialized" or status == "active": self.st_train_status[m] = "active" - outputs = [os.path.join(model_dir, self.st_datetimes[m])] - cfg_overrides["model"]["losses_to_use"] = self.st_losses[m] + config_overrides["model"]["losses_to_use"] = self.st_losses[m] if m.find("ctx") > -1: - cfg_overrides["model"]["model_type"] = "heatmap_mhcrnn" - cfg_overrides["model"]["do_context"] = True + config_overrides["model"]["model_type"] = "heatmap_mhcrnn" self.work.run( - action="train", inputs=inputs, outputs=outputs, cfg_overrides=cfg_overrides, + action="train", + config_file=os.path.join(self.proj_dir, config_filename), + config_overrides=config_overrides, results_dir=os.path.join(base_dir, MODELS_DIR, self.st_datetimes[m]) ) self.st_train_status[m] = "complete" @@ -680,7 +577,7 @@ def _train( self.submit_count_train += 1 - def _run_inference(self, model_dir=None, video_files=None): + def _run_inference(self, model_dir=None, video_files=None, testing=False): self.work_is_done_inference = False @@ -689,9 +586,7 @@ def _run_inference(self, model_dir=None, video_files=None): if not video_files: video_files = self.st_inference_videos - # launch works: - # - sequential if local - # - parallel if on cloud + # launch works (sequentially for now) for video_file in video_files: video_key = video_file.replace(".", "_") # keys cannot contain "." if video_key not in self.works_dict.keys(): @@ -702,13 +597,12 @@ def _run_inference(self, model_dir=None, video_files=None): status = self.st_infer_status[video_file] if status == "initialized" or status == "active": self.st_infer_status[video_file] = "active" - # move video from ui machine to shared FileSystem - self._push_video(video_file=video_file) # run inference (automatically reformats video for DALI) self.works_dict[video_key].run( action="run_inference", model_dir=model_dir, video_file="/" + video_file, + remove_old=not testing, # remove bad format file by default ) self.st_infer_status[video_file] = "complete" @@ -719,7 +613,8 @@ def _run_inference(self, model_dir=None, video_files=None): and self.works_dict[video_key].work_is_done_inference: # kill work _logger.info(f"killing work from video {video_key}") - self.works_dict[video_key].stop() + if not testing: # cannot run stop() from tests for some reason + self.works_dict[video_key].stop() del self.works_dict[video_key] # set flag for parent app @@ -727,58 +622,13 @@ def _run_inference(self, model_dir=None, video_files=None): def _determine_dataset_type(self, **kwargs): """Check if labeled data directory contains context frames.""" - - def get_frame_number(basename): - ext = basename.split(".")[-1] # get base name - split_idx = None - for c_idx, c in enumerate(basename): - try: - int(c) - split_idx = c_idx - break - except ValueError: - continue - # remove prefix - prefix = basename[:split_idx] - idx = basename[split_idx:] - # remove file extension - idx = idx.replace(f".{ext}", "") - return int(idx), prefix, ext - - # pull labeled data - src = os.path.join(self.proj_dir, LABELED_DATA_DIR) - dst = os.path.join(os.getcwd(), self.proj_dir[1:], LABELED_DATA_DIR) - if not os.path.exists(dst): - self._drive.get(src, dst) - - # loop over all labeled frames, break as soon as single frame fails - for d in os.listdir(dst): - frames_in_dir_file = os.path.join(dst, d, SELECTED_FRAMES_FILENAME) - if not os.path.exists(frames_in_dir_file): - continue - frames_in_dir = np.genfromtxt(frames_in_dir_file, delimiter=",", dtype=str) - for frame in frames_in_dir: - idx_img, prefix, ext = get_frame_number(frame.split("/")[-1]) - # get the frames -> t-2, t-1, t, t+1, t + 2 - list_idx = [idx_img - 2, idx_img - 1, idx_img, idx_img + 1, idx_img + 2] - for fr_num in list_idx: - # replace frame number with 0 if we're at the beginning of the video - fr_num = max(0, fr_num) - # split name into pieces - img_pieces = frame.split("/") - # figure out length of integer - int_len = len(img_pieces[-1].replace(f".{ext}", "").replace(prefix, "")) - # replace original frame number with context frame number - img_pieces[-1] = f"{prefix}{str(fr_num).zfill(int_len)}.{ext}" - img_name = "/".join(img_pieces) - if not os.path.exists(os.path.join(dst, d, img_name)): - self.allow_context = False - break + self.allow_context = is_context_dataset( + labeled_data_dir=os.path.join(abspath(self.proj_dir), LABELED_DATA_DIR), + selected_frames_filename=SELECTED_FRAMES_FILENAME, + ) def run(self, action, **kwargs): - if action == "push_video": - self._push_video(**kwargs) - elif action == "train": + if action == "train": self._train(**kwargs) elif action == "run_inference": self._run_inference(**kwargs) @@ -836,11 +686,12 @@ def _render_streamlit_fn(state: AppState): #### Training options """ ) - expander = st.expander("Change Defaults") - + # expander = st.expander("Change Defaults") + expander = st.expander( + "Expand to adjust maximum training epochs and select unsupervised losses") # max epochs st_max_epochs = expander.text_input( - "Max training epochs (all models)", value=state.max_epochs_default) + "Set the max training epochs (all models)", value=state.max_epochs_default) # unsupervised losses (semi-supervised only; only expose relevant losses) expander.write("Select losses for semi-supervised model") @@ -858,8 +709,10 @@ def _render_streamlit_fn(state: AppState): st.markdown( """ - #### Video handling options - """ + #### Video handling options""", + help="Choose if you want to automatically run inference on the videos uploaded for " + "labeling. **Warning** : Video traces will not be available in the " + "Video Diagnostics tab if you choose “Do not run inference”" ) st_train_label_opt = st.radio( "", @@ -922,7 +775,7 @@ def _render_streamlit_fn(state: AppState): state.st_max_epochs = int(st_max_epochs) state.st_train_label_opt = st_train_label_opt state.st_train_status = { - "super": "initialized" if st_train_super else "none", + "super": "initialized" if st_train_super else "none", "semisuper": "initialized" if st_train_semisuper else "none", "super ctx": "initialized" if st_train_super_ctx else "none", "semisuper ctx": "initialized" if st_train_semisuper_ctx else "none", @@ -949,13 +802,13 @@ def _render_streamlit_fn(state: AppState): # force different datetimes for i in range(4): if i == 0: # supervised model - st_datetimes["super"] = dtime[:-2] + "00" + st_datetimes["super"] = dtime[:-2] + "00_super" if i == 1: # semi-supervised model - st_datetimes["semisuper"] = dtime[:-2] + "01" + st_datetimes["semisuper"] = dtime[:-2] + "01_semisuper" if i == 2: # supervised context model - st_datetimes["super ctx"] = dtime[:-2] + "02" + st_datetimes["super ctx"] = dtime[:-2] + "02_super-ctx" if i == 3: # semi-supervised context model - st_datetimes["semisuper ctx"] = dtime[:-2] + "03" + st_datetimes["semisuper ctx"] = dtime[:-2] + "03_semisuper-ctx" # NOTE: cannot set these dicts entry-by-entry in the above loop, o/w don't get set? state.st_datetimes = st_datetimes @@ -966,7 +819,16 @@ def _render_streamlit_fn(state: AppState): with infer_tab: - st.header("Predict on New Videos") + st.header( + body="Predict on New Videos", + help="Select your preferred inference model, then" + " drag and drop your video file(s). Monitor the upload progress bar" + " and click **Run inference** once uploads are complete. After completion," + " a brief snippet is extracted for each video during the period of highest" + " motion energy, and a diagnostic video with raw frames and model" + " predictions is generated. Once inference concludes for all videos, the" + " 'waiting for existing inference to finish' warning will disappear." + ) model_dir = st.selectbox( "Choose model to run inference", sorted(state.trained_models, reverse=True)) @@ -1016,7 +878,8 @@ def _render_streamlit_fn(state: AppState): p = 100.0 else: st.text(status) - st.progress(p / 100.0, f"{vid} progress ({status_ or status}: {int(p)}\% complete)") + st.progress( + p / 100.0, f"{vid} progress ({status_ or status}: {int(p)}\% complete)") st.warning("waiting for existing inference to finish") # Lightning way of returning the parameters diff --git a/lightning_pose_app/utilities.py b/lightning_pose_app/utilities.py index 82c7036..3540b57 100644 --- a/lightning_pose_app/utilities.py +++ b/lightning_pose_app/utilities.py @@ -1,8 +1,6 @@ import cv2 import glob -from lightning.app import LightningWork from lightning.app.frontend import StreamlitFrontend as LitStreamlitFrontend -from lightning.app.storage import FileSystem import logging import numpy as np import os @@ -28,15 +26,6 @@ def args_to_dict(script_args: str) -> dict: return script_args_dict -def dict_to_args(script_args_dict: dict) -> str: - """convert dict {'A':1, 'B':2} to str A=1 B=2 to """ - script_args_array = [] - for k,v in script_args_dict.items(): - script_args_array.append(f"{k}={v}") - # return as a text - return " \n".join(script_args_array) - - class StreamlitFrontend(LitStreamlitFrontend): """Provide helpful print statements for where streamlit tabs are forwarded.""" @@ -52,52 +41,22 @@ def start_server(self, *args, **kwargs): pass -class WorkWithFileSystem(LightningWork): - - def __init__(self, *args, name, **kwargs): - - super().__init__(*args, **kwargs) - - # uniquely identify prints - self.work_name = name - - # initialize shared storage system - self._drive = FileSystem() - - def get_from_drive(self, inputs, overwrite=True): - for i in inputs: - _logger.debug(f"{self.work_name.upper()} get {i}") - try: # file may not be ready - src = i # shared - dst = self.abspath(i) # local - self._drive.get(src, dst, overwrite=overwrite) - _logger.debug(f"{self.work_name.upper()} data saved at {dst}") - except Exception as e: - _logger.debug(f"{self.work_name.upper()} did not load {i} from FileSystem: {e}") - continue +def check_codec_format(input_file: str) -> bool: + """Run FFprobe command to get video codec and pixel format.""" - def put_to_drive(self, outputs): - for o in outputs: - _logger.debug(f"{self.work_name.upper()} drive try put {o}") - src = self.abspath(o) # local - dst = o # shared - # make sure dir ends with / so that put works correctly - if os.path.isdir(src): - src = os.path.join(src, "") - dst = os.path.join(dst, "") - # check to make sure file exists locally - if not os.path.exists(src): - continue - self._drive.put(src, dst) - _logger.debug(f"{self.work_name.upper()} drive success put {dst}") + ffmpeg_cmd = f'ffmpeg -i {input_file}' + output_str = subprocess.run(ffmpeg_cmd, shell=True, capture_output=True, text=True) + # stderr because the ffmpeg command has no output file, but the stderr still has codec info. + output_str = output_str.stderr - @staticmethod - def abspath(path): - if path[0] == "/": - path_ = path[1:] - else: - path_ = path - return os.path.abspath(path_) + # search for correct codec (h264) and pixel format (yuv420p) + if output_str.find('h264') != -1 and output_str.find('yuv420p') != -1: + # print('Video uses H.264 codec') + is_codec = True + else: + # print('Video does not use H.264 codec') + is_codec = False + return is_codec def reencode_video(input_file: str, output_file: str) -> None: @@ -116,22 +75,43 @@ def reencode_video(input_file: str, output_file: str) -> None: subprocess.run(ffmpeg_cmd, shell=True) -def check_codec_format(input_file: str) -> bool: - """Run FFprobe command to get video codec and pixel format.""" +def copy_and_reformat_video(video_file: str, dst_dir: str, remove_old=True) -> str: + """Copy a single video, reformatting if necessary, and delete the original.""" - ffmpeg_cmd = f'ffmpeg -i {input_file}' - output_str = subprocess.run(ffmpeg_cmd, shell=True, capture_output=True, text=True) - # stderr because the ffmpeg command has no output file, but the stderr still has codec info. - output_str = output_str.stderr + src = video_file - # search for correct codec (h264) and pixel format (yuv420p) - if output_str.find('h264') != -1 and output_str.find('yuv420p') != -1: - # print('Video uses H.264 codec') - is_codec = True + # make sure copied vid has mp4 extension + dst = os.path.join(dst_dir, os.path.basename(video_file).replace(".avi", ".mp4")) + + # check 0: do we even need to reformat? + if os.path.isfile(dst): + return dst + + # check 1: does file exist? + if not os.path.exists(src): + _logger.info(f"{src} does not exist! skipping") + return None + + # check 2: is file in the correct format for DALI? + video_file_correct_codec = check_codec_format(src) + + # reencode/rename + if not video_file_correct_codec: + _logger.info(f"re-encoding {src} to be compatable with Lightning Pose video reader") + reencode_video(src, dst) + # remove old video + if remove_old: + os.remove(src) else: - # print('Video does not use H.264 codec') - is_codec = False - return is_codec + # make dir to write into + os.makedirs(os.path.dirname(dst), exist_ok=True) + # rename + if remove_old: + os.rename(src, dst) + else: + shutil.copyfile(src, dst) + + return dst def copy_and_reformat_video_directory(src_dir: str, dst_dir: str) -> None: @@ -195,6 +175,114 @@ def get_frames_from_idxs(cap, idxs) -> np.ndarray: return frames +def make_video_snippet( + video_file: str, + preds_file: str, + clip_length: int = 30, + likelihood_thresh: float = 0.9 +) -> str: + + # save videos with csv file + save_dir = os.path.dirname(preds_file) + + # load pose predictions + df = pd.read_csv(preds_file, header=[0, 1, 2], index_col=0) + + # how large is the clip window? + video = cv2.VideoCapture(video_file) + fps = video.get(cv2.CAP_PROP_FPS) + win_len = int(fps * clip_length) + + # make a `clip_length` second video clip that contains the highest keypoint motion energy + src = video_file + dst = os.path.join(save_dir, os.path.basename(video_file).replace(".mp4", ".short.mp4")) + if win_len >= df.shape[0]: + # short video, no need to shorten further. just copy existing video + shutil.copyfile(src, dst) + else: + # compute motion energy (averaged over keypoints) + kps_and_conf = df.to_numpy().reshape(df.shape[0], -1, 3) + kps = kps_and_conf[:, :, :2] + conf = kps_and_conf[:, :, -1] + conf2 = np.concatenate([conf[:, :, None], conf[:, :, None]], axis=2) + kps[conf2 < likelihood_thresh] = np.nan + me = np.nanmean(np.linalg.norm(kps[1:] - kps[:1], axis=2), axis=-1) + + # find window + df_me = pd.DataFrame({"me": np.concatenate([[0], me])}) + df_me_win = df_me.rolling(window=win_len, center=False).mean() + # rolling places results in right edge of window, need to subtract this + clip_start_idx = df_me_win.me.argmax() - win_len + # convert to seconds + clip_start_sec = int(clip_start_idx / fps) + # if all predictions are bad, make sure we still create a valid snippet video + if np.isnan(clip_start_sec) or clip_start_sec < 0: + clip_start_sec = 0 + + # make clip + ffmpeg_cmd = f"ffmpeg -ss {clip_start_sec} -i {src} -t {clip_length} {dst}" + subprocess.run(ffmpeg_cmd, shell=True) + + return dst + + +def get_frame_number(basename: str) -> tuple: + """img0000234.png -> (234, "img", ".png")""" + ext = basename.split(".")[-1] # get base name + split_idx = None + for c_idx, c in enumerate(basename): + try: + int(c) + split_idx = c_idx + break + except ValueError: + continue + # remove prefix + prefix = basename[:split_idx] + idx = basename[split_idx:] + # remove file extension + idx = idx.replace(f".{ext}", "") + return int(idx), prefix, ext + + +def is_context_dataset(labeled_data_dir: str, selected_frames_filename: str) -> bool: + """Starting from labeled data directory, determine if this is a context dataset or not.""" + # loop over all labeled frames, break as soon as single frame fails + is_context = True + n_frames = 0 + if os.path.isdir(labeled_data_dir): + for d in os.listdir(labeled_data_dir): + frames_in_dir_file = os.path.join(labeled_data_dir, d, selected_frames_filename) + if not os.path.exists(frames_in_dir_file): + continue + frames_in_dir = np.genfromtxt(frames_in_dir_file, delimiter=",", dtype=str) + print(frames_in_dir) + for frame in frames_in_dir: + idx_img, prefix, ext = get_frame_number(frame.split("/")[-1]) + # get the frames -> t-2, t-1, t, t+1, t + 2 + list_idx = [idx_img - 2, idx_img - 1, idx_img, idx_img + 1, idx_img + 2] + print(list_idx) + for fr_num in list_idx: + # replace frame number with 0 if we're at the beginning of the video + fr_num = max(0, fr_num) + # split name into pieces + img_pieces = frame.split("/") + # figure out length of integer + int_len = len(img_pieces[-1].replace(f".{ext}", "").replace(prefix, "")) + # replace original frame number with context frame number + img_pieces[-1] = f"{prefix}{str(fr_num).zfill(int_len)}.{ext}" + img_name = "/".join(img_pieces) + if not os.path.exists(os.path.join(labeled_data_dir, d, img_name)): + is_context = False + break + else: + n_frames += 1 + # set to False if we didn't find any frames + if n_frames == 0: + is_context = False + return is_context + + def collect_dlc_labels(dlc_dir: str) -> pd.DataFrame: """Collect video-specific labels from DLC project and save in a single pandas dataframe.""" @@ -207,17 +295,13 @@ def collect_dlc_labels(dlc_dir: str) -> pd.DataFrame: df_tmp = pd.read_csv(csv_file, header=[0, 1, 2], index_col=0) if len(df_tmp.index.unique()) != df_tmp.shape[0]: # new DLC labeling scheme that splits video/image in different cells - vids = df_tmp.loc[ - :, ("Unnamed: 1_level_0", "Unnamed: 1_level_1", "Unnamed: 1_level_2")] - imgs = df_tmp.loc[ - :, ("Unnamed: 2_level_0", "Unnamed: 2_level_1", "Unnamed: 2_level_2")] + levels1 = ("Unnamed: 1_level_0", "Unnamed: 1_level_1", "Unnamed: 1_level_2") + vids = df_tmp.loc[:, levels1] + levels2 = ("Unnamed: 2_level_0", "Unnamed: 2_level_1", "Unnamed: 2_level_2") + imgs = df_tmp.loc[:, levels2] new_col = [f"labeled-data/{v}/{i}" for v, i in zip(vids, imgs)] - df_tmp1 = df_tmp.drop( - ("Unnamed: 1_level_0", "Unnamed: 1_level_1", "Unnamed: 1_level_2"), axis=1, - ) - df_tmp2 = df_tmp1.drop( - ("Unnamed: 2_level_0", "Unnamed: 2_level_1", "Unnamed: 2_level_2"), axis=1, - ) + df_tmp1 = df_tmp.drop(levels1, axis=1) + df_tmp2 = df_tmp1.drop(levels2, axis=1) df_tmp2.index = new_col df_tmp = df_tmp2 except IndexError: @@ -226,7 +310,7 @@ def collect_dlc_labels(dlc_dir: str) -> pd.DataFrame: os.path.join(dlc_dir, "labeled-data", d, "CollectedData*.h5") )[0] df_tmp = pd.read_hdf(h5_file) - if type(df_tmp.index) == pd.core.indexes.multi.MultiIndex: + if isinstance(df_tmp.index, pd.core.indexes.multi.MultiIndex): # new DLC labeling scheme that splits video/image in different cells imgs = [i[2] for i in df_tmp.index] vids = [df_tmp.index[0][1] for _ in imgs] @@ -243,3 +327,11 @@ def collect_dlc_labels(dlc_dir: str) -> pd.DataFrame: df_all = pd.concat(dfs) return df_all + + +def abspath(path): + if path[0] == "/": + path_ = path[1:] + else: + path_ = path + return os.path.abspath(path_) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..23c0117 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,16 @@ +[flake8] +max-line-length = 99 +ignore = F821, W503 +extend-ignore = E203 +exclude = + .git, + __pycache__, + __init__.py, + build, + dist, + docs/ + scripts/ + +[isort] +line_length = 99 +profile = black \ No newline at end of file diff --git a/setup.py b/setup.py index 798d1c7..09932b6 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import find_packages, setup -VERSION = "1.0.0" +VERSION = "1.1.0" # add the README.md file to the long_description with open("README.md", "r") as fh: @@ -20,7 +20,7 @@ "tables", "tqdm", "watchdog", - "google-auth-oauthlib==0.7.1", # freeze this for compatibility between tensorboard and label-studio + "google-auth-oauthlib==0.7.1", # freeze for compatibility between tensorboard and label-studio "label-studio==1.9.1", "label-studio-sdk==0.0.32", ] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6d79268 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,81 @@ +"""Provide pytest fixtures for the entire test suite. + +These fixtures create data and data modules that can be reused by other tests. + +""" + +import numpy as np +import os +import pytest +import shutil + +from lightning_pose_app import ( + LIGHTNING_POSE_DIR, + LABELED_DATA_DIR, + SELECTED_FRAMES_FILENAME, + VIDEOS_DIR, + VIDEOS_TMP_DIR, +) + +ROOT = os.path.dirname(os.path.dirname(__file__)) + + +def make_tmp_project() -> tuple: + + proj_dir = "data/mirror-mouse-example" + + proj_dir_abs = os.path.join(ROOT, proj_dir) + if os.path.isdir(proj_dir_abs): + print(f"{proj_dir_abs} already exists!") + return proj_dir, proj_dir_abs + + # copy full example data directory over + src = os.path.join(ROOT, LIGHTNING_POSE_DIR, proj_dir) + shutil.copytree(src, proj_dir_abs) + + # copy and rename the video for further tests + tmp_vid_dir = os.path.join(proj_dir_abs, VIDEOS_TMP_DIR) + os.makedirs(tmp_vid_dir, exist_ok=True) + src = os.path.join(proj_dir_abs, VIDEOS_DIR, "test_vid.mp4") + dst = os.path.join(tmp_vid_dir, "test_vid_copy.mp4") + shutil.copyfile(src, dst) + + # make csv file for label studio + n_frames = len(os.listdir(os.path.join(proj_dir_abs, LABELED_DATA_DIR))) + idxs_selected = np.arange(n_frames) + n_digits = 2 + extension = "png" + frames_to_label = np.sort(np.array([ + "img%s.%s" % (str(idx).zfill(n_digits), extension) for idx in idxs_selected])) + np.savetxt( + os.path.join(proj_dir_abs, LABELED_DATA_DIR, SELECTED_FRAMES_FILENAME), + frames_to_label, + delimiter=",", + fmt="%s" + ) + + return proj_dir, proj_dir_abs + + +@pytest.fixture +def tmp_proj_dir() -> str: + + proj_dir, proj_dir_abs = make_tmp_project() + + # return to tests + yield proj_dir + + # cleanup after all tests have run (no more calls to yield) + shutil.rmtree(proj_dir_abs) + + +@pytest.fixture +def video_file() -> str: + return os.path.join( + ROOT, LIGHTNING_POSE_DIR, "data/mirror-mouse-example/videos/test_vid.mp4", + ) + + +@pytest.fixture +def root_dir() -> str: + return ROOT diff --git a/tests/test_ui/test_extract_frames.py b/tests/test_ui/test_extract_frames.py new file mode 100644 index 0000000..29fea1b --- /dev/null +++ b/tests/test_ui/test_extract_frames.py @@ -0,0 +1,187 @@ +from lightning.app import CloudCompute +import numpy as np +import os +import pandas as pd +import shutil + +from lightning_pose_app import LABELED_DATA_DIR, SELECTED_FRAMES_FILENAME +from lightning_pose_app import VIDEOS_TMP_DIR, VIDEOS_DIR + + +def test_extract_frames_work(video_file, tmpdir): + """Test private methods here; test run method externally from the UI object.""" + + from lightning_pose_app import LABELED_DATA_DIR, SELECTED_FRAMES_FILENAME + from lightning_pose_app.ui.extract_frames import ExtractFramesWork + + work = ExtractFramesWork( + cloud_compute=CloudCompute("default"), + ) + + # ----------------- + # read frame function + # ----------------- + resize_dims = 8 + frames = work._read_nth_frames(video_file, n=10, resize_dims=resize_dims) + assert frames.shape == (100, resize_dims, resize_dims) + + # ----------------- + # select indices + # ----------------- + n_clusters = 5 + idxs = work._select_frame_idxs( + video_file, resize_dims=resize_dims, n_clusters=n_clusters, frame_skip=1, + ) + assert len(idxs) == n_clusters + + # ----------------- + # export frames + # ----------------- + save_dir_0 = os.path.join(str(tmpdir), 'labeled-frames-0') + work._export_frames( + video_file=video_file, + save_dir=save_dir_0, + frame_idxs=idxs, + context_frames=0, # no context + ) + assert len(os.listdir(save_dir_0)) == len(idxs) + + save_dir_1 = os.path.join(str(tmpdir), 'labeled-frames-1') + idxs = np.array([5, 10, 15, 20]) + work._export_frames( + video_file=video_file, + save_dir=save_dir_1, + frame_idxs=idxs, + context_frames=2, # 2-frame context + ) + assert len(os.listdir(save_dir_1)) == 5 * len(idxs) + + save_dir_2 = os.path.join(str(tmpdir), 'labeled-frames-2') + idxs = np.array([10]) # try with single frame + work._export_frames( + video_file=video_file, + save_dir=save_dir_2, + frame_idxs=idxs, + context_frames=2, # 2-frame context + ) + assert len(os.listdir(save_dir_2)) == 5 * len(idxs) + + # ----------------- + # extract frames + # ----------------- + proj_dir = os.path.join(str(tmpdir), 'proj-dir-0') + video_name = os.path.splitext(os.path.basename(str(video_file)))[0] + video_dir = os.path.join(proj_dir, LABELED_DATA_DIR, video_name) + os.makedirs(os.path.dirname(video_dir), exist_ok=True) # need to create for path purposes + n_frames_per_video = 10 + work._extract_frames( + video_file=video_file, + proj_dir=proj_dir, + n_frames_per_video=n_frames_per_video, + frame_range=[0, 1], + ) + assert os.path.exists(video_dir) + assert len(os.listdir(video_dir)) > n_frames_per_video + assert os.path.exists(os.path.join(video_dir, SELECTED_FRAMES_FILENAME)) + assert work.work_is_done_extract_frames + + # ----------------- + # unzip frames + # ----------------- + # zip up a subset of the frames extracted from the previous test + n_frames_to_zip = 5 + frame_files = os.listdir(save_dir_1) + new_vid_name = "TEST_VID_ZIPPED_FRAMES" + dst = os.path.join(tmpdir, new_vid_name) + os.makedirs(dst, exist_ok=True) + files = [] + for f in range(n_frames_to_zip): + src = os.path.join(save_dir_1, frame_files[f]) + shutil.copyfile(src, os.path.join(dst, frame_files[f])) + files.append(frame_files[f]) + # make a csv file to accompany frames + np.savetxt( + os.path.join(dst, SELECTED_FRAMES_FILENAME), + np.sort(files), + delimiter=",", + fmt="%s", + ) + # zip it all up + new_video_name = new_vid_name + "_NEW" + new_video_path = os.path.join(tmpdir, new_video_name) + zipped_file = new_video_path + ".zip" + shutil.make_archive(new_video_path, "zip", dst) + + # test unzip frames + proj_dir = os.path.join(str(tmpdir), 'proj-dir-1') + video_dir = os.path.join(proj_dir, LABELED_DATA_DIR, new_video_name) + os.makedirs(os.path.dirname(video_dir), exist_ok=True) # need to create for path purposes + work.work_is_done_extract_frames = False + work._unzip_frames( + video_file=zipped_file, + proj_dir=proj_dir, + ) + assert os.path.exists(video_dir) + assert len(os.listdir(video_dir)) == (n_frames_to_zip + 1) + idx_file_abs = os.path.join(video_dir, SELECTED_FRAMES_FILENAME) + assert os.path.exists(idx_file_abs) + df = pd.read_csv(idx_file_abs, header=None) + assert df.shape[0] == n_frames_to_zip + assert work.work_is_done_extract_frames + + # ----------------- + # cleanup + # ----------------- + del work + + +def test_extract_frames_ui(root_dir, tmp_proj_dir): + + from lightning_pose_app.ui.extract_frames import ExtractFramesUI + + video_name = "test_vid_copy" + video_file_ = video_name + ".mp4" + video_file = os.path.join(tmp_proj_dir, VIDEOS_TMP_DIR, video_file_) + + flow = ExtractFramesUI() + + # set attributes + flow.proj_dir = tmp_proj_dir + flow.st_extract_status[video_file] = "initialized" + + # ------------------- + # extract frames + # ------------------- + n_frames_per_video = 10 + flow.run( + action="extract_frames", + video_files=[video_file], + n_frames_per_video=n_frames_per_video, + testing=True, + ) + + # make sure flow attributes are properly cleaned up + assert flow.st_extract_status[video_file] == "complete" + assert len(flow.works_dict) == 0 + assert flow.work_is_done_extract_frames + + # make sure frames were extracted + proj_dir_abs = os.path.join(root_dir, tmp_proj_dir) + frame_dir_abs = os.path.join(proj_dir_abs, LABELED_DATA_DIR, video_name) + idx_file_abs = os.path.join(frame_dir_abs, SELECTED_FRAMES_FILENAME) + assert os.path.isfile(os.path.join(proj_dir_abs, VIDEOS_DIR, video_file_)) + assert os.path.isdir(frame_dir_abs) + assert os.path.isfile(idx_file_abs) + + df = pd.read_csv(idx_file_abs, header=None) + assert df.shape[0] == n_frames_per_video + + # ------------------- + # unzip frames + # ------------------- + # TODO + + # ----------------- + # cleanup + # ----------------- + del flow diff --git a/tests/test_ui/test_project.py b/tests/test_ui/test_project.py new file mode 100644 index 0000000..9e07f42 --- /dev/null +++ b/tests/test_ui/test_project.py @@ -0,0 +1,182 @@ +import os +import shutil +import yaml + +from lightning_pose_app import LIGHTNING_POSE_DIR, MODELS_DIR +from lightning_pose_app import LABELSTUDIO_DB_DIR, LABELSTUDIO_METADATA_FILENAME + + +def test_project_ui(root_dir, tmp_proj_dir): + + from lightning_pose_app.ui.project import ProjectUI + + proj_name = os.path.split(tmp_proj_dir)[-1] + proj_dir_abs = os.path.join(root_dir, tmp_proj_dir) + + # load default config and pass to project manager + config_dir = os.path.join(LIGHTNING_POSE_DIR, "scripts", "configs") + default_config_dict = yaml.safe_load(open(os.path.join(config_dir, "config_default.yaml"))) + + flow = ProjectUI( + data_dir="/data", + default_config_dict=default_config_dict, + ) + + # ------------------- + # test find projects + # ------------------- + flow.run(action="find_initialized_projects") + assert proj_name in flow.initialized_projects + assert LABELSTUDIO_DB_DIR not in flow.initialized_projects + + # ------------------- + # test update paths + # ------------------- + flow.run(action="update_paths") + assert flow.proj_dir is None + + flow.run(action="update_paths", project_name=proj_name) + assert flow.proj_dir == "/" + str(tmp_proj_dir) + assert flow.config_name == f"model_config_{proj_name}.yaml" + assert flow.config_file == "/" + os.path.join(tmp_proj_dir, flow.config_name) + assert flow.model_dir == "/" + os.path.join(tmp_proj_dir, MODELS_DIR) + assert flow.proj_dir_abs == proj_dir_abs + + # ------------------- + # test update config + # ------------------- + # config dict is initially none + assert flow.config_dict is None + + # update project config with no vals should set config_dict to defaults (+ "keypoints"} and + # NOT save out config + config_init = flow.default_config_dict.copy() + config_init["data"]["keypoints"] = None + flow.run(action="update_project_config", new_vals=None) + assert flow.config_dict == config_init + assert not os.path.exists(os.path.join(root_dir, flow.config_file[1:])) + + # update with new vals; should update object attribute and yaml file + new_vals_dict_0 = { + "data": {"kepoints": ["nose", "tail"], "num_keypoints": 2}, + "training": {"train_batch_size": 2}, + } + flow.run(action="update_project_config", new_vals_dict=new_vals_dict_0) + config_file = os.path.join(root_dir, flow.config_file[1:]) + assert os.path.exists(config_file) + config_dict_saved = yaml.safe_load(open(config_file)) + for key1, val1 in new_vals_dict_0.items(): + for key2, val2 in val1.items(): + assert flow.config_dict[key1][key2] == val2 + assert config_dict_saved[key1][key2] == val2 + + # update with new vals again; make sure previous new vals remain + new_vals_dict_1 = { + "data": {"columns_for_singleview_pca": [0, 1], "mirrored_column_matches": []}, + "training": {"val_batch_size": 2}, + } + flow.run(action="update_project_config", new_vals_dict=new_vals_dict_1) + config_dict_saved = yaml.safe_load(open(config_file)) + for key1, val1 in new_vals_dict_0.items(): + for key2, val2 in val1.items(): + assert flow.config_dict[key1][key2] == val2 + assert config_dict_saved[key1][key2] == val2 + for key1, val1 in new_vals_dict_1.items(): + for key2, val2 in val1.items(): + assert flow.config_dict[key1][key2] == val2 + assert config_dict_saved[key1][key2] == val2 + + # ------------------- + # test update shapes + # ------------------- + new_vals_dict_2 = { + "data": { + "image_orig_dims": { + "height": 406, + "width": 396, + }, + "image_resize_dims": { + "height": 256, + "width": 256, + }, + }, + } + flow.run(action="update_frame_shapes") + config_dict_saved = yaml.safe_load(open(config_file)) + for key1, val1 in new_vals_dict_2.items(): + for key2, val2 in val1.items(): + assert flow.config_dict[key1][key2] == val2 + assert config_dict_saved[key1][key2] == val2 + + # ------------------- + # test update shapes + # ------------------- + # should return None if no label studio metadata file + metadata_file = os.path.join(root_dir, tmp_proj_dir, LABELSTUDIO_METADATA_FILENAME) + flow.run(action="compute_labeled_frame_fraction") + if not os.path.exists(metadata_file): + assert flow.n_labeled_frames is None + assert flow.n_total_frames is None + else: + assert flow.n_labeled_frames is not None + assert flow.n_total_frames is not None + + # ------------------- + # test load defaults + # ------------------- + assert flow.st_keypoints_ == [] + assert flow.st_n_keypoints == 0 + assert flow.st_pcasv_columns == [] + assert flow.st_pcamv_columns == [] + assert flow.st_n_views == 0 + flow.run(action="load_project_defaults") + assert flow.st_keypoints_ == config_dict_saved["data"]["keypoints"] + assert flow.st_n_keypoints == config_dict_saved["data"]["num_keypoints"] + assert flow.st_n_keypoints == 2 + assert flow.st_pcasv_columns == config_dict_saved["data"]["columns_for_singleview_pca"] + assert len(flow.st_pcasv_columns) == 2 + assert flow.st_pcamv_columns == config_dict_saved["data"]["mirrored_column_matches"] + assert len(flow.st_pcamv_columns) == 0 + assert flow.st_n_views == 1 + + # ------------------- + # test find models + # ------------------- + m1 = "00-11-22/33-44-55" + m2 = "aa-bb-cc/dd-ee-ff" + m1_path = os.path.join(root_dir, tmp_proj_dir, MODELS_DIR, m1) + m2_path = os.path.join(root_dir, tmp_proj_dir, MODELS_DIR, m2) + os.makedirs(m1_path, exist_ok=True) + os.makedirs(m2_path, exist_ok=True) + flow.run(action="update_trained_models_list") + assert len(flow.trained_models) == 2 + assert m1 in flow.trained_models + assert m2 in flow.trained_models + + # ------------------- + # test upload existing + # ------------------- + # TODO: fill out tests for uploading an existing project + + # ------------------- + # test delete project + # ------------------- + # copy project + copy_proj_name = "TEMP_TEST_PROJECT" + src = proj_dir_abs + dst = proj_dir_abs.replace(proj_name, copy_proj_name) + shutil.copytree(src, dst) + flow.run(action="find_initialized_projects") + assert proj_name in flow.initialized_projects + assert copy_proj_name in flow.initialized_projects + + flow.proj_dir = flow.proj_dir.replace(proj_name, copy_proj_name) + flow.run(action="delete_project") + assert flow.st_project_name == "" + assert not os.path.exists(dst) + assert copy_proj_name not in flow.initialized_projects + + # ------------------- + # cleanup + # ------------------- + del flow diff --git a/tests/test_ui/test_train_infer.py b/tests/test_ui/test_train_infer.py new file mode 100644 index 0000000..3e176a3 --- /dev/null +++ b/tests/test_ui/test_train_infer.py @@ -0,0 +1,230 @@ +from datetime import datetime +import os +import yaml + +from lightning_pose_app import ( + LIGHTNING_POSE_DIR, + MODELS_DIR, + MODEL_VIDEO_PREDS_INFER_DIR, + MODEL_VIDEO_PREDS_TRAIN_DIR, + VIDEOS_DIR, +) + + +def test_train_infer_work(root_dir, tmp_proj_dir, video_file): + """Test private methods here; test run method externally from the UI object.""" + + from lightning_pose_app.ui.project import ProjectUI + from lightning_pose_app.ui.train_infer import LitPose + + work = LitPose() + + # ---------------- + # helper flow + # ---------------- + # load default config and pass to project manager + config_dir = os.path.join(LIGHTNING_POSE_DIR, "scripts", "configs") + default_config_dict = yaml.safe_load(open(os.path.join(config_dir, "config_default.yaml"))) + flow = ProjectUI( + data_dir="/data", + default_config_dict=default_config_dict, + ) + proj_name = os.path.split(tmp_proj_dir)[-1] + flow.run(action="update_paths", project_name=proj_name) + flow.run(action="update_frame_shapes") + + # ---------------- + # train + # ---------------- + base_dir = os.path.join(root_dir, tmp_proj_dir) + config_overrides = { + "data": { + "data_dir": base_dir, + "video_dir": os.path.join(base_dir, VIDEOS_DIR), + "num_keypoints": 17, + }, + "eval": { + "test_videos_directory": os.path.join(base_dir, VIDEOS_DIR), + "predict_vids_after_training": False, + "save_vids_after_training": False, + }, + "model": { + "model_type": "heatmap", + "losses_to_use": [], + }, + "training": { + "imgaug": "dlc", + "max_epochs": 2, + "check_val_every_n_epoch": 2, + }, + } + model_name_0 = datetime.today().strftime("%Y-%m-%d/%H-%M-%S_PYTEST") + results_dir_0 = os.path.join(base_dir, MODELS_DIR, model_name_0) + work._train( + config_file=os.path.join(tmp_proj_dir, flow.config_name), + config_overrides=config_overrides, + results_dir=results_dir_0, + ) + results_artifacts_0 = os.listdir(results_dir_0) + assert work.work_is_done_training + assert os.path.exists(results_dir_0) + assert "predictions.csv" in results_artifacts_0 + assert "lightning_logs" not in results_artifacts_0 + assert "video_preds" not in results_artifacts_0 + + # ---------------- + # output videos + # ---------------- + config_overrides["eval"]["predict_vids_after_training"] = True + config_overrides["eval"]["save_vids_after_training"] = True + model_name_1 = datetime.today().strftime("%Y-%m-%d/%H-%M-%S_PYTEST") + results_dir_1 = os.path.join(base_dir, MODELS_DIR, model_name_1) + work._train( + config_file=os.path.join(tmp_proj_dir, flow.config_name), + config_overrides=config_overrides, + results_dir=results_dir_1, + ) + results_artifacts_1 = os.listdir(results_dir_1) + assert work.work_is_done_training + assert os.path.exists(results_dir_1) + assert "predictions.csv" in results_artifacts_1 + assert "lightning_logs" not in results_artifacts_1 + assert MODEL_VIDEO_PREDS_TRAIN_DIR in results_artifacts_1 + labeled_vid_dir = os.path.join(results_dir_1, MODEL_VIDEO_PREDS_TRAIN_DIR, "labeled_videos") + assert os.path.exists(labeled_vid_dir) + assert len(os.listdir(labeled_vid_dir)) > 0 + + # ---------------- + # infer + # ---------------- + work._run_inference( + model_dir=os.path.join(tmp_proj_dir, MODELS_DIR, model_name_0), + video_file=video_file, + ) + results_dir_2 = os.path.join(base_dir, MODELS_DIR, model_name_0, MODEL_VIDEO_PREDS_INFER_DIR) + results_artifacts_2 = os.listdir(results_dir_2) + assert work.work_is_done_inference + preds = os.path.basename(video_file).replace(".mp4", ".csv") + assert preds in results_artifacts_2 + assert preds.replace(".csv", "_temporal_norm.csv") in results_artifacts_2 + # assert preds.replace(".csv", "_pca_singleview_error.csv") in results_artifacts_2 + # assert preds.replace(".csv", "_pca_multiview_error.csv") in results_artifacts_2 + assert preds.replace(".csv", ".short.mp4") in results_artifacts_2 + assert preds.replace(".csv", ".short.csv") in results_artifacts_2 + assert preds.replace(".csv", ".short.labeled.mp4") in results_artifacts_2 + + # ---------------- + # fiftyone + # ---------------- + # just run and make sure it doesn't fail + work._make_fiftyone_dataset( + config_file=os.path.join(tmp_proj_dir, flow.config_name), + results_dir=results_dir_1, + config_overrides=config_overrides, + ) + + # ---------------- + # clean up + # ---------------- + del flow + del work + + +def test_train_infer_ui(root_dir, tmp_proj_dir, video_file): + """Test private methods here; test run method externally from the UI object.""" + + from lightning_pose_app.ui.project import ProjectUI + from lightning_pose_app.ui.train_infer import TrainUI, VIDEO_LABEL_NONE + + base_dir = os.path.join(root_dir, tmp_proj_dir) + + flow = TrainUI() + + # set attributes + flow.proj_dir = "/" + str(tmp_proj_dir) + flow.st_train_status = { + "super": "initialized", + "semisuper": None, + "super ctx": None, + "semisuper ctx": None, + } + flow.st_losses = {"super": []} + flow.st_train_label_opt = VIDEO_LABEL_NONE # don't run inference on vids + flow.st_max_epochs = 2 + + # ---------------- + # helper flow + # ---------------- + # load default config and pass to project manager + config_dir = os.path.join(LIGHTNING_POSE_DIR, "scripts", "configs") + default_config_dict = yaml.safe_load(open(os.path.join(config_dir, "config_default.yaml"))) + flowp = ProjectUI( + data_dir="/data", + default_config_dict=default_config_dict, + ) + proj_name = os.path.split(tmp_proj_dir)[-1] + flowp.run(action="update_paths", project_name=proj_name) + flowp.run(action="update_frame_shapes") + flowp.run( + action="update_project_config", + new_vals_dict={ + "data": {"num_keypoints": 17}, + "training": {"check_val_every_n_epoch": 2}, # match flow.st_max_epochs + }, + ) + + # ---------------- + # train + # ---------------- + model_name_0 = datetime.today().strftime("%Y-%m-%d/%H-%M-%S_PYTEST") + flow.st_datetimes = {"super": model_name_0} + flow.run(action="train", config_filename=f"model_config_{proj_name}.yaml") + + # check flow state + assert flow.st_train_status["super"] == "complete" + assert flow.work.progress == 0.0 + assert flow.work.work_is_done_training + + # check output files + results_dir_0 = os.path.join(base_dir, MODELS_DIR, model_name_0) + results_artifacts_0 = os.listdir(results_dir_0) + assert os.path.exists(results_dir_0) + assert "predictions.csv" in results_artifacts_0 + assert "lightning_logs" not in results_artifacts_0 + assert "video_preds" not in results_artifacts_0 + + # ---------------- + # infer + # ---------------- + flow.st_infer_status[video_file] = "initialized" + flow.st_inference_model = model_name_0 + flow.run(action="run_inference", video_files=[video_file], testing=True) + + # check flow state + assert flow.st_infer_status[video_file] == "complete" + assert flow.work_is_done_inference + assert len(flow.works_dict) == 0 + + # check output files + results_dir_1 = os.path.join(base_dir, MODELS_DIR, model_name_0, MODEL_VIDEO_PREDS_INFER_DIR) + results_artifacts_1 = os.listdir(results_dir_1) + preds = os.path.basename(video_file).replace(".mp4", ".csv") + assert preds in results_artifacts_1 + assert preds.replace(".csv", "_temporal_norm.csv") in results_artifacts_1 + # assert preds.replace(".csv", "_pca_singleview_error.csv") in results_artifacts_2 + # assert preds.replace(".csv", "_pca_multiview_error.csv") in results_artifacts_2 + assert preds.replace(".csv", ".short.mp4") in results_artifacts_1 + assert preds.replace(".csv", ".short.csv") in results_artifacts_1 + assert preds.replace(".csv", ".short.labeled.mp4") in results_artifacts_1 + + # ---------------- + # determine type + # ---------------- + flow.run(action="determine_dataset_type") + assert not flow.allow_context + + # ---------------- + # clean up + # ---------------- + del flowp + del flow diff --git a/tests/test_utilities.py b/tests/test_utilities.py new file mode 100644 index 0000000..2d49b5a --- /dev/null +++ b/tests/test_utilities.py @@ -0,0 +1,172 @@ +import cv2 +import numpy as np +import os +import pandas as pd + +from lightning_pose_app.utilities import check_codec_format + + +def test_args_to_dict(): + + from lightning_pose_app.utilities import args_to_dict + + string = "A=1 B=2" + args_dict = args_to_dict(string) + assert len(args_dict) == 2 + assert args_dict["A"] == "1" + assert args_dict["B"] == "2" + + +def test_check_codec_format(video_file): + assert check_codec_format(video_file) + + +def test_reencode_video(video_file, tmpdir): + from lightning_pose_app.utilities import reencode_video + video_file_new = os.path.join(str(tmpdir), 'test.mp4') + reencode_video(video_file, video_file_new) + assert check_codec_format(video_file_new) + + +def test_copy_and_reformat_video(video_file, tmpdir): + + from lightning_pose_app.utilities import copy_and_reformat_video + + # check when dst_dir exists + video_file_new_1 = copy_and_reformat_video(video_file, str(tmpdir), remove_old=False) + assert os.path.exists(video_file) + assert check_codec_format(video_file_new_1) + + # check when dst_dir does not exist + dst_dir = str(os.path.join(tmpdir, 'subdir')) + video_file_new_2 = copy_and_reformat_video(video_file, dst_dir, remove_old=False) + assert os.path.exists(video_file) + assert check_codec_format(video_file_new_2) + + +def test_copy_and_reformat_video_directory(video_file, tmpdir): + from lightning_pose_app.utilities import copy_and_reformat_video_directory + src_dir = os.path.dirname(video_file) + dst_dir = str(tmpdir) + copy_and_reformat_video_directory(src_dir, dst_dir) + assert os.path.exists(video_file) + files = os.listdir(dst_dir) + for file in files: + assert check_codec_format(os.path.join(dst_dir, file)) + + +def test_get_frames_from_idxs(video_file): + from lightning_pose_app.utilities import get_frames_from_idxs + cap = cv2.VideoCapture(video_file) + n_frames = 3 + frames = get_frames_from_idxs(cap, np.arange(n_frames)) + cap.release() + assert frames.shape == (n_frames, 1, 406, 396) + assert frames.dtype == np.uint8 + + +def test_make_video_snippet(video_file, tmpdir): + + from lightning_pose_app.utilities import make_video_snippet + + # get video info + cap = cv2.VideoCapture(video_file) + n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + + # make fake predictions and save to tmpdir + keypoints = ['paw1', 'paw2'] + n_keypoints = len(keypoints) + xyl_labels = ["x", "y", "likelihood"] + pdindex = pd.MultiIndex.from_product( + [["tracker"], keypoints, xyl_labels], names=["scorer", "bodyparts", "coords"], + ) + preds = np.random.rand(n_frames, n_keypoints * 3) # x, y, likelihood + df = pd.DataFrame(preds, columns=pdindex) + preds_file = os.path.join(str(tmpdir), 'preds.csv') + df.to_csv(preds_file) + + # CHECK 1: requested clip is shorter than actual video + clip_length = 1 + snippet_file = make_video_snippet( + video_file=video_file, + preds_file=preds_file, + clip_length=clip_length, + ) + cap = cv2.VideoCapture(snippet_file) + fps = cap.get(cv2.CAP_PROP_FPS) + n_frames_1 = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + assert n_frames_1 == int(fps * clip_length) + + # CHECK 2: requested clip is longer than actual video (return original video) + clip_length = 100 + snippet_file = make_video_snippet( + video_file=video_file, + preds_file=preds_file, + clip_length=clip_length, + ) + cap = cv2.VideoCapture(snippet_file) + n_frames_2 = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + assert n_frames_2 == n_frames + + +def test_get_frame_number(): + + from lightning_pose_app.utilities import get_frame_number + + file = "img000346.png" + out = get_frame_number(file) + assert out == (346, "img", "png") + + file = "frame3.jpg" + out = get_frame_number(file) + assert out == (3, "frame", "jpg") + + file = "im000.jpeg" + out = get_frame_number(file) + assert out == (0, "im", "jpeg") + + +def test_is_context_dataset(tmp_proj_dir): + + from lightning_pose_app import LABELED_DATA_DIR, SELECTED_FRAMES_FILENAME + from lightning_pose_app.utilities import is_context_dataset + + labeled_data_dir = os.path.abspath(tmp_proj_dir) + + # test should fail since each frame has an entry in the csv file + assert not is_context_dataset( + labeled_data_dir=labeled_data_dir, + selected_frames_filename=SELECTED_FRAMES_FILENAME, + ) + + # remove final two entries to provide context + csv_file = os.path.join(labeled_data_dir, LABELED_DATA_DIR, SELECTED_FRAMES_FILENAME) + img_files = np.genfromtxt(csv_file, delimiter=',', dtype=str) + new_csv_file = csv_file.replace(".csv", ".tmp.csv") + np.savetxt(new_csv_file, img_files[:-2], delimiter=",", fmt="%s") + + # test should pass since each frame has context (frame 00 will auto use 00 for negative frames) + assert is_context_dataset( + labeled_data_dir=labeled_data_dir, + selected_frames_filename=os.path.basename(new_csv_file), + ) + + # test should fail if labeled frame directory does not exist + assert not is_context_dataset( + labeled_data_dir=os.path.join(labeled_data_dir, "nonexistent_directory"), + selected_frames_filename=SELECTED_FRAMES_FILENAME, + ) + + +def test_abspath(): + + from lightning_pose_app.utilities import abspath + + path1 = 'test/directory' + abspath1 = abspath(path1) + assert abspath1 == os.path.abspath(path1) + + path2 = '/test/directory' + abspath2 = abspath(path2) + assert abspath2 == os.path.abspath(path2[1:])