Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 584072375
  • Loading branch information
roark-google authored and copybara-github committed Nov 20, 2023
1 parent 974518e commit 7285a6d
Show file tree
Hide file tree
Showing 5 changed files with 729 additions and 0 deletions.
61 changes: 61 additions & 0 deletions nisaba/translit/tools/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2023 Nisaba Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Error rate calculation utilities for specific transliteration use cases.

package(
default_applicable_licenses = [
],
default_visibility = [
"//nlp/sweet/translit:__subpackages__",
],
)

cc_binary(
name = "calculate_error_rate",
srcs = ["calculate_error_rate_main.cc"],
deps = [
":calculate_error_rate_lib",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/log:check",
],
)

cc_library(
name = "calculate_error_rate_lib",
srcs = ["calculate_error_rate.cc"],
hdrs = ["calculate_error_rate.h"],
deps = [
"//nisaba/port:file_util",
"//nisaba/port:utf8_util",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@org_openfst//:lib_lite",
"@org_openfst//:symbol-table",
],
)

cc_test(
name = "calculate_error_rate_test",
size = "medium",
srcs = ["calculate_error_rate_test.cc"],
deps = [
":calculate_error_rate_lib",
"//nisaba/port:status-matchers",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
315 changes: 315 additions & 0 deletions nisaba/translit/tools/calculate_error_rate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
// Copyright 2023 Nisaba Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "nisaba/translit/tools/calculate_error_rate.h"

#include <cmath>
#include <fstream>
#include <string>
#include <utility>
#include <vector>

#include "fst/arc.h"
#include "fst/arcsort.h"
#include "fst/compose.h"
#include "fst/shortest-path.h"
#include "fst/symbol-table.h"
#include "fst/vector-fst.h"
#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "nisaba/port/file_util.h"
#include "nisaba/port/utf8_util.h"

namespace nisaba {
namespace translit {
namespace tools {

namespace impl {
namespace {

// Uses composition and shortest path to return min cost alignment.
::fst::StdVectorFst GetMinErrorAlignment(const ::fst::StdVectorFst &ifst,
const ::fst::StdVectorFst &ofst,
::fst::StdVectorFst *align_mod) {
ArcSort(align_mod, ::fst::ILabelCompare<::fst::StdArc>());
::fst::StdVectorFst ifst_o_align_mod;
::fst::Compose(ifst, *align_mod, &ifst_o_align_mod);
::fst::ArcSort(&ifst_o_align_mod,
::fst::OLabelCompare<::fst::StdArc>());
::fst::StdVectorFst ifst_o_align_mod_o_ofst;
::fst::Compose(ifst_o_align_mod, ofst, &ifst_o_align_mod_o_ofst);
::fst::StdVectorFst alignment;
::fst::ShortestPath(ifst_o_align_mod_o_ofst, &alignment);
return alignment;
}

// Prints 1-4 items in a line to the output file.
void LineToFile(std::ofstream &output_file, const absl::AlphaNum &a,
const absl::AlphaNum &b = "", const absl::AlphaNum &c = "",
const absl::AlphaNum &d = "") {
output_file << absl::StrCat(a, b, c, d, "\n");
QCHECK(output_file);
}

// Returns true if first edit distance components yields better error rate than
// second.
bool BetterErrorValues(const EditDistanceDouble &this_min_error_values,
const EditDistanceDouble &min_error_values) {
return this_min_error_values.ErrorRate() < min_error_values.ErrorRate();
}

// Creates linear string Fst from vector of symbols, adding to symbol table.
::fst::StdVectorFst GetStringFst(absl::Span<const std::string> token_string,
::fst::SymbolTable *syms) {
::fst::StdVectorFst fst;
fst.SetStart(fst.AddState());
int curr_state = fst.Start();
for (const auto &token : token_string) {
int sym = syms->Find(token);
if (sym < 0) {
sym = syms->AddSymbol(token);
}
int next_state = fst.AddState();
fst.AddArc(curr_state, ::fst::StdArc(sym, sym, 0.0, next_state));
curr_state = next_state;
}
fst.SetFinal(curr_state, 0.0);
return fst;
}

// Finds minimum cost alignment between two string automata.
::fst::StdVectorFst AlignStringFsts(const ::fst::StdVectorFst &ref_fst,
const ::fst::StdVectorFst &test_fst,
int num_symbols) {
::fst::StdVectorFst align_mod;
align_mod.SetStart(align_mod.AddState());
align_mod.SetFinal(align_mod.Start(), 0.0);
for (int sym = 1; sym < num_symbols; ++sym) {
// Adds insertion, deletion and substitution arcs for all symbols.
align_mod.AddArc(align_mod.Start(),
::fst::StdArc(sym, 0, 1.0, align_mod.Start()));
align_mod.AddArc(align_mod.Start(),
::fst::StdArc(0, sym, 1.0, align_mod.Start()));
for (int sub_sym = 1; sub_sym < num_symbols; ++sub_sym) {
// If sym == sub_sym, cost is 0.0, otherwise 1.0.
align_mod.AddArc(align_mod.Start(),
::fst::StdArc(sym, sub_sym, sym == sub_sym ? 0.0 : 1.0,
align_mod.Start()));
}
}
return impl::GetMinErrorAlignment(test_fst, ref_fst, &align_mod);
}

EditDistanceInt CalculatePairEditDistance(const ::fst::StdVectorFst &ref_fst,
const ::fst::StdVectorFst &test_fst,
int num_symbols) {
const ::fst::StdVectorFst min_cost_alignment =
AlignStringFsts(ref_fst, test_fst, num_symbols);
int curr_state = min_cost_alignment.Start();
EditDistanceInt ed_int;
ed_int.reference_length = ref_fst.NumStates() - 1;
while (curr_state >= 0 && min_cost_alignment.NumArcs(curr_state) > 0) {
QCHECK_EQ(min_cost_alignment.NumArcs(curr_state), 1);
::fst::ArcIterator<::fst::StdVectorFst> aiter(min_cost_alignment,
curr_state);
::fst::StdArc arc = aiter.Value();
if (arc.ilabel != arc.olabel) { // This is an edit in the alignment.
if (arc.ilabel == 0) {
// Reference token aligns with nothing in the test string: deletion.
++ed_int.deletions;
} else if (arc.olabel == 0) {
// Test token aligns with nothing in the reference string: insertion.
++ed_int.insertions;
} else {
++ed_int.substitutions;
}
}
curr_state = arc.nextstate;
}
return ed_int;
}

EditDistanceInt CalculatePairEditDistance(
const std::vector<std::string> &ref_string,
const std::vector<std::string> &test_string) {
::fst::SymbolTable syms;
syms.AddSymbol("<epsilon>");
const ::fst::StdVectorFst ref_fst = GetStringFst(ref_string, &syms);
const ::fst::StdVectorFst test_fst = GetStringFst(test_string, &syms);
return CalculatePairEditDistance(ref_fst, test_fst, syms.NumSymbols());
}

// Splits string on either whitespace or characters.
std::vector<std::string> MaybeSplitChars(absl::string_view str,
bool split_chars) {
std::vector<std::string> tokenized_string =
absl::StrSplit(str, utf8::Utf8WhitespaceDelimiter(), absl::SkipEmpty());
if (split_chars) {
// Rejoins with a single whitespace prior to splitting on single unicode
// codepoints, thus normalizing the whitespace in the strings.
tokenized_string =
utf8::StrSplitByChar(absl::StrJoin(tokenized_string, " "));
}
return tokenized_string;
}

} // namespace
} // namespace impl

double MultiRefErrorRate::CalcErrorRate() {
EditDistanceDouble tot_ed_double;
for (int i = 0; i < total_ed_double_.size(); ++i) {
tot_ed_double += total_ed_double_[i];
}
return tot_ed_double.ErrorRate();
}

void MultiRefErrorRate::Write(absl::string_view ofile) {
std::ofstream output_file;
output_file.open(std::string(ofile));
QCHECK(output_file) << "Cannot open " << ofile << " for writing.";
EditDistanceDouble tot_ed_double;
for (int i = 0; i < total_ed_double_.size(); ++i) {
impl::LineToFile(output_file, total_ed_double_[i].ToString());
tot_ed_double += total_ed_double_[i];
}
std::string error_rate_label = is_split_chars_ ? "CER" : "WER";
impl::LineToFile(output_file, "Summary ", error_rate_label);
impl::LineToFile(output_file, "total statistics: ", tot_ed_double.ToString());
impl::LineToFile(output_file, "overall ", error_rate_label, ": ",
tot_ed_double.ErrorRate());
}

void MultiRefErrorRate::ReadInputs(absl::string_view input_file,
bool is_reference) {
if (is_reference) {
references_.clear();
} else {
test_input_.clear();
}
const auto &input_lines_status = file::ReadLines(input_file, kMaxLine);
QCHECK(input_lines_status.ok()) << "Failed to read " << input_file;
const std::vector<std::string> input_lines = input_lines_status.value();
for (const std::string &str : input_lines) {
const std::vector<std::string> seq =
absl::StrSplit(str, '\t', absl::SkipEmpty());
QCHECK_GE(seq.size(), 2);
int input_idx = std::stoi(seq[0]);
int osym = 0;
if (!seq[1].empty()) {
osym = output_syms_.Find(seq[1]);
if (osym < 0) {
osym = output_syms_.AddSymbol(seq[1]);
}
}
// Default count (in reference) is 1.0; otherwise default (-logP) is 0.0.
double item_value = is_reference ? 1.0 : 0.0;
if (seq.size() > 2) {
QCHECK_LE(seq.size(), 3);
item_value = std::stod(seq[2]);
}
if (is_reference) {
// Converts raw count to -log count for references.
item_value = -log(item_value);
if (input_idx >= references_.size()) {
references_.resize(input_idx + 1);
}
references_[input_idx].push_back(std::make_pair(osym, item_value));
} else {
if (input_idx >= test_input_.size()) {
test_input_.resize(input_idx + 1);
}
test_input_[input_idx].push_back(std::make_pair(osym, item_value));
}
}
}

std::vector<std::string> MultiRefErrorRate::GetTokenizedString(
int idx, int k, bool is_test_item) const {
const auto &inputs = is_test_item ? test_input_ : references_;
if (idx >= inputs.size() || k >= inputs[idx].size()) {
// If requested string does not exist in collection.
return std::vector<std::string>();
}
const auto input_pair = inputs[idx][k];
return input_pair.first == 0
? std::vector<std::string>()
: impl::MaybeSplitChars(output_syms_.Find(input_pair.first),
is_split_chars_);
}

void MultiRefErrorRate::CalculateMinErrorRate(int idx) {
int min_cost_test_item = 0;
for (int i = 1; i < test_input_[idx].size(); ++i) {
if (test_input_[idx][i].second <
test_input_[idx][min_cost_test_item].second) {
min_cost_test_item = i;
}
}
EditDistanceDouble min_error_values;
// If minimum cost test item is empty string, empty vector; otherwise
// tokenize string as requested.
const std::vector<std::string> test_string =
GetTokenizedString(idx, min_cost_test_item, /*is_test_item=*/true);
for (int i = 0; i < references_[idx].size(); ++i) {
const std::vector<std::string> ref_string =
GetTokenizedString(idx, i, /*is_test_item=*/false);
auto this_min_error_values =
impl::CalculatePairEditDistance(ref_string, test_string);
if (i == 0 || impl::BetterErrorValues(
static_cast<EditDistanceDouble>(this_min_error_values),
min_error_values)) {
min_error_values = this_min_error_values;
}
}
total_ed_double_.push_back(min_error_values);
}

void MultiRefErrorRate::CalculateErrorRate(int idx) {
if (references_[idx].empty()) {
// Nothing to do for this example.
return;
}
CalculateMinErrorRate(idx);
}

void MultiRefErrorRate::CalculateErrorRate() {
QCHECK_EQ(references_.size(), test_input_.size());
total_ed_double_.clear();
total_ed_double_.reserve(references_.size());
for (int idx = 0; idx < references_.size(); ++idx) {
// Either both are empty or reference is non-empty.
QCHECK(!references_[idx].empty() || test_input_[idx].empty());
if (test_input_[idx].empty()) {
// Creates empty string test input if none given for reference item.
test_input_[idx].push_back(std::make_pair(0, 0.0));
}
CalculateErrorRate(idx);
}
}

void MultiRefErrorRate::CalculateErrorRate(absl::string_view reffile,
absl::string_view testfile) {
ReadInputs(reffile, /*is_reference=*/true);
ReadInputs(testfile, /*is_reference=*/false);
CalculateErrorRate();
}

} // namespace tools
} // namespace translit
} // namespace nisaba
Loading

0 comments on commit 7285a6d

Please sign in to comment.