Skip to content

Commit

Permalink
feat: add WithKeepSeparator option for RecursiveCharacter (tmc#721)
Browse files Browse the repository at this point in the history
* feat: add WithKeepSeparator option for RecursiveCharacter


---------

Co-authored-by: Ivan Zhang <[email protected]>
  • Loading branch information
zhangi and Ivan Zhang authored May 7, 2024
1 parent 08132f9 commit 09ac6e0
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 22 deletions.
21 changes: 17 additions & 4 deletions textsplitter/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ type Options struct {
ChunkSize int
ChunkOverlap int
Separators []string
KeepSeparator bool
LenFunc func(string) int
ModelName string
EncodingName string
Expand All @@ -20,10 +21,11 @@ type Options struct {
// DefaultOptions returns the default options for all text splitter.
func DefaultOptions() Options {
return Options{
ChunkSize: _defaultTokenChunkSize,
ChunkOverlap: _defaultTokenChunkOverlap,
Separators: []string{"\n\n", "\n", " ", ""},
LenFunc: utf8.RuneCountInString,
ChunkSize: _defaultTokenChunkSize,
ChunkOverlap: _defaultTokenChunkOverlap,
Separators: []string{"\n\n", "\n", " ", ""},
KeepSeparator: false,
LenFunc: utf8.RuneCountInString,

ModelName: _defaultTokenModelName,
EncodingName: _defaultTokenEncoding,
Expand Down Expand Up @@ -118,3 +120,14 @@ func WithReferenceLinks(referenceLinks bool) Option {
o.ReferenceLinks = referenceLinks
}
}

// WithKeepSeparator sets whether the separators should be kept in the resulting
// split text or not. When it is set to True, the separators are included in the
// resulting split text. When it is set to False, the separators are not included
// in the resulting split text. The purpose of having this parameter is to provide
// flexibility in how text splitting is handled. Default to False if not specified.
func WithKeepSeparator(keepSeparator bool) Option {
return func(o *Options) {
o.KeepSeparator = keepSeparator
}
}
48 changes: 35 additions & 13 deletions textsplitter/recursive_character.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import (
// RecursiveCharacter is a text splitter that will split texts recursively by different
// characters.
type RecursiveCharacter struct {
Separators []string
ChunkSize int
ChunkOverlap int
LenFunc func(string) int
Separators []string
ChunkSize int
ChunkOverlap int
LenFunc func(string) int
KeepSeparator bool
}

// NewRecursiveCharacter creates a new recursive character splitter with default values. By
Expand All @@ -23,31 +24,52 @@ func NewRecursiveCharacter(opts ...Option) RecursiveCharacter {
}

s := RecursiveCharacter{
Separators: options.Separators,
ChunkSize: options.ChunkSize,
ChunkOverlap: options.ChunkOverlap,
LenFunc: options.LenFunc,
Separators: options.Separators,
ChunkSize: options.ChunkSize,
ChunkOverlap: options.ChunkOverlap,
LenFunc: options.LenFunc,
KeepSeparator: options.KeepSeparator,
}

return s
}

// SplitText splits a text into multiple text.
func (s RecursiveCharacter) SplitText(text string) ([]string, error) {
return s.splitText(text, s.Separators)
}

// addSeparatorInSplits adds the separator in each of splits.
func (s RecursiveCharacter) addSeparatorInSplits(splits []string, separator string) []string {
splitsWithSeparator := make([]string, 0, len(splits))
for i, s := range splits {
if i > 0 {
s = separator + s
}
splitsWithSeparator = append(splitsWithSeparator, s)
}
return splitsWithSeparator
}

func (s RecursiveCharacter) splitText(text string, separators []string) ([]string, error) {
finalChunks := make([]string, 0)

// Find the appropriate separator
separator := s.Separators[len(s.Separators)-1]
// Find the appropriate separator.
separator := separators[len(separators)-1]
newSeparators := []string{}
for i, c := range s.Separators {
for i, c := range separators {
if c == "" || strings.Contains(text, c) {
separator = c
newSeparators = s.Separators[i+1:]
newSeparators = separators[i+1:]
break
}
}

splits := strings.Split(text, separator)
if s.KeepSeparator {
splits = s.addSeparatorInSplits(splits, separator)
separator = ""
}
goodSplits := make([]string, 0)

// Merge the splits, recursively splitting larger texts.
Expand All @@ -67,7 +89,7 @@ func (s RecursiveCharacter) SplitText(text string) ([]string, error) {
if len(newSeparators) == 0 {
finalChunks = append(finalChunks, split)
} else {
otherInfo, err := s.SplitText(split)
otherInfo, err := s.splitText(split, newSeparators)
if err != nil {
return nil, err
}
Expand Down
43 changes: 38 additions & 5 deletions textsplitter/recursive_character_test.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
package textsplitter

import (
"strings"
"testing"

"github.com/pkoukk/tiktoken-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/schema"
)

//nolint:dupword,funlen
func TestRecursiveCharacterSplitter(t *testing.T) {
tokenEncoder, _ := tiktoken.GetEncoding("cl100k_base")

t.Parallel()
type testCase struct {
text string
chunkOverlap int
chunkSize int
separators []string
expectedDocs []schema.Document
text string
chunkOverlap int
chunkSize int
separators []string
expectedDocs []schema.Document
keepSeparator bool
LenFunc func(string) int
}
testCases := []testCase{
{
Expand Down Expand Up @@ -106,12 +112,39 @@ Bye!
{PageContent: "Bye!\n\n-H.", Metadata: map[string]any{}},
},
},
{
text: "Hi, Harrison. \nI am glad to meet you",
chunkOverlap: 0,
chunkSize: 10,
separators: []string{"\n", "$"},
keepSeparator: true,
expectedDocs: []schema.Document{
{PageContent: "Hi, Harrison. ", Metadata: map[string]any{}},
{PageContent: "\nI am glad to meet you", Metadata: map[string]any{}},
},
},
{
text: strings.Repeat("The quick brown fox jumped over the lazy dog. ", 2),
chunkOverlap: 0,
chunkSize: 10,
separators: []string{" "},
keepSeparator: true,
LenFunc: func(s string) int { return len(tokenEncoder.Encode(s, nil, nil)) },
expectedDocs: []schema.Document{
{PageContent: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}},
{PageContent: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}},
},
},
}
splitter := NewRecursiveCharacter()
for _, tc := range testCases {
splitter.ChunkOverlap = tc.chunkOverlap
splitter.ChunkSize = tc.chunkSize
splitter.Separators = tc.separators
splitter.KeepSeparator = tc.keepSeparator
if tc.LenFunc != nil {
splitter.LenFunc = tc.LenFunc
}

docs, err := CreateDocuments(splitter, []string{tc.text}, nil)
require.NoError(t, err)
Expand Down

0 comments on commit 09ac6e0

Please sign in to comment.