-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRewriteTesterUtils.cpp
343 lines (319 loc) · 12.5 KB
/
RewriteTesterUtils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
//
// Created by Daniel Schnell on 19.05.20.
//
#include "RewriteTesterUtils.h"
#include <boost/algorithm/string.hpp>
#include <boost/filesystem.hpp>
#include <boost/locale.hpp>
#include <fst/arc.h>
#include <fst/flags.h>
#include <fst/fst.h>
#include <fst/string.h>
#include <fst/symbol-table.h>
#include <fst/vector-fst.h>
#include <thrax/grm-manager.h>
#include <thrax/algo/paths.h>
#include <thrax/symbols.h>
using ::fst::StdArc;
using ::fst::StdVectorFst;
using ::fst::StringCompiler;
using ::fst::StringTokenType;
using ::fst::SymbolTable;
using ::thrax::FstToStrings;
using ::thrax::GetGeneratedSymbolTable;
using ::thrax::RuleTriple;
#define HISTORY_FILE ".rewrite-tester-history"
DEFINE_string(far, "g2p.far", "Path to the FAR.");
DEFINE_string(rules, "G2P", "Names of the rewrite rules.");
DEFINE_string(input_mode, "utf8", "Either \"byte\", \"utf8\", or the path to a "
"symbol table for input parsing.");
DEFINE_string(output_mode, "utf8", "Either \"byte\", \"utf8\", or the path to "
"a symbol table for input parsing.");
DEFINE_string(history_file, HISTORY_FILE,
"Location of history file");
DEFINE_int64(noutput, 1, "Maximum number of output strings for each input.");
DEFINE_bool(show_details, false, "Show the output of each individual rule when"
" multiple rules are specified.");
DEFINE_string(field_separator, " ",
"Field separator for strings of symbols from a symbol table.");
DEFINE_string(word_file, "",
"File with newline separated fields that should be used for input");
namespace thrax {
namespace {
using ::fst::kNoStateId;
using ::fst::LabelsToUTF8String;
using ::fst::PathIterator;
using ::fst::Project;
using ::fst::PROJECT_OUTPUT;
using ::fst::RmEpsilon;
using ::fst::ShortestPath;
using ::fst::StdArc;
using ::fst::StdVectorFst;
using ::fst::SymbolTable;
using Label = StdArc::Label;
inline bool AppendLabel(Label label, TokenType type,
const SymbolTable *generated_symtab,
SymbolTable *symtab, std::string *path) {
if (label != 0) {
// Check first to see if this label is in the generated symbol set. Note
// that this should not conflict with a user-provided symbol table since
// the parser used by GrmCompiler doesn't generate extra labels if a
// string is parsed using a user-provided symbol table.
if (generated_symtab && !generated_symtab->Find(label).empty()) {
const auto &sym = generated_symtab->Find(label);
*path += "[" + sym + "]";
} else if (type == SYMBOL) {
const auto &sym = symtab->Find(label);
if (sym.empty()) {
LOG(ERROR) << "Missing symbol in symbol table for id: " << label;
return false;
}
// For non-byte, non-UTF8 symbols, one overwhelmingly wants these to be
// space-separated.
if (!path->empty()) *path += FLAGS_field_separator;
*path += sym;
} else if (type == BYTE) {
path->push_back(label);
} else if (type == UTF8) {
std::string utf8_string;
std::vector<Label> labels;
labels.push_back(label);
if (!LabelsToUTF8String(labels, &utf8_string)) {
LOG(ERROR) << "LabelsToUTF8String: Bad code point: " << label;
return false;
}
*path += utf8_string;
}
}
return true;
}
} // namespace
bool FstToStrings(const StdVectorFst &fst,
std::vector<std::pair<std::string, float>> *strings,
const SymbolTable *generated_symtab, TokenType type,
SymbolTable *symtab, size_t n) {
StdVectorFst shortest_path;
if (n == 1) {
ShortestPath(fst, &shortest_path, n);
} else {
// The uniqueness feature of ShortestPath requires us to have an acceptor,
// so we project and remove epsilon arcs.
StdVectorFst temp(fst);
Project(&temp, PROJECT_OUTPUT);
RmEpsilon(&temp);
ShortestPath(temp, &shortest_path, n, /*unique=*/true);
}
if (shortest_path.Start() == kNoStateId) return false;
for (PathIterator<StdArc> iter(shortest_path, /*check_acyclic=*/false);
!iter.Done(); iter.Next()) {
std::string path;
for (const auto label : iter.OLabels()) {
if (!AppendLabel(label, type, generated_symtab, symtab, &path)) {
return false;
}
}
strings->emplace_back(std::move(path), iter.Weight().Value());
}
return true;
}
const SymbolTable *GetGeneratedSymbolTable(GrmManagerSpec<StdArc> *grm) {
const auto *symbolfst = grm->GetFst("*StringFstSymbolTable");
return symbolfst ? symbolfst->InputSymbols()->Copy() : nullptr;
}
} // namespace thrax
RewriteTesterUtils::RewriteTesterUtils() :
compiler_(nullptr),
byte_symtab_(nullptr),
utf8_symtab_(nullptr),
input_symtab_(nullptr),
output_symtab_(nullptr)
{
if (!FLAGS_word_file.empty())
{
namespace FS=boost::filesystem;
if (!FS::exists(FLAGS_word_file))
{
auto errorMsg = std::string("No such file: ") + FLAGS_word_file;
throw std::runtime_error(errorMsg);
}
}
}
RewriteTesterUtils::~RewriteTesterUtils() {
delete compiler_;
delete input_symtab_;
delete output_symtab_;
delete byte_symtab_;
delete utf8_symtab_;
}
void RewriteTesterUtils::Initialize() {
boost::locale::generator gen;
std::locale::global(gen("is_IS.UTF-8"));
CHECK(grm_.LoadArchive(FLAGS_far));
rules_ = ::fst::StringSplit(FLAGS_rules, ',');
byte_symtab_ = nullptr;
utf8_symtab_ = nullptr;
if (rules_.empty()) LOG(FATAL) << "--rules must be specified";
for (size_t i = 0; i < rules_.size(); ++i) {
RuleTriple triple(rules_[i]);
const auto *fst = grm_.GetFst(triple.main_rule);
if (!fst) {
LOG(FATAL) << "grm.GetFst() must be non nullptr for rule: "
<< triple.main_rule;
}
StdVectorFst vfst(*fst);
// If the input transducers in the FAR have symbol tables then we need to
// add the appropriate symbol table(s) to the input strings, according to
// the parse mode.
if (vfst.InputSymbols()) {
if (!byte_symtab_ &&
vfst.InputSymbols()->Name() ==
::thrax::function::kByteSymbolTableName) {
byte_symtab_ = vfst.InputSymbols()->Copy();
} else if (!utf8_symtab_ &&
vfst.InputSymbols()->Name() ==
::thrax::function::kUtf8SymbolTableName) {
utf8_symtab_ = vfst.InputSymbols()->Copy();
}
}
if (!triple.pdt_parens_rule.empty()) {
fst = grm_.GetFst(triple.pdt_parens_rule);
if (!fst) {
LOG(FATAL) << "grm.GetFst() must be non nullptr for rule: "
<< triple.pdt_parens_rule;
}
}
if (!triple.mpdt_assignments_rule.empty()) {
fst = grm_.GetFst(triple.mpdt_assignments_rule);
if (!fst) {
LOG(FATAL) << "grm.GetFst() must be non nullptr for rule: "
<< triple.mpdt_assignments_rule;
}
}
}
generated_symtab_ = GetGeneratedSymbolTable(&grm_);
if (FLAGS_input_mode == "byte") {
compiler_ = new StringCompiler<StdArc>(StringTokenType::BYTE);
} else if (FLAGS_input_mode == "utf8") {
compiler_ = new StringCompiler<StdArc>(StringTokenType::UTF8);
} else {
input_symtab_ = SymbolTable::ReadText(FLAGS_input_mode);
if (!input_symtab_) {
LOG(FATAL) << "Invalid mode or symbol table path.";
}
compiler_ =
new StringCompiler<StdArc>(StringTokenType::SYMBOL, input_symtab_);
}
output_symtab_ = nullptr;
if (FLAGS_output_mode == "byte") {
type_ = BYTE;
} else if (FLAGS_output_mode == "utf8") {
type_ = UTF8;
} else {
type_ = SYMBOL;
output_symtab_ = SymbolTable::ReadText(FLAGS_output_mode);
if (!output_symtab_) {
LOG(FATAL) << "Invalid mode or symbol table path.";
}
}
}
const std::string RewriteTesterUtils::ProcessInput(const std::string& input,
bool prepend_output) {
StdVectorFst input_fst;
StdVectorFst output_fst;
if (!compiler_->operator()(input, &input_fst)) {
return "Unable to parse input string.";
}
std::ostringstream sstrm;
// Set symbols for the input, if appropriate
if (byte_symtab_ && type_ == BYTE) {
input_fst.SetInputSymbols(byte_symtab_);
input_fst.SetOutputSymbols(byte_symtab_);
} else if (utf8_symtab_ && type_ == UTF8) {
input_fst.SetInputSymbols(utf8_symtab_);
input_fst.SetOutputSymbols(utf8_symtab_);
} else if (input_symtab_ && type_ == SYMBOL) {
input_fst.SetInputSymbols(input_symtab_);
input_fst.SetOutputSymbols(input_symtab_);
}
bool succeeded = true;
for (size_t i = 0; i < rules_.size(); ++i) {
RuleTriple triple(rules_[i]);
if (grm_.Rewrite(triple.main_rule, input_fst, &output_fst,
triple.pdt_parens_rule, triple.mpdt_assignments_rule)) {
if (FLAGS_show_details && rules_.size() > 1) {
std::vector<std::pair<std::string, float>> tmp;
FstToStrings(output_fst, &tmp, generated_symtab_, type_,
output_symtab_, FLAGS_noutput);
for (const auto& one_result : tmp) {
sstrm << "output of rule[" << triple.main_rule
<< "] is: " << one_result.first << '\n';
}
}
input_fst = output_fst;
} else {
succeeded = false;
break;
}
}
std::vector<std::pair<std::string, float>> strings;
std::set<std::string> seen;
if (succeeded && FstToStrings(output_fst, &strings,
generated_symtab_, type_,
output_symtab_, FLAGS_noutput)) {
for (auto it = strings.cbegin(); it != strings.cend(); ++it) {
const auto sx = seen.find(it->first);
if (sx != seen.end()) continue;
if (prepend_output) {
sstrm << "Output string: " << it->first;
} else {
sstrm << it->first;
}
if (FLAGS_noutput != 1 && it->second != 0) {
sstrm << " <cost=" << it->second << '>';
}
seen.insert(it->first);
if (it + 1 != strings.cend()) sstrm << '\n';
}
return sstrm.str();
} else {
return "Rewrite failed.";
}
}
// Run() for interactive/file based mode.
void RewriteTesterUtils::Run()
{
// if input word file has been given, we are not using interactive mode,
// but each line of the word_file corresponds to a word
if (!FLAGS_word_file.empty())
{
processFile(FLAGS_word_file);
}
else
{
std::string input;
while (ReadInput(&input))
std::cout << ProcessInput(boost::locale::to_lower(input)) << std::endl;
}
}
void RewriteTesterUtils::processFile(const std::string& filename)
{
std::ifstream file(filename);
std::string word;
while(std::getline(file, word))
{
std::cout << word << "\t" << processWord(word) << std::endl;
}
}
std::string RewriteTesterUtils::processWord(const std::string &word)
{
// Words need to be in lowercase, so that our model size doesn't explode. Icelandic words are in UTF-8
// encoding, accordingly std::tolower would not help, therefore use boost::locale
auto processedWord = ProcessInput(boost::locale::to_lower(word), false);
boost::trim(processedWord);
return processedWord;
}
bool RewriteTesterUtils::ReadInput(std::string* s)
{
std::cout << "Input string: ";
return static_cast<bool>(getline(std::cin, *s));
}