diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 9d74247c..806e9833 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -5,7 +5,7 @@ from typing import List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase, PretrainedConfig -from text_generation_server.models.types import Batch, GeneratedText +from text_generation_server.models.types import Batch, Generation from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) @@ -52,7 +52,7 @@ def batch_type(self) -> Type[B]: raise NotImplementedError @abstractmethod - def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: + def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]: raise NotImplementedError def warmup(self, batch: B) -> Optional[int]: