Skip to content

Commit

Permalink
fix reference tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aminemindee committed Oct 12, 2022
1 parent 5da5b09 commit 1b5c181
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
13 changes: 10 additions & 3 deletions references/detection/evaluate_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

import os

import numpy as np

from doctr.file_utils import CLASS_NAME

os.environ["USE_TORCH"] = "1"

import logging
Expand Down Expand Up @@ -44,9 +48,12 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
out = model(images, targets, return_preds=True)
# Compute metric
loc_preds = out["preds"]
for boxes_gt, boxes_pred in zip(targets, loc_preds):
# Remove scores
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1])
for target, loc_pred in zip(targets, loc_preds):
if isinstance(target, np.ndarray):
target = {CLASS_NAME: target}
for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()):
# Remove scores
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1])

val_loss += out["loss"].item()
batch_cnt += 1
Expand Down
13 changes: 10 additions & 3 deletions references/detection/evaluate_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

import os

import numpy as np

from doctr.file_utils import CLASS_NAME

os.environ["USE_TF"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

Expand Down Expand Up @@ -39,9 +43,12 @@ def evaluate(model, val_loader, batch_transforms, val_metric):
out = model(images, targets, training=False, return_preds=True)
# Compute metric
loc_preds = out["preds"]
for boxes_gt, boxes_pred in zip(targets, loc_preds):
# Remove scores
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1])
for target, loc_pred in zip(targets, loc_preds):
if isinstance(target, np.ndarray):
target = {CLASS_NAME: target}
for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()):
# Remove scores
val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1])

val_loss += out["loss"].numpy()
batch_cnt += 1
Expand Down

0 comments on commit 1b5c181

Please sign in to comment.