mirror of
https://github.com/mii443/mozc.git
synced 2025-08-23 00:25:34 +00:00
Stop using the pointer of Result. This fix allows to insert new candidates in supplemental model.
PiperOrigin-RevId: 701241583
This commit is contained in:
committed by
Hiroyuki Komatsu
parent
2b5a6b4d2d
commit
3c4054ba6f
@ -84,12 +84,12 @@ class SupplementalModelInterface {
|
||||
// Reranks (boost or promote) the typing corrected candidates at `results`.
|
||||
virtual void RerankTypingCorrection(
|
||||
const ConversionRequest &request, const Segments &segments,
|
||||
std::vector<absl::Nonnull<const prediction::Result *>> *results) const {}
|
||||
std::vector<prediction::Result> &results) const {}
|
||||
|
||||
// Reranks the zero query suggestion generated by suffix dictionary.
|
||||
virtual void RerankZeroQuerySuggestion(
|
||||
const ConversionRequest &request, const Segments &segments,
|
||||
std::vector<absl::Nonnull<const prediction::Result *>> *results) const {}
|
||||
std::vector<prediction::Result> &results) const {}
|
||||
|
||||
// Performs general post correction on `segments`.
|
||||
virtual void PostCorrect(const ConversionRequest &request,
|
||||
|
@ -66,11 +66,11 @@ class MockSupplementalModel : public SupplementalModelInterface {
|
||||
(const, override));
|
||||
MOCK_METHOD(void, RerankTypingCorrection,
|
||||
(const ConversionRequest &request, const Segments &segments,
|
||||
std::vector<absl::Nonnull<const prediction::Result *>> *results),
|
||||
std::vector<prediction::Result> &results),
|
||||
(const, override));
|
||||
MOCK_METHOD(void, RerankZeroQuerySuggestion,
|
||||
(const ConversionRequest &request, const Segments &segments,
|
||||
std::vector<absl::Nonnull<const prediction::Result *>> *results),
|
||||
std::vector<prediction::Result> &results),
|
||||
(const, override));
|
||||
MOCK_METHOD(void, PostCorrect,
|
||||
(const ConversionRequest &, absl::Nonnull<Segments *> segments),
|
||||
|
@ -311,7 +311,8 @@ bool DictionaryPredictor::PredictForRequest(const ConversionRequest &request,
|
||||
|
||||
MaybeRescoreResults(request, *segments, absl::MakeSpan(results));
|
||||
|
||||
return AddPredictionToCandidates(request, segments, absl::MakeSpan(results));
|
||||
// `results` are no longer used.
|
||||
return AddPredictionToCandidates(request, segments, std::move(results));
|
||||
}
|
||||
|
||||
void DictionaryPredictor::RewriteResultsForPrediction(
|
||||
@ -363,7 +364,7 @@ void DictionaryPredictor::MaybePopulateTypingCorrectedResults(
|
||||
|
||||
bool DictionaryPredictor::AddPredictionToCandidates(
|
||||
const ConversionRequest &request, Segments *segments,
|
||||
absl::Span<Result> results) const {
|
||||
std::vector<Result> results) const {
|
||||
DCHECK(segments);
|
||||
|
||||
const KeyValueView history = GetHistoryKeyAndValue(*segments);
|
||||
@ -371,20 +372,13 @@ bool DictionaryPredictor::AddPredictionToCandidates(
|
||||
Segment *segment = segments->mutable_conversion_segment(0);
|
||||
DCHECK(segment);
|
||||
|
||||
// This pointer array is used to perform heap operations efficiently.
|
||||
std::vector<const Result *> result_ptrs;
|
||||
result_ptrs.reserve(results.size());
|
||||
for (const auto &r : results) result_ptrs.push_back(&r);
|
||||
|
||||
// Instead of sorting all the results, we construct a heap.
|
||||
// This is done in linear time and
|
||||
// we can pop as many results as we need efficiently.
|
||||
auto min_heap_cmp = [](const Result *lhs, const Result *rhs) {
|
||||
// `rhs < lhs` instead of `lhs < rhs`, since `make_heap()` creates max heap
|
||||
// by default.
|
||||
return ResultCostLess()(*rhs, *lhs);
|
||||
};
|
||||
std::make_heap(result_ptrs.begin(), result_ptrs.end(), min_heap_cmp);
|
||||
std::make_heap(results.begin(), results.end(),
|
||||
[](const Result &lhs, const Result &rhs) {
|
||||
return ResultCostLess()(rhs, lhs);
|
||||
});
|
||||
|
||||
const size_t max_candidates_size = std::min(
|
||||
request.max_dictionary_prediction_candidates_size(), results.size());
|
||||
@ -419,41 +413,42 @@ bool DictionaryPredictor::AddPredictionToCandidates(
|
||||
|
||||
#endif // MOZC_DEBUG
|
||||
|
||||
std::vector<absl::Nonnull<const Result *>> final_results_ptrs;
|
||||
final_results_ptrs.reserve(result_ptrs.size());
|
||||
std::shared_ptr<Result> prev_top_result;
|
||||
std::vector<Result> final_results;
|
||||
|
||||
for (size_t i = 0; i < result_ptrs.size(); ++i) {
|
||||
std::pop_heap(result_ptrs.begin(), result_ptrs.end() - i, min_heap_cmp);
|
||||
const Result &result = *result_ptrs[result_ptrs.size() - i - 1];
|
||||
for (size_t i = 0; i < results.size(); ++i) {
|
||||
std::pop_heap(results.begin(), results.end() - i,
|
||||
[](const Result &lhs, const Result &rhs) {
|
||||
return ResultCostLess()(rhs, lhs);
|
||||
});
|
||||
Result &result = results[results.size() - i - 1];
|
||||
|
||||
if (final_results_ptrs.size() >= max_candidates_size ||
|
||||
if (final_results.size() >= max_candidates_size ||
|
||||
result.cost >= kInfinity) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (i == 0 && (prev_top_result = MaybeGetPreviousTopResult(
|
||||
result, request, *segments)) != nullptr) {
|
||||
final_results_ptrs.emplace_back(prev_top_result.get());
|
||||
final_results.emplace_back(*prev_top_result);
|
||||
}
|
||||
|
||||
std::string log_message;
|
||||
if (filter.ShouldRemove(result, final_results_ptrs.size(), &log_message)) {
|
||||
if (filter.ShouldRemove(result, final_results.size(), &log_message)) {
|
||||
MOZC_ADD_DEBUG_CANDIDATE(result, log_message);
|
||||
continue;
|
||||
}
|
||||
|
||||
final_results_ptrs.emplace_back(&result);
|
||||
final_results.emplace_back(std::move(result));
|
||||
}
|
||||
|
||||
MaybeRerankZeroQuerySuggestion(request, *segments, &final_results_ptrs);
|
||||
MaybeRerankAggressiveTypingCorrection(request, *segments, final_results);
|
||||
|
||||
MaybeRerankAggressiveTypingCorrection(request, *segments,
|
||||
&final_results_ptrs);
|
||||
MaybeRerankZeroQuerySuggestion(request, *segments, final_results);
|
||||
|
||||
// Fill segments from final_results_ptrs.
|
||||
for (const Result *result : final_results_ptrs) {
|
||||
FillCandidate(request, *result, GetCandidateKeyAndValue(*result, history),
|
||||
for (const Result &result : final_results) {
|
||||
FillCandidate(request, result, GetCandidateKeyAndValue(result, history),
|
||||
merged_types, segment->push_back_candidate());
|
||||
}
|
||||
|
||||
@ -464,14 +459,14 @@ bool DictionaryPredictor::AddPredictionToCandidates(
|
||||
AddRescoringDebugDescription(segments);
|
||||
}
|
||||
|
||||
return !final_results_ptrs.empty();
|
||||
return !final_results.empty();
|
||||
#undef MOZC_ADD_DEBUG_CANDIDATE
|
||||
}
|
||||
|
||||
void DictionaryPredictor::MaybeRerankAggressiveTypingCorrection(
|
||||
const ConversionRequest &request, const Segments &segments,
|
||||
std::vector<absl::Nonnull<const Result *>> *results) const {
|
||||
if (!IsTypingCorrectionEnabled(request) || results->empty()) {
|
||||
std::vector<Result> &results) const {
|
||||
if (!IsTypingCorrectionEnabled(request) || results.empty()) {
|
||||
return;
|
||||
}
|
||||
const engine::SupplementalModelInterface *supplemental_model =
|
||||
@ -482,7 +477,7 @@ void DictionaryPredictor::MaybeRerankAggressiveTypingCorrection(
|
||||
|
||||
void DictionaryPredictor::MaybeRerankZeroQuerySuggestion(
|
||||
const ConversionRequest &request, const Segments &segments,
|
||||
std::vector<absl::Nonnull<const Result *>> *results) const {
|
||||
std::vector<Result> &results) const {
|
||||
if (!IsTypingCorrectionEnabled(request)) {
|
||||
return;
|
||||
}
|
||||
|
@ -142,9 +142,11 @@ class DictionaryPredictor : public PredictorInterface {
|
||||
aggregator,
|
||||
const ImmutableConverterInterface *immutable_converter);
|
||||
|
||||
// It is better to pass the rvalue of `results` if the
|
||||
// caller doesn't use the results after calling this method.
|
||||
bool AddPredictionToCandidates(const ConversionRequest &request,
|
||||
Segments *segments,
|
||||
absl::Span<Result> results) const;
|
||||
std::vector<Result> results) const;
|
||||
|
||||
void FillCandidate(
|
||||
const ConversionRequest &request, const Result &result,
|
||||
@ -271,11 +273,11 @@ class DictionaryPredictor : public PredictorInterface {
|
||||
|
||||
void MaybeRerankAggressiveTypingCorrection(
|
||||
const ConversionRequest &request, const Segments &segments,
|
||||
std::vector<absl::Nonnull<const Result *>> *results) const;
|
||||
std::vector<Result> &results) const;
|
||||
|
||||
void MaybeRerankZeroQuerySuggestion(
|
||||
const ConversionRequest &request, const Segments &segments,
|
||||
std::vector<absl::Nonnull<const Result *>> *results) const;
|
||||
void MaybeRerankZeroQuerySuggestion(const ConversionRequest &request,
|
||||
const Segments &segments,
|
||||
std::vector<Result> &results) const;
|
||||
|
||||
static void MaybeApplyPostCorrection(const ConversionRequest &request,
|
||||
const engine::Modules &modules,
|
||||
|
@ -131,7 +131,7 @@ class DictionaryPredictorTestPeer {
|
||||
|
||||
bool AddPredictionToCandidates(const ConversionRequest &request,
|
||||
Segments *segments,
|
||||
absl::Span<Result> results) const {
|
||||
std::vector<Result> &results) const {
|
||||
return predictor_.AddPredictionToCandidates(request, segments, results);
|
||||
}
|
||||
|
||||
@ -920,8 +920,7 @@ TEST_F(DictionaryPredictorTest, MergeAttributesForDebug) {
|
||||
config_->set_verbose_level(1);
|
||||
const ConversionRequest convreq =
|
||||
CreateConversionRequest(ConversionRequest::SUGGESTION);
|
||||
predictor.AddPredictionToCandidates(convreq, &segments,
|
||||
absl::MakeSpan(results));
|
||||
predictor.AddPredictionToCandidates(convreq, &segments, results);
|
||||
|
||||
EXPECT_EQ(segments.conversion_segments_size(), 1);
|
||||
const Segment &segment = segments.conversion_segment(0);
|
||||
@ -947,8 +946,7 @@ TEST_F(DictionaryPredictorTest, SetDescription) {
|
||||
|
||||
const ConversionRequest convreq =
|
||||
CreateConversionRequest(ConversionRequest::PREDICTION);
|
||||
predictor.AddPredictionToCandidates(convreq, &segments,
|
||||
absl::MakeSpan(results));
|
||||
predictor.AddPredictionToCandidates(convreq, &segments, results);
|
||||
|
||||
EXPECT_EQ(segments.conversion_segments_size(), 1);
|
||||
const Segment &segment = segments.conversion_segment(0);
|
||||
@ -988,8 +986,7 @@ TEST_F(DictionaryPredictorTest, PropagateResultCosts) {
|
||||
.max_dictionary_prediction_candidates_size = kTestSize,
|
||||
});
|
||||
|
||||
predictor.AddPredictionToCandidates(convreq, &segments,
|
||||
absl::MakeSpan(results));
|
||||
predictor.AddPredictionToCandidates(convreq, &segments, results);
|
||||
|
||||
EXPECT_EQ(segments.conversion_segments_size(), 1);
|
||||
ASSERT_EQ(kTestSize, segments.conversion_segment(0).candidates_size());
|
||||
@ -1029,8 +1026,7 @@ TEST_F(DictionaryPredictorTest, PredictNCandidates) {
|
||||
.max_dictionary_prediction_candidates_size = kLowCostCandidateSize + 1,
|
||||
});
|
||||
|
||||
predictor.AddPredictionToCandidates(convreq, &segments,
|
||||
absl::MakeSpan(results));
|
||||
predictor.AddPredictionToCandidates(convreq, &segments, results);
|
||||
|
||||
ASSERT_EQ(1, segments.conversion_segments_size());
|
||||
ASSERT_EQ(kLowCostCandidateSize,
|
||||
@ -1489,8 +1485,7 @@ TEST_F(DictionaryPredictorTest, Dedup) {
|
||||
InitSegmentsWithKey("test", &segments);
|
||||
const ConversionRequest convreq =
|
||||
CreateConversionRequest(ConversionRequest::PREDICTION);
|
||||
predictor.AddPredictionToCandidates(convreq, &segments,
|
||||
absl::MakeSpan(results));
|
||||
predictor.AddPredictionToCandidates(convreq, &segments, results);
|
||||
|
||||
ASSERT_EQ(segments.conversion_segments_size(), 1);
|
||||
EXPECT_EQ(segments.conversion_segment(0).candidates_size(), kSize);
|
||||
@ -1528,8 +1523,8 @@ TEST_F(DictionaryPredictorTest, TypingCorrectionResultsLimit) {
|
||||
InitSegmentsWithKey("original_key", &segments);
|
||||
const ConversionRequest convreq =
|
||||
CreateConversionRequest(ConversionRequest::PREDICTION);
|
||||
predictor.AddPredictionToCandidates(convreq, &segments,
|
||||
absl::MakeSpan(results));
|
||||
predictor.AddPredictionToCandidates(convreq, &segments, results);
|
||||
|
||||
ASSERT_EQ(segments.conversion_segments_size(), 1);
|
||||
const Segment segment = segments.conversion_segment(0);
|
||||
EXPECT_EQ(segment.candidates_size(), 3);
|
||||
@ -1563,8 +1558,7 @@ TEST_F(DictionaryPredictorTest, SortResult) {
|
||||
InitSegmentsWithKey("test", &segments);
|
||||
const ConversionRequest convreq =
|
||||
CreateConversionRequest(ConversionRequest::PREDICTION);
|
||||
predictor.AddPredictionToCandidates(convreq, &segments,
|
||||
absl::MakeSpan(results));
|
||||
predictor.AddPredictionToCandidates(convreq, &segments, results);
|
||||
|
||||
ASSERT_EQ(segments.conversion_segments_size(), 1);
|
||||
const Segment &segment = segments.conversion_segment(0);
|
||||
|
Reference in New Issue
Block a user