Skip to content

Commit

Permalink
implement evenly chunk
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Feb 26, 2024
1 parent 57a4a20 commit 93dd2f4
Showing 1 changed file with 86 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@
@Log4j2
public class FixedTokenLengthChunker implements IFieldChunker {

// parameters

public static final String TOKEN_LIMIT = "token_limit";
public static final String OVERLAP_RATE = "overlap_rate";

public static final String MAX_TOKEN_COUNT = "max_token_count";

public static final String TOKENIZER = "tokenizer";

// default values for each parameter
private static final int DEFAULT_TOKEN_LIMIT = 500;
private static final double DEFAULT_OVERLAP_RATE = 0.2;
private static final int DEFAULT_MAX_TOKEN_COUNT = 10000;
private static final String DEFAULT_TOKENIZER = "standard";

private final AnalysisRegistry analysisRegistry;

public FixedTokenLengthChunker(AnalysisRegistry analysisRegistry) {
Expand All @@ -48,15 +54,64 @@ private List<String> tokenize(String content, String tokenizer, int maxTokenCoun
} catch (IOException e) {
throw new RuntimeException(e);
}
};
}

private static int getOverlapTokenNumber(int tokenCount, double overlapRate) {
return Math.min(tokenCount - 1, (int) Math.floor(tokenCount * overlapRate));
}

private static int getDocumentTokenCount(int tokensPerPassage, int passageCount, double overlapRate) {
// return the token count of the document, if all (passageCount) passages have (tokensPerPassage) tokens
int overlapTokenNumber = getOverlapTokenNumber(tokensPerPassage, overlapRate);
return tokensPerPassage + (passageCount - 1) * (tokensPerPassage - overlapTokenNumber);
}

private static int getPassageCount(int contentLength, int tokenLimit, double overlapRate) {
/*
passageCount means the number of chunked passages, which should be the minimum integer
so that getDocumentTokenCount(tokenLimit, passageCount, overlapRate) >= contentLength
*/
int overlapTokenNumber = getOverlapTokenNumber(tokenLimit, overlapRate);
return 1 + (int) Math.ceil((contentLength - tokenLimit) / (double) (tokenLimit - overlapTokenNumber));
}

private static int getTokensPerPassage(int contentLength, int passageCount, int tokeLimit, double overlapRate) {
/*
To evenly chunk the documents, the token length difference among passages is at most 1.
The output passages contain long passages and short passages (with 1 token less than long passages).
tokensPerPassage means the number of tokens for longest passages.
tokensPerPassage should be the minimum integer so that
getDocumentTokenCount(tokensPerPassage, passageCount, overlapRate) >= contentLength
As this problem do not have a closed form solution, we use binary search to find the tokensPerPassages
*/
int left = 1;
int right = tokeLimit;
int tokensPerPassage = right;
int mid;
while (left <= right) {
mid = (left + right) >> 1;
if (getDocumentTokenCount(mid, passageCount, overlapRate) < contentLength) {
// mid is too small
left = mid + 1;
} else if (mid > left && getDocumentTokenCount(mid - 1, passageCount, overlapRate) >= contentLength) {
// mid - 1 suffices
right = mid + 1;
} else {
tokensPerPassage = mid;
break;
}
}
return tokensPerPassage;
}

@Override
public List<String> chunk(String content, Map<String, Object> parameters) {
// parameters has been validated
int tokenLimit = 500;
double overlapRate = 0.2;
int maxTokenCount = 10000;
String tokenizer = "standard";
// assume that parameters has been validated
int tokenLimit = DEFAULT_TOKEN_LIMIT;
double overlapRate = DEFAULT_OVERLAP_RATE;
int maxTokenCount = DEFAULT_MAX_TOKEN_COUNT;
String tokenizer = DEFAULT_TOKENIZER;

if (parameters.containsKey(TOKEN_LIMIT)) {
tokenLimit = ((Number) parameters.get(TOKEN_LIMIT)).intValue();
Expand All @@ -72,26 +127,35 @@ public List<String> chunk(String content, Map<String, Object> parameters) {
}

List<String> tokens = tokenize(content, tokenizer, maxTokenCount);
List<String> passages = new ArrayList<>();
int tokenLength = tokens.size();

if (tokenLength == 0) {
return new ArrayList<>();
}
if (tokenLength <= tokenLimit) {
return List.of(content);
}

int passageCount = getPassageCount(tokenLength, tokenLimit, overlapRate);
int tokensPerPassage = getTokensPerPassage(tokenLength, passageCount, tokenLimit, overlapRate);
int overlapTokenNumber = getOverlapTokenNumber(tokensPerPassage, overlapRate);
int exceedingTokenCount = getDocumentTokenCount(tokensPerPassage, passageCount, overlapRate) - tokenLength;

String passage;
int startToken = 0;
int overlapTokenNumber = (int) Math.floor(tokenLimit * overlapRate);
// overlapTokenNumber must be smaller than the token limit
overlapTokenNumber = Math.min(overlapTokenNumber, tokenLimit - 1);

while (startToken < tokens.size()) {
if (startToken + tokenLimit >= tokens.size()) {
// break the loop when already cover the last token
passage = String.join(" ", tokens.subList(startToken, tokens.size()));
passages.add(passage);
break;
List<String> passages = new ArrayList<>();

for (int i = 0; i < passageCount; i++) {
if (i + exceedingTokenCount < passageCount) {
passage = String.join(" ", tokens.subList(startToken, startToken + tokensPerPassage));
} else {
passage = String.join(" ", tokens.subList(startToken, startToken + tokenLimit));
passages.add(passage);
// (exceedingTokenCount) passages contain 1 less token
passage = String.join(" ", tokens.subList(startToken, startToken + tokensPerPassage - 1));
}
startToken += tokenLimit - overlapTokenNumber;
passages.add(passage);
startToken += tokensPerPassage - overlapTokenNumber;
}

return passages;
}

Expand Down

0 comments on commit 93dd2f4

Please sign in to comment.