Skip to content

Commit

Permalink
Pass previous context
Browse files Browse the repository at this point in the history
  • Loading branch information
streamer45 committed Jul 10, 2024
1 parent 6525661 commit 87f5586
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 16 deletions.
7 changes: 6 additions & 1 deletion speech/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (
)

const (
stateLen = 2 * 1 * 128
stateLen = 2 * 1 * 128
contextLen = 64
)

type LogLevel int
Expand Down Expand Up @@ -92,6 +93,7 @@ type Detector struct {
cfg DetectorConfig

state [stateLen]float32
ctx [contextLen]float32

currSample int
triggered bool
Expand Down Expand Up @@ -260,6 +262,9 @@ func (sd *Detector) Reset() error {
for i := 0; i < stateLen; i++ {
sd.state[i] = 0
}
for i := 0; i < contextLen; i++ {
sd.ctx[i] = 0
}

return nil
}
Expand Down
51 changes: 37 additions & 14 deletions speech/detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,55 @@ func TestSpeechDetection(t *testing.T) {
require.NoError(t, sd.Destroy())
}()

data, err := os.ReadFile("../testfiles/samples.pcm")
require.NoError(t, err)
readSamplesFromFile := func(path string) []float32 {
data, err := os.ReadFile(path)
require.NoError(t, err)

samples := make([]float32, 0, len(data)/4)
for i := 0; i < len(data); i += 4 {
samples = append(samples, math.Float32frombits(binary.LittleEndian.Uint32(data[i:i+4])))
samples := make([]float32, 0, len(data)/4)
for i := 0; i < len(data); i += 4 {
samples = append(samples, math.Float32frombits(binary.LittleEndian.Uint32(data[i:i+4])))
}
return samples
}

samples := readSamplesFromFile("../testfiles/samples.pcm")
samples2 := readSamplesFromFile("../testfiles/samples2.pcm")

t.Run("detect", func(t *testing.T) {
segments, err := sd.Detect(samples)
require.NoError(t, err)
require.NotEmpty(t, segments)
require.Equal(t, []Segment{
{
SpeechStartAt: 1.088,
SpeechStartAt: 1.056,
SpeechEndAt: 1.632,
},
{
SpeechStartAt: 2.912,
SpeechEndAt: 3.264,
SpeechStartAt: 2.88,
SpeechEndAt: 3.232,
},
{
SpeechStartAt: 4.448,
SpeechEndAt: 0,
},
}, segments)

err = sd.Reset()
require.NoError(t, err)

segments, err = sd.Detect(samples2)
require.NoError(t, err)
require.NotEmpty(t, segments)
require.Equal(t, []Segment{
{
SpeechStartAt: 3.008,
SpeechEndAt: 6.24,
},
{
SpeechStartAt: 7.072,
SpeechEndAt: 8.16,
},
}, segments)
})

t.Run("reset", func(t *testing.T) {
Expand All @@ -154,12 +177,12 @@ func TestSpeechDetection(t *testing.T) {
require.NotEmpty(t, segments)
require.Equal(t, []Segment{
{
SpeechStartAt: 1.088,
SpeechStartAt: 1.056,
SpeechEndAt: 1.632,
},
{
SpeechStartAt: 2.912,
SpeechEndAt: 3.264,
SpeechStartAt: 2.88,
SpeechEndAt: 3.232,
},
{
SpeechStartAt: 4.448,
Expand All @@ -182,12 +205,12 @@ func TestSpeechDetection(t *testing.T) {
require.NotEmpty(t, segments)
require.Equal(t, []Segment{
{
SpeechStartAt: 1.088 - 0.01,
SpeechStartAt: 1.056 - 0.01,
SpeechEndAt: 1.632 + 0.01,
},
{
SpeechStartAt: 2.912 - 0.01,
SpeechEndAt: 3.264 + 0.01,
SpeechStartAt: 2.88 - 0.01,
SpeechEndAt: 3.232 + 0.01,
},
{
SpeechStartAt: 4.448 - 0.01,
Expand Down
10 changes: 9 additions & 1 deletion speech/infer_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ import (
"unsafe"
)

func (sd *Detector) infer(pcm []float32) (float32, error) {
func (sd *Detector) infer(samples []float32) (float32, error) {
pcm := samples
if sd.currSample > 0 {
// Append context from previous iteration.
pcm = append(sd.ctx[:], samples...)
}
// Save the last contextLen samples as context for the next iteration.
copy(sd.ctx[:], samples[len(samples)-contextLen:])

// Create tensors
var pcmValue *C.OrtValue
pcmInputDims := []C.long{
Expand Down
Binary file added testfiles/samples2.pcm
Binary file not shown.

0 comments on commit 87f5586

Please sign in to comment.