Skip to content

Commit

Permalink
Merge pull request #11 from FieldsMedal/hotwords
Browse files Browse the repository at this point in the history
Add hotwords
  • Loading branch information
Slyne authored May 18, 2023
2 parents 68f8bae + a505681 commit 03259fd
Show file tree
Hide file tree
Showing 11 changed files with 617 additions and 23 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,27 @@ How language model in called in this implementation of ctc prefix beam search ?

If the language model is char based (like the Mandarin lm), it will call the language model scorer all the times.
If the language model is word based (like the English lm), it will only call the scorer whenever `space_id` is detected.

### Adding hotwords
Please refer to the following steps how to use hotwordsboosting.
* Step 1. Initialize HotWordsScorer
```
# if you don't want to use hotwords. set hotwords_scorer=None(default),
# vocab_list is Chinese characters.
hot_words = {'再接': 10, '再厉': -10, '好好学习': 100}
hotwords_scorer = HotWordsScorer(hot_words, vocab_list, is_character_based=True)
```
If you set is_character_based is True (default mode), the first step is to combine Chinese characters into words, if words in hotwords dictionary then add hotwords score. If you set is_character_based is False, all words in the fixed window will be enumerated.

* Step 2. Add hotwords_scorer when decoding
```
result2 = decoder.ctc_beam_search_decoder_batch(batch_chunk_log_prob_seq,
batch_chunk_log_probs_idx,
batch_root_trie,
batch_start,
beam_size, num_processes,
blank_id, space_id,
cutoff_prob, scorer, hotwords_scorer)
```
Please refer to ```swig/test/test_zh.py``` for how to decode with hotwordsboosting.

77 changes: 58 additions & 19 deletions swig/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Copyright (c) 2022,DeepSpeech Authors
// 2023, 58.com(Wuba) Inc AI Lab
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Modified from DeepSpeech(https://github.com/mozilla/DeepSpeech)

#include "ctc_beam_search_decoder.h"
#include <algorithm>
#include <cmath>
Expand All @@ -44,7 +59,8 @@ std::vector<std::pair<double, std::vector<int>>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &log_probs_seq,
const std::vector<std::vector<int>> &log_probs_idx, PathTrie &root,
const bool start, size_t beam_size, int blank_id, int space_id,
double cutoff_prob, Scorer *ext_scorer) {
double cutoff_prob, Scorer *ext_scorer,
HotWordsScorer *hotwords_scorer) {
if (start) {
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
Expand Down Expand Up @@ -117,24 +133,46 @@ std::vector<std::pair<double, std::vector<int>>> ctc_beam_search_decoder(
log_p = log_prob_c + prefix->score;
}

// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
// hotwords boosting
float hotwords_score = 0.0;
std::vector<std::string> ngram;
PathTrie *prefix_to_score = nullptr;
if (hotwords_scorer != nullptr && !hotwords_scorer->hotwords_dict.empty()) {
if (hotwords_scorer->is_character_based) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
int offset;
std::tie(offset,ngram) = hotwords_scorer->make_ngram(prefix_to_score);
hotwords_score = hotwords_scorer->get_hotwords_score(ngram, offset);
}
log_p += hotwords_score;

// language model scoring
float ngram_score = 0.0;
if (ext_scorer != nullptr ) {
if (hotwords_scorer != nullptr && !hotwords_scorer->hotwords_dict.empty() &&
!(hotwords_scorer->is_character_based ^ ext_scorer->is_character_based()) &&
hotwords_scorer->window_length >= ext_scorer->get_max_order()) {
std::vector<std::string>::const_iterator first = ngram.end() - ext_scorer->get_max_order();
std::vector<std::string>::const_iterator last = ngram.end();
std::vector<std::string> slice_ngram(first, last);
ngram_score = ext_scorer->get_log_cond_prob(slice_ngram) * ext_scorer->alpha + ext_scorer->beta;
} else {
if (c == space_id || ext_scorer->is_character_based()) {
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
ngram = ext_scorer->make_ngram(prefix_to_score);
ngram_score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha + ext_scorer->beta;
}
}
}
log_p += ngram_score;
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
Expand Down Expand Up @@ -204,7 +242,8 @@ ctc_beam_search_decoder_batch(
std::vector<PathTrie *> &batch_root_trie,
const std::vector<bool> &batch_start, size_t beam_size,
size_t num_processes, int blank_id, int space_id, double cutoff_prob,
Scorer *ext_scorer) {
Scorer *ext_scorer,
HotWordsScorer *hotwords_scorer) {
// thread pool
ThreadPool pool(num_processes);
// number of samples
Expand All @@ -220,7 +259,7 @@ ctc_beam_search_decoder_batch(
pool.enqueue(ctc_beam_search_decoder, std::ref(batch_log_probs_seq[i]),
std::ref(batch_log_probs_idx[i]),
std::ref(*batch_root_trie[i]), batch_start[i], beam_size,
blank_id, space_id, cutoff_prob, ext_scorer));
blank_id, space_id, cutoff_prob, ext_scorer, hotwords_scorer));
}

// get decoding results
Expand Down
26 changes: 24 additions & 2 deletions swig/ctc_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Copyright (c) 2023, 58.com(Wuba) Inc AI Lab. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Modified from DeepSpeech(https://github.com/mozilla/DeepSpeech)

#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_

Expand All @@ -34,6 +49,7 @@
#include <vector>
#include "path_trie.h"
#include "scorer.h"
#include "hotwords.h"

/* CTC Beam Search Decoder
Expand All @@ -53,6 +69,8 @@
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hotwords_scorer: External scorer to add hotwords score. Default null,
* decoding the input sample without hotwordsboosting.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
Expand All @@ -61,7 +79,8 @@ std::vector<std::pair<double, std::vector<int>>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &log_probs_seq,
const std::vector<std::vector<int>> &log_probs_idx, PathTrie &root,
const bool start, size_t beam_size, int blank_id = 0, int space_id = -1,
double cutoff_prob = 0.999, Scorer *ext_scorer = nullptr);
double cutoff_prob = 0.999, Scorer *ext_scorer = nullptr,
HotWordsScorer *hotwords_scorer = nullptr);

/* CTC Beam Search Decoder for batch data
Expand All @@ -84,6 +103,8 @@ std::vector<std::pair<double, std::vector<int>>> ctc_beam_search_decoder(
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hotwords_scorer: External scorer to add hotwords score. Default null,
* decoding the input sample without hotwordsboosting.
* Return:
* A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample.
Expand All @@ -95,7 +116,8 @@ ctc_beam_search_decoder_batch(
std::vector<PathTrie *> &batch_root_trie,
const std::vector<bool> &batch_start, size_t beam_size,
size_t num_processes, int blank_id = 0, int space_id = -1,
double cutoff_prob = 0.999, Scorer *ext_scorer = nullptr);
double cutoff_prob = 0.999, Scorer *ext_scorer = nullptr,
HotWordsScorer *hotwords_scorer = nullptr);

/* Map vector of int to string
Expand Down
5 changes: 5 additions & 0 deletions swig/decoders.i
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
#include "ctc_beam_search_decoder.h"
#include "decoder_utils.h"
#include "path_trie.h"
#include "hotwords.h"
%}

%include "std_vector.i"
%include "std_pair.i"
%include "std_string.i"
%include "std_map.i"
%include "std_unordered_map.i"
%include "path_trie.h"
%import "decoder_utils.h"

Expand All @@ -27,11 +30,13 @@ namespace std {
%template(IntVector3) std::vector<std::vector<std::vector<int>>>;
%template(TrieVector) std::vector<PathTrie*>;
%template(BoolVector) std::vector<bool>;
%template(HotWordsMap) unordered_map<string, float>;
}
%template(IntDoublePairCompSecondRev) pair_comp_second_rev<int, double>;
%template(StringDoublePairCompSecondRev) pair_comp_second_rev<std::string, double>;
%template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>;

%include "hotwords.h"
%include "scorer.h"
%include "path_trie.h"
%include "ctc_beam_search_decoder.h"
104 changes: 104 additions & 0 deletions swig/hotwords.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (c) 2023, 58.com(Wuba) Inc AI Lab. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Modified from DeepSpeech(https://github.com/mozilla/DeepSpeech)

#include <iostream>
#include <unordered_map>

#include "path_trie.h"
#include "hotwords.h"
#include "scorer.h"

HotWordsScorer::HotWordsScorer(const std::unordered_map<std::string, float> &hotwords_dict, const std::vector<std::string>& char_list,
int window_length, int SPACE_ID, bool is_character_based) {
this->hotwords_dict = hotwords_dict;
this->window_length = window_length;
this->is_character_based = is_character_based;
this->SPACE_ID = SPACE_ID;
this->char_list = char_list;
}

HotWordsScorer::~HotWordsScorer(){
}

std::string HotWordsScorer::vec2str(const std::vector<int>& input) {
std::string word;
for (auto ind : input) {
word += this->char_list[ind];
}
return word;
}

std::pair<int, std::vector<std::string>> HotWordsScorer::make_ngram(PathTrie* prefix){
std::vector<std::string> ngram;
PathTrie* current_node = prefix;
PathTrie* new_node = nullptr;
int no_start_token_count = 0;
for (int order = 0; order < this->window_length; order++) {
std::vector<int> prefix_vec;

if (this->is_character_based) {
new_node = current_node->get_path_vec(prefix_vec, this->SPACE_ID, 1);
current_node = new_node;
} else {
new_node = current_node->get_path_vec(prefix_vec, this->SPACE_ID);
current_node = new_node->parent; // Skipping spaces
}

// reconstruct word
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
no_start_token_count++;

if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < this->window_length - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
std::pair<int, std::vector<std::string>> result(this->window_length - no_start_token_count, ngram);
return result;
}

float HotWordsScorer::get_hotwords_score(const std::vector<std::string>& words, int offset) {
float hotwords_score = 0;
int words_size = words.size();
std::unordered_map<std::string, float>::const_iterator iter;
for (size_t index = 0; index < words_size; index++) {
std::string word = "";
if (this->is_character_based) {
// contains at least two chinese characters.
// words.end()-words.begin()-offset-index+1=words.size()-1-offset-index+1>=2
if( words_size - offset - index <= 1) {
break;
}
// chinese characters in fixed window, combining chinese characters into words.
// word = std::accumulate(words.begin() + offset, words.end() - index, std::string{});
word = std::accumulate(words.begin() + offset + index, words.end(), std::string{});
} else {
// word in fixed window, traverse each word in words.
word = words[index];
}
iter = this->hotwords_dict.find(word);
if (iter != this->hotwords_dict.end()) {
hotwords_score += iter->second;
// break loop after matching the hotwords.
break;
}
}
return hotwords_score;
}
47 changes: 47 additions & 0 deletions swig/hotwords.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2023, 58.com(Wuba) Inc AI Lab. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Modified from DeepSpeech(https://github.com/mozilla/DeepSpeech)

#ifndef HOTWORDS_H
#define HOTWORDS_H

#include <iostream>
#include <string>
#include <unordered_map>

#include "scorer.h"

class HotWordsScorer {
public:
HotWordsScorer(const std::unordered_map<std::string, float> &hotwords_dict, const std::vector<std::string>& char_list,
int window_length=4, int SPACE_ID=-1, bool is_character_based=true);
~HotWordsScorer();

// make ngram for a given prefix
std::pair<int, std::vector<std::string>> make_ngram(PathTrie *prefix);

// translate the vector in index to string
std::string vec2str(const std::vector<int>& input);

// add hotwords score
float get_hotwords_score(const std::vector<std::string>& words, int begin_index);

std::unordered_map<std::string, float> hotwords_dict;
int window_length;
int SPACE_ID;
bool is_character_based;
std::vector<std::string> char_list;
};

#endif // HOTWORDS_H
Loading

0 comments on commit 03259fd

Please sign in to comment.