Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can you provide code to generate demo? #46

Open
toyottttttt opened this issue Jul 27, 2023 · 1 comment
Open

can you provide code to generate demo? #46

toyottttttt opened this issue Jul 27, 2023 · 1 comment

Comments

@toyottttttt
Copy link

No description provided.

@wjn922
Copy link
Owner

wjn922 commented Mar 12, 2024

Here, I came across the demo file from when I was working on the project. But I haven't tested it recently, so please consider it as your reference. Hope it helps!

import torch
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as F
from einops import rearrange
import numpy as np
from PIL import Image, ImageDraw, ImageOps, ImageFont
from yt_dlp import YoutubeDL
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from tqdm import trange, tqdm

import argparse
import opts
from models import build_model
from tools.colormap import colormap

import os
import json


# # colormap
color_list = colormap()
color_list = color_list.astype('uint8').tolist()

# # build transform
transform = T.Compose([
    T.Resize(360),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# ------------------------------------------------------------
# get text_queries and images
images_dir = "valid_all_frames/JPEGImages"
with open("meta_expressions.json", "r") as f:
    videos = json.load(f)["videos"]   # list[dict]: {video_name: dict}
# sub dict: "expressions" (dict) "0": xxxx, "1": xxxx
#           "frames" (list)

video_name = "1ecc34b1bf"
expressions = videos[video_name]["expressions"]    # dict
text_queries = [expressions[str(i)]["exp"] for i, _ in enumerate(expressions)][0::2][:2]   # the first two objects

images = sorted(os.listdir(os.path.join(images_dir, video_name)))
imgs = []
for image in images:
    img_path = os.path.join(images_dir, video_name, image)
    img = Image.open(img_path).convert('RGB')
    imgs.append(transform(img))
video = torch.stack(imgs, dim=0)

# -----------------------------------------------------
def apply_mask(image, mask, color, transparency=0.7):
    mask = mask[..., np.newaxis].repeat(repeats=3, axis=2)
    mask = mask * transparency
    color_matrix = np.ones(image.shape, dtype=np.float) * color
    out_image = color_matrix * mask + image * (1.0 - mask)
    return out_image

parser = argparse.ArgumentParser('ReferFormer inference script', parents=[opts.get_args_parser()])
args = parser.parse_args()

args.masks = True
args.binary = True
args.with_box_refine = True

# model
model, criterion, _ = build_model(args) 
device = args.device
model.to(device)
model.eval()

model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
    unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
    if len(missing_keys) > 0:
        print('Missing Keys: {}'.format(missing_keys))
    if len(unexpected_keys) > 0:
        print('Unexpected Keys: {}'.format(unexpected_keys))
else:
    raise ValueError('Please specify the checkpoint for inference.')


window_length = 36  # length of window during inference
window_overlap = 6  # overlap (in frames) between consecutive windows

with torch.inference_mode():
    # read and preprocess the video clip:
    # input_video = F.resize(video, size=360, max_size=640).cuda()
    input_video = video.cuda()
    # input_video = input_video.to(torch.float).div_(255)
    # input_video = F.normalize(input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    origin_h, origin_w = video.shape[-2:]
    img_h, img_w = input_video.shape[-2:]
    size = torch.as_tensor([int(img_h), int(img_w)]).to(args.device)
    target = {"size": size}
    # video_metadata = {'resized_frame_size': input_video.shape[-2:], 'original_frame_size': video.shape[-2:]}

    # partition the clip into overlapping windows of frames:
    windows = [input_video[i:i+window_length] for i in range(0, len(input_video), window_length - window_overlap)]
    # clean up the text queries:
    text_queries = [" ".join(q.lower().split()) for q in text_queries]

    pred_masks_per_query = []  # store the pred masks of a video (all text queries)
    t, _, h, w = video.shape   # here, t is the length of full video
    for text_query in tqdm(text_queries, desc='text queries'):
        pred_masks_all_clips = torch.zeros(size=(t, 1, h, w))
        # for each video clip
        for i, window in enumerate(tqdm(windows, desc='windows')):
            # window = nested_tensor_from_videos_list([window])
            window_len = window.shape[0]
            valid_indices = torch.arange(len(window)).cuda()
            with torch.no_grad():
                outputs = model([window], [text_query], [target])

            # postprocess
            pred_logits = outputs["pred_logits"][0] 
            pred_boxes = outputs["pred_boxes"][0]   
            pred_masks = outputs["pred_masks"][0]   
            pred_ref_points = outputs["reference_points"][0]  

            # according to pred_logits, select the query index
            pred_scores = pred_logits.sigmoid() # [t, q, k]
            pred_scores = pred_scores.mean(0)   # [q, k]
            max_scores, _ = pred_scores.max(-1) # [q,]
            _, max_ind = max_scores.max(-1)     # [1,]
            max_inds = max_ind.repeat(window_len)
            pred_masks = pred_masks[range(window_len), max_inds, ...] # [t, h, w]
            pred_masks = torch.nn.functional.interpolate(pred_masks[:, None], size=(origin_h, origin_w), mode='bilinear', align_corners=False)  # [t, 1, h, w]
            
            window_masks = pred_masks > 0.5
            win_start_idx = i * (window_length-window_overlap)
            pred_masks_all_clips[win_start_idx:win_start_idx + window_length] = window_masks # [window_lenght, 1, h, w], torch
        
        pred_masks_per_query.append(pred_masks_all_clips)

# --------------------------------------------------------
# apply the predicted masks and queries to the video for visualization
# RGB colors for instance masks:
light_blue = (41, 171, 226)
purple = (237, 30, 121)
dark_green = (35, 161, 90)
orange = (255, 148, 59)
# colors = np.array([light_blue, purple, dark_green, orange])
color_idx = np.random.randint(80, size=4)
colors = np.array([color_list[i] for i in color_idx])

# width (in pixels) of the black strip above the video on which the text queries will be displayed:
text_border_height_per_query = 36

# video has been normalized 
PIXEL_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) # RGB
PIXEL_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
for i in range(len(video)):
    video[i] = video[i] * PIXEL_STD + PIXEL_MEAN
video_np = rearrange(video, 't c h w -> t h w c').numpy() # RGB -> BGR, in [0,1]
# del video
pred_masks_per_frame = rearrange(torch.stack(pred_masks_per_query), 'q t 1 h w -> t q h w').numpy()
masked_video = []
for vid_frame, frame_masks in tqdm(zip(video_np, pred_masks_per_frame), total=len(video_np), desc='applying masks...'):
    # apply the masks:
    for inst_mask, color in zip(frame_masks, colors):
        vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0)
    vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8))
    # visualize the text queries:
    vid_frame = ImageOps.expand(vid_frame, border=(0, len(text_queries)*text_border_height_per_query, 0, 0))
    W, H = vid_frame.size
    draw = ImageDraw.Draw(vid_frame)
    font = ImageFont.truetype(font='LiberationSans-Regular.ttf', size=25)
    # font = ImageFont.load_default(size=30)
    for i, (text_query, color) in enumerate(zip(text_queries, colors), start=1):
        w, h = draw.textsize(text_query, font=font)
        draw.text(((W - w) / 2, (text_border_height_per_query * i) - h - 3),
                    text_query, fill=tuple(color) + (255,), font=font)
    masked_video.append(np.array(vid_frame))

# generate and save the output clip:
output_clip_path = 'output_clip.mp4'
clip = ImageSequenceClip(sequence=masked_video, fps=30)
# clip = clip.set_audio(AudioFileClip(input_clip_path))
clip.write_videofile(output_clip_path, fps=30, audio=False)
del masked_video

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants