-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Falcon: batched generation #26137
Falcon: batched generation #26137
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,19 +99,40 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): | |
self.cos_cached = self.cos_cached.type(dtype) | ||
self.sin_cached = self.sin_cached.type(dtype) | ||
|
||
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: | ||
def cos_sin( | ||
self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16 | ||
) -> torch.Tensor: | ||
total_length = seq_len + past_key_values_length | ||
if total_length > self.seq_len_cached: | ||
self._set_cos_sin_cache(total_length, device, dtype) | ||
return ( | ||
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the slicing here is equivalent to building position ids from the sequence length, without taking into account any potential left-padding |
||
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length], | ||
) | ||
# Gather cos, sin at the designated position ids | ||
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] | ||
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] | ||
return cos, sin | ||
|
||
def forward(self, query, key, past_key_values_length, position_ids): | ||
_, seq_len, _ = query.shape | ||
cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype) | ||
# Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to | ||
# avoid unnecessary repeat_interleave operations. | ||
query_expansion_factor = int(query.shape[0] / cos.shape[0]) | ||
if query_expansion_factor > 1: | ||
query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0) | ||
query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0) | ||
else: | ||
query_cos, query_sin = cos, sin | ||
|
||
key_expansion_factor = int(key.shape[0] / cos.shape[0]) | ||
if key_expansion_factor > 1: | ||
if key_expansion_factor != query_expansion_factor: | ||
key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0) | ||
key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0) | ||
else: | ||
key_cos, key_sin = query_cos, query_sin | ||
else: | ||
key_cos, key_sin = cos, sin | ||
|
||
def forward(self, query, key, past_key_values_length=0): | ||
batch, seq_len, head_dim = query.shape | ||
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) | ||
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) | ||
return (query * query_cos) + (rotate_half(query) * query_sin), (key * key_cos) + (rotate_half(key) * key_sin) | ||
|
||
|
||
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): | ||
|
@@ -270,7 +291,7 @@ def __init__(self, config: FalconConfig): | |
f" {self.num_heads})." | ||
) | ||
|
||
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k) | ||
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t, p: (q, k) | ||
|
||
# Layer-wise attention scaling | ||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) | ||
|
@@ -378,6 +399,7 @@ def forward( | |
hidden_states: torch.Tensor, | ||
alibi: Optional[torch.Tensor], | ||
attention_mask: torch.Tensor, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||
head_mask: Optional[torch.Tensor] = None, | ||
use_cache: bool = False, | ||
|
@@ -399,7 +421,7 @@ def forward( | |
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) | ||
|
||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] | ||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) | ||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids) | ||
|
||
if layer_past is not None: | ||
past_key, past_value = layer_past | ||
|
@@ -415,7 +437,8 @@ def forward( | |
else: | ||
present = None | ||
|
||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @Rocketknight1 this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, my bad! |
||
float_min = torch.finfo(query_layer.dtype).min | ||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(query_layer.dtype) | ||
|
||
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) | ||
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) | ||
|
@@ -536,6 +559,7 @@ def forward( | |
hidden_states: torch.Tensor, | ||
alibi: Optional[torch.Tensor], | ||
attention_mask: torch.Tensor, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||
head_mask: Optional[torch.Tensor] = None, | ||
use_cache: bool = False, | ||
|
@@ -554,6 +578,7 @@ def forward( | |
attention_layernorm_out, | ||
layer_past=layer_past, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
alibi=alibi, | ||
head_mask=head_mask, | ||
use_cache=use_cache, | ||
|
@@ -632,6 +657,11 @@ def forward( | |
- 0 for tokens that are **masked**. | ||
|
||
[What are attention masks?](../glossary#attention-mask) | ||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | ||
config.n_positions - 1]`. | ||
|
||
[What are position IDs?](../glossary#position-ids) | ||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): | ||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: | ||
|
||
|
@@ -836,6 +866,7 @@ def forward( | |
input_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
head_mask: Optional[torch.LongTensor] = None, | ||
inputs_embeds: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
|
@@ -892,6 +923,14 @@ def forward( | |
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) | ||
else: | ||
alibi = None | ||
if position_ids is None: | ||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
position_ids = torch.arange( | ||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device | ||
) | ||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | ||
else: | ||
position_ids = position_ids.view(-1, seq_length).long() | ||
|
||
causal_mask = self._prepare_attn_mask( | ||
attention_mask, | ||
|
@@ -922,13 +961,15 @@ def custom_forward(*inputs): | |
hidden_states, | ||
alibi, | ||
causal_mask, | ||
position_ids, | ||
head_mask[i], | ||
) | ||
else: | ||
outputs = block( | ||
hidden_states, | ||
layer_past=layer_past, | ||
attention_mask=causal_mask, | ||
position_ids=position_ids, | ||
head_mask=head_mask[i], | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
|
@@ -988,13 +1029,23 @@ def prepare_inputs_for_generation( | |
input_ids: torch.LongTensor, | ||
past_key_values: Optional[torch.Tensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.Tensor] = None, | ||
**kwargs, | ||
) -> dict: | ||
if past_key_values is not None: | ||
input_ids = input_ids[:, -1:] | ||
|
||
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. | ||
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None: | ||
# create position_ids on the fly for batch generation | ||
position_ids = attention_mask.long().cumsum(-1) - 1 | ||
position_ids.masked_fill_(attention_mask == 0, 1) | ||
if past_key_values: | ||
position_ids = position_ids[:, -1].unsqueeze(-1) | ||
|
||
return { | ||
"input_ids": input_ids, | ||
"position_ids": position_ids, | ||
"past_key_values": past_key_values, | ||
"use_cache": kwargs.get("use_cache"), | ||
"attention_mask": attention_mask, | ||
|
@@ -1011,6 +1062,7 @@ def forward( | |
input_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
head_mask: Optional[torch.Tensor] = None, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
labels: Optional[torch.Tensor] = None, | ||
|
@@ -1032,6 +1084,7 @@ def forward( | |
input_ids, | ||
past_key_values=past_key_values, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
head_mask=head_mask, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm gonna be a bit noisy here, but this looks a LOT like the rotary embedding we have in Llama no?
The query expansion is also supported there, not sure how much of an overhead it is to first apply rotary then expand:
and also storing the full size key and values is less memory efficient no? (unrelated to the PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they are the same. So, we would benefit from copying the structure (at least in terms of complexity for us, the maintainers) 👍
I would like to push it to the future, though, as I'm about to go on long holidays and I'd like to enable batched generation on Falcon :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just add a TODO then 😉