From 64008d2a81eec58f277e89dc7bb0997ebcfc85f8 Mon Sep 17 00:00:00 2001 From: fatih <34196005+fcakyon@users.noreply.github.com> Date: Sat, 24 Dec 2022 02:39:56 +0300 Subject: [PATCH 1/3] fix torch dependency in HuggingfaceDetectionModel --- sahi/models/huggingface.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sahi/models/huggingface.py b/sahi/models/huggingface.py index 99a179eef..aa12a4bf6 100644 --- a/sahi/models/huggingface.py +++ b/sahi/models/huggingface.py @@ -16,8 +16,6 @@ class HuggingfaceDetectionModel(DetectionModel): - import torch - def __init__( self, model_path: Optional[str] = None, From 10b55c83f227f90d721bc23b5399efa4f52a0387 Mon Sep 17 00:00:00 2001 From: fatih <34196005+fcakyon@users.noreply.github.com> Date: Sat, 24 Dec 2022 11:05:20 +0300 Subject: [PATCH 2/3] Update huggingface.py --- sahi/models/huggingface.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sahi/models/huggingface.py b/sahi/models/huggingface.py index aa12a4bf6..75c13fd13 100644 --- a/sahi/models/huggingface.py +++ b/sahi/models/huggingface.py @@ -120,10 +120,16 @@ def perform_inference(self, image: Union[List, np.ndarray]): self._image_shapes = [image.shape] self._original_predictions = outputs - def get_valid_predictions( - self, logits: torch.Tensor, pred_boxes: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - import torch + def get_valid_predictions(self, logits, pred_boxes) -> Tuple: + """ + Args: + logits: torch.Tensor + pred_boxes: torch.Tensor + Returns: + scores: torch.Tensor + cat_ids: torch.Tensor + boxes: torch.Tensor + """ probs = logits.softmax(-1) scores = probs.max(-1).values From f15d1341fb0a562ef5ff883a8eb04ab0a908a8b1 Mon Sep 17 00:00:00 2001 From: fatih <34196005+fcakyon@users.noreply.github.com> Date: Sat, 24 Dec 2022 11:26:52 +0300 Subject: [PATCH 3/3] Update huggingface.py --- sahi/models/huggingface.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sahi/models/huggingface.py b/sahi/models/huggingface.py index 75c13fd13..0b06261ce 100644 --- a/sahi/models/huggingface.py +++ b/sahi/models/huggingface.py @@ -130,6 +130,7 @@ def get_valid_predictions(self, logits, pred_boxes) -> Tuple: cat_ids: torch.Tensor boxes: torch.Tensor """ + import torch probs = logits.softmax(-1) scores = probs.max(-1).values