From e5a2bc445a61ad3ca9fd1822fcc1d37295a89424 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 4 Nov 2024 05:52:19 +0000 Subject: [PATCH] Fix torch no grad --- mteb/models/e5_v.py | 67 +++++++++++++++++------------------ mteb/models/vlm2vec_models.py | 26 ++++++-------- 2 files changed, 44 insertions(+), 49 deletions(-) diff --git a/mteb/models/e5_v.py b/mteb/models/e5_v.py index 5647bee380..4ead48a7cf 100644 --- a/mteb/models/e5_v.py +++ b/mteb/models/e5_v.py @@ -64,32 +64,33 @@ def get_image_embeddings( ): all_image_embeddings = [] - if isinstance(images, DataLoader): - for batch_images in tqdm(images): - img_inputs = self.processor( - [self.img_prompt] * len(batch_images), - batch_images, - return_tensors="pt", - padding=True, - ).to("cuda") - image_outputs = self.model( - **img_inputs, output_hidden_states=True, return_dict=True - ).hidden_states[-1][:, -1, :] - all_image_embeddings.append(image_outputs.cpu()) with torch.no_grad(): - for i in tqdm(range(0, len(images), batch_size)): - batch_images = images[i : i + batch_size] - img_inputs = self.processor( - [self.img_prompt] * len(batch_images), - batch_images, - return_tensors="pt", - padding=True, - ).to("cuda") - image_outputs = self.model( - **img_inputs, output_hidden_states=True, return_dict=True - ).hidden_states[-1][:, -1, :] - all_image_embeddings.append(image_outputs.cpu()) - return torch.cat(all_image_embeddings, dim=0) + if isinstance(images, DataLoader): + for batch_images in tqdm(images): + img_inputs = self.processor( + [self.img_prompt] * len(batch_images), + batch_images, + return_tensors="pt", + padding=True, + ).to("cuda") + image_outputs = self.model( + **img_inputs, output_hidden_states=True, return_dict=True + ).hidden_states[-1][:, -1, :] + all_image_embeddings.append(image_outputs.cpu()) + else: + for i in tqdm(range(0, len(images), batch_size)): + batch_images = images[i : i + batch_size] + img_inputs = self.processor( + [self.img_prompt] * len(batch_images), + batch_images, + return_tensors="pt", + padding=True, + ).to("cuda") + image_outputs = self.model( + **img_inputs, output_hidden_states=True, return_dict=True + ).hidden_states[-1][:, -1, :] + all_image_embeddings.append(image_outputs.cpu()) + return torch.cat(all_image_embeddings, dim=0) def calculate_probs(self, text_embeddings, image_embeddings): text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) @@ -112,8 +113,8 @@ def get_fused_embeddings( all_fused_embeddings = [] if texts is not None and images is not None: - if isinstance(images, DataLoader): - with torch.no_grad(): + with torch.no_grad(): + if isinstance(images, DataLoader): for index, batch_images in enumerate(tqdm(images)): batch_texts = texts[ index * batch_size : (index + 1) * batch_size @@ -128,12 +129,11 @@ def get_fused_embeddings( **inputs, output_hidden_states=True, return_dict=True ).hidden_states[-1][:, -1, :] all_fused_embeddings.append(outputs.cpu()) - else: - if len(texts) != len(images): - raise ValueError( - "The number of texts and images must have the same length" - ) - with torch.no_grad(): + else: + if len(texts) != len(images): + raise ValueError( + "The number of texts and images must have the same length" + ) for i in tqdm(range(0, len(images), batch_size)): batch_texts = texts[i : i + batch_size] batch_images = images[i : i + batch_size] @@ -148,7 +148,6 @@ def get_fused_embeddings( ).hidden_states[-1][:, -1, :] all_fused_embeddings.append(outputs.cpu()) return torch.cat(all_fused_embeddings, dim=0) - elif texts is not None: text_embeddings = self.get_text_embeddings(texts, batch_size) return text_embeddings diff --git a/mteb/models/vlm2vec_models.py b/mteb/models/vlm2vec_models.py index 321cba24a4..d75236a93a 100644 --- a/mteb/models/vlm2vec_models.py +++ b/mteb/models/vlm2vec_models.py @@ -254,22 +254,18 @@ def get_fused_embeddings( all_fused_embeddings = [] if isinstance(images, DataLoader): import torchvision.transforms.functional as F - - for batch in images: - for b in batch: - text = next(texts) - inputs = self.processor( - f"<|image_1|> Represent the given image with the following question: {text}", - [F.to_pil_image(b.to("cpu"))], - ) - inputs = { - key: value.to(self.device) for key, value in inputs.items() - } - outputs = self.encode_input(inputs) - all_fused_embeddings.append(outputs.cpu()) - + with torch.no_grad(): + for batch in images: + for b in batch: + text = next(texts) + inputs = self.processor( + f"<|image_1|> Represent the given image with the following question: {text}", + [F.to_pil_image(b.to("cpu"))], + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + outputs = self.encode_input(inputs) + all_fused_embeddings.append(outputs.cpu()) fused_embeddings = torch.cat(all_fused_embeddings, dim=0) - return fused_embeddings elif text_embeddings is not None: return text_embeddings