Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Check labels feature #89

Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,11 @@ def run(self):
self.extract_ui.run_script_video_model = False
self.label_studio.run(action="update_tasks", videos=self.extract_ui.st_video_files)
self.extract_ui.run_script_video_model = False

#TODO: Add one more IF block
Shmuel-columbia marked this conversation as resolved.
Show resolved Hide resolved
if self.extract_ui.proj_dir and self.extract_ui.run_script_check_labels:
self.extract_ui.run(action="save_annotated_frames",selected_body_parts=self.extract_ui.selected_body_parts)
self.extract_ui.run_script_check_labels = False

# -------------------------------------------------------------
# periodically check labeling task and export new labels
Expand Down
160 changes: 160 additions & 0 deletions lightning_pose_app/backend/extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from typing import Optional
from itertools import groupby
from operator import itemgetter
import streamlit as st
from PIL import Image
import zipfile
import logging
import matplotlib.pyplot as plt

from scipy.stats import zscore

from lightning_pose_app.backend.video import (
compute_motion_energy_from_predection_df,
Expand Down Expand Up @@ -338,3 +344,157 @@ def export_frames(
filename=os.path.join(save_dir, "img%s.%s" % (str(idx).zfill(n_digits), format)),
img=frame[0],
)


def get_frame_number(image_path: str) -> int:
base_name = os.path.basename(image_path)
frame_number = int(''.join(filter(str.isdigit, base_name)))
return frame_number


def get_frame_paths(video_folder_path: str):
Shmuel-columbia marked this conversation as resolved.
Show resolved Hide resolved
frame_paths = [
os.path.join(video_folder_path, f)
for f in os.listdir(video_folder_path)
if f.endswith('.png')
]
frame_paths.sort(key=lambda x: int(''.join(filter(str.isdigit, os.path.basename(x)))))
return frame_paths


def convert_csv_to_dict(csv_path: str, selected_body_parts: list = None) -> dict:
proj_dir = os.path.dirname(csv_path)
try:
annotations = pd.read_csv(csv_path, header=[1, 2], index_col=0)
data_dict = {}
for index, row in annotations.iterrows():
frame_rel_path = index
video = os.path.basename(os.path.dirname(frame_rel_path))
frame_number = get_frame_number(os.path.basename(frame_rel_path))

bodyparts = {}
for bodypart in annotations.columns.levels[0]:
if selected_body_parts and "All" not in selected_body_parts:
if bodypart not in selected_body_parts:
continue
try:
x = row[(bodypart, 'x')]
y = row[(bodypart, 'y')]
bodyparts[bodypart] = {'x': x, 'y': y}
except KeyError as e:
print(f"Error extracting {bodypart} coordinates: {e}")

data_dict[frame_rel_path] = {
'frame_full_path': os.path.join(proj_dir, frame_rel_path),
'video': video,
'frame_number': frame_number,
'bodyparts': bodyparts
}
return data_dict
except Exception as e:
print(f"Error converting CSV to dictionary: {e}")
return {}


def annotate_frames(image_path: str, annotations: dict, output_path: str):
try:
image = Image.open(image_path).convert('L')
fig, ax = plt.subplots()

ax.imshow(image, cmap="gray")

# Get a list of unique body parts and determine colors
unique_bodyparts = list(set([label.split('_')[0] for label in annotations.keys()]))
unique_views = list(set([label.split('_')[1] for label in annotations.keys()]))

color_map = plt.cm.get_cmap('tab10', len(unique_bodyparts))
bodypart_colors = {bodypart: color_map(i) for i, bodypart in enumerate(unique_bodyparts)}

# Create suffix_marker_map dynamically
markers = ['o', '^', 's', 'p', '*', 'x', 'd', 'v', '<', '>']
suffix_marker_map = {
view: markers[i % len(markers)]
for i, view
in enumerate(unique_views)
}

img_width, img_height = image.size
font_size = max(6, min(img_width, img_height) // 50)

for label, coords in annotations.items():
try:
x = coords['x']
y = coords['y']

# Skip plotting if coordinates are missing (NaN)
if x is None or y is None or np.isnan(x) or np.isnan(y):
_logger.warning(f"Missing x or y in annotation for {label}")
continue

bodypart = label.split('_')[0]
view = label.split('_')[1]
color = bodypart_colors[bodypart]
marker = suffix_marker_map[view]

ax.plot(x, y, marker, color=color, markersize=3)

ha = 'left' if x < img_width * 0.9 else 'right'
va = 'bottom' if y < img_height * 0.9 else 'top'

ax.text(x + 5, y + 5, label, color='white', fontsize=font_size, ha=ha, va=va)
except ValueError as e:
print(f"Error plotting {label}: {e}")

video = os.path.basename(os.path.dirname(image_path))
frame_number = int(get_frame_number(image_path))

title_text = f'Video: {video} | Frame: {frame_number}'
ax.set_title(title_text, fontsize=font_size, pad=15)
ax.axis('off')

# Ensure the output directory exists
os.makedirs(output_path, exist_ok=True)

output_file = os.path.join(output_path, os.path.basename(image_path))
fig.savefig(output_file, bbox_inches='tight')
plt.close()
_logger.info(f"Annotated frame saved at: {output_file}")
except Exception as e:
_logger.error(f"Failed to plot annotations for {image_path}: {e}")


def zip_annotated_images(labeled_data_check_path):
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(labeled_data_check_path):
for file in files:
file_path = os.path.join(root, file)
zf.write(file_path, os.path.relpath(file_path, labeled_data_check_path))
return zip_buffer


def find_models(model_dir):
Shmuel-columbia marked this conversation as resolved.
Show resolved Hide resolved
trained_models = []
# this returns a list of model training days
dirs_day = os.listdir(model_dir)
# loop over days and find HH-MM-SS
for dir_day in dirs_day:
fullpath1 = os.path.join(model_dir, dir_day)
dirs_time = os.listdir(fullpath1)
for dir_time in dirs_time:
fullpath2 = os.path.join(fullpath1, dir_time)
trained_models.append('/'.join(fullpath2.split('/')[-2:]))
return trained_models


@st.cache_data(show_spinner=False)
def load_image(image_path: str) -> Image:
return Image.open(image_path)


@st.cache_data(show_spinner=False)
def get_all_images(frame_paths: list) -> dict:
images = {}
for path in frame_paths:
images[path] = load_image(path)
return images
1 change: 0 additions & 1 deletion lightning_pose_app/backend/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import glob
import logging


from lightning_pose_app import (
COLLECTED_DATA_FILENAME,
LABELED_DATA_DIR,
Expand Down
Loading