Skip to content

Commit

Permalink
reverted back loop
Browse files Browse the repository at this point in the history
  • Loading branch information
vegaluisjose committed Feb 2, 2024
1 parent 082527f commit 754bb59
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,16 +330,16 @@ def generate(
cache.pending_copy_from_to = []

try:
next_logits = logits
# If JSON mode (constrained sampling) is enabled
if request.sampling_params.logits_processor is not None:
cs_logits = torch.from_dlpack(logits.to_dlpack())
for i, (sequence_id, request) in enumerate(zip(sequence_ids,requests)):
cs_input_ids = request.token_ids if isinstance(request, DecodeRequest) else []
cs_logits[i] = request.sampling_params.logits_processor(sequence_id, cs_input_ids, cs_logits[i])
next_logits = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(cs_logits))

next_tokens = sample(next_logits, sampling_params, self.vocab_size)
cs_logits = torch.from_dlpack(logits.to_dlpack())

for i, (sequence_id, request) in enumerate(zip(sequence_ids,requests)):
if request.sampling_params.logits_processor is not None:
cs_input_ids = request.token_ids if isinstance(request, DecodeRequest) else []
cs_logits[i] = request.sampling_params.logits_processor(sequence_id, cs_input_ids, cs_logits[i])

new_logits = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(cs_logits))

next_tokens = sample(new_logits, sampling_params, self.vocab_size)
assert next_tokens is not None
outputs = []
for i, (sequence_id, new_token) in enumerate(
Expand Down

0 comments on commit 754bb59

Please sign in to comment.