Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs #776

Merged
merged 5 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions docs/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The sampling_params.
# The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
Expand All @@ -23,7 +23,7 @@ class GenerateReqInput:
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return.
top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs.
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
# Whether to stream output.
stream: bool = False
Expand All @@ -32,27 +32,28 @@ class GenerateReqInput:
The `sampling_params` follows this format

```python
class SamplingParams:
def __init__(
self,
max_new_tokens: int = 16,
stop: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
dtype: Optional[str] = None,
regex: Optional[str] = None,
) -> None:
# The maximum number of output tokens
max_new_tokens: int = 16,
# Stop when hitting any of the strings in this list.
stop: Optional[Union[str, List[str]]] = None,
# Sampling temperature
temperature: float = 1.0,
# Top-p sampling
top_p: float = 1.0,
# Top-k sampling
top_k: int = -1,
# Whether to ignore EOS token.
ignore_eos: bool = False,
# Whether to skip the special tokens during detokenization.
skip_special_tokens: bool = True,
# Whether to add spaces between special tokens during detokenization.
spaces_between_special_tokens: bool = True,
# Constrains the output to follow a given regular expression.
regex: Optional[str] = None,
# Do parallel sampling and return `n` outputs.
n: int = 1,
```

- `max_new_tokens`, `stop`, `temperature`, `top_p`, `top_k` are common sampling parameters.
- `ignore_eos` means ignoring the EOS token and continue decoding, which is helpful for benchmarking purposes.
- `regex` constrains the output to follow a given regular expression.

## Examples

### Normal
Expand Down
8 changes: 4 additions & 4 deletions examples/usage/choices_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def main():
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
print("-" * 50)

# Run a batch
Expand All @@ -34,8 +34,8 @@ def main():
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
print("-" * 50)


Expand Down
14 changes: 7 additions & 7 deletions examples/usage/cot_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
top_logprobs_num=get_top_k,
return_text_in_logprobs=True,
)
logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0]
logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0]

print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
for idx, (f, token) in enumerate(zip(forks, logprobs)):
Expand All @@ -55,9 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
)

# calculate probability disparity between the top and secondary tokens
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]]
delta = (sum(x1s) - sum(x2s)) / len(x1s)

# extract the answer span (without the '<|end_of_text|>' token)
Expand All @@ -81,19 +81,19 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
answer_tokens = [
xt[0][2]
for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs"
"output_top_logprobs"
]
]
answer_x1s = [
exp(xt[0][0])
for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs"
"output_top_logprobs"
]
]
answer_x2s = [
exp(xt[1][0])
for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs"
"output_top_logprobs"
]
]

Expand Down
26 changes: 13 additions & 13 deletions examples/usage/json_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def srt_api_request(name):
# fout.write(json.dumps(res, indent=4))

meta_info = res["meta_info"]
assert len(meta_info["prefill_token_logprobs"]) == len(
meta_info["prefill_top_logprobs"]
assert len(meta_info["input_token_logprobs"]) == len(
meta_info["input_top_logprobs"]
)
assert len(meta_info["decode_token_logprobs"]) == len(
meta_info["decode_top_logprobs"]
assert len(meta_info["output_token_logprobs"]) == len(
meta_info["output_top_logprobs"]
)
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1

return res

Expand All @@ -72,21 +72,21 @@ def pretty_print(res):
meta_info = res["meta_info"]

print("\n\n", "=" * 30, "Prefill", "=" * 30)
for i in range(len(meta_info["prefill_token_logprobs"])):
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
for i in range(len(meta_info["input_token_logprobs"])):
print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = (
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
if meta_info["prefill_top_logprobs"][i]
[str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
if meta_info["input_top_logprobs"][i]
else []
)
for top_k in top_ks:
print(f"{top_k: <15}", end="")
print()

print("\n\n", "=" * 30, "Decode", "=" * 30)
for i in range(len(meta_info["decode_token_logprobs"])):
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
for i in range(len(meta_info["output_token_logprobs"])):
print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
for top_k in top_ks:
print(f"{top_k: <15}", end="")
print()
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/lang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,14 @@ def select(
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
decision = choices[np.argmax(normalized_prompt_logprobs)]
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]

return (
decision,
normalized_prompt_logprobs,
prefill_token_logprobs,
decode_token_logprobs,
input_token_logprobs,
output_token_logprobs,
)

def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,16 +541,16 @@ def _execute_select(self, expr: SglSelect):
(
decision,
normalized_prompt_logprobs,
prefill_token_logprobs,
decode_token_logprobs,
input_token_logprobs,
output_token_logprobs,
) = self.backend.select(self, expr.choices, expr.temperature)
if expr.name is not None:
name = expr.name
self.variables[name] = decision
self.meta_info[name] = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"prefill_token_logprobs": prefill_token_logprobs,
"decode_token_logprobs": decode_token_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
}
self.variable_event[name].set()
self.text_ += decision
Expand Down
74 changes: 35 additions & 39 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ class LogitProcessorOutput:

# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs: torch.Tensor
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor

# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs: List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs: List
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs: List
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs: List


@dataclasses.dataclass
Expand Down Expand Up @@ -58,20 +58,16 @@ def __init__(self, config):
self.tp_size = get_tensor_model_parallel_world_size()

def _get_normalized_prompt_logprobs(
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
self, input_token_logprobs, logits_metadata: LogitsMetadata
):
logprobs_cumsum = torch.cumsum(
prefill_token_logprobs, dim=0, dtype=torch.float32
)
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)

start = logits_metadata.extend_start_loc.clone()
end = start + logits_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_token_logprobs[start]
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
Expand All @@ -83,34 +79,34 @@ def _get_normalized_prompt_logprobs(
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = []
output_top_logprobs = []
for i in range(all_logprobs.shape[0]):
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.tolist()
p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, output_top_logprobs
else:
prefill_top_logprobs, decode_top_logprobs = [], []
input_top_logprobs, output_top_logprobs = [], []
pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
input_top_logprobs.append([])
output_top_logprobs.append([])
continue
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
prefill_top_logprobs.append(
input_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_len

return prefill_top_logprobs, decode_top_logprobs
return input_top_logprobs, output_top_logprobs

def forward(
self,
Expand Down Expand Up @@ -150,9 +146,9 @@ def forward(
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
else:
# When logprob is requested, compute the logits for all tokens.
Expand All @@ -164,19 +160,19 @@ def forward(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
decode_top_logprobs = self.get_top_logprobs(
output_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
else:
decode_top_logprobs = None
output_top_logprobs = None

return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=output_top_logprobs,
)
else:
all_logits = torch.matmul(hidden_states, weight.T)
Expand All @@ -193,32 +189,32 @@ def forward(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None
input_top_logprobs = output_top_logprobs = None

last_logprobs = all_logprobs[last_index]

# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_token_logprobs = all_logprobs[
input_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]

normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, logits_metadata
input_token_logprobs, logits_metadata
)

return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_top_logprobs=decode_top_logprobs,
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_top_logprobs=output_top_logprobs,
)


Expand Down
Loading
Loading