Skip to content

Commit

Permalink
make synchronization device agnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
mamei16 committed Feb 23, 2025
1 parent b9b270b commit f5fe0e4
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def batch_encode(

if device is None:
device = self.device
else:
device = torch.device(device)

self.to(device)

Expand Down Expand Up @@ -273,7 +275,7 @@ def batch_encode(
# fixes for #522 and #487 to avoid oom problems on gpu with large datasets
if convert_to_numpy:
embeddings = embeddings.to("cpu", non_blocking=True)
torch.cuda.synchronize()
sync_device(device)

all_embeddings.extend(embeddings)

Expand Down Expand Up @@ -305,3 +307,16 @@ def batch_encode(
all_embeddings = all_embeddings[0]

return all_embeddings


def sync_device(device: torch.device):
if device.type == "cpu":
return
elif device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize(device)
else:
warnings.warn("Device type does not match 'cuda', 'xpu' or 'mps'. Not synchronizing")

0 comments on commit f5fe0e4

Please sign in to comment.