-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
5 changed files
with
729 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright 2023 Nisaba Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Error rate calculation utilities for specific transliteration use cases. | ||
|
||
package( | ||
default_applicable_licenses = [ | ||
], | ||
default_visibility = [ | ||
"//nlp/sweet/translit:__subpackages__", | ||
], | ||
) | ||
|
||
cc_binary( | ||
name = "calculate_error_rate", | ||
srcs = ["calculate_error_rate_main.cc"], | ||
deps = [ | ||
":calculate_error_rate_lib", | ||
"@com_google_absl//absl/flags:flag", | ||
"@com_google_absl//absl/flags:parse", | ||
"@com_google_absl//absl/log:check", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "calculate_error_rate_lib", | ||
srcs = ["calculate_error_rate.cc"], | ||
hdrs = ["calculate_error_rate.h"], | ||
deps = [ | ||
"//nisaba/port:file_util", | ||
"//nisaba/port:utf8_util", | ||
"@com_google_absl//absl/log:check", | ||
"@com_google_absl//absl/strings", | ||
"@com_google_absl//absl/types:span", | ||
"@org_openfst//:lib_lite", | ||
"@org_openfst//:symbol-table", | ||
], | ||
) | ||
|
||
cc_test( | ||
name = "calculate_error_rate_test", | ||
size = "medium", | ||
srcs = ["calculate_error_rate_test.cc"], | ||
deps = [ | ||
":calculate_error_rate_lib", | ||
"//nisaba/port:status-matchers", | ||
"@com_google_absl//absl/strings", | ||
"@com_google_googletest//:gtest_main", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
// Copyright 2023 Nisaba Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "nisaba/translit/tools/calculate_error_rate.h" | ||
|
||
#include <cmath> | ||
#include <fstream> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "fst/arc.h" | ||
#include "fst/arcsort.h" | ||
#include "fst/compose.h" | ||
#include "fst/shortest-path.h" | ||
#include "fst/symbol-table.h" | ||
#include "fst/vector-fst.h" | ||
#include "absl/log/check.h" | ||
#include "absl/strings/str_cat.h" | ||
#include "absl/strings/str_join.h" | ||
#include "absl/strings/str_split.h" | ||
#include "absl/strings/string_view.h" | ||
#include "absl/types/span.h" | ||
#include "nisaba/port/file_util.h" | ||
#include "nisaba/port/utf8_util.h" | ||
|
||
namespace nisaba { | ||
namespace translit { | ||
namespace tools { | ||
|
||
namespace impl { | ||
namespace { | ||
|
||
// Uses composition and shortest path to return min cost alignment. | ||
::fst::StdVectorFst GetMinErrorAlignment(const ::fst::StdVectorFst &ifst, | ||
const ::fst::StdVectorFst &ofst, | ||
::fst::StdVectorFst *align_mod) { | ||
ArcSort(align_mod, ::fst::ILabelCompare<::fst::StdArc>()); | ||
::fst::StdVectorFst ifst_o_align_mod; | ||
::fst::Compose(ifst, *align_mod, &ifst_o_align_mod); | ||
::fst::ArcSort(&ifst_o_align_mod, | ||
::fst::OLabelCompare<::fst::StdArc>()); | ||
::fst::StdVectorFst ifst_o_align_mod_o_ofst; | ||
::fst::Compose(ifst_o_align_mod, ofst, &ifst_o_align_mod_o_ofst); | ||
::fst::StdVectorFst alignment; | ||
::fst::ShortestPath(ifst_o_align_mod_o_ofst, &alignment); | ||
return alignment; | ||
} | ||
|
||
// Prints 1-4 items in a line to the output file. | ||
void LineToFile(std::ofstream &output_file, const absl::AlphaNum &a, | ||
const absl::AlphaNum &b = "", const absl::AlphaNum &c = "", | ||
const absl::AlphaNum &d = "") { | ||
output_file << absl::StrCat(a, b, c, d, "\n"); | ||
QCHECK(output_file); | ||
} | ||
|
||
// Returns true if first edit distance components yields better error rate than | ||
// second. | ||
bool BetterErrorValues(const EditDistanceDouble &this_min_error_values, | ||
const EditDistanceDouble &min_error_values) { | ||
return this_min_error_values.ErrorRate() < min_error_values.ErrorRate(); | ||
} | ||
|
||
// Creates linear string Fst from vector of symbols, adding to symbol table. | ||
::fst::StdVectorFst GetStringFst(absl::Span<const std::string> token_string, | ||
::fst::SymbolTable *syms) { | ||
::fst::StdVectorFst fst; | ||
fst.SetStart(fst.AddState()); | ||
int curr_state = fst.Start(); | ||
for (const auto &token : token_string) { | ||
int sym = syms->Find(token); | ||
if (sym < 0) { | ||
sym = syms->AddSymbol(token); | ||
} | ||
int next_state = fst.AddState(); | ||
fst.AddArc(curr_state, ::fst::StdArc(sym, sym, 0.0, next_state)); | ||
curr_state = next_state; | ||
} | ||
fst.SetFinal(curr_state, 0.0); | ||
return fst; | ||
} | ||
|
||
// Finds minimum cost alignment between two string automata. | ||
::fst::StdVectorFst AlignStringFsts(const ::fst::StdVectorFst &ref_fst, | ||
const ::fst::StdVectorFst &test_fst, | ||
int num_symbols) { | ||
::fst::StdVectorFst align_mod; | ||
align_mod.SetStart(align_mod.AddState()); | ||
align_mod.SetFinal(align_mod.Start(), 0.0); | ||
for (int sym = 1; sym < num_symbols; ++sym) { | ||
// Adds insertion, deletion and substitution arcs for all symbols. | ||
align_mod.AddArc(align_mod.Start(), | ||
::fst::StdArc(sym, 0, 1.0, align_mod.Start())); | ||
align_mod.AddArc(align_mod.Start(), | ||
::fst::StdArc(0, sym, 1.0, align_mod.Start())); | ||
for (int sub_sym = 1; sub_sym < num_symbols; ++sub_sym) { | ||
// If sym == sub_sym, cost is 0.0, otherwise 1.0. | ||
align_mod.AddArc(align_mod.Start(), | ||
::fst::StdArc(sym, sub_sym, sym == sub_sym ? 0.0 : 1.0, | ||
align_mod.Start())); | ||
} | ||
} | ||
return impl::GetMinErrorAlignment(test_fst, ref_fst, &align_mod); | ||
} | ||
|
||
EditDistanceInt CalculatePairEditDistance(const ::fst::StdVectorFst &ref_fst, | ||
const ::fst::StdVectorFst &test_fst, | ||
int num_symbols) { | ||
const ::fst::StdVectorFst min_cost_alignment = | ||
AlignStringFsts(ref_fst, test_fst, num_symbols); | ||
int curr_state = min_cost_alignment.Start(); | ||
EditDistanceInt ed_int; | ||
ed_int.reference_length = ref_fst.NumStates() - 1; | ||
while (curr_state >= 0 && min_cost_alignment.NumArcs(curr_state) > 0) { | ||
QCHECK_EQ(min_cost_alignment.NumArcs(curr_state), 1); | ||
::fst::ArcIterator<::fst::StdVectorFst> aiter(min_cost_alignment, | ||
curr_state); | ||
::fst::StdArc arc = aiter.Value(); | ||
if (arc.ilabel != arc.olabel) { // This is an edit in the alignment. | ||
if (arc.ilabel == 0) { | ||
// Reference token aligns with nothing in the test string: deletion. | ||
++ed_int.deletions; | ||
} else if (arc.olabel == 0) { | ||
// Test token aligns with nothing in the reference string: insertion. | ||
++ed_int.insertions; | ||
} else { | ||
++ed_int.substitutions; | ||
} | ||
} | ||
curr_state = arc.nextstate; | ||
} | ||
return ed_int; | ||
} | ||
|
||
EditDistanceInt CalculatePairEditDistance( | ||
const std::vector<std::string> &ref_string, | ||
const std::vector<std::string> &test_string) { | ||
::fst::SymbolTable syms; | ||
syms.AddSymbol("<epsilon>"); | ||
const ::fst::StdVectorFst ref_fst = GetStringFst(ref_string, &syms); | ||
const ::fst::StdVectorFst test_fst = GetStringFst(test_string, &syms); | ||
return CalculatePairEditDistance(ref_fst, test_fst, syms.NumSymbols()); | ||
} | ||
|
||
// Splits string on either whitespace or characters. | ||
std::vector<std::string> MaybeSplitChars(absl::string_view str, | ||
bool split_chars) { | ||
std::vector<std::string> tokenized_string = | ||
absl::StrSplit(str, utf8::Utf8WhitespaceDelimiter(), absl::SkipEmpty()); | ||
if (split_chars) { | ||
// Rejoins with a single whitespace prior to splitting on single unicode | ||
// codepoints, thus normalizing the whitespace in the strings. | ||
tokenized_string = | ||
utf8::StrSplitByChar(absl::StrJoin(tokenized_string, " ")); | ||
} | ||
return tokenized_string; | ||
} | ||
|
||
} // namespace | ||
} // namespace impl | ||
|
||
double MultiRefErrorRate::CalcErrorRate() { | ||
EditDistanceDouble tot_ed_double; | ||
for (int i = 0; i < total_ed_double_.size(); ++i) { | ||
tot_ed_double += total_ed_double_[i]; | ||
} | ||
return tot_ed_double.ErrorRate(); | ||
} | ||
|
||
void MultiRefErrorRate::Write(absl::string_view ofile) { | ||
std::ofstream output_file; | ||
output_file.open(std::string(ofile)); | ||
QCHECK(output_file) << "Cannot open " << ofile << " for writing."; | ||
EditDistanceDouble tot_ed_double; | ||
for (int i = 0; i < total_ed_double_.size(); ++i) { | ||
impl::LineToFile(output_file, total_ed_double_[i].ToString()); | ||
tot_ed_double += total_ed_double_[i]; | ||
} | ||
std::string error_rate_label = is_split_chars_ ? "CER" : "WER"; | ||
impl::LineToFile(output_file, "Summary ", error_rate_label); | ||
impl::LineToFile(output_file, "total statistics: ", tot_ed_double.ToString()); | ||
impl::LineToFile(output_file, "overall ", error_rate_label, ": ", | ||
tot_ed_double.ErrorRate()); | ||
} | ||
|
||
void MultiRefErrorRate::ReadInputs(absl::string_view input_file, | ||
bool is_reference) { | ||
if (is_reference) { | ||
references_.clear(); | ||
} else { | ||
test_input_.clear(); | ||
} | ||
const auto &input_lines_status = file::ReadLines(input_file, kMaxLine); | ||
QCHECK(input_lines_status.ok()) << "Failed to read " << input_file; | ||
const std::vector<std::string> input_lines = input_lines_status.value(); | ||
for (const std::string &str : input_lines) { | ||
const std::vector<std::string> seq = | ||
absl::StrSplit(str, '\t', absl::SkipEmpty()); | ||
QCHECK_GE(seq.size(), 2); | ||
int input_idx = std::stoi(seq[0]); | ||
int osym = 0; | ||
if (!seq[1].empty()) { | ||
osym = output_syms_.Find(seq[1]); | ||
if (osym < 0) { | ||
osym = output_syms_.AddSymbol(seq[1]); | ||
} | ||
} | ||
// Default count (in reference) is 1.0; otherwise default (-logP) is 0.0. | ||
double item_value = is_reference ? 1.0 : 0.0; | ||
if (seq.size() > 2) { | ||
QCHECK_LE(seq.size(), 3); | ||
item_value = std::stod(seq[2]); | ||
} | ||
if (is_reference) { | ||
// Converts raw count to -log count for references. | ||
item_value = -log(item_value); | ||
if (input_idx >= references_.size()) { | ||
references_.resize(input_idx + 1); | ||
} | ||
references_[input_idx].push_back(std::make_pair(osym, item_value)); | ||
} else { | ||
if (input_idx >= test_input_.size()) { | ||
test_input_.resize(input_idx + 1); | ||
} | ||
test_input_[input_idx].push_back(std::make_pair(osym, item_value)); | ||
} | ||
} | ||
} | ||
|
||
std::vector<std::string> MultiRefErrorRate::GetTokenizedString( | ||
int idx, int k, bool is_test_item) const { | ||
const auto &inputs = is_test_item ? test_input_ : references_; | ||
if (idx >= inputs.size() || k >= inputs[idx].size()) { | ||
// If requested string does not exist in collection. | ||
return std::vector<std::string>(); | ||
} | ||
const auto input_pair = inputs[idx][k]; | ||
return input_pair.first == 0 | ||
? std::vector<std::string>() | ||
: impl::MaybeSplitChars(output_syms_.Find(input_pair.first), | ||
is_split_chars_); | ||
} | ||
|
||
void MultiRefErrorRate::CalculateMinErrorRate(int idx) { | ||
int min_cost_test_item = 0; | ||
for (int i = 1; i < test_input_[idx].size(); ++i) { | ||
if (test_input_[idx][i].second < | ||
test_input_[idx][min_cost_test_item].second) { | ||
min_cost_test_item = i; | ||
} | ||
} | ||
EditDistanceDouble min_error_values; | ||
// If minimum cost test item is empty string, empty vector; otherwise | ||
// tokenize string as requested. | ||
const std::vector<std::string> test_string = | ||
GetTokenizedString(idx, min_cost_test_item, /*is_test_item=*/true); | ||
for (int i = 0; i < references_[idx].size(); ++i) { | ||
const std::vector<std::string> ref_string = | ||
GetTokenizedString(idx, i, /*is_test_item=*/false); | ||
auto this_min_error_values = | ||
impl::CalculatePairEditDistance(ref_string, test_string); | ||
if (i == 0 || impl::BetterErrorValues( | ||
static_cast<EditDistanceDouble>(this_min_error_values), | ||
min_error_values)) { | ||
min_error_values = this_min_error_values; | ||
} | ||
} | ||
total_ed_double_.push_back(min_error_values); | ||
} | ||
|
||
void MultiRefErrorRate::CalculateErrorRate(int idx) { | ||
if (references_[idx].empty()) { | ||
// Nothing to do for this example. | ||
return; | ||
} | ||
CalculateMinErrorRate(idx); | ||
} | ||
|
||
void MultiRefErrorRate::CalculateErrorRate() { | ||
QCHECK_EQ(references_.size(), test_input_.size()); | ||
total_ed_double_.clear(); | ||
total_ed_double_.reserve(references_.size()); | ||
for (int idx = 0; idx < references_.size(); ++idx) { | ||
// Either both are empty or reference is non-empty. | ||
QCHECK(!references_[idx].empty() || test_input_[idx].empty()); | ||
if (test_input_[idx].empty()) { | ||
// Creates empty string test input if none given for reference item. | ||
test_input_[idx].push_back(std::make_pair(0, 0.0)); | ||
} | ||
CalculateErrorRate(idx); | ||
} | ||
} | ||
|
||
void MultiRefErrorRate::CalculateErrorRate(absl::string_view reffile, | ||
absl::string_view testfile) { | ||
ReadInputs(reffile, /*is_reference=*/true); | ||
ReadInputs(testfile, /*is_reference=*/false); | ||
CalculateErrorRate(); | ||
} | ||
|
||
} // namespace tools | ||
} // namespace translit | ||
} // namespace nisaba |
Oops, something went wrong.