Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pre release changes #94

38 changes: 19 additions & 19 deletions lightning_pose_app/backend/extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def read_nth_frames(
# If the frame was successfully read, then process it
if frame_counter % n == 0:
frame_resize = cv2.resize(frame, (resize_dims, resize_dims))
frame_gray = cv2.cvtColor(frame_resize, cv2.COLOR_BGR2GRAY)
frame_gray = cv2.cvtColor(frame_resize, cv2.COLOR_BGR2RGB)
frames.append(frame_gray.astype(np.float16))
frame_counter += 1
progress = frame_counter / frame_total * 100.0
Expand Down Expand Up @@ -344,7 +344,7 @@ def export_frames(
for frame, idx in zip(frames, frame_idxs):
cv2.imwrite(
filename=os.path.join(save_dir, "img%s.%s" % (str(idx).zfill(n_digits), format)),
img=frame[0],
img=cv2.cvtColor(frame.transpose(1, 2, 0), cv2.COLOR_RGB2BGR),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the conversion from RGB to BGR necessary here? this should just work if we save in RGB format right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Im getting an odd images when I remove the conversion. looks like one of the channels is missing..

)


Expand Down Expand Up @@ -394,28 +394,28 @@ def convert_csv_to_dict(csv_path: str, selected_body_parts: list = None) -> dict

def annotate_frames(image_path: str, annotations: dict, output_path: str):
try:
image = Image.open(image_path).convert('L')
fig, ax = plt.subplots()

ax.imshow(image, cmap="gray")
image = Image.open(image_path)
# Convert image to RGB if necessary
if image.mode == 'L':
image = image.convert('L')
elif image.mode != 'RGB':
image = image.convert('RGB')

# Get a list of unique body parts and determine colors
unique_bodyparts = list(set([label.split('_')[0] for label in annotations.keys()]))
unique_views = list(set([label.split('_')[1] for label in annotations.keys()]))
fig, ax = plt.subplots()
ax.imshow(image, cmap='gray' if image.mode == 'L' else None)
# Get a list of unique body parts
unique_bodyparts = list(set(annotations.keys()))

color_map = plt.cm.get_cmap('tab10', len(unique_bodyparts))
bodypart_colors = {bodypart: color_map(i) for i, bodypart in enumerate(unique_bodyparts)}

# Create suffix_marker_map dynamically
markers = ['o', '^', 's', 'p', '*', 'x', 'd', 'v', '<', '>']
suffix_marker_map = {
view: markers[i % len(markers)]
for i, view
in enumerate(unique_views)
bodypart_markers = {
bodypart: markers[i % len(markers)] for i, bodypart in enumerate(unique_bodyparts)
}

img_width, img_height = image.size
font_size = max(6, min(img_width, img_height) // 50)
font_size = max(6, min(img_width, img_height) // 100)

for label, coords in annotations.items():
try:
Expand All @@ -427,10 +427,9 @@ def annotate_frames(image_path: str, annotations: dict, output_path: str):
_logger.warning(f"Missing x or y in annotation for {label}")
continue

bodypart = label.split('_')[0]
view = label.split('_')[1]
bodypart = label
color = bodypart_colors[bodypart]
marker = suffix_marker_map[view]
marker = bodypart_markers[bodypart]

ax.plot(x, y, marker, color=color, markersize=3)

Expand All @@ -445,7 +444,8 @@ def annotate_frames(image_path: str, annotations: dict, output_path: str):
frame_number = int(get_frame_number(image_path)[0])

title_text = f'Video: {video} | Frame: {frame_number}'
ax.set_title(title_text, fontsize=font_size, pad=15)

ax.set_title(title_text, fontsize=10, pad=15)
ax.axis('off')

# Ensure the output directory exists
Expand Down
7 changes: 4 additions & 3 deletions lightning_pose_app/backend/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ def get_frames_from_idxs(cap: cv2.VideoCapture, idxs: np.ndarray) -> np.ndarray:
cap.set(1, i)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if fr == 0:
height, width, _ = frame.shape
frames = np.zeros((n_frames, 1, height, width), dtype="uint8")
frames[fr, 0, :, :] = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
height, width, _ = frame_rgb.shape
frames = np.zeros((n_frames, 3, height, width), dtype="uint8")
frames[fr] = frame_rgb.transpose(2, 0, 1)
else:
_logger.debug(
"warning! reached end of video; returning blank frames for remainder of "
Expand Down
33 changes: 20 additions & 13 deletions lightning_pose_app/ui/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,10 @@ def _render_streamlit_fn(state: AppState):
sorted(state.initialized_projects)
)
st.markdown(
"Project exports only contain frames, labels, and the project config file."
"To download trained models or the results files from video inference, see [here]."
"(https://pose-app.readthedocs.io/en/latest/source/accessing_your_data.html#)",
"Project exports only contain frames, labels, and the project config file. "
"To download trained models or the results files from video inference, see "
"[here]"
"(https://pose-app.readthedocs.io/en/latest/source/accessing_your_data.html#).",
unsafe_allow_html=True,
)

Expand All @@ -594,20 +595,26 @@ def _render_streamlit_fn(state: AppState):
"Please label some frames before attempting to download the project."
)
else:
try:
zip_filepath = zip_project_for_export(proj_dir)
# Add download botton
with open(zip_filepath, "rb") as f:
if 'zip_filepath' not in st.session_state:
st.session_state.zip_filepath = None
if st.button("Zip Project Files"):
try:
st.session_state.zip_filepath = zip_project_for_export(proj_dir)
st.success(
"Project files are zipped and ready. Press the button below to "
"download the zipped files."
)
except FileNotFoundError as e:
st.error(str(e))
if st.session_state.zip_filepath:
with open(st.session_state.zip_filepath, "rb") as f:
st.download_button(
label="Download Project",
data=f,
file_name=os.path.basename(zip_filepath)
file_name=os.path.basename(st.session_state.zip_filepath)
)
os.remove(zip_filepath)

except FileNotFoundError as e:
st.error(str(e))

os.remove(st.session_state.zip_filepath)
st.session_state.zip_filepath = None
st.header("Manage Lightning Pose projects")

st_mode = st.radio(
Expand Down
81 changes: 43 additions & 38 deletions lightning_pose_app/ui/streamlit_video_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,46 +82,51 @@ def _render_streamlit_fn(state: AppState):
)
st.divider()
# Add a section to download results files
video_name = extract_video_name(selected_video)
results_files = list_results_files(video_name, model_dir, selected_model)

if results_files:
selected_result_files = st.multiselect(
"**Step 3:** Select results files to download", list(results_files.keys())
)
if selected_result_files:
if len(selected_result_files) == 1:
selected_result_file = results_files[selected_result_files[0]]
new_file_name = f"{selected_model}_{video_name}_{selected_result_files[0]}"
with open(selected_result_file, "rb") as file:
if selected_video:
video_name = extract_video_name(selected_video)
results_files = list_results_files(video_name, model_dir, selected_model)

if results_files:
selected_result_files = st.multiselect(
"**Step 3:** Select results files to download", list(results_files.keys())
)
if selected_result_files:
if len(selected_result_files) == 1:
selected_result_file = results_files[selected_result_files[0]]
new_file_name = (
f"{selected_model}_{video_name}_{selected_result_files[0]}"
)
with open(selected_result_file, "rb") as file:
st.download_button(
label="Download File",
data=file,
file_name=new_file_name
)
else:
st.warning(
"If you select more than one file, they will be downloaded "
"together as a ZIP folder"
)
# Create a zip file
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zip_file:
for result_file_name in selected_result_files:
result_file_path = results_files[result_file_name]
new_file_name = (
f"{selected_model}_{video_name}_{result_file_name}"
)
with open(result_file_path, "rb") as file:
zip_file.writestr(new_file_name, file.read())
zip_buffer.seek(0)
zip_file_name = f"results_{selected_model}_{video_name}.zip"
st.download_button(
label="Download File",
data=file,
file_name=new_file_name
label="Download Files",
data=zip_buffer,
file_name=zip_file_name,
mime="application/zip"
)
else:
st.warning(
"If you select more than one file, they will be downloaded "
"together as a ZIP folder"
)
# Create a zip file
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zip_file:
for result_file_name in selected_result_files:
result_file_path = results_files[result_file_name]
new_file_name = f"{selected_model}_{video_name}_{result_file_name}"
with open(result_file_path, "rb") as file:
zip_file.writestr(new_file_name, file.read())
zip_buffer.seek(0)
zip_file_name = f"results_{selected_model}_{video_name}.zip"
st.download_button(
label="Download Files",
data=zip_buffer,
file_name=zip_file_name,
mime="application/zip"
)
else:
st.write("No results files available for this video.")
else:
st.write("No results files available for this video.")

if selected_video:
# read and show the predictions labeled video
Expand Down
2 changes: 1 addition & 1 deletion tests/backend/test_backend_extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_read_nth_frames(video_file):

resize_dims = 8
frames = read_nth_frames(video_file=video_file, n=10, resize_dims=resize_dims)
assert frames.shape == (100, resize_dims, resize_dims)
assert frames.shape == (100, resize_dims, resize_dims, 3)


def test_select_idxs_kmeans(video_file):
Expand Down
2 changes: 1 addition & 1 deletion tests/backend/test_backend_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_get_frames_from_idxs(video_file):
n_frames = 3
frames = get_frames_from_idxs(cap, np.arange(n_frames))
cap.release()
assert frames.shape == (n_frames, 1, 406, 396)
assert frames.shape == (n_frames, 3, 406, 396)
assert frames.dtype == np.uint8


Expand Down