-
Notifications
You must be signed in to change notification settings - Fork 148
/
Copy pathflash_causal_lm.py
2119 lines (1836 loc) · 87.3 KB
/
flash_causal_lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import math
import os
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, ContextManager, Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
import torch.distributed
import torch.profiler
from loguru import logger
from opentelemetry import trace
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer, GenerationConfig, PretrainedConfig, PreTrainedTokenizerBase
from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata
from lorax_server.models.metadata_kernels import (
block_tables_to_padded,
block_tables_to_ragged,
copy_next_input_ids_inplace,
has_triton,
prepare_position_slot_ids,
slots_filtering,
)
from lorax_server.models.model import Model
from lorax_server.models.types import (
AlternativeTokens,
Batch,
GeneratedText,
Generation,
NextTokens,
)
from lorax_server.pb import generate_pb2
from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, create_merged_weight_files
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM, initialize_torch_distributed
from lorax_server.utils.graph import GraphCache
from lorax_server.utils.import_utils import get_cuda_free_memory
from lorax_server.utils.punica import LORAX_PUNICA_TRITON_DISABLED, PunicaWrapper
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments
from lorax_server.utils.sources import HUB
from lorax_server.utils.sources.hub import weight_files
from lorax_server.utils.state import (
BLOCK_SIZE,
FLASH_INFER,
get_max_prefill_tokens,
get_speculative_tokens,
get_supports_chunking,
warmup_mode,
)
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.utils.torch_utils import is_fp8, is_fp8_kv, is_fp8_supported
from lorax_server.utils.weights import Weights
ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1"))
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
tracer = trace.get_tracer(__name__)
@dataclass
class FlashCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# request id -> idx in list mapping
requests_idx_mapping: Dict[int, int]
# Decoder values
# Can be a list for easy filtering
# If `input_ids` is a list, it needs to be materialized to a tensor first
input_ids: Union[torch.Tensor, List[List[int]]]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
position_ids: Optional[torch.Tensor]
# Spculative decoding values
speculative_ids: Optional[torch.Tensor]
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slot_indices: Optional[torch.Tensor]
# list of length b of list of length s_i // block_size
block_tables: List[List[int]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slots: Optional[torch.Tensor]
# list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch
# used for filtering
cu_slots: torch.Tensor
max_input_length: int
max_current_length: int
# Whether this batch contains at least one request that is prefilling
prefilling: bool
# Whether each request is prefilling
prefilling_mask: List[bool]
prefilling_mask_tensor: Optional[torch.Tensor]
# Prefill metadata tensors to efficiently compute logprobs
# tensor of length b+1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill # noqa: E501
cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_head_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_next_token_indices: Optional[torch.tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_cu_outlens: Optional[List[int]]
# Will be set by `generate_token` and reset after each prefill forward
prefill_logprob_tokens: List[Optional[NextTokens]]
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor
# Lengths of all generations present in the batch
input_lengths: List[int]
# size [b], containing the number of blocks that can be retrieved from the cache
cache_lengths: List[int]
prompt_lengths: List[int]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
input_lengths_tensor: Optional[torch.Tensor]
cache_lengths_tensor: Optional[torch.Tensor]
prompt_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]]
# Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
# Adapter metadata for each request
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
adapter_meta: AdapterBatchMetadata
# Number of blocks in this batch
num_blocks: int
# Maximum number of blocks
max_blocks: int
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.num_blocks * BLOCK_SIZE,
current_tokens=(
sum([len(i) for i in self.input_ids]) if isinstance(self.input_ids, list) else len(self.input_ids)
),
)
@classmethod
def to_pb_embed(self, batch, embeddings) -> generate_pb2.EmbedResponse:
embeddings_proto = []
for i, embedding in enumerate(embeddings):
embeddings_proto.append(generate_pb2.Embedding(request_id=batch.requests[i].id, values=embedding))
return generate_pb2.EmbedResponse(embeddings=embeddings_proto)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
processor,
config,
dtype: torch.dtype,
device: torch.device,
batch_tokenized_inputs=None,
) -> "FlashCausalLMBatch":
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
if batch_tokenized_inputs is None:
batch_inputs = []
max_truncation = 0
for request in pb.requests:
inputs = tokenizers.get_inputs(request, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, request.truncate)
if all(
request.HasField("tokenized_inputs") and len(request.tokenized_inputs.ids) > 0
for request in pb.requests
):
batch_tokenized_inputs = [request.tokenized_inputs.ids[-max_truncation:] for request in pb.requests]
else:
batch_tokenized_inputs = tokenizer(batch_inputs, truncation=True, max_length=max_truncation)[
"input_ids"
]
speculative_tokens = get_speculative_tokens()
cache_lengths = []
input_lengths = []
prompt_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
all_postfix_ids = []
requests_idx_mapping = {}
slots = []
cu_slots = [0]
next_token_chooser_parameters = []
stopping_criterias = []
num_blocks = 0
max_input_length = 0
max_current_length = 0
max_length = 0
max_blocks = 0
cu_blocks = [0]
block_tables = []
block_tables_ragged = []
# Parse batch
for i, (request, tokenized_input) in enumerate(zip(pb.requests, batch_tokenized_inputs)):
# request id -> idx in list mapping
requests_idx_mapping[request.id] = i
tokenized_input = tokenized_input[-request.truncate :]
prompt_length = len(tokenized_input)
prompt_lengths.append(prompt_length)
cache_length = request.cache_len
assert cache_length <= prompt_length, f"Prefix {cache_length} vs input {prompt_length}"
if cache_length == prompt_length:
assert False, "unreachable"
# `chunk_len` is an optional field in the protobuf
# It is only set if the model support chunking
if request.HasField("chunk_len"):
input_length = request.chunk_len
if cache_length + input_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment
assert speculative_tokens == 0
assert get_supports_chunking()
assert input_length > 0
postfix_ids = tokenized_input[cache_length : cache_length + input_length]
assert len(postfix_ids) == input_length, "Rust and Python tokenizers are not aligned"
else:
# Use all the remaining ids
postfix_ids = tokenized_input[cache_length:]
input_length = len(postfix_ids)
input_lengths.append(input_length)
prefix_offsets.append(prompt_length - 5)
read_offsets.append(prompt_length)
all_postfix_ids.append(postfix_ids)
all_input_ids.append(tokenized_input)
next_token_chooser_parameters.append(request.parameters)
stopping_criteria = StoppingCriteria.from_pb(request.stopping_parameters, tokenizer)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
# adapter_indices_list.append(torch.full((input_length,), r.adapter_index))
# adapter_set.add(r.adapter_index)
# Tokens that need to be mapped to blocks.
# Remove one as the first token des not have a past
block_tokens = prompt_length + max_new_tokens - 1 + speculative_tokens
# blocks and slots can be empty (for example in warmup)
if not request.blocks:
needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
request_blocks = [b for b in range(num_blocks, num_blocks + needed_blocks)]
request_slots = [s for b in request_blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)]
else:
request_blocks = request.blocks
request_slots = request.slots
block_tables.append(request_blocks)
block_tables_ragged.extend(request_blocks)
cu_blocks.append(len(block_tables_ragged))
slots.extend(request_slots)
cu_slots.append(len(slots))
cache_lengths.append(cache_length)
num_blocks += len(request_blocks)
# Update
max_blocks = max(max_blocks, len(request_blocks))
max_input_length = max(max_input_length, input_length)
max_current_length = max(max_current_length, cache_length + input_length)
max_length = max(
max_length,
prompt_length + max_new_tokens + speculative_tokens,
)
# always use the base model tokenizer for the next token chooser until we revisit adding back support
# for per-request tokenizers
request_tokenizers = [tokenizer for _ in pb.requests]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, request_tokenizers, dtype, device
)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros((len(all_input_ids), max_length), dtype=np.int64)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64, device=device)
block_tables_ragged = torch.tensor(block_tables_ragged, device=device, dtype=torch.int32)
cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
block_tables_tensor = torch.empty(
(len(block_tables), max_blocks),
device=device,
dtype=torch.int32,
)
# If the device supports Triton, we can use a fused kernel
if has_triton():
block_tables_to_padded(max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged)
else:
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32, device=device)
slots = torch.tensor(slots, dtype=torch.int64, device=device)
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
prefilling_mask = [True] * len(pb.requests)
prefilling_mask_tensor = torch.tensor(prefilling_mask, dtype=torch.bool, device=device)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=all_postfix_ids,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
cache_lengths=cache_lengths,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=True,
prefilling_mask=prefilling_mask,
prefilling_mask_tensor=prefilling_mask_tensor,
prefill_logprob_tokens=[None] * len(pb.requests),
input_lengths=input_lengths,
prompt_lengths=prompt_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=None,
prompt_lengths_tensor=prompt_lengths_tensor,
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=None,
slots=slots,
cu_slots=cu_slots,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
cache_lengths_tensor=None,
input_lengths_tensor=None,
adapter_meta=None,
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
device = self.block_tables_tensor.device
# New values after filtering
requests_idx_mapping = {}
# Used to index into tensors
indices = []
# slots to keep after filtering
if not has_triton():
# slots to keep after filtering
slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool, device=device)
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_input_length = 0
max_current_length = 0
requests = []
block_tables = []
all_input_ids = []
input_ids = []
prompt_lengths = []
input_lengths = []
cache_lengths = []
prefix_offsets = []
read_offsets = []
cu_slots = [0]
prefilling_mask = []
prefill_logprob_tokens = []
stopping_criterias = []
adapter_list = []
num_blocks = 0
max_blocks = 0
max_slots = 0
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
# Prefilling
request_prefilling = self.prefilling_mask[idx]
prefilling_mask.append(request_prefilling)
# Get length
request_input_length = self.input_lengths[idx]
request_cache_length = self.cache_lengths[idx]
max_input_length = max(max_input_length, request_input_length)
max_current_length = max(max_current_length, request_cache_length + request_input_length)
all_input_ids.append(self.all_input_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx])
input_lengths.append(request_input_length)
cache_lengths.append(request_cache_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
adapter_list.append(self.requests[idx].adapter_index)
request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slot = self.cu_slots[idx]
end_slot = self.cu_slots[idx + 1]
slot_length = end_slot - start_slot
if not has_triton():
# Set slice
slot_filtering_indices[start_slot:end_slot] = True
cu_slots.append(cumulative_slot_tokens + slot_length)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
else:
# Copy to tensor (CPU)
slot_indices[i] = cumulative_slot_tokens + request_cache_length
cumulative_slot_tokens += slot_length
max_blocks = max(max_blocks, len(request_block_table))
max_slots = max(max_slots, slot_length)
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
if not has_triton():
slots = self.slots[slot_filtering_indices]
else:
slots = self.slots.new_empty(cumulative_slot_tokens)
gpu_cu_slots = cu_slots.to(device)
slots_indexing_start = self.cu_slots.to(device)[indices]
slots_filtering(max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start)
if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
prefilling_mask_tensor = self.prefilling_mask_tensor[indices]
else:
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices]
prefilling_mask_tensor = None
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32, device=device)
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_list=adapter_list,
adapter_set=set(adapter_list),
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return type(self)(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
speculative_ids=speculative_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
cu_slots=cu_slots,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=self.prefilling,
prefilling_mask=prefilling_mask,
prefilling_mask_tensor=prefilling_mask_tensor,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
num_blocks=num_blocks,
max_blocks=max_blocks,
adapter_meta=adapter_meta,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
prefilling = False
num_blocks = 0
total_batch_size = 0
total_slots = 0
max_blocks = 0
max_length = 0
max_input_length = 0
max_current_length = 0
for b in batches:
total_batch_size += len(b)
max_blocks = max(max_blocks, b.max_blocks)
# If `b` is prefilling and was just filtered, `b.slots` is None
# `total_slots` is not used if any of the batches is prefilling
total_slots += len(b.slots) if not b.prefilling else 0
num_blocks += b.num_blocks
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
max_input_length = max(max_input_length, b.max_input_length)
max_current_length = max(max_current_length, b.max_current_length)
max_length = max(
max_length,
max(
prompt_length + stopping_criteria.max_new_tokens + speculative_length
for prompt_length, stopping_criteria in zip(b.prompt_lengths, b.stopping_criterias)
),
)
prefilling = prefilling or b.prefilling
slots = batches[0].slots.new_empty(total_slots)
cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
if prefilling:
input_ids = []
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
prefilling_mask_tensor = batches[0].prefilling_mask_tensor.new_empty(total_batch_size)
adapter_meta = None
adapter_segment_builder = None
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(total_batch_size)
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(total_batch_size)
prefilling_mask_tensor = None
total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size)
adapter_segment_builder = SegmentConcatBuilder()
adapter_list = []
adapter_set = set()
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(total_batch_size)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros((total_batch_size, max_blocks))
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros((total_batch_size, max_length))
block_tables = []
cache_lengths = []
all_input_ids = []
prompt_lengths = []
input_lengths = []
prefix_offsets = []
read_offsets = []
prefill_logprob_tokens = []
next_token_chooser_parameters = []
sequence_processors = []
stopping_criterias = []
prefilling_mask = []
# Cumulative length
cumulative_batch_size = 0
cumulative_slots = 0
cumulative_adapter_indices_size = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
# Copy tensors (GPU)
all_input_ids_tensor[start_index:end_index, : batch.all_input_ids_tensor.shape[1]] = (
batch.all_input_ids_tensor[:, :max_length]
)
block_tables_tensor[start_index:end_index, : batch.block_tables_tensor.shape[1]] = (
batch.block_tables_tensor[:, :max_blocks]
)
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
slots[slots_start_index:slots_end_index] = batch.slots
cu_slots[start_index + 1 : end_index + 1] = batch.cu_slots[1:] + cumulative_slots
if not prefilling:
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = cumulative_adapter_indices_size + batch.adapter_meta.adapter_indices.shape[0]
adapter_indices[adapter_start_index:adapter_end_index] = batch.adapter_meta.adapter_indices
cumulative_adapter_indices_size = adapter_end_index
adapter_list.extend(batch.adapter_meta.adapter_list)
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
else:
if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
input_ids.extend(batch.input_ids)
prefilling_mask_tensor[start_index:end_index] = batch.prefilling_mask_tensor
prefilling_mask.extend(batch.prefilling_mask)
block_tables.extend(batch.block_tables)
cache_lengths.extend(batch.cache_lengths)
all_input_ids.extend(batch.all_input_ids)
prompt_lengths.extend(batch.prompt_lengths)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
if batch.next_token_chooser.schema_processor is not None:
sequence_processors.extend(batch.next_token_chooser.schema_processor.sequence_processors)
else:
# No sequence processors, so pad with Nones
sequence_processors.extend([None for _ in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
# Update
cumulative_slots += len(batch.slots)
cumulative_batch_size += len(batch)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
tokenizers=[],
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
sequence_processors=sequence_processors,
)
# We skip computing the speculative_ids when the batch size is too large, so
# we must check that all batches have them, otherwise they must be discarded
speculative_ids = None
if get_speculative_tokens() > 0:
if all(b.speculative_ids is not None for b in batches):
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
else:
logger.info("Discarding speculative IDs, not every batch has them")
if adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_list=adapter_list,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
speculative_ids=speculative_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
slots=slots,
cu_slots=cu_slots,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=prefilling,
prefilling_mask=prefilling_mask,
prefilling_mask_tensor=prefilling_mask_tensor,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
num_blocks=num_blocks,
max_blocks=max_blocks,
adapter_meta=adapter_meta,
)
def prepare_for_prefill(self):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
# Prepare values if we need to continue prefilling
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert self.speculative_ids is None
device = self.block_tables_tensor.device
if isinstance(self.input_ids, list):
if len(self) > 1:
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
else:
input_ids = self.input_ids[0]
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32, device=device)
self.cu_seqlen_prefill = torch.nn.functional.pad(torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)).to(
torch.int32
)
self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32, device=device)
# If the device supports Triton, we can use a fused kernel
if has_triton():
self.position_ids = torch.empty(len(self.input_ids), dtype=torch.int32, device=device)
self.slot_indices = torch.empty(len(self.input_ids), dtype=torch.int64, device=device)
cu_slots_gpu = self.cu_slots.to(device)
prepare_position_slot_ids(
self.max_input_length,
self.cache_lengths_tensor,
self.cu_seqlen_prefill,
cu_slots_gpu,
self.position_ids,
self.slot_indices,
)
position_ids = []
slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_cu_outlens = [0]
# Cumulative length
cumulative_length = 0
cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0
adapter_indices_list = []
adapter_list = []
for i, (
r,
cache_length,
input_length,
prompt_length,
request_prefilling,
blocks,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prompt_lengths,
self.prefilling_mask,
self.block_tables,
)
):
next_chunk_length = input_length
if not has_triton():
# Position ids
request_position_ids = torch.arange(cache_length, cache_length + input_length, dtype=torch.int32)
position_ids.append(request_position_ids)
if not r.slots:
request_slots = [s for b in blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)]
else:
request_slots = r.slots
request_slot_indices = torch.arange(
cache_length + cumulative_slot_tokens,
cache_length + cumulative_slot_tokens + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Update
cumulative_slot_tokens += len(request_slots)
# Create tensor to slice into the kv tensor in prefill
if SLIDING_WINDOW is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - SLIDING_WINDOW),
cumulative_length + input_length,
dtype=torch.int64,
)
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs:
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
if SLIDING_WINDOW is not None:
prefill_cache_indices.append(request_prefill_cache_indices)
adapter_indices_list.append(torch.full((next_chunk_length,), r.adapter_index))
adapter_list.append(r.adapter_index)
# Update
cumulative_length += next_chunk_length
if not all_prefill_logprobs and not no_prefill_logprobs:
prefill_head_indices = []
prefill_next_token_indices = []
# Cumulative length
cumulative_length = 0
prefill_out_cumulative_length = 0
for i, (
r,
input_length,
request_prefilling,
) in enumerate(
zip(
self.requests,
self.input_lengths,
self.prefilling_mask,
)
):
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
if prefill_logprobs:
prefill_head_indices.append(
torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length + input_length - 1)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int64,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
if len(self) > 1:
if position_ids:
position_ids = torch.cat(position_ids)
if slot_indices:
slot_indices = torch.cat(slot_indices)
if SLIDING_WINDOW is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
if position_ids:
position_ids = position_ids[0]
if slot_indices:
slot_indices = slot_indices[0]
if SLIDING_WINDOW is not None:
prefill_cache_indices = prefill_cache_indices[0]
if not has_triton():
self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device)
self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None
if all_prefill_logprobs: