-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
Copy pathengine_v2.py
239 lines (181 loc) · 9.16 KB
/
engine_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import json
import pickle
from typing import Iterable, Tuple
import torch
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.comm.comm import init_distributed
from .model_implementations import InferenceV2Policy
from .logging import inference_logger
from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor
from .scheduling_utils import SchedulingError, SchedulingResult
from .model_implementations.flat_model_helpers import make_param_filename, make_metadata_filename
from .config_v2 import RaggedInferenceEngineConfig
INFERENCE_MODEL_TIMER = "model-forward-inference"
class InferenceEngineV2:
_config: RaggedInferenceEngineConfig
"""
Configuration of the inference engine.
"""
#_model: DSInferenceModelBase
"""
Inference model supporting ragged inference.
"""
_state_manager: DSStateManager
"""
Persistent state manager for sequences and KV-cache.
"""
@property
def free_blocks(self) -> int:
"""
Number of free KV blocks.
"""
return self._state_manager.free_blocks
def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngineConfig) -> None:
"""
Create the Inference V2 engine.
Arguments:
policy (InferenceV2Policy): Policy for the model implementation. This policy object
will be used to build the model and load the checkpoint associated with it.
engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine.
"""
self._config = engine_config
self._policy = policy
self._base_mp_group = self._initialize_tp_group()
# Build model from policy
inference_logger().info("Building model...")
self._model = self._policy.build_model(self._config, self._base_mp_group)
inference_logger().info("Model built.")
# Create state manager
self._batch = RaggedBatchWrapper(self._config.state_manager)
self._state_manager = DSStateManager(self._config.state_manager,
self._model.kv_cache_config(),
base_mp_group=self._base_mp_group)
self._model.set_state_manager(self._state_manager)
def _initialize_tp_group(self):
"""
Implementation of our TP group initialization.
"""
init_distributed()
local_rank = int(os.getenv("LOCAL_RANK", 0))
get_accelerator().set_device(local_rank)
if local_rank >= self._config.tensor_parallel.tp_size:
raise RuntimeError("Local rank is greater than TP size, ensure that the TP config is correct.")
ranks = list(range(self._config.tensor_parallel.tp_size))
return dist.new_group(ranks=ranks)
def put(self, batch_uids: Iterable[int], batch_tokens: Iterable[torch.Tensor]) -> torch.Tensor:
"""
Put a ragged batch onto the inference engine. This will perform one forward and return
a Tensor of the shape [len(batch_uids), *output_shape]. Logits for the non-final tokens
are not calculated.
Arguments:
batch_uids: Iterable of uids for the batch on the host
batch_tokens: Iterable of token tensors for the batch on the host
"""
token_lens = [len(tokens) for tokens in batch_tokens]
schedule_check = self.can_schedule(batch_uids, token_lens)
if schedule_check != SchedulingResult.Success:
raise SchedulingError(schedule_check)
self._batch.clear()
for uid, tokens in zip(batch_uids, batch_tokens):
host_seq_desc = self._state_manager.get_or_create_sequence(uid)
self._model.maybe_allocate_kv(host_seq_desc, tokens.numel())
host_seq_desc.pre_forward(tokens.numel())
# We can disable checks since we already validated schedulability.
self._batch.insert_sequence(host_seq_desc, tokens, do_checks=False)
# Send all metadata to the device
self._batch.finalize()
# Prep all data structures for the actual forward (in anticipation of CG in the future)
# and also to amortize some of the costs in a more straightforward way.
self._model.prepare_batch(self._batch)
# Model implementation will pick up in the forward.
logits = self._model.forward(self._batch)
# We return one set of logits per sequence in the batch (saves cost on unembedding)
assert logits.shape[0] == self._batch.current_sequences
for uid in batch_uids:
host_seq_desc = self._state_manager.get_sequence(uid)
host_seq_desc.post_forward() # Updates sequence metadata.
self._model.maybe_free_kv(host_seq_desc)
return logits
def query(self, uid: int, max_request_tokens: int, max_request_blocks) -> Tuple[int, int]:
"""
Determine the number of tokens and KV blocks to reserve for a given request. Given a UID
(this UID may not be recognized by the model yet), this will return the number of tokens
and blocks to reserve for the request.
Arguments:
uid (int): The UID of the sequence (as tracked by the scheduling entity). If
this is a new sequence (with a UID unknown to the inference engine), then
an empty placeholder is created to pass to the occupancy logic.
n_tokens (int): The number of tokens to hypothetically send.
Returns:
Tuple[int, Optional[int]]: Tuple of free kv blocks and the number of blocks
required to schedule the sequence.
"""
seq_desc = self._state_manager.get_sequence(uid)
if seq_desc is None:
if (self._state_manager.n_tracked_sequences == self._config.state_manager.max_tracked_sequences):
return (0, 0)
seq_desc = PlaceholderSequenceDescriptor()
req_tokens, req_blocks = self._model.get_kv_requirements(seq_desc, max_request_tokens, max_request_blocks)
return (req_tokens, req_blocks)
def can_schedule(self, uids: Iterable[int], lengths: Iterable[int]) -> SchedulingResult:
"""
Dry run a batch to determine if it can be scheduled. Placeholder sequences will be
created for any UIDs that are unknown to the inference engine.
Arguments:
uids (Iterable[int]): Iterable of UIDs for the batch
lengths (Iterable[int]): Iterable of lengths for each sequence of the batch. This lengths
corresponds to the number of tokens to send in the hypothetical forward; history
tokens will be determined via UID lookup and future tokens are disregarded.
Returns:
bool: True if the batch can be scheduled, False otherwise.
"""
cur_seqs = self._state_manager.n_tracked_sequences
free_blocks = self._state_manager.free_blocks
req_blocks = 0
batch_len = 0
if len(uids) > self._config.state_manager.max_ragged_sequence_count:
# Can only compose a batch from a limited number of sequences
return SchedulingResult.BatchSequenceLimitExceeded
for uid, length in zip(uids, lengths):
seq_desc = self._state_manager.get_sequence(uid)
if seq_desc is None:
cur_seqs += 1
seq_desc = PlaceholderSequenceDescriptor()
sched_len, sched_blocks = self._model.get_kv_requirements(seq_desc, length, free_blocks)
if sched_len != length:
# We ran out of KV cache
return SchedulingResult.KVCacheLimitExceeded
batch_len += length
free_blocks -= sched_blocks
if cur_seqs > self._config.state_manager.max_tracked_sequences:
# Would run out of tracking metadata
return SchedulingResult.EngineSequenceLimitExceeded
if batch_len > self._config.state_manager.max_ragged_batch_size:
# Would exceed the maximum batch size
return SchedulingResult.BatchTokenLimitExceeded
return SchedulingResult.Success
def flush(self, uid: int) -> None:
"""
Remove all state associated with a sequence from the inference engine.
Arguments:
uid (int): The UID of the sequence to flush.
"""
self._state_manager.flush_sequence(uid)
def serialize(self, save_path: str) -> None:
"""
Serialize the model to a file.
Arguments:
path (str): Path to the file to serialize to.
"""
param_file_name = make_param_filename(save_path, self._model.tp_rank, self._model.tp_size)
metadata_file_name = make_metadata_filename(save_path, self._model.tp_rank, self._model.tp_size)
# Save the flattened parameters
torch.save(self._model.flattened_params, param_file_name)
json.dump(self._model.flattened_param_metadata.json(), open(metadata_file_name, "w"))
if self._model.tp_rank == 0:
pickle.dump(self._model._config, open(os.path.join(save_path, "ds_model_config.pkl"), "wb"))