Skip to content

Commit

Permalink
pre release changes (#94)
Browse files Browse the repository at this point in the history
* change from RGB2GRAY to RGB2BGR

* remove gray scale conv

* change get_frames_from_idxs

* remove diffrent markers for diffrent views

* add two steps zip and download project file

* fix empty video error in video player

* remove color for same bodyparts

* flaked

* fix tests
  • Loading branch information
Shmuel-columbia authored Jul 10, 2024
1 parent ef18d5c commit dfe93e1
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 75 deletions.
38 changes: 19 additions & 19 deletions lightning_pose_app/backend/extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def read_nth_frames(
# If the frame was successfully read, then process it
if frame_counter % n == 0:
frame_resize = cv2.resize(frame, (resize_dims, resize_dims))
frame_gray = cv2.cvtColor(frame_resize, cv2.COLOR_BGR2GRAY)
frame_gray = cv2.cvtColor(frame_resize, cv2.COLOR_BGR2RGB)
frames.append(frame_gray.astype(np.float16))
frame_counter += 1
progress = frame_counter / frame_total * 100.0
Expand Down Expand Up @@ -344,7 +344,7 @@ def export_frames(
for frame, idx in zip(frames, frame_idxs):
cv2.imwrite(
filename=os.path.join(save_dir, "img%s.%s" % (str(idx).zfill(n_digits), format)),
img=frame[0],
img=cv2.cvtColor(frame.transpose(1, 2, 0), cv2.COLOR_RGB2BGR),
)


Expand Down Expand Up @@ -394,28 +394,28 @@ def convert_csv_to_dict(csv_path: str, selected_body_parts: list = None) -> dict

def annotate_frames(image_path: str, annotations: dict, output_path: str):
try:
image = Image.open(image_path).convert('L')
fig, ax = plt.subplots()

ax.imshow(image, cmap="gray")
image = Image.open(image_path)
# Convert image to RGB if necessary
if image.mode == 'L':
image = image.convert('L')
elif image.mode != 'RGB':
image = image.convert('RGB')

# Get a list of unique body parts and determine colors
unique_bodyparts = list(set([label.split('_')[0] for label in annotations.keys()]))
unique_views = list(set([label.split('_')[1] for label in annotations.keys()]))
fig, ax = plt.subplots()
ax.imshow(image, cmap='gray' if image.mode == 'L' else None)
# Get a list of unique body parts
unique_bodyparts = list(set(annotations.keys()))

color_map = plt.cm.get_cmap('tab10', len(unique_bodyparts))
bodypart_colors = {bodypart: color_map(i) for i, bodypart in enumerate(unique_bodyparts)}

# Create suffix_marker_map dynamically
markers = ['o', '^', 's', 'p', '*', 'x', 'd', 'v', '<', '>']
suffix_marker_map = {
view: markers[i % len(markers)]
for i, view
in enumerate(unique_views)
bodypart_markers = {
bodypart: markers[i % len(markers)] for i, bodypart in enumerate(unique_bodyparts)
}

img_width, img_height = image.size
font_size = max(6, min(img_width, img_height) // 50)
font_size = max(6, min(img_width, img_height) // 100)

for label, coords in annotations.items():
try:
Expand All @@ -427,10 +427,9 @@ def annotate_frames(image_path: str, annotations: dict, output_path: str):
_logger.warning(f"Missing x or y in annotation for {label}")
continue

bodypart = label.split('_')[0]
view = label.split('_')[1]
bodypart = label
color = bodypart_colors[bodypart]
marker = suffix_marker_map[view]
marker = bodypart_markers[bodypart]

ax.plot(x, y, marker, color=color, markersize=3)

Expand All @@ -445,7 +444,8 @@ def annotate_frames(image_path: str, annotations: dict, output_path: str):
frame_number = int(get_frame_number(image_path)[0])

title_text = f'Video: {video} | Frame: {frame_number}'
ax.set_title(title_text, fontsize=font_size, pad=15)

ax.set_title(title_text, fontsize=10, pad=15)
ax.axis('off')

# Ensure the output directory exists
Expand Down
7 changes: 4 additions & 3 deletions lightning_pose_app/backend/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ def get_frames_from_idxs(cap: cv2.VideoCapture, idxs: np.ndarray) -> np.ndarray:
cap.set(1, i)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if fr == 0:
height, width, _ = frame.shape
frames = np.zeros((n_frames, 1, height, width), dtype="uint8")
frames[fr, 0, :, :] = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
height, width, _ = frame_rgb.shape
frames = np.zeros((n_frames, 3, height, width), dtype="uint8")
frames[fr] = frame_rgb.transpose(2, 0, 1)
else:
_logger.debug(
"warning! reached end of video; returning blank frames for remainder of "
Expand Down
33 changes: 20 additions & 13 deletions lightning_pose_app/ui/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,10 @@ def _render_streamlit_fn(state: AppState):
sorted(state.initialized_projects)
)
st.markdown(
"Project exports only contain frames, labels, and the project config file."
"To download trained models or the results files from video inference, see [here]."
"(https://pose-app.readthedocs.io/en/latest/source/accessing_your_data.html#)",
"Project exports only contain frames, labels, and the project config file. "
"To download trained models or the results files from video inference, see "
"[here]"
"(https://pose-app.readthedocs.io/en/latest/source/accessing_your_data.html#).",
unsafe_allow_html=True,
)

Expand All @@ -594,20 +595,26 @@ def _render_streamlit_fn(state: AppState):
"Please label some frames before attempting to download the project."
)
else:
try:
zip_filepath = zip_project_for_export(proj_dir)
# Add download botton
with open(zip_filepath, "rb") as f:
if 'zip_filepath' not in st.session_state:
st.session_state.zip_filepath = None
if st.button("Zip Project Files"):
try:
st.session_state.zip_filepath = zip_project_for_export(proj_dir)
st.success(
"Project files are zipped and ready. Press the button below to "
"download the zipped files."
)
except FileNotFoundError as e:
st.error(str(e))
if st.session_state.zip_filepath:
with open(st.session_state.zip_filepath, "rb") as f:
st.download_button(
label="Download Project",
data=f,
file_name=os.path.basename(zip_filepath)
file_name=os.path.basename(st.session_state.zip_filepath)
)
os.remove(zip_filepath)

except FileNotFoundError as e:
st.error(str(e))

os.remove(st.session_state.zip_filepath)
st.session_state.zip_filepath = None
st.header("Manage Lightning Pose projects")

st_mode = st.radio(
Expand Down
81 changes: 43 additions & 38 deletions lightning_pose_app/ui/streamlit_video_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,46 +82,51 @@ def _render_streamlit_fn(state: AppState):
)
st.divider()
# Add a section to download results files
video_name = extract_video_name(selected_video)
results_files = list_results_files(video_name, model_dir, selected_model)

if results_files:
selected_result_files = st.multiselect(
"**Step 3:** Select results files to download", list(results_files.keys())
)
if selected_result_files:
if len(selected_result_files) == 1:
selected_result_file = results_files[selected_result_files[0]]
new_file_name = f"{selected_model}_{video_name}_{selected_result_files[0]}"
with open(selected_result_file, "rb") as file:
if selected_video:
video_name = extract_video_name(selected_video)
results_files = list_results_files(video_name, model_dir, selected_model)

if results_files:
selected_result_files = st.multiselect(
"**Step 3:** Select results files to download", list(results_files.keys())
)
if selected_result_files:
if len(selected_result_files) == 1:
selected_result_file = results_files[selected_result_files[0]]
new_file_name = (
f"{selected_model}_{video_name}_{selected_result_files[0]}"
)
with open(selected_result_file, "rb") as file:
st.download_button(
label="Download File",
data=file,
file_name=new_file_name
)
else:
st.warning(
"If you select more than one file, they will be downloaded "
"together as a ZIP folder"
)
# Create a zip file
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zip_file:
for result_file_name in selected_result_files:
result_file_path = results_files[result_file_name]
new_file_name = (
f"{selected_model}_{video_name}_{result_file_name}"
)
with open(result_file_path, "rb") as file:
zip_file.writestr(new_file_name, file.read())
zip_buffer.seek(0)
zip_file_name = f"results_{selected_model}_{video_name}.zip"
st.download_button(
label="Download File",
data=file,
file_name=new_file_name
label="Download Files",
data=zip_buffer,
file_name=zip_file_name,
mime="application/zip"
)
else:
st.warning(
"If you select more than one file, they will be downloaded "
"together as a ZIP folder"
)
# Create a zip file
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zip_file:
for result_file_name in selected_result_files:
result_file_path = results_files[result_file_name]
new_file_name = f"{selected_model}_{video_name}_{result_file_name}"
with open(result_file_path, "rb") as file:
zip_file.writestr(new_file_name, file.read())
zip_buffer.seek(0)
zip_file_name = f"results_{selected_model}_{video_name}.zip"
st.download_button(
label="Download Files",
data=zip_buffer,
file_name=zip_file_name,
mime="application/zip"
)
else:
st.write("No results files available for this video.")
else:
st.write("No results files available for this video.")

if selected_video:
# read and show the predictions labeled video
Expand Down
2 changes: 1 addition & 1 deletion tests/backend/test_backend_extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_read_nth_frames(video_file):

resize_dims = 8
frames = read_nth_frames(video_file=video_file, n=10, resize_dims=resize_dims)
assert frames.shape == (100, resize_dims, resize_dims)
assert frames.shape == (100, resize_dims, resize_dims, 3)


def test_select_idxs_kmeans(video_file):
Expand Down
2 changes: 1 addition & 1 deletion tests/backend/test_backend_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_get_frames_from_idxs(video_file):
n_frames = 3
frames = get_frames_from_idxs(cap, np.arange(n_frames))
cap.release()
assert frames.shape == (n_frames, 1, 406, 396)
assert frames.shape == (n_frames, 3, 406, 396)
assert frames.dtype == np.uint8


Expand Down

0 comments on commit dfe93e1

Please sign in to comment.