Skip to content

Commit

Permalink
Update dry_penalty.py
Browse files Browse the repository at this point in the history
  • Loading branch information
81549361 authored Sep 9, 2024
1 parent 62b7e59 commit 7360ec2
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions python/sglang/srt/sampling/penaltylib/penalizers/dry_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,40 +72,45 @@ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
batch_size, seq_length = logits.shape[0], logits.shape[1]
max_back_length = 50 # Limit the backward match to 50 to prevent overflow
for i in range(batch_size):
for i in range(1):
if self.output_ids is not None:
input_ids = self.input_ids[i] = torch.cat(
[self.input_ids[i], self.output_ids], dim=0
)
if i < len(self.input_ids):
input_ids = self.input_ids[i] = torch.cat(
[self.input_ids[i], self.output_ids], dim=0
)
else:
input_ids = self.input_ids[i] = torch.cat(
[self.input_ids, self.output_ids], dim=0
)
else:
input_ids = self.input_ids[i]
input_ids = input_ids.tolist()
range_limit = min(self.ranges[i].item(), len(input_ids))
range_limit = min(self.ranges[0].item(), len(input_ids))
input_ids = input_ids[-range_limit:] if range_limit > 0 else input_ids
last_token = input_ids[-1]
if last_token in self.sequence_breakers[i]:
if last_token in self.sequence_breakers[0]:
continue

match_indices = [idx for idx, val in enumerate(input_ids[:-1]) if val == last_token]
match_lengths = defaultdict(int)

for idx in match_indices:
next_token = input_ids[idx + 1]
if next_token in self.sequence_breakers[i]:
if next_token in self.sequence_breakers[0]:
continue
match_length = 1
while match_length < max_back_length and idx - match_length >= 0:
previous_token = input_ids[-(match_length + 1)]
if input_ids[idx - match_length] != previous_token:
break
if previous_token in self.sequence_breakers[i]:
if previous_token in self.sequence_breakers[0]:
break
match_length += 1
match_lengths[next_token] = max(match_length, match_lengths[next_token])

for token, match_length in match_lengths.items():
if match_length >= self.allowed_lengths[i].item():
penalty = self.multipliers[i].item() * self.bases[i].item() ** (match_length - self.allowed_lengths[i].item())
if match_length >= self.allowed_lengths[0].item():
penalty = self.multipliers[0].item() * self.bases[0].item() ** (match_length - self.allowed_lengths[0].item())
logits[i, token] -= penalty

return logits
Expand All @@ -124,4 +129,4 @@ def _merge(self, their: "BatchedDryPenalizer"):
self.bases = torch.cat([self.bases, their.bases], dim=0)
self.allowed_lengths = torch.cat([self.allowed_lengths, their.allowed_lengths], dim=0)
self.sequence_breakers.extend(their.sequence_breakers)
self.ranges = torch.cat([self.ranges, their.ranges], dim=0)
self.ranges = torch.cat([self.ranges, their.ranges], dim=0)

0 comments on commit 7360ec2

Please sign in to comment.