Skip to content

Commit

Permalink
fix: update examples to be compatible with the latest protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaofei-du committed Sep 20, 2022
1 parent c4a0c3c commit 36a847b
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 89 deletions.
4 changes: 2 additions & 2 deletions examples/async/yolov7/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ $ python main.py
You can use the following SQL query to convert the raw detections into multiple records and build a _Cow Counter_ dashboard
```sql
-- Cow counter
SELECT "public"."_airbyte_raw_detection"."_airbyte_ab_id" AS "id", "public"."_airbyte_raw_detection"."_airbyte_data"->'index' AS "index", "public"."_airbyte_raw_detection"."_airbyte_emitted_at" AS "processed_at", ceil(x.score) AS "count", x.category
FROM "public"."_airbyte_raw_detection" CROSS JOIN LATERAL jsonb_to_recordset("public"."_airbyte_raw_detection"."_airbyte_data"->'detection'->'bounding_boxes') AS x(score numeric, category text)
SELECT "public"."_airbyte_raw_vdp"."_airbyte_ab_id" AS "id", "public"."_airbyte_raw_vdp"."_airbyte_data"->'index' AS "index", "public"."_airbyte_raw_vdp"."_airbyte_emitted_at" AS "processed_at", ceil(x.score) AS "count", x.category
FROM "public"."_airbyte_raw_vdp" CROSS JOIN LATERAL jsonb_to_recordset("public"."_airbyte_raw_vdp"."_airbyte_data"->'detection'->'objects') AS x(score numeric, category text)
WHERE x.category = 'cow'
ORDER BY "processed_at" ASC
LIMIT 1048575
Expand Down
145 changes: 108 additions & 37 deletions examples/async/yolov7/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import shutil
from os import listdir
from os.path import isfile, join
from typing import Final, Any, List, Dict
from typing import Final, Any, List, Dict, Tuple

import cv2
import ffmpeg
Expand All @@ -21,22 +21,49 @@
from utils import draw_detection


def download_data(filename):
def download_data(bucket_name: str, blob_filename: str, dst_filename: str) -> bool:
r""" Download a file from a GCS bucket into a local file
print("\n===== Download video from GCS bucket...")
Args:
bucket_name (str): GCS bucket name
blob_filename (str): file name to be downloaded from the in the GCS bucket
dst_filename (str): the file name used to save the downloaded file
Returns: bool
a flag to indicate whether the downloading is successful
"""
print("\n===== Download video {} from GCS bucket {} to {} ...".format(blob_filename, bucket_name, dst_filename))

client = storage.Client.create_anonymous_client()
bucket = client.bucket('public-europe-west2-c-artifacts')
blob = bucket.blob("vdp/tutorial/cows_dornick/cows_dornick.mp4")
blob.download_to_filename(filename)
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_filename)
try:
blob.download_to_filename(dst_filename)
except Exception as e:
print(e)
os.remove(dst_filename)
return False
print("Done!")
return True

def extract_frames_from_video(image_dir: str, filename: str, framerate: int=30) -> bool:
r""" Extract frames from a video file at constant frames per second
Args:
image_dir (str): the directory where the extracted frames will be stored
filename (str): name of the video file
framerate (int): frames per second (fps) to extract the video. By default set to 30 fps.
Returns: bool
a flag to indicate whether the extraction is successful
def extract_frames_from_video(image_dir, filename, framerate=30):
"""
if os.path.exists(image_dir) and os.path.isdir(image_dir):
shutil.rmtree(image_dir)

print("\n===== Extract frames from the video {} into {} ...".format(filename, image_dir))

pathlib.Path(image_dir).mkdir(parents=True, exist_ok=True)
try:
(
Expand All @@ -47,29 +74,68 @@ def extract_frames_from_video(image_dir, filename, framerate=30):
.run(capture_stdout=True, capture_stderr=True)
)
print("Done!\n")
return True
except ffmpeg.Error as error:
print('stdout:', error.stdout.decode('utf8'))
print('stderr:', error.stderr.decode('utf8'))
shutil.rmtree(image_dir)
return False


def generate_video_from_frames(image_dir: str, output_filename: str, framerate: int=30) -> bool:
r""" Generate a video from a array of image frames
Args:
image_dir (str): the directory where the image frames are stored
output_filename (str): the name of the video file to be generated
framerate (int): fps of the video to be generated
Returns: bool
a flag to indicate whether the operation is successful
def generate_video_from_frames(image_dir, output_filename, framerate=30):
"""
if os.path.exists(output_filename):
os.remove(output_filename)

print("\n=====Generate video {} from image files in {}...".format(output_filename, image_dir))
try:
(
(
ffmpeg
.input(join(image_dir, '*.png'), pattern_type='glob', framerate=framerate)
.output(output_filename, pix_fmt="yuv420p")
.run()
)
print("Done!\n")
return True
except ffmpeg.Error as error:
print(error)


def parse_detection_from_database(detection_ls: List[Dict[str, Any]]):
return False


def parse_detection_from_database(detection_ls: List[Dict[str, Any]]) -> Tuple[List[Tuple[float]], List[str], List[float]]:
r""" Parse the raw detection output from the database
Args:
detection_ls: a list of detection outputs for standardised VDP object detection task
[
{
"bounding_box": {
"left": 324,
"top": 102,
"width": 208,
"height": 405,
},
"category": "dog",
"score": 0.9
}
]
Returns: parsed output, a tuple of
List[Tuple[float]]: a list of detected bounding boxes in the format of (top, left, width, height)
List[str]: a list of category labels, each of which corresponds to a detected bounding box. The length of this list must be the same as the detected bounding boxes.
List[float]: a list of scores, each of which corresponds to a detected bounding box. The length of this list must be the same as the detected bounding boxes.
"""
boxes_ltwh, categories, scores = [], [], []

for det in detection_ls:
Expand All @@ -81,7 +147,7 @@ def parse_detection_from_database(detection_ls: List[Dict[str, Any]]):
categories.append(det["category"])
scores.append(det["score"])

return boxes_ltwh, categories, scores
return boxes_ltwh, categories, scores

###############################################################################
# VDP backends
Expand Down Expand Up @@ -112,18 +178,18 @@ def parse_detection_from_database(detection_ls: List[Dict[str, Any]]):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Trigger VDP pipeline')
parser.add_argument("--pipeline", dest = 'pipeline', help =
parser.add_argument("--pipeline", dest = 'pipeline', help =
"VDP pipeline ID",
default = "detection", type = str)
parser.add_argument("--output-filename", dest = 'output_filename', help =
"Output video file name",
parser.add_argument("--output-filename", dest = 'output_filename', help =
"Output video file name",
default = "output.mp4", type = str)
parser.add_argument("--framerate", dest = 'framerate', help =
"Frame rate of the video",
parser.add_argument("--framerate", dest = 'framerate', help =
"Frame rate of the video",
default = 30, type = int)
parser.add_argument("--skip-draw", dest="draw", action="store_false", help =
parser.add_argument("--skip-draw", dest="draw", action="store_false", help =
"Skip draw detections on images")

opt = parser.parse_args()

###############################################################################
Expand All @@ -136,12 +202,16 @@ def parse_detection_from_database(detection_ls: List[Dict[str, Any]]):
if skip_download:
print("\n===== Skip downloading video")
else:
download_data(video_filename)
success = download_data(bucket_name='public-europe-west2-c-artifacts',
blob_filename="vdp/tutorial/cows_dornick/cows_dornick.mp4",
dst_filename=video_filename)
if not success:
sys.exit(1)

###############################################################################
# Extract frames from the video file
###############################################################################

image_dir = join(os.path.dirname(os.path.realpath(__file__)), "inputs")

skip_extract = False
Expand All @@ -151,51 +221,52 @@ def parse_detection_from_database(detection_ls: List[Dict[str, Any]]):
if skip_extract:
print("\n===== Skip extracting frames from video {}".format(video_filename))
else:
extract_frames_from_video(image_dir, video_filename, framerate=opt.framerate)
success = extract_frames_from_video(image_dir, video_filename, framerate=opt.framerate)
if not success:
sys.exit(1)


###############################################################################
# Trigger pipeline to process video frames
###############################################################################

batch_size = 1
img_files = [filename for filename in sorted(listdir(image_dir)) if isfile(
join(image_dir, filename)) and not filename.startswith(".")]
join(image_dir, filename)) and not filename.startswith(".")]
img_batch = [img_files[i:i+batch_size] for i in range(0, len(img_files), batch_size)]
filenames = [file for files in img_batch for file in files]
data_mapping_indices = []

print("\n=====Trigger {} pipeline to process images in '{}'\n".format(opt.pipeline, image_dir))
print("\n=====Trigger {} pipeline to process images in '{}'\n".format(opt.pipeline, image_dir))
for files in tqdm(img_batch):
resp = requests.post(f'http://{backend["pipeline"]}/{ver}/pipelines/{opt.pipeline}:trigger-multipart',
files=[("file", (filename, open(join(image_dir, filename), 'rb'))) for filename in files])
if resp.status_code == 200:
data_mapping_indices += resp.json()['data_mapping_indices']
else:
print(resp.status_code)
print(resp.json())
sys.exit()
sys.exit(1)

###############################################################################
# Draw detections on video frames
###############################################################################
# ###############################################################################
# # Draw detections on video frames
# ###############################################################################

if opt.draw:
time.sleep(10)
conn = None
print("#", end="", flush=True)
assert len(filenames) == len(data_mapping_indices), "number of files {} not consistent with number of records {}".format(len(filenames), len(data_mapping_indices))

# Create output directory
output_dir = join(os.path.dirname(os.path.realpath(__file__)), "outputs")
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

for filename, mapping_index in tzip(filenames, data_mapping_indices):
# Fetch detections from destination PostgreSQL database
try:
conn = psycopg2.connect(
user=pq_cfg["username"], password=pq_cfg["password"], host=pq_cfg["host"], port=pq_cfg["port"], database=pq_cfg["database"])
cur = conn.cursor()
cur.execute("""SELECT _airbyte_raw_detection._airbyte_data->'detection'->'bounding_boxes' AS "bounding_boxes" from _airbyte_raw_detection WHERE _airbyte_raw_detection._airbyte_data->>'index' = '{}'""".format(mapping_index))
cur.execute("""SELECT _airbyte_raw_vdp._airbyte_data->'detection'->'objects' AS "objects" from _airbyte_raw_vdp WHERE _airbyte_raw_vdp._airbyte_data->>'index' = '{}'""".format(mapping_index))
row = cur.fetchone()[0]

boxes_ltwh, categories, scores = parse_detection_from_database(row)
Expand All @@ -213,4 +284,4 @@ def parse_detection_from_database(detection_ls: List[Dict[str, Any]]):
conn.close()

# Generate video with detections
generate_video_from_frames(output_dir, opt.output_filename, framerate=opt.framerate)
success = generate_video_from_frames(output_dir, opt.output_filename, framerate=opt.framerate)
30 changes: 0 additions & 30 deletions examples/async/yolov7/outputs/Dockerfile

This file was deleted.

6 changes: 2 additions & 4 deletions examples/async/yolov7/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ def draw_detection(img: cv2.Mat, boxes_ltwh: List[Tuple[float]], categories: Lis
Args:
img (cv2.Mat): the original image
boxes_ltwh (List[Tuple[float]]): a list of detected bounding boxes in the format of (top, left, width, height)
boxes_ltwh (List[Tuple[float]]): a list of detected bounding boxes in the format of (top, left, width, height)
categories (List[str]): a list of category labels, each of which corresponds to a detected bounding box. The length of this list must be the same as the detected bounding boxes.
scores (List[float]]): a list of scores, each of which corresponds to a detected bounding box. The length of this list must be the same as the detected bounding boxes.
Returns:
cv2.Mat: image overlaid with detection results
Returns: cv2.Mat: image overlaid with detection results
"""
img_draw = img.copy()

Expand Down
Loading

0 comments on commit 36a847b

Please sign in to comment.