Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine the llava-next warmup #264

Open
wants to merge 1 commit into
base: habana-main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 75 additions & 98 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,17 @@
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 256))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get('PT_HPU_LAZY_MODE', 1))
BATCH_BUCKET_SIZE= int(os.environ.get('BATCH_BUCKET_SIZE', 1))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 1))


PREFILL_WARMUP_BATCH_SIZE_LIST = []
PREFILL_WARMUP_SEQLEN_LIST = []
DECODE_WARMUP_BATCH_SIZE_LIST = []
def round_up(warmup_list:list, num) :
i = 0
for i in warmup_list:
if num <= i :
break
return i

def round_up(number, k):
return (number + k - 1) // k * k


def split(string) -> List[Dict[str, str]]:
parts = []
Expand Down Expand Up @@ -216,7 +217,7 @@ def from_tokenized(
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence
# was filtered out
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
parameters = [r.parameters for r in pb.requests]
# append the dummy parameters for dummy request
parameters = pad_next_token_chooser_parameters(parameters, new_bs)
Expand All @@ -235,7 +236,7 @@ def from_tokenized(
left_padding = max_input_length - input_len
if is_warmup is False:
if input_len < max_input_length :
rounded_seq_len = round_up(PREFILL_WARMUP_SEQLEN_LIST, input_len + 1)
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1
else:
Expand Down Expand Up @@ -333,7 +334,7 @@ def batch_tokenized_inputs(
missing_inputs = 0
dummy_images = None
if is_warmup is False:
new_bs = round_up(PREFILL_WARMUP_BATCH_SIZE_LIST, len(requests))
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
missing_inputs = new_bs - len(requests)
if missing_inputs > 0:
dummy_inputs = []
Expand Down Expand Up @@ -427,7 +428,7 @@ def recombine(cls, batches: List["VlmCausalLMBatch"], pad_token_id: int, is_warm
total_requests = sum(len(b) for b in batches)
new_bs = total_requests
if is_warmup is False :
new_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, total_requests)
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
batch_id = batches[0].batch_id
device = batches[0].input_ids.device

Expand Down Expand Up @@ -897,9 +898,9 @@ def generate_token(
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id, is_warmup)

scenario = 'PREFILL' if prefill else 'GENERATE'
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size) != self.prev_bs:
if self.enable_hpu_graph and self.limit_hpu_graph and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs:
self.model.clear_cache()
self.prev_bs = round_up(DECODE_WARMUP_BATCH_SIZE_LIST, batch.batch_size)
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
dbg_trace(
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
#assert batch.right_padding > 0, 'No more room for next token!'
Expand Down Expand Up @@ -1112,116 +1113,92 @@ def generate_warmup_batch(self, request, seq_len, batch_size, is_warmup):
return self.batch_from_pb(batch, is_warmup)

def warmup(self, request) -> None:
is_warmup = True
batch = self.batch_from_pb(request.batch, is_warmup)

MAX_TOTAL_TOKENS = request.max_total_tokens
MAX_BATCH_TOTAL_TOKENS = request.max_batch_total_tokens
batch = self.batch_from_pb(request.batch, is_warmup=True)
max_prefill_batch_size = batch.input_ids.shape[0]
try:
# max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
_, prefill_batch, _ = self.generate_token([batch])
except:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
)

global BASE_IMAGE_TOKENS, MAX_TOTAL_TOKENS, MAX_BATCH_TOTAL_TOKENS, PREFILL_WARMUP_BATCH_SIZE_LIST, PREFILL_WARMUP_SEQLEN_LIST, DECODE_WARMUP_BATCH_SIZE_LIST
max_input_length = batch.input_ids.shape[1]
max_prefill_batch_size = batch.input_ids.shape[0]
PREFILL_WARMUP_BATCH_SIZE_LIST = []
batch_size = 1
while batch_size <= max_prefill_batch_size:
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
batch_size = batch_size * 2
if PREFILL_WARMUP_BATCH_SIZE_LIST[-1] < max_prefill_batch_size :
PREFILL_WARMUP_BATCH_SIZE_LIST.append(max_prefill_batch_size)

seq_len = BASE_IMAGE_TOKENS
PREFILL_WARMUP_SEQLEN_LIST = []
i = 0
while seq_len <= max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
i += 1
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)

#Prefill and decode warmup
DECODE_WARMUP_BATCH_SIZE_LIST = []
prefill_batch = None
decode_batch = None
try:
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
batch = self.generate_warmup_batch(request, seq_len, batch_size, is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
_, decode_batch, _ = self.generate_token([prefill_batch], is_warmup)

DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
del prefill_batch

# Warmup prefill batch_size
max_input_length = request.max_input_length
prefill_batch_size_list = [batch for batch in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE)]
prefill_batch_size_list.append(max_prefill_batch_size)
prefill_seqlen_list = [round_up(seq, PAD_SEQUENCE_TO_MULTIPLE_OF) for seq in range(BASE_IMAGE_TOKENS, max_input_length, PAD_SEQUENCE_TO_MULTIPLE_OF)]
prefill_seqlen_list.append(max_input_length)
prefill_batch_size_list.sort(reverse=True)
prefill_seqlen_list.sort(reverse=True)
try:
for batch_size in prefill_batch_size_list:
for seq_len in prefill_seqlen_list:
batch = self.generate_warmup_batch(request, seq_len-1, batch_size)
_, prefill_batch, _ = self.generate_token([batch])
except:
prefill_batch_size_list.sort()
prefill_seqlen_list.sort()
raise RuntimeError(
f"Not enough memory to handle following prefill and decode warmup."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
f"Not enough memory to run following prefill batch_size."
f"Prefill batch size list:{prefill_batch_size_list}"
f"Prefill sequence length list:{prefill_seqlen_list}"
f"You need to decrease `--max-batch-prefill-tokens`"
)

prefill_seqlen_list.sort()
prefill_batch_size_list.sort()
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing prefill and decode warmup successfully.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats} "
)
f"\nFollowing prefill warmup successfully.\n"
f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{prefill_seqlen_list}\n"
f"Memory stats: {mem_stats} "
)

