From 8754e75652b73d26b49092f0524cbecbe4893a3d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 31 Jan 2025 11:49:30 +0000 Subject: [PATCH 1/2] moshi cant compile --- .../models/moshi/modeling_moshi.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 3796e2dc5f3..49e1f4e8ef2 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1062,6 +1062,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): """ config_class = MoshiDepthConfig + _supports_static_cache = False # When switching this to true, delete the overwritten `_get_cache` method def __init__(self, config: MoshiDepthConfig): super().__init__(config) @@ -1441,6 +1442,15 @@ def _prepare_4d_causal_attention_mask_with_cache_position( ) return causal_mask + def _get_cache(self, *args, **kwargs): + """ + Overwritten: Moshi doesn't support compilation, yet it defaults to the sliding window cache implementation. + The sliding window cache is compileable, which may trigger automatic compilation. This overwrite disables it. + """ + cache = super()._get_cache(*args, **kwargs) + cache.is_compileable = False + return cache + @add_start_docstrings( "The bare Moshi Model outputting raw hidden-states without any specific head on top.", @@ -1912,6 +1922,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_static_cache = False # When switching this to true, delete the overwritten `_get_cache` method def __init__(self, config: MoshiConfig): super().__init__(config) @@ -2727,5 +2738,14 @@ def _reorder_cache( for layer_past in past_key_values ) + def _get_cache(self, *args, **kwargs): + """ + Overwritten: Moshi doesn't support compilation, yet it defaults to the sliding window cache implementation. + The sliding window cache is compileable, which may trigger automatic compilation. This overwrite disables it. + """ + cache = super()._get_cache(*args, **kwargs) + cache.is_compileable = False + return cache + __all__ = ["MoshiForCausalLM", "MoshiForConditionalGeneration", "MoshiModel", "MoshiPreTrainedModel"] From 6e2885980c2d6bdcf61eda72615bb2c223d8ce96 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 31 Jan 2025 12:00:57 +0000 Subject: [PATCH 2/2] disable compilation on external caches as well --- src/transformers/generation/utils.py | 2 +- .../models/moshi/modeling_moshi.py | 40 ++++++++++++++----- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index fed276b3238..97e0cf9b156 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1677,7 +1677,7 @@ def _prepare_cache_for_generation( batch_size: int, max_cache_length: int, device: torch.device, - ) -> bool: + ) -> None: """ Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is instantiated, writes it to `model_kwargs`, under the name expected by the model. diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 49e1f4e8ef2..bdd9b74f667 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1062,7 +1062,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): """ config_class = MoshiDepthConfig - _supports_static_cache = False # When switching this to true, delete the overwritten `_get_cache` method + _supports_static_cache = False # When switching to true, delete the overwritten `_prepare_cache_for_generation` def __init__(self, config: MoshiDepthConfig): super().__init__(config) @@ -1442,14 +1442,24 @@ def _prepare_4d_causal_attention_mask_with_cache_position( ) return causal_mask - def _get_cache(self, *args, **kwargs): + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: Dict, + assistant_model: PreTrainedModel, + batch_size: int, + max_cache_length: int, + device: torch.device, + ) -> None: """ Overwritten: Moshi doesn't support compilation, yet it defaults to the sliding window cache implementation. The sliding window cache is compileable, which may trigger automatic compilation. This overwrite disables it. """ - cache = super()._get_cache(*args, **kwargs) - cache.is_compileable = False - return cache + super()._prepare_cache_for_generation( + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ) + if "past_key_values" in model_kwargs: + model_kwargs["past_key_values"].is_compileable = False @add_start_docstrings( @@ -1922,7 +1932,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_static_cache = False # When switching this to true, delete the overwritten `_get_cache` method + _supports_static_cache = False # When switching to true, delete the overwritten `_prepare_cache_for_generation` def __init__(self, config: MoshiConfig): super().__init__(config) @@ -2738,14 +2748,24 @@ def _reorder_cache( for layer_past in past_key_values ) - def _get_cache(self, *args, **kwargs): + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: Dict, + assistant_model: PreTrainedModel, + batch_size: int, + max_cache_length: int, + device: torch.device, + ) -> None: """ Overwritten: Moshi doesn't support compilation, yet it defaults to the sliding window cache implementation. The sliding window cache is compileable, which may trigger automatic compilation. This overwrite disables it. """ - cache = super()._get_cache(*args, **kwargs) - cache.is_compileable = False - return cache + super()._prepare_cache_for_generation( + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ) + if "past_key_values" in model_kwargs: + model_kwargs["past_key_values"].is_compileable = False __all__ = ["MoshiForCausalLM", "MoshiForConditionalGeneration", "MoshiModel", "MoshiPreTrainedModel"]