Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Unified return_preds flags across all tasks #741

Merged
merged 3 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main():

# Forward the image to the model
processed_batches = predictor.det_predictor.pre_processor([doc[page_idx]])
out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
out = predictor.det_predictor.model(processed_batches[0], return_preds=True)
seg_map = out["out_map"]
seg_map = tf.squeeze(seg_map[0, ...], axis=[2])
seg_map = cv2.resize(seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]),
Expand Down
6 changes: 3 additions & 3 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def forward(
x: torch.Tensor,
target: Optional[List[np.ndarray]] = None,
return_model_output: bool = False,
return_boxes: bool = False,
return_preds: bool = False,
) -> Dict[str, torch.Tensor]:
# Extract feature maps at different stages
feats = self.feat_extractor(x)
Expand All @@ -178,13 +178,13 @@ def forward(
logits = self.prob_head(feat_concat)

out: Dict[str, Any] = {}
if return_model_output or target is None or return_boxes:
if return_model_output or target is None or return_preds:
prob_map = torch.sigmoid(logits)

if return_model_output:
out["out_map"] = prob_map

if target is None or return_boxes:
if target is None or return_preds:
# Post-process boxes (keep only text predictions)
out["preds"] = [
preds[0] for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def call(
x: tf.Tensor,
target: Optional[List[np.ndarray]] = None,
return_model_output: bool = False,
return_boxes: bool = False,
return_preds: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:

Expand All @@ -232,13 +232,13 @@ def call(
logits = self.probability_head(feat_concat, **kwargs)

out: Dict[str, tf.Tensor] = {}
if return_model_output or target is None or return_boxes:
if return_model_output or target is None or return_preds:
prob_map = tf.math.sigmoid(logits)

if return_model_output:
out["out_map"] = prob_map

if target is None or return_boxes:
if target is None or return_preds:
# Post-process boxes (keep only text predictions)
out["preds"] = [preds[0] for preds in self.postprocessor(prob_map.numpy())]

Expand Down
6 changes: 3 additions & 3 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(
x: torch.Tensor,
target: Optional[List[np.ndarray]] = None,
return_model_output: bool = False,
return_boxes: bool = False,
return_preds: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:

Expand All @@ -142,12 +142,12 @@ def forward(
logits = self.classifier(logits)

out: Dict[str, Any] = {}
if return_model_output or target is None or return_boxes:
if return_model_output or target is None or return_preds:
prob_map = torch.sigmoid(logits)
if return_model_output:
out["out_map"] = prob_map

if target is None or return_boxes:
if target is None or return_preds:
# Post-process boxes
out["preds"] = [
preds[0] for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
Expand Down
6 changes: 3 additions & 3 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,20 @@ def call(
x: tf.Tensor,
target: Optional[List[np.ndarray]] = None,
return_model_output: bool = False,
return_boxes: bool = False,
return_preds: bool = False,
) -> Dict[str, Any]:

logits = self.stem(x)
logits = self.fpn(logits)
logits = self.classifier(logits)

out: Dict[str, tf.Tensor] = {}
if return_model_output or target is None or return_boxes:
if return_model_output or target is None or return_preds:
prob_map = tf.math.sigmoid(logits)
if return_model_output:
out["out_map"] = prob_map

if target is None or return_boxes:
if target is None or return_preds:
# Post-process boxes
out["preds"] = [preds[0] for preds in self.postprocessor(prob_map.numpy())]

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(

processed_batches = self.pre_processor(pages)
predicted_batches = [
self.model(batch, return_boxes=True, **kwargs)['preds'] # type:ignore[operator]
self.model(batch, return_preds=True, **kwargs)['preds'] # type:ignore[operator]
for batch in processed_batches
]
return [pred for batch in predicted_batches for pred in batch]
2 changes: 1 addition & 1 deletion doctr/models/detection/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(

processed_batches = self.pre_processor(pages)
predicted_batches = [
self.model(batch, return_boxes=True, training=False, **kwargs)['preds'] # type:ignore[operator]
self.model(batch, return_preds=True, training=False, **kwargs)['preds'] # type:ignore[operator]
for batch in processed_batches
]
return [pred for batch in predicted_batches for pred in batch]
2 changes: 1 addition & 1 deletion tests/pytorch/test_models_detection_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob):
if torch.cuda.is_available():
model.cuda()
input_tensor = input_tensor.cuda()
out = model(input_tensor, target, return_model_output=True, return_boxes=True)
out = model(input_tensor, target, return_model_output=True, return_preds=True)
assert isinstance(out, dict)
assert len(out) == 3
# Check proba map
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow/test_models_detection_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob):
np.array([[.5, .5, 1, 1], [0.5, 0.5, .8, .9]], dtype=np.float32),
]
# test training model
out = model(input_tensor, target, return_model_output=True, return_boxes=True, training=True)
out = model(input_tensor, target, return_model_output=True, return_preds=True, training=True)
assert isinstance(out, dict)
assert len(out) == 3
# Check proba map
Expand Down