#warmup decode batch size
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
batch_size = max_prefill_batch_size * 2
# Decode warmup with bigger batch_size
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
decode_batch_size_list = [i for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)]
decode_batch_size_list.append(max_decode_batch_size)
decode_batch_size_list.sort(reverse=True)

try:
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size and batch_size <= max_decode_batch_size:
batches = []
for i in range(int(batch_size/max_prefill_batch_size)) :
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
for batch_size in decode_batch_size_list:
batches= []
iters = math.floor(batch_size/max_prefill_batch_size)
for i in range(iters):
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, max_prefill_batch_size)
_, prefill_batch, _ = self.generate_token([batch])
batches.append(prefill_batch)
while batch_size <= max_decode_batch_size:
_, decode_batch, _ = self.generate_token(batches, is_warmup)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
batch_size = batch_size * 2
batches.clear()

for i in range(int(batch_size/max_prefill_batch_size)) :
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], DECODE_WARMUP_BATCH_SIZE_LIST[-1], is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch)
if batch_size % max_prefill_batch_size != 0:
batch = self.generate_warmup_batch(request, PAD_SEQUENCE_TO_MULTIPLE_OF - 1, batch_size % max_prefill_batch_size)
_, prefill_batch, _ = self.generate_token([batch])
batches.append(prefill_batch)

_, decode_batch, _ = self.generate_token(batches)
_, decode_batch, _ = self.generate_token([decode_batch])
del decode_batch
batches.clear()
if DECODE_WARMUP_BATCH_SIZE_LIST[-1] < max_decode_batch_size:
max_decode_batch_size = math.floor( max_decode_batch_size / 2) * 2
batch_size = max_decode_batch_size
for i in range(int(max_decode_batch_size / 2)) :
batch = self.generate_warmup_batch(request, PREFILL_WARMUP_SEQLEN_LIST[0], 2, is_warmup)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)
batches.append(prefill_batch)
_, decode_batch, _ = self.generate_token(batches, is_warmup)
DECODE_WARMUP_BATCH_SIZE_LIST.append(max_decode_batch_size)
max_batch_total_tokens = max_decode_batch_size * MAX_TOTAL_TOKENS
MAX_BATCH_TOTAL_TOKENS = max_batch_total_tokens
except :
raise RuntimeError(
f"Not enough memory to handle batch_size({batch_size}) decode warmup."
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
f"max_decode_batch_size is {max_decode_batch_size}"
f"You need to decrease env `MAX_BATCH_TOTAL_TOKENS` or '--max_batch_total_tokens'"
)

except:
raise RuntimeError(
f"Not enough memory to warmup decode batch_sizes({decode_batch_size_list})."
f"You need to decrease `--max-batch-total-tokens`"
)

decode_batch_size_list.sort()
MAX_BATCH_TOTAL_TOKENS = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats}"
)
f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{decode_batch_size_list}\n"
f"Memory stats: {mem_stats} "
)

return MAX_BATCH_TOTAL_TOKENS