-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make _evalImgs_cpp pickleable or readable by python #45
Comments
Function COCOevalEvaluateImages std::vector<ImageEvaluation> EvaluateImages struct ImageEvaluation
{
// For each of the D detected instances, the id of the matched ground truth
// instance, or 0 if unmatched
std::vector<int64_t> detection_matches;
std::vector<int64_t> ground_truth_matches;
// The detection score of each of the D detected instances
std::vector<double> detection_scores;
// Marks whether or not each of G instances was ignored from evaluation (e.g.,
// because it's outside area_range)
std::vector<bool> ground_truth_ignores;
// Marks whether or not each of D instances was ignored from evaluation (e.g.,
// because it's outside aRng)
std::vector<bool> detection_ignores;
std::vector<MatchedAnnotation> matched_annotations;
}; // Stores the match between a detected instance and a ground truth instance
struct MatchedAnnotation
{
MatchedAnnotation(
uint64_t dt_id,
uint64_t gt_id,
double iou) : dt_id{dt_id}, gt_id{gt_id}, iou{iou} {}
uint64_t dt_id;
uint64_t gt_id;
double iou;
}; In general, I can make it pickleable. But I need to work on it. Do you have a code example that I can easily check if it works? and write tests |
Can you prepare an example that works on pycocotools, let's say just a py file where 2-4 different processes create evalImgs, and another file where all this is combined to calculate metrics. If you can prepare something like that, I will implement it in the library, because in general now I don't understand what exactly you want |
If this is enough, then use the installation from github. If not, then describe the problem and we will solve it in the next release.
|
Hi, @MiXaiLL76 Sorry for late response. I have prepared a small example to test the evaluation of multiprocessing. First, it generates some dummy coco predictions and targets for each process. After evaluating each process, it collects the results in the main process (rank =0) and prints the summarized metrics. In import os
from contextlib import redirect_stdout
import copy
import io
import random
import numpy as np
import torch
import torch.distributed as dist
num_classes = 10
world_size = 3 # the number of processes
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
def dummy_pred_target(num_boxes, shape, image_id):
x2 = np.random.randint(10, shape[1], (num_boxes,))
y2 = np.random.randint(10, shape[0], (num_boxes,))
x2y2 = np.stack([x2, y2], axis=-1)
x1y1 = np.random.rand(num_boxes, 2) * x2y2 * 0.95
boxes = np.concatenate([x1y1, x2y2], axis=-1)
labels = np.random.randint(0, num_classes, (num_boxes,))
scores = np.random.rand(num_boxes)
annotations = [
{
"segmentation": [],
"area": (box[2] - box[0]) * (box[3] - box[1]),
"iscrowd": 0,
"image_id": image_id,
"bbox": box.tolist(),
"category_id": label.tolist(),
"score": score.tolist(),
"id": image_id * 1000 + i,
}
for i, (box, label, score) in enumerate(zip(boxes, labels, scores))
]
wh = boxes[:, 2:] - boxes[:, :2]
# add 50% disturb on xy1y and wh
boxes[:, :2] += np.random.randn(num_boxes, 2) * wh * 0.5
boxes[:, 2:] = boxes[:, :2] + np.random.rand(num_boxes, 2) * wh * 0.5
boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, shape[1])
boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, shape[0])
num_boxes = np.random.randint(1, num_boxes)
boxes = boxes[:num_boxes]
labels = np.random.randint(0, num_classes, (num_boxes,))
targets = [
{
"segmentation": [],
"area": (box[2] - box[0]) * (box[3] - box[1]),
"iscrowd": 0,
"image_id": image_id,
"bbox": box.tolist(),
"category_id": label.tolist(),
"score": score.tolist(),
"id": image_id * 1000 + i,
}
for i, (box, label, score) in enumerate(zip(boxes, labels, scores))
]
return annotations, targets
def generate_image_target():
image_id = random.randint(0, 100)
shape = (random.randint(256, 512), random.randint(256, 512))
images = [
{"file_name": f"{image_id}.jpg", "height": shape[0], "width": shape[1], "id": image_id}
]
preds, targets = dummy_pred_target(100, shape, image_id)
classes = [
{"supercategory": idx + 1, "id": idx + 1, "name": str(idx + 1)}
for idx in range(num_classes)
]
pred = {"images": images, "annotations": preds, "categories": classes}
target = {"images": images, "annotations": targets, "categories": classes}
return pred, target
def build_coco_eval(pred, target):
if os.environ.get("FASTEVAL", "0") == "1":
import faster_coco_eval
faster_coco_eval.init_as_pycocotools()
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco_gt = COCO()
coco_gt.dataset = pred
coco_gt.createIndex()
coco_dt = COCO()
coco_dt.dataset = target
coco_dt.createIndex()
coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
coco_eval.separate_eval = True # for faster_coco_eval
coco_eval.params.imgIds = list(coco_gt.imgs.keys())
return coco_eval
def evaluate(imgs):
with redirect_stdout(io.StringIO()):
imgs.evaluate()
if hasattr(imgs, "_evalImgs_cpp"):
eval_imgs = imgs._evalImgs_cpp
else:
eval_imgs = imgs.evalImgs
return imgs.params.imgIds, np.asarray(eval_imgs).reshape(
-1, len(imgs.params.areaRng), len(imgs.params.imgIds)
)
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def all_gather(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
data_list = [None] * world_size
dist.all_gather_object(data_list, data, group)
return data_list
def merge(img_ids, eval_imgs):
all_img_ids = all_gather(img_ids)
all_eval_imgs = all_gather(eval_imgs)
merged_img_ids = []
for p in all_img_ids:
merged_img_ids.extend(p)
merged_eval_imgs = []
for p in all_eval_imgs:
merged_eval_imgs.append(p)
merged_img_ids = np.array(merged_img_ids)
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
# keep only unique (and in sorted order) images
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
merged_eval_imgs = merged_eval_imgs[..., idx]
return merged_img_ids, merged_eval_imgs
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
eval_imgs = np.concatenate(eval_imgs, 2)
img_ids, eval_imgs = merge(img_ids, eval_imgs)
img_ids = list(img_ids)
eval_imgs = list(eval_imgs.flatten())
if hasattr(coco_eval, "_evalImgs_cpp"):
coco_eval._evalImgs_cpp = eval_imgs
else:
coco_eval.evalImgs = eval_imgs
coco_eval.params.imgIds = img_ids
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
def worker(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
pred, target = generate_image_target()
coco_eval = build_coco_eval(pred, target)
img_ids, eval_imgs = evaluate(coco_eval)
create_common_coco_eval(coco_eval, img_ids, [eval_imgs])
if rank == 0:
coco_eval.accumulate()
coco_eval.summarize()
def multi_process_eval():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
import torch.multiprocessing as mp
mp.spawn(worker, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
multi_process_eval()
# set faster eval
os.environ["FASTEVAL"] = "1"
multi_process_eval() |
I will test it using the installation from github and I will let you know. Thanks for your work! |
Hi @MiXaiLL76 |
Then for now use the version from git, later I will make a release when I finish more functions! Thanks |
Hi, thanks for the great work in this repository. It helps save my evaluation time very much.
Is your feature request related to a problem? Please describe.
I try to use this framework to evaluate instance segmentation on COCO. Because saving all segmentation results costs too much memory (only 700 images need 50G of memory), I have to use "separate_eval=True" to evaluate the images individually and collect
eval_imgs
(in pycocotools) or_evalImgs_cpp
(in faster-coco-eval) to get the final result. When using DDP evaluation, I can't find a way to collect it from other processes because_evalImgs_cpp
is not pickleable and readable for python.Describe the solution you'd like
Could you please make the
faster_coco_eval.faster_eval_api_cpp.ImageEvaluation
object pickleable, or make_C.COCOevalEvaluateImages
returnnp.ndarray
, just likeeval_imgs
in pycocotools. This will be much helpful for evaluation in DDP training.Describe alternatives you've considered
Now I perform evaluate only in main process and let other process wait until it finishes, or use
pycocotools
to evaluate with multi-processes. Both alternative solutions are too slow.Additional context
My code is modified from official training reference of
torchvision
(https://github.com/pytorch/vision/tree/main/references/detection). And I changeeval_imgs
to_evalImg_cpp
and setseparate_eval=True
. This is some snappiest of my code.The text was updated successfully, but these errors were encountered: