Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Prevent to from being ignored #2351

Merged
merged 4 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self, model_name_or_path: Optional[str] = None,
device = get_device_name()
logger.info("Use pytorch device_name: {}".format(device))

self._target_device = torch.device(device)
self.to(device)

def encode(self, sentences: Union[str, List[str]],
batch_size: int = 32,
Expand Down Expand Up @@ -167,7 +167,7 @@ def encode(self, sentences: Union[str, List[str]],
input_was_string = True

if device is None:
device = self._target_device
device = self.device

self.to(device)

Expand Down Expand Up @@ -658,7 +658,7 @@ def fit(self,
from torch.cuda.amp import autocast
scaler = torch.cuda.amp.GradScaler()

self.to(self._target_device)
self.to(self.device)

dataloaders = [dataloader for dataloader, _ in train_objectives]

Expand All @@ -668,7 +668,7 @@ def fit(self,

loss_models = [loss for _, loss in train_objectives]
for loss_model in loss_models:
loss_model.to(self._target_device)
loss_model.to(self.device)

self.best_score = -9999999

Expand Down Expand Up @@ -724,8 +724,8 @@ def fit(self,
data = next(data_iterator)

features, labels = data
labels = labels.to(self._target_device)
features = list(map(lambda batch: batch_to_device(batch, self._target_device), features))
labels = labels.to(self.device)
features = list(map(lambda batch: batch_to_device(batch, self.device), features))

if use_amp:
with autocast():
Expand Down Expand Up @@ -949,3 +949,14 @@ def max_seq_length(self, value):
Property to set the maximal input sequence length for the model. Longer inputs will be truncated.
"""
self._first_module().max_seq_length = value

@property
def _target_device(self) -> torch.device:
logger.warning(
"`SentenceTransformer._target_device` has been removed, please use `SentenceTransformer.device` instead.",
)
return self.device

@_target_device.setter
def _target_device(self, device: Optional[Union[int, str, torch.device]] = None) -> None:
self.to(device)
19 changes: 19 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Tests general behaviour of the SentenceTransformer class
"""


from pathlib import Path
import tempfile

Expand Down Expand Up @@ -45,3 +46,21 @@ def test_load_with_safetensors(self):
torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)),
msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings",
)

@unittest.skipUnless(torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_to(self):
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu")

test_device = torch.device("cuda")
self.assertEqual(model.device.type, "cpu")
self.assertEqual(test_device.type, "cuda")

model.to(test_device)
self.assertEqual(model.device.type, "cuda", msg="The model device should have updated")

model.encode("Test sentence")
self.assertEqual(model.device.type, "cuda", msg="Encoding shouldn't change the device")

self.assertEqual(model._target_device, model.device, msg="Prevent backwards compatibility failure for _target_device")
model._target_device = "cpu"
self.assertEqual(model.device.type, "cpu", msg="Ensure that setting `_target_device` doesn't crash.")