Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mieb] Fix torch no grad #1378

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 33 additions & 34 deletions mteb/models/e5_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,33 @@ def get_image_embeddings(
):
all_image_embeddings = []

if isinstance(images, DataLoader):
for batch_images in tqdm(images):
img_inputs = self.processor(
[self.img_prompt] * len(batch_images),
batch_images,
return_tensors="pt",
padding=True,
).to("cuda")
image_outputs = self.model(
**img_inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
all_image_embeddings.append(image_outputs.cpu())
with torch.no_grad():
for i in tqdm(range(0, len(images), batch_size)):
batch_images = images[i : i + batch_size]
img_inputs = self.processor(
[self.img_prompt] * len(batch_images),
batch_images,
return_tensors="pt",
padding=True,
).to("cuda")
image_outputs = self.model(
**img_inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
all_image_embeddings.append(image_outputs.cpu())
return torch.cat(all_image_embeddings, dim=0)
if isinstance(images, DataLoader):
for batch_images in tqdm(images):
img_inputs = self.processor(
[self.img_prompt] * len(batch_images),
batch_images,
return_tensors="pt",
padding=True,
).to("cuda")
image_outputs = self.model(
**img_inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
all_image_embeddings.append(image_outputs.cpu())
else:
for i in tqdm(range(0, len(images), batch_size)):
batch_images = images[i : i + batch_size]
img_inputs = self.processor(
[self.img_prompt] * len(batch_images),
batch_images,
return_tensors="pt",
padding=True,
).to("cuda")
image_outputs = self.model(
**img_inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
all_image_embeddings.append(image_outputs.cpu())
return torch.cat(all_image_embeddings, dim=0)

def calculate_probs(self, text_embeddings, image_embeddings):
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
Expand All @@ -112,8 +113,8 @@ def get_fused_embeddings(
all_fused_embeddings = []

if texts is not None and images is not None:
if isinstance(images, DataLoader):
with torch.no_grad():
with torch.no_grad():
if isinstance(images, DataLoader):
for index, batch_images in enumerate(tqdm(images)):
batch_texts = texts[
index * batch_size : (index + 1) * batch_size
Expand All @@ -128,12 +129,11 @@ def get_fused_embeddings(
**inputs, output_hidden_states=True, return_dict=True
).hidden_states[-1][:, -1, :]
all_fused_embeddings.append(outputs.cpu())
else:
if len(texts) != len(images):
raise ValueError(
"The number of texts and images must have the same length"
)
with torch.no_grad():
else:
if len(texts) != len(images):
raise ValueError(
"The number of texts and images must have the same length"
)
for i in tqdm(range(0, len(images), batch_size)):
batch_texts = texts[i : i + batch_size]
batch_images = images[i : i + batch_size]
Expand All @@ -148,7 +148,6 @@ def get_fused_embeddings(
).hidden_states[-1][:, -1, :]
all_fused_embeddings.append(outputs.cpu())
return torch.cat(all_fused_embeddings, dim=0)

elif texts is not None:
text_embeddings = self.get_text_embeddings(texts, batch_size)
return text_embeddings
Expand Down
26 changes: 11 additions & 15 deletions mteb/models/vlm2vec_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,22 +254,18 @@ def get_fused_embeddings(
all_fused_embeddings = []
if isinstance(images, DataLoader):
import torchvision.transforms.functional as F

for batch in images:
for b in batch:
text = next(texts)
inputs = self.processor(
f"<|image_1|> Represent the given image with the following question: {text}",
[F.to_pil_image(b.to("cpu"))],
)
inputs = {
key: value.to(self.device) for key, value in inputs.items()
}
outputs = self.encode_input(inputs)
all_fused_embeddings.append(outputs.cpu())

with torch.no_grad():
for batch in images:
for b in batch:
text = next(texts)
inputs = self.processor(
f"<|image_1|> Represent the given image with the following question: {text}",
[F.to_pil_image(b.to("cpu"))],
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.encode_input(inputs)
all_fused_embeddings.append(outputs.cpu())
fused_embeddings = torch.cat(all_fused_embeddings, dim=0)

return fused_embeddings
elif text_embeddings is not None:
return text_embeddings
Expand Down
Loading