diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index e9dc2f8560..eaf8438191 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -330,16 +330,16 @@ def generate( cache.pending_copy_from_to = [] try: - next_logits = logits - # If JSON mode (constrained sampling) is enabled - if request.sampling_params.logits_processor is not None: - cs_logits = torch.from_dlpack(logits.to_dlpack()) - for i, (sequence_id, request) in enumerate(zip(sequence_ids,requests)): - cs_input_ids = request.token_ids if isinstance(request, DecodeRequest) else [] - cs_logits[i] = request.sampling_params.logits_processor(sequence_id, cs_input_ids, cs_logits[i]) - next_logits = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(cs_logits)) - - next_tokens = sample(next_logits, sampling_params, self.vocab_size) + cs_logits = torch.from_dlpack(logits.to_dlpack()) + + for i, (sequence_id, request) in enumerate(zip(sequence_ids,requests)): + if request.sampling_params.logits_processor is not None: + cs_input_ids = request.token_ids if isinstance(request, DecodeRequest) else [] + cs_logits[i] = request.sampling_params.logits_processor(sequence_id, cs_input_ids, cs_logits[i]) + + new_logits = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(cs_logits)) + + next_tokens = sample(new_logits, sampling_params, self.vocab_size) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate(