Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Falcon: batched generation #26137

Merged
merged 5 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,40 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)

def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
def cos_sin(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm gonna be a bit noisy here, but this looks a LOT like the rotary embedding we have in Llama no?
The query expansion is also supported there, not sure how much of an overhead it is to first apply rotary then expand:

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

and also storing the full size key and values is less memory efficient no? (unrelated to the PR).

Copy link
Member Author

@gante gante Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are the same. So, we would benefit from copying the structure (at least in terms of complexity for us, the maintainers) 👍

I would like to push it to the future, though, as I'm about to go on long holidays and I'd like to enable batched generation on Falcon :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just add a TODO then 😉

self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16
) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self._set_cos_sin_cache(total_length, device, dtype)
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the slicing here is equivalent to building position ids from the sequence length, without taking into account any potential left-padding

self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
)
# Gather cos, sin at the designated position ids
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
return cos, sin

def forward(self, query, key, past_key_values_length, position_ids):
_, seq_len, _ = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype)
# Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to
# avoid unnecessary repeat_interleave operations.
query_expansion_factor = int(query.shape[0] / cos.shape[0])
if query_expansion_factor > 1:
query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0)
query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0)
else:
query_cos, query_sin = cos, sin

key_expansion_factor = int(key.shape[0] / cos.shape[0])
if key_expansion_factor > 1:
if key_expansion_factor != query_expansion_factor:
key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0)
key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0)
else:
key_cos, key_sin = query_cos, query_sin
else:
key_cos, key_sin = cos, sin

def forward(self, query, key, past_key_values_length=0):
batch, seq_len, head_dim = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
return (query * query_cos) + (rotate_half(query) * query_sin), (key * key_cos) + (rotate_half(key) * key_sin)


class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
Expand Down Expand Up @@ -270,7 +291,7 @@ def __init__(self, config: FalconConfig):
f" {self.num_heads})."
)

self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k)
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t, p: (q, k)

# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
Expand Down Expand Up @@ -378,6 +399,7 @@ def forward(
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
Expand All @@ -399,7 +421,7 @@ def forward(
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)

past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)

if layer_past is not None:
past_key, past_value = layer_past
Expand All @@ -415,7 +437,8 @@ def forward(
else:
present = None

attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
Copy link
Member Author

@gante gante Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Rocketknight1 this 1e-9 was causing problems in some numerical precisions (it would be converted to -inf) :p

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my bad!

float_min = torch.finfo(query_layer.dtype).min
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(query_layer.dtype)

query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
Expand Down Expand Up @@ -536,6 +559,7 @@ def forward(
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
Expand All @@ -554,6 +578,7 @@ def forward(
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
Expand Down Expand Up @@ -632,6 +657,11 @@ def forward(
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.

[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

Expand Down Expand Up @@ -836,6 +866,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down Expand Up @@ -892,6 +923,14 @@ def forward(
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

causal_mask = self._prepare_attn_mask(
attention_mask,
Expand Down Expand Up @@ -922,13 +961,15 @@ def custom_forward(*inputs):
hidden_states,
alibi,
causal_mask,
position_ids,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -988,13 +1029,23 @@ def prepare_inputs_for_generation(
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
if past_key_values is not None:
input_ids = input_ids[:, -1:]

# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

return {
"input_ids": input_ids,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
Expand All @@ -1011,6 +1062,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
Expand All @@ -1032,6 +1084,7 @@ def forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down
41 changes: 39 additions & 2 deletions tests/models/falcon/test_modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@

from parameterized import parameterized

from transformers import AutoConfig, AutoModel, AutoTokenizer, FalconConfig, is_torch_available, set_seed
from transformers.testing_utils import CaptureLogger, require_torch, slow, tooslow, torch_device
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
FalconConfig,
is_torch_available,
set_seed,
)
from transformers.testing_utils import CaptureLogger, require_bitsandbytes, require_torch, slow, tooslow, torch_device
from transformers.utils import logging as transformers_logging

from ...generation.test_utils import GenerationTesterMixin
Expand Down Expand Up @@ -502,6 +510,35 @@ def test_lm_generation_use_cache(self):
outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)

@require_bitsandbytes
@slow
def test_batched_generation(self):
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
"tiiuae/falcon-7b",
device_map="auto",
load_in_4bit=True,
)

test_text = "A sequence: 1, 2" # should generate the rest of the sequence

unpadded_inputs = tokenizer([test_text], return_tensors="pt").to("cuda:0")
unpadded_inputs.pop("token_type_ids")
unpadded_gen_out = model.generate(**unpadded_inputs, max_new_tokens=20)
unpadded_gen_text = tokenizer.batch_decode(unpadded_gen_out, skip_special_tokens=True)

dummy_text = "This is a longer text " * 2 # forces left-padding on `test_text`
padded_inputs = tokenizer([test_text, dummy_text], return_tensors="pt", padding=True).to("cuda:0")
padded_inputs.pop("token_type_ids")
padded_gen_out = model.generate(**padded_inputs, max_new_tokens=20)
padded_gen_text = tokenizer.batch_decode(padded_gen_out, skip_special_tokens=True)

expected_output = "A sequence: 1, 2, 3, 4, 5, 6, 7, 8, "
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
self.assertEqual(unpadded_gen_text[0], expected_output)
self.assertEqual(padded_gen_text[0], expected_output)


# TODO Lysandre: Remove this in version v4.34
class FalconOverrideTest(unittest.TestCase):
Expand Down