Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make _evalImgs_cpp pickleable or readable by python #45

Closed
xiuqhou opened this issue Oct 31, 2024 · 9 comments
Closed

Make _evalImgs_cpp pickleable or readable by python #45

xiuqhou opened this issue Oct 31, 2024 · 9 comments

Comments

@xiuqhou
Copy link

xiuqhou commented Oct 31, 2024

Hi, thanks for the great work in this repository. It helps save my evaluation time very much.

Is your feature request related to a problem? Please describe.
I try to use this framework to evaluate instance segmentation on COCO. Because saving all segmentation results costs too much memory (only 700 images need 50G of memory), I have to use "separate_eval=True" to evaluate the images individually and collect eval_imgs(in pycocotools) or _evalImgs_cpp (in faster-coco-eval) to get the final result. When using DDP evaluation, I can't find a way to collect it from other processes because _evalImgs_cpp is not pickleable and readable for python.

image

Describe the solution you'd like
Could you please make the faster_coco_eval.faster_eval_api_cpp.ImageEvaluation object pickleable, or make _C.COCOevalEvaluateImages return np.ndarray, just like eval_imgs in pycocotools. This will be much helpful for evaluation in DDP training.

Describe alternatives you've considered
Now I perform evaluate only in main process and let other process wait until it finishes, or use pycocotools to evaluate with multi-processes. Both alternative solutions are too slow.

Additional context
My code is modified from official training reference of torchvision (https://github.com/pytorch/vision/tree/main/references/detection). And I change eval_imgs to _evalImg_cpp and set separate_eval=True. This is some snappiest of my code.

# evaluate each image and save results into eval_imgs
def evaluate(imgs):
    with redirect_stdout(io.StringIO()):
        imgs.evaluate()
    eval_imgs = imgs._evalImgs_cpp if USE_FASTER_EVAL else imgs.evalImgs
    return imgs.params.imgIds, np.asarray(eval_imgs).reshape(
        -1, len(imgs.params.areaRng), len(imgs.params.imgIds)
    )

# after evaluation, merge eval_imgs from all processes
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
    img_ids, eval_imgs = merge(img_ids, eval_imgs)
    img_ids = list(img_ids)
    eval_imgs = list(eval_imgs.flatten())

    if USE_FASTER_EVAL:
        coco_eval._evalImgs_cpp = eval_imgs
    else:
        coco_eval.evalImgs = eval_imgs
    coco_eval.params.imgIds = img_ids
    coco_eval._paramsEval = copy.deepcopy(coco_eval.params)

# use utils.all_gather to collect eval_imgs
def merge(img_ids, eval_imgs):
    all_img_ids = utils.all_gather(img_ids)
    all_eval_imgs = utils.all_gather(eval_imgs)

    ...

# dist.all_gather_object needs python objects to be picklable in order to be gathered, which is the problem
def all_gather(data, group=None):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]
    data_list = [None] * world_size
    dist.all_gather_object(data_list, data, group)
    return data_list
@MiXaiLL76
Copy link
Owner

Function COCOevalEvaluateImages
Returns an array of

std::vector<ImageEvaluation> EvaluateImages
struct ImageEvaluation
    {
      // For each of the D detected instances, the id of the matched ground truth
      // instance, or 0 if unmatched
      std::vector<int64_t> detection_matches;

      std::vector<int64_t> ground_truth_matches;
      // The detection score of each of the D detected instances
      std::vector<double> detection_scores;

      // Marks whether or not each of G instances was ignored from evaluation (e.g.,
      // because it's outside area_range)
      std::vector<bool> ground_truth_ignores;

      // Marks whether or not each of D instances was ignored from evaluation (e.g.,
      // because it's outside aRng)
      std::vector<bool> detection_ignores;

      std::vector<MatchedAnnotation> matched_annotations;
    };
// Stores the match between a detected instance and a ground truth instance
    struct MatchedAnnotation
    {
      MatchedAnnotation(
          uint64_t dt_id,
          uint64_t gt_id,
          double iou) : dt_id{dt_id}, gt_id{gt_id}, iou{iou} {}
      uint64_t dt_id;
      uint64_t gt_id;
      double iou;
    };

In general, I can make it pickleable. But I need to work on it. Do you have a code example that I can easily check if it works? and write tests

@MiXaiLL76
Copy link
Owner

Can you prepare an example that works on pycocotools, let's say just a py file where 2-4 different processes create evalImgs, and another file where all this is combined to calculate metrics. If you can prepare something like that, I will implement it in the library, because in general now I don't understand what exactly you want

@MiXaiLL76
Copy link
Owner

изображение

1caac51

@MiXaiLL76
Copy link
Owner

If this is enough, then use the installation from github. If not, then describe the problem and we will solve it in the next release.

pip3 install git+https://github.com/MiXaiLL76/faster_coco_eval.git

@xiuqhou
Copy link
Author

xiuqhou commented Oct 31, 2024

Hi, @MiXaiLL76

Sorry for late response. I have prepared a small example to test the evaluation of multiprocessing. First, it generates some dummy coco predictions and targets for each process. After evaluating each process, it collects the results in the main process (rank =0) and prints the summarized metrics. In _ _ name _ = =" _ _ main _ _,I tested pycocotools and faster_coco_eval by changing the "FASTEVAL" environment variable. You can change random seed to test different cases, and they should always print out the same results if everything is alright. Now fast_coco_eval throws an "unpackable" exception. I hope this code can help your test.

import os
from contextlib import redirect_stdout
import copy
import io
import random
import numpy as np
import torch
import torch.distributed as dist


num_classes = 10
world_size = 3  # the number of processes
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)


def dummy_pred_target(num_boxes, shape, image_id):
    x2 = np.random.randint(10, shape[1], (num_boxes,))
    y2 = np.random.randint(10, shape[0], (num_boxes,))
    x2y2 = np.stack([x2, y2], axis=-1)
    x1y1 = np.random.rand(num_boxes, 2) * x2y2 * 0.95
    boxes = np.concatenate([x1y1, x2y2], axis=-1)
    labels = np.random.randint(0, num_classes, (num_boxes,))
    scores = np.random.rand(num_boxes)

    annotations = [
        {
            "segmentation": [],
            "area": (box[2] - box[0]) * (box[3] - box[1]),
            "iscrowd": 0,
            "image_id": image_id,
            "bbox": box.tolist(),
            "category_id": label.tolist(),
            "score": score.tolist(),
            "id": image_id * 1000 + i,
        }
        for i, (box, label, score) in enumerate(zip(boxes, labels, scores))
    ]

    wh = boxes[:, 2:] - boxes[:, :2]
    # add 50% disturb on xy1y and wh
    boxes[:, :2] += np.random.randn(num_boxes, 2) * wh * 0.5
    boxes[:, 2:] = boxes[:, :2] + np.random.rand(num_boxes, 2) * wh * 0.5
    boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, shape[1])
    boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, shape[0])
    num_boxes = np.random.randint(1, num_boxes)
    boxes = boxes[:num_boxes]
    labels = np.random.randint(0, num_classes, (num_boxes,))
    targets = [
        {
            "segmentation": [],
            "area": (box[2] - box[0]) * (box[3] - box[1]),
            "iscrowd": 0,
            "image_id": image_id,
            "bbox": box.tolist(),
            "category_id": label.tolist(),
            "score": score.tolist(),
            "id": image_id * 1000 + i,
        }
        for i, (box, label, score) in enumerate(zip(boxes, labels, scores))
    ]
    return annotations, targets


def generate_image_target():
    image_id = random.randint(0, 100)
    shape = (random.randint(256, 512), random.randint(256, 512))
    images = [
        {"file_name": f"{image_id}.jpg", "height": shape[0], "width": shape[1], "id": image_id}
    ]
    preds, targets = dummy_pred_target(100, shape, image_id)

    classes = [
        {"supercategory": idx + 1, "id": idx + 1, "name": str(idx + 1)}
        for idx in range(num_classes)
    ]
    pred = {"images": images, "annotations": preds, "categories": classes}
    target = {"images": images, "annotations": targets, "categories": classes}
    return pred, target


def build_coco_eval(pred, target):
    if os.environ.get("FASTEVAL", "0") == "1":
        import faster_coco_eval

        faster_coco_eval.init_as_pycocotools()

    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval

    coco_gt = COCO()
    coco_gt.dataset = pred
    coco_gt.createIndex()
    coco_dt = COCO()
    coco_dt.dataset = target
    coco_dt.createIndex()

    coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
    coco_eval.separate_eval = True  # for faster_coco_eval
    coco_eval.params.imgIds = list(coco_gt.imgs.keys())
    return coco_eval


def evaluate(imgs):
    with redirect_stdout(io.StringIO()):
        imgs.evaluate()
    if hasattr(imgs, "_evalImgs_cpp"):
        eval_imgs = imgs._evalImgs_cpp
    else:
        eval_imgs = imgs.evalImgs
    return imgs.params.imgIds, np.asarray(eval_imgs).reshape(
        -1, len(imgs.params.areaRng), len(imgs.params.imgIds)
    )


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def all_gather(data, group=None):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]
    data_list = [None] * world_size
    dist.all_gather_object(data_list, data, group)
    return data_list


def merge(img_ids, eval_imgs):
    all_img_ids = all_gather(img_ids)
    all_eval_imgs = all_gather(eval_imgs)

    merged_img_ids = []
    for p in all_img_ids:
        merged_img_ids.extend(p)

    merged_eval_imgs = []
    for p in all_eval_imgs:
        merged_eval_imgs.append(p)

    merged_img_ids = np.array(merged_img_ids)
    merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)

    # keep only unique (and in sorted order) images
    merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
    merged_eval_imgs = merged_eval_imgs[..., idx]

    return merged_img_ids, merged_eval_imgs


def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
    eval_imgs = np.concatenate(eval_imgs, 2)
    img_ids, eval_imgs = merge(img_ids, eval_imgs)
    img_ids = list(img_ids)
    eval_imgs = list(eval_imgs.flatten())

    if hasattr(coco_eval, "_evalImgs_cpp"):
        coco_eval._evalImgs_cpp = eval_imgs
    else:
        coco_eval.evalImgs = eval_imgs
    coco_eval.params.imgIds = img_ids
    coco_eval._paramsEval = copy.deepcopy(coco_eval.params)


def worker(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    torch.cuda.set_device(device)
    pred, target = generate_image_target()
    coco_eval = build_coco_eval(pred, target)
    img_ids, eval_imgs = evaluate(coco_eval)

    create_common_coco_eval(coco_eval, img_ids, [eval_imgs])

    if rank == 0:
        coco_eval.accumulate()
        coco_eval.summarize()


def multi_process_eval():
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    import torch.multiprocessing as mp
    mp.spawn(worker, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    multi_process_eval()

    # set faster eval
    os.environ["FASTEVAL"] = "1"
    multi_process_eval()

@xiuqhou
Copy link
Author

xiuqhou commented Oct 31, 2024

I will test it using the installation from github and I will let you know. Thanks for your work!

@xiuqhou
Copy link
Author

xiuqhou commented Oct 31, 2024

Hi @MiXaiLL76
I've tested the updated code. It produces consistent results with pycocotools and everything works well! Thanks for your help!

@xiuqhou xiuqhou closed this as completed Oct 31, 2024
@MiXaiLL76
Copy link
Owner

Then for now use the version from git, later I will make a release when I finish more functions! Thanks

This was referenced Nov 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants