Skip to content

Commit

Permalink
changed batched inference logic
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxTeselkin committed Feb 11, 2025
1 parent d233d0a commit f30b4fb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 61 deletions.
96 changes: 37 additions & 59 deletions src/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,19 @@ def load_model(

if runtime == RuntimeType.PYTORCH:
model_path = model_files["checkpoint"]
self.torch_dtype = (
torch.float16 if torch.cuda.is_available() else torch.float32
)
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=self.torch_dtype, trust_remote_code=True
).eval()
self.processor = AutoProcessor.from_pretrained(
model_path, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
self.model = self.model.to(device)

def predict(self, image_path: np.ndarray, settings: dict = None):
if self.runtime == RuntimeType.PYTORCH:
return self._predict_pytorch(image_path, settings)

@torch.no_grad()
def _predict_pytorch(
self, image_path: str, settings: dict = None
) -> List[PredictionBBox]:
def _predict_pytorch(self, image_path: str, settings: dict = None) -> List[PredictionBBox]:
# 1. Preprocess
self.task_prompt = settings.get("task_prompt", self.default_task_prompt)
size_scaler = None
Expand All @@ -101,42 +95,34 @@ def _predict_pytorch(
def predict_batch(self, images_np, settings):
self.task_prompt = settings.get("task_prompt", self.default_task_prompt)
text = settings.get("text", "find all objects")
batch_size = settings.get("batch_size", 2)

images = [Image.fromarray(img) for img in images_np]
images_batched = [
images[i : i + batch_size] for i in range(0, len(images), batch_size)
]
prompt = [self.task_prompt + text] * len(images)

total_predictions = []
for images_batch in images_batched:
prompt = [self.task_prompt + text] * len(images_batch)
inputs = self.processor(
text=prompt, images=images_batch, return_tensors="pt"
).to(self.device, self.torch_dtype)
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_texts = self.processor.batch_decode(
generated_ids, skip_special_tokens=False
inputs = self.processor(text=prompt, images=images, return_tensors="pt").to(
self.device, self.torch_dtype
)
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=False)
parsed_answers = [
self.processor.post_process_generation(
text, task=self.task_prompt, image_size=(img.width, img.height)
)
parsed_answers = [
self.processor.post_process_generation(
text, task=self.task_prompt, image_size=(img.width, img.height)
)
for text, img in zip(generated_texts, images_batch)
]
batch_predictions = []
for answer in parsed_answers:
predictions = self._format_predictions_cp(answer, size_scaler=None)
batch_predictions.append(predictions)
total_predictions.extend(batch_predictions)
return total_predictions
for text, img in zip(generated_texts, images)
]

batch_predictions = []
for answer in parsed_answers:
predictions = self._format_predictions_cp(answer, size_scaler=None)
batch_predictions.append(predictions)
return batch_predictions

def _common_prompt_inference(self, img_input: Image.Image, text: str):
if text == "":
Expand All @@ -153,9 +139,7 @@ def _common_prompt_inference(self, img_input: Image.Image, text: str):
do_sample=False,
num_beams=3,
)
generated_texts = self.processor.batch_decode(
generated_ids, skip_special_tokens=False
)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=False)
parsed_answer = self.processor.post_process_generation(
generated_texts[0],
task=self.task_prompt,
Expand All @@ -170,9 +154,9 @@ def _classes_mapping_inference(self, img_input: Image.Image, mapping: dict):
prompt = self.task_prompt
else:
prompt = self.task_prompt + text
inputs = self.processor(
text=prompt, images=img_input, return_tensors="pt"
).to(self.device, self.torch_dtype)
inputs = self.processor(text=prompt, images=img_input, return_tensors="pt").to(
self.device, self.torch_dtype
)
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
Expand All @@ -181,9 +165,7 @@ def _classes_mapping_inference(self, img_input: Image.Image, mapping: dict):
do_sample=False,
num_beams=3,
)
generated_texts = self.processor.batch_decode(
generated_ids, skip_special_tokens=False
)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=False)
parsed_answer = self.processor.post_process_generation(
generated_texts[0],
task=self.task_prompt,
Expand All @@ -195,9 +177,9 @@ def _classes_mapping_inference(self, img_input: Image.Image, mapping: dict):
def _get_detailed_caption_text(self, img_input: Image.Image) -> str:
logger.info("Text prompt is empty. Getting detailed caption for the image...")
task_prompt = "<DETAILED_CAPTION>"
inputs = self.processor(
text=task_prompt, images=img_input, return_tensors="pt"
).to(self.device, self.torch_dtype)
inputs = self.processor(text=task_prompt, images=img_input, return_tensors="pt").to(
self.device, self.torch_dtype
)
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
Expand All @@ -206,9 +188,7 @@ def _get_detailed_caption_text(self, img_input: Image.Image) -> str:
do_sample=False,
num_beams=3,
)
generated_text = self.processor.batch_decode(
generated_ids, skip_special_tokens=False
)[0]
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = self.processor.post_process_generation(
generated_text,
task=task_prompt,
Expand Down Expand Up @@ -349,9 +329,7 @@ def _load_model(self, deploy_params: dict):
self.model_source = deploy_params.get("model_source")
self.device = deploy_params.get("device")
self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
self.model_precision = (
torch.float16 if torch.cuda.is_available() else torch.float32
)
self.model_precision = torch.float16 if torch.cuda.is_available() else torch.float32
self._hardware = get_hardware_info(self.device)
self.load_model(**deploy_params)
self._model_served = True
Expand Down
3 changes: 1 addition & 2 deletions src/inference_settings.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
task_prompt: "<CAPTION_TO_PHRASE_GROUNDING>"
text: "find all objects"
batch_size: 2
text: "find all objects"

0 comments on commit f30b4fb

Please sign in to comment.