diff --git a/optimum/utils/preprocessing/text_classification.py b/optimum/utils/preprocessing/text_classification.py index ffda27e61c6..d96468f8ee8 100644 --- a/optimum/utils/preprocessing/text_classification.py +++ b/optimum/utils/preprocessing/text_classification.py @@ -46,16 +46,14 @@ def load_datasets(self): ) # Preprocessing the raw_datasets - def preprocess_function( - examples, data_keys: Dict[str, str], tokenizer: PreTrainedTokenizerBase, max_length: int - ): + def preprocess_function(examples, data_keys: Dict[str, str], tokenizer: PreTrainedTokenizerBase): # Tokenize the texts tokenized_inputs = tokenizer( text=examples[data_keys["primary"]], text_pair=examples[data_keys["secondary"]] if data_keys["secondary"] else None, padding="max_length", - max_length=min(max_length, tokenizer.model_max_length), + max_length=tokenizer.model_max_length, truncation=True, ) return tokenized_inputs @@ -74,7 +72,6 @@ def preprocess_function( preprocess_function, tokenizer=self.tokenizer, data_keys=self.data_keys, - max_length=self.max_seq_length, ), batched=True, load_from_cache_file=True, diff --git a/optimum/utils/preprocessing/token_classification.py b/optimum/utils/preprocessing/token_classification.py index 33ea1359359..90c3b5d5c19 100644 --- a/optimum/utils/preprocessing/token_classification.py +++ b/optimum/utils/preprocessing/token_classification.py @@ -52,16 +52,14 @@ def get_label_list(labels): max_eval_samples = 100 # TODO remove this # Preprocessing the raw_datasets - def preprocess_function( - examples, data_keys: Dict[str, str], tokenizer: PreTrainedTokenizerBase, max_length: Optional[int] = None - ): + def preprocess_function(examples, data_keys: Dict[str, str], tokenizer: PreTrainedTokenizerBase): # Tokenize the texts tokenized_inputs = tokenizer( text=examples[data_keys["primary"]], text_pair=examples[data_keys["secondary"]] if data_keys["secondary"] else None, padding="max_length", truncation=True, - max_length=min(max_length, tokenizer.model_max_length), + max_length=tokenizer.model_max_length, is_split_into_words=True, ) return tokenized_inputs @@ -79,7 +77,6 @@ def preprocess_function( preprocess_function, tokenizer=self.tokenizer, data_keys=self.data_keys, - max_length=self.max_seq_length, ), batched=True, load_from_cache_file=True,