Stop using the pointer of Result. This fix allows to insert new candidates in supplemental model.

PiperOrigin-RevId: 701241583
This commit is contained in:
Taku Kudo
2024-11-29 11:41:18 +00:00
committed by Hiroyuki Komatsu
parent 2b5a6b4d2d
commit 3c4054ba6f
5 changed files with 46 additions and 55 deletions

View File

@ -84,12 +84,12 @@ class SupplementalModelInterface {
// Reranks (boost or promote) the typing corrected candidates at `results`.
virtual void RerankTypingCorrection(
const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const prediction::Result *>> *results) const {}
std::vector<prediction::Result> &results) const {}
// Reranks the zero query suggestion generated by suffix dictionary.
virtual void RerankZeroQuerySuggestion(
const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const prediction::Result *>> *results) const {}
std::vector<prediction::Result> &results) const {}
// Performs general post correction on `segments`.
virtual void PostCorrect(const ConversionRequest &request,

View File

@ -66,11 +66,11 @@ class MockSupplementalModel : public SupplementalModelInterface {
(const, override));
MOCK_METHOD(void, RerankTypingCorrection,
(const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const prediction::Result *>> *results),
std::vector<prediction::Result> &results),
(const, override));
MOCK_METHOD(void, RerankZeroQuerySuggestion,
(const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const prediction::Result *>> *results),
std::vector<prediction::Result> &results),
(const, override));
MOCK_METHOD(void, PostCorrect,
(const ConversionRequest &, absl::Nonnull<Segments *> segments),

View File

@ -311,7 +311,8 @@ bool DictionaryPredictor::PredictForRequest(const ConversionRequest &request,
MaybeRescoreResults(request, *segments, absl::MakeSpan(results));
return AddPredictionToCandidates(request, segments, absl::MakeSpan(results));
// `results` are no longer used.
return AddPredictionToCandidates(request, segments, std::move(results));
}
void DictionaryPredictor::RewriteResultsForPrediction(
@ -363,7 +364,7 @@ void DictionaryPredictor::MaybePopulateTypingCorrectedResults(
bool DictionaryPredictor::AddPredictionToCandidates(
const ConversionRequest &request, Segments *segments,
absl::Span<Result> results) const {
std::vector<Result> results) const {
DCHECK(segments);
const KeyValueView history = GetHistoryKeyAndValue(*segments);
@ -371,20 +372,13 @@ bool DictionaryPredictor::AddPredictionToCandidates(
Segment *segment = segments->mutable_conversion_segment(0);
DCHECK(segment);
// This pointer array is used to perform heap operations efficiently.
std::vector<const Result *> result_ptrs;
result_ptrs.reserve(results.size());
for (const auto &r : results) result_ptrs.push_back(&r);
// Instead of sorting all the results, we construct a heap.
// This is done in linear time and
// we can pop as many results as we need efficiently.
auto min_heap_cmp = [](const Result *lhs, const Result *rhs) {
// `rhs < lhs` instead of `lhs < rhs`, since `make_heap()` creates max heap
// by default.
return ResultCostLess()(*rhs, *lhs);
};
std::make_heap(result_ptrs.begin(), result_ptrs.end(), min_heap_cmp);
std::make_heap(results.begin(), results.end(),
[](const Result &lhs, const Result &rhs) {
return ResultCostLess()(rhs, lhs);
});
const size_t max_candidates_size = std::min(
request.max_dictionary_prediction_candidates_size(), results.size());
@ -419,41 +413,42 @@ bool DictionaryPredictor::AddPredictionToCandidates(
#endif // MOZC_DEBUG
std::vector<absl::Nonnull<const Result *>> final_results_ptrs;
final_results_ptrs.reserve(result_ptrs.size());
std::shared_ptr<Result> prev_top_result;
std::vector<Result> final_results;
for (size_t i = 0; i < result_ptrs.size(); ++i) {
std::pop_heap(result_ptrs.begin(), result_ptrs.end() - i, min_heap_cmp);
const Result &result = *result_ptrs[result_ptrs.size() - i - 1];
for (size_t i = 0; i < results.size(); ++i) {
std::pop_heap(results.begin(), results.end() - i,
[](const Result &lhs, const Result &rhs) {
return ResultCostLess()(rhs, lhs);
});
Result &result = results[results.size() - i - 1];
if (final_results_ptrs.size() >= max_candidates_size ||
if (final_results.size() >= max_candidates_size ||
result.cost >= kInfinity) {
break;
}
if (i == 0 && (prev_top_result = MaybeGetPreviousTopResult(
result, request, *segments)) != nullptr) {
final_results_ptrs.emplace_back(prev_top_result.get());
final_results.emplace_back(*prev_top_result);
}
std::string log_message;
if (filter.ShouldRemove(result, final_results_ptrs.size(), &log_message)) {
if (filter.ShouldRemove(result, final_results.size(), &log_message)) {
MOZC_ADD_DEBUG_CANDIDATE(result, log_message);
continue;
}
final_results_ptrs.emplace_back(&result);
final_results.emplace_back(std::move(result));
}
MaybeRerankZeroQuerySuggestion(request, *segments, &final_results_ptrs);
MaybeRerankAggressiveTypingCorrection(request, *segments, final_results);
MaybeRerankAggressiveTypingCorrection(request, *segments,
&final_results_ptrs);
MaybeRerankZeroQuerySuggestion(request, *segments, final_results);
// Fill segments from final_results_ptrs.
for (const Result *result : final_results_ptrs) {
FillCandidate(request, *result, GetCandidateKeyAndValue(*result, history),
for (const Result &result : final_results) {
FillCandidate(request, result, GetCandidateKeyAndValue(result, history),
merged_types, segment->push_back_candidate());
}
@ -464,14 +459,14 @@ bool DictionaryPredictor::AddPredictionToCandidates(
AddRescoringDebugDescription(segments);
}
return !final_results_ptrs.empty();
return !final_results.empty();
#undef MOZC_ADD_DEBUG_CANDIDATE
}
void DictionaryPredictor::MaybeRerankAggressiveTypingCorrection(
const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const Result *>> *results) const {
if (!IsTypingCorrectionEnabled(request) || results->empty()) {
std::vector<Result> &results) const {
if (!IsTypingCorrectionEnabled(request) || results.empty()) {
return;
}
const engine::SupplementalModelInterface *supplemental_model =
@ -482,7 +477,7 @@ void DictionaryPredictor::MaybeRerankAggressiveTypingCorrection(
void DictionaryPredictor::MaybeRerankZeroQuerySuggestion(
const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const Result *>> *results) const {
std::vector<Result> &results) const {
if (!IsTypingCorrectionEnabled(request)) {
return;
}

View File

@ -142,9 +142,11 @@ class DictionaryPredictor : public PredictorInterface {
aggregator,
const ImmutableConverterInterface *immutable_converter);
// It is better to pass the rvalue of `results` if the
// caller doesn't use the results after calling this method.
bool AddPredictionToCandidates(const ConversionRequest &request,
Segments *segments,
absl::Span<Result> results) const;
std::vector<Result> results) const;
void FillCandidate(
const ConversionRequest &request, const Result &result,
@ -271,11 +273,11 @@ class DictionaryPredictor : public PredictorInterface {
void MaybeRerankAggressiveTypingCorrection(
const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const Result *>> *results) const;
std::vector<Result> &results) const;
void MaybeRerankZeroQuerySuggestion(
const ConversionRequest &request, const Segments &segments,
std::vector<absl::Nonnull<const Result *>> *results) const;
void MaybeRerankZeroQuerySuggestion(const ConversionRequest &request,
const Segments &segments,
std::vector<Result> &results) const;
static void MaybeApplyPostCorrection(const ConversionRequest &request,
const engine::Modules &modules,

View File

@ -131,7 +131,7 @@ class DictionaryPredictorTestPeer {
bool AddPredictionToCandidates(const ConversionRequest &request,
Segments *segments,
absl::Span<Result> results) const {
std::vector<Result> &results) const {
return predictor_.AddPredictionToCandidates(request, segments, results);
}
@ -920,8 +920,7 @@ TEST_F(DictionaryPredictorTest, MergeAttributesForDebug) {
config_->set_verbose_level(1);
const ConversionRequest convreq =
CreateConversionRequest(ConversionRequest::SUGGESTION);
predictor.AddPredictionToCandidates(convreq, &segments,
absl::MakeSpan(results));
predictor.AddPredictionToCandidates(convreq, &segments, results);
EXPECT_EQ(segments.conversion_segments_size(), 1);
const Segment &segment = segments.conversion_segment(0);
@ -947,8 +946,7 @@ TEST_F(DictionaryPredictorTest, SetDescription) {
const ConversionRequest convreq =
CreateConversionRequest(ConversionRequest::PREDICTION);
predictor.AddPredictionToCandidates(convreq, &segments,
absl::MakeSpan(results));
predictor.AddPredictionToCandidates(convreq, &segments, results);
EXPECT_EQ(segments.conversion_segments_size(), 1);
const Segment &segment = segments.conversion_segment(0);
@ -988,8 +986,7 @@ TEST_F(DictionaryPredictorTest, PropagateResultCosts) {
.max_dictionary_prediction_candidates_size = kTestSize,
});
predictor.AddPredictionToCandidates(convreq, &segments,
absl::MakeSpan(results));
predictor.AddPredictionToCandidates(convreq, &segments, results);
EXPECT_EQ(segments.conversion_segments_size(), 1);
ASSERT_EQ(kTestSize, segments.conversion_segment(0).candidates_size());
@ -1029,8 +1026,7 @@ TEST_F(DictionaryPredictorTest, PredictNCandidates) {
.max_dictionary_prediction_candidates_size = kLowCostCandidateSize + 1,
});
predictor.AddPredictionToCandidates(convreq, &segments,
absl::MakeSpan(results));
predictor.AddPredictionToCandidates(convreq, &segments, results);
ASSERT_EQ(1, segments.conversion_segments_size());
ASSERT_EQ(kLowCostCandidateSize,
@ -1489,8 +1485,7 @@ TEST_F(DictionaryPredictorTest, Dedup) {
InitSegmentsWithKey("test", &segments);
const ConversionRequest convreq =
CreateConversionRequest(ConversionRequest::PREDICTION);
predictor.AddPredictionToCandidates(convreq, &segments,
absl::MakeSpan(results));
predictor.AddPredictionToCandidates(convreq, &segments, results);
ASSERT_EQ(segments.conversion_segments_size(), 1);
EXPECT_EQ(segments.conversion_segment(0).candidates_size(), kSize);
@ -1528,8 +1523,8 @@ TEST_F(DictionaryPredictorTest, TypingCorrectionResultsLimit) {
InitSegmentsWithKey("original_key", &segments);
const ConversionRequest convreq =
CreateConversionRequest(ConversionRequest::PREDICTION);
predictor.AddPredictionToCandidates(convreq, &segments,
absl::MakeSpan(results));
predictor.AddPredictionToCandidates(convreq, &segments, results);
ASSERT_EQ(segments.conversion_segments_size(), 1);
const Segment segment = segments.conversion_segment(0);
EXPECT_EQ(segment.candidates_size(), 3);
@ -1563,8 +1558,7 @@ TEST_F(DictionaryPredictorTest, SortResult) {
InitSegmentsWithKey("test", &segments);
const ConversionRequest convreq =
CreateConversionRequest(ConversionRequest::PREDICTION);
predictor.AddPredictionToCandidates(convreq, &segments,
absl::MakeSpan(results));
predictor.AddPredictionToCandidates(convreq, &segments, results);
ASSERT_EQ(segments.conversion_segments_size(), 1);
const Segment &segment = segments.conversion_segment(0);