Skip to content

Commit

Permalink
Fix the issue that typing correction is unexpectedly demoted by resco…
Browse files Browse the repository at this point in the history
…ring.

There are two issues:
  1) The typing correction bonus/penalty is lost by rescoring as it is added in
     `Result::wcost`.
  2) Rescoring is performed for normal rusults and typing corrections
     separately (i.e., rescoring function is called twice in one prediction
     call).  Since the conversion from Transformer LM score to cost depends on
     other results, both normal and typing correction need to be rescored
     together to get consistent results.

This CL fixes the above issues as follows:
  1) Record the bonus/penalty of typing correction in `Result` and restore it
     after rescoring.
  2) Stop calling `MaybeRescoreResults()` in `RewiteResultsForPrediction()`.
     Instead, call it after generating both normal and typing correction results.

To rescore typing correction results, this CL also stops using the composing
Hiragana reading when evaluating the transformer LM scores, because typing
correction naturally has different readings. Without this treatment, typing
correction results likely to have lower LM scores due to the mismatch between
the input reading and the corrected surface form. The side effect of this treatment
is that some candidates having common surface form but irregular reading might be promoted, e.g., こうべ -> 頭. I will fix this issue in future CLs.

PiperOrigin-RevId: 638980678
  • Loading branch information
Noriyuki Takahashi authored and hiroyuki-komatsu committed May 31, 2024
1 parent 6a3e223 commit 42cbb3f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/prediction/dictionary_prediction_aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1761,7 +1761,9 @@ void DictionaryPredictionAggregator::AggregateTypingCorrectedPrediction(
// bias = hyp_score - base_score, so larger is better.
// bias is computed in log10 domain, so we need to use the different
// scale factor. 500 * log(10) = ~1150.
result.wcost -= 1150 * query.bias;
const int adjustment = -1150 * query.bias;
result.typing_correction_adjustment = adjustment;
result.wcost += adjustment;
results->emplace_back(std::move(result));
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/prediction/dictionary_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ bool DictionaryPredictor::PredictForRequest(const ConversionRequest &request,
const TypingCorrectionMixingParams typing_correction_mixing_params =
MaybePopulateTypingCorrectedResults(request, *segments, &results);

MaybeRescoreResults(request, *segments, absl::MakeSpan(results));

return AddPredictionToCandidates(request, segments,
typing_correction_mixing_params,
absl::MakeSpan(results));
Expand All @@ -334,8 +336,6 @@ void DictionaryPredictor::RewriteResultsForPrediction(
SetPredictionCost(request.request_type(), segments, results);
}

MaybeRescoreResults(request, segments, absl::MakeSpan(*results));

if (!is_mixed_conversion) {
const size_t input_key_len =
Util::CharsLen(segments.conversion_segment(0).key());
Expand Down
31 changes: 21 additions & 10 deletions src/prediction/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ struct Result {
// Context "insensitive" candidate cost.
int wcost = 0;
// Context "sensitive" candidate cost.
// TODO(noriyukit): The cost is basically calculated by the underlying LM, but
// currently it is updated by other modules and heuristics at many locations;
// e.g., see SetPredictionCostForMixedConversion() in
// dictionary_predictgor.cc. Ideally, such cost adjustments should be kept
// separately from the original LM cost to perform rescoring in a rigorous
// manner.
int cost = 0;
int lid = 0;
int rid = 0;
Expand All @@ -127,23 +133,28 @@ struct Result {
int cost_before_rescoring = 0;
// If removed is true, this result is not used for a candidate.
bool removed = false;
// confidence score of typing correction. Larger is more confident.
// Confidence score of typing correction. Larger is more confident.
float typing_correction_score = 0.0;
// Adjustment for `wcost` made by the typing correction. This value can be
// zero, positive (penalty) or negative (bonus), and it is added to `wcost`.
int typing_correction_adjustment = 0;
#ifndef NDEBUG
std::string log;
#endif // NDEBUG

template <typename S>
friend void AbslStringify(S &sink, const Result &r) {
absl::Format(&sink,
"key: %s, value: %s, types: %d, wcost: %d, cost: %d, lid: %d, "
"rid: %d, attrs: %d, bdd: %s, srcinfo: %d, origkey: %s, "
"consumed_key_size: %d, penalty: %d, removed: %v",
r.key, r.value, r.types, r.wcost, r.cost, r.lid, r.rid,
r.candidate_attributes,
absl::StrJoin(r.inner_segment_boundary, ","), r.source_info,
r.non_expanded_original_key, r.consumed_key_size, r.penalty,
r.removed);
absl::Format(
&sink,
"key: %s, value: %s, types: %d, wcost: %d, cost: %d, cost_before: %d, "
"lid: %d, "
"rid: %d, attrs: %d, bdd: %s, srcinfo: %d, origkey: %s, "
"consumed_key_size: %d, penalty: %d, tc_adjustment: %d, removed: %v",
r.key, r.value, r.types, r.wcost, r.cost, r.cost_before_rescoring,
r.lid, r.rid, r.candidate_attributes,
absl::StrJoin(r.inner_segment_boundary, ","), r.source_info,
r.non_expanded_original_key, r.consumed_key_size, r.penalty,
r.typing_correction_adjustment, r.removed);
#ifndef NDEBUG
sink.Append(", log:\n");
for (absl::string_view line : absl::StrSplit(r.log, '\n')) {
Expand Down

0 comments on commit 42cbb3f

Please sign in to comment.