Skip to content

Commit

Permalink
added mapping to batch inference
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxTeselkin committed Feb 11, 2025
1 parent f30b4fb commit 45c3028
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
2 changes: 1 addition & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"icon_cover": true,
"poster": "https://github.com/user-attachments/assets/79b71648-f78e-4ae3-be3f-508e385fc0b4",
"gpu": "required",
"session_tags": ["deployed_florence_2", "sly_smart_annotation"],
"session_tags": ["deployed_florence_2", "deployed_nn"],
"community_agent": false,
"docker_image": "supervisely/florence-2:1.0.1",
"instance_version": "6.12.12",
Expand Down
58 changes: 36 additions & 22 deletions src/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,33 +95,47 @@ def _predict_pytorch(self, image_path: str, settings: dict = None) -> List[Predi
def predict_batch(self, images_np, settings):
self.task_prompt = settings.get("task_prompt", self.default_task_prompt)
text = settings.get("text", "find all objects")
mapping = settings.get("mapping")

images = [Image.fromarray(img) for img in images_np]
prompt = [self.task_prompt + text] * len(images)

inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(
self.device, self.torch_dtype
)
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=False)
parsed_answers = [
self.processor.post_process_generation(
text, task=self.task_prompt, image_size=(img.width, img.height)
if mapping is None and text is not None:
prompt = [self.task_prompt + text] * len(images)

inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(
self.device, self.torch_dtype
)
for text, img in zip(generated_texts, images)
]
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=False)
parsed_answers = [
self.processor.post_process_generation(
text, task=self.task_prompt, image_size=(img.width, img.height)
)
for text, img in zip(generated_texts, images)
]

batch_predictions = []
for answer in parsed_answers:
predictions = self._format_predictions_cp(answer, size_scaler=None)
batch_predictions.append(predictions)

elif mapping is not None and text is None:
batch_predictions = []
for image in images:
predictions_mapping = self._classes_mapping_inference(image, mapping)
predictions = self._format_predictions_cm(predictions_mapping, size_scaler=None)
batch_predictions.append(predictions)

else:
raise ValueError("Either 'mapping' or 'text' should be provided")

batch_predictions = []
for answer in parsed_answers:
predictions = self._format_predictions_cp(answer, size_scaler=None)
batch_predictions.append(predictions)
return batch_predictions

def _common_prompt_inference(self, img_input: Image.Image, text: str):
Expand Down

0 comments on commit 45c3028

Please sign in to comment.