Skip to content

Commit

Permalink
Merge pull request #2710 from tomaarsen/add_generated_from_trainer_tag
Browse files Browse the repository at this point in the history
Add "generated_from_trainer" tag to auto-generated model cards
  • Loading branch information
tomaarsen authored Jun 4, 2024
2 parents 2224477 + 10d70ff commit aa57232
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 2 additions & 0 deletions sentence_transformers/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self, trainer: "SentenceTransformerTrainer", default_args_dict: Dic
trainer.model.model_card_data.code_carbon_callback = callbacks[0]

trainer.model.model_card_data.trainer = trainer
if "generated_from_trainer" not in trainer.model.model_card_data.tags:
trainer.model.model_card_data.tags.append("generated_from_trainer")

def on_init_end(
self,
Expand Down
10 changes: 9 additions & 1 deletion tests/test_model_card_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer


@pytest.mark.parametrize(
Expand All @@ -22,3 +22,11 @@ def test_model_card_data(revision, expected_base_revision) -> None:
assert len(model.model_card_data.base_model_revision) == 40
else:
assert model.model_card_data.base_model_revision == expected_base_revision


def test_generated_from_trainer_tag(stsb_bert_tiny_model: SentenceTransformer) -> None:
model = stsb_bert_tiny_model

assert "generated_from_trainer" not in model.model_card_data.tags
SentenceTransformerTrainer(model)
assert "generated_from_trainer" in model.model_card_data.tags

0 comments on commit aa57232

Please sign in to comment.