diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index fed276b3238..97e0cf9b156 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1677,7 +1677,7 @@ def _prepare_cache_for_generation( batch_size: int, max_cache_length: int, device: torch.device, - ) -> bool: + ) -> None: """ Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is instantiated, writes it to `model_kwargs`, under the name expected by the model. diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 3796e2dc5f3..bdd9b74f667 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1062,6 +1062,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): """ config_class = MoshiDepthConfig + _supports_static_cache = False # When switching to true, delete the overwritten `_prepare_cache_for_generation` def __init__(self, config: MoshiDepthConfig): super().__init__(config) @@ -1441,6 +1442,25 @@ def _prepare_4d_causal_attention_mask_with_cache_position( ) return causal_mask + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: Dict, + assistant_model: PreTrainedModel, + batch_size: int, + max_cache_length: int, + device: torch.device, + ) -> None: + """ + Overwritten: Moshi doesn't support compilation, yet it defaults to the sliding window cache implementation. + The sliding window cache is compileable, which may trigger automatic compilation. This overwrite disables it. + """ + super()._prepare_cache_for_generation( + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ) + if "past_key_values" in model_kwargs: + model_kwargs["past_key_values"].is_compileable = False + @add_start_docstrings( "The bare Moshi Model outputting raw hidden-states without any specific head on top.", @@ -1912,6 +1932,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_static_cache = False # When switching to true, delete the overwritten `_prepare_cache_for_generation` def __init__(self, config: MoshiConfig): super().__init__(config) @@ -2727,5 +2748,24 @@ def _reorder_cache( for layer_past in past_key_values ) + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: Dict, + assistant_model: PreTrainedModel, + batch_size: int, + max_cache_length: int, + device: torch.device, + ) -> None: + """ + Overwritten: Moshi doesn't support compilation, yet it defaults to the sliding window cache implementation. + The sliding window cache is compileable, which may trigger automatic compilation. This overwrite disables it. + """ + super()._prepare_cache_for_generation( + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ) + if "past_key_values" in model_kwargs: + model_kwargs["past_key_values"].is_compileable = False + __all__ = ["MoshiForCausalLM", "MoshiForConditionalGeneration", "MoshiModel", "MoshiPreTrainedModel"]