diff --git a/components/history_embeddings/history_embeddings_service.cc b/components/history_embeddings/history_embeddings_service.cc index e43a8d3d2526d2..5c44045300caee 100644 --- a/components/history_embeddings/history_embeddings_service.cc +++ b/components/history_embeddings/history_embeddings_service.cc @@ -423,8 +423,7 @@ void HistoryEmbeddingsService::OnQueryEmbeddingComputed( << (query_passages.empty() ? "(NONE)" : query_passages[0]) << "'"; if (!succeeded) { - // Query embedding failed. Just return no search results. - std::move(callback).Run({}); + std::move(callback).Run(std::move(result)); return; } diff --git a/components/history_embeddings/history_embeddings_service_unittest.cc b/components/history_embeddings/history_embeddings_service_unittest.cc index 75a87af54c2537..4fa0e33f47860b 100644 --- a/components/history_embeddings/history_embeddings_service_unittest.cc +++ b/components/history_embeddings/history_embeddings_service_unittest.cc @@ -13,6 +13,7 @@ #include "base/run_loop.h" #include "base/strings/string_number_conversions.h" #include "base/task/cancelable_task_tracker.h" +#include "base/task/sequenced_task_runner.h" #include "base/test/bind.h" #include "base/test/metrics/histogram_tester.h" #include "base/test/scoped_feature_list.h" @@ -53,6 +54,27 @@ base::FilePath GetTestFilePath(const std::string& file_name) { .AppendASCII(file_name); } +class MockEmbedderWithDelay : public MockEmbedder { + public: + static constexpr base::TimeDelta kTimeout = base::Seconds(1); + + MockEmbedderWithDelay() = default; + ~MockEmbedderWithDelay() override = default; + + // Embedder: + void ComputePassagesEmbeddings( + passage_embeddings::PassagePriority priority, + std::vector passages, + ComputePassagesEmbeddingsCallback callback) override { + base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask( + FROM_HERE, + base::BindOnce(std::move(callback), std::move(passages), + ComputeEmbeddingsForPassages(passages), + passage_embeddings::ComputeEmbeddingsStatus::kSuccess), + kTimeout); + } +}; + } // namespace class HistoryEmbeddingsServicePublic : public HistoryEmbeddingsService { @@ -79,6 +101,7 @@ class HistoryEmbeddingsServicePublic : public HistoryEmbeddingsService { using HistoryEmbeddingsService::OnPassagesEmbeddingsComputed; using HistoryEmbeddingsService::OnSearchCompleted; using HistoryEmbeddingsService::QueryIsFiltered; + using HistoryEmbeddingsService::RebuildAbsentEmbeddings; using HistoryEmbeddingsService::answerer_; using HistoryEmbeddingsService::embedder_metadata_; @@ -116,7 +139,8 @@ class HistoryEmbeddingsServiceTest : public testing::Test { os_crypt_.get(), history_service_.get(), page_content_annotations_service_.get(), /*optimization_guide_decider=*/nullptr, - std::make_unique(), std::make_unique(), + std::make_unique(), + std::make_unique(), std::make_unique()); ASSERT_TRUE(listener()->filter_words_hashes().empty()); @@ -204,7 +228,8 @@ class HistoryEmbeddingsServiceTest : public testing::Test { false); } - base::test::TaskEnvironment task_environment_; + base::test::TaskEnvironment task_environment_{ + base::test::TaskEnvironment::TimeSource::MOCK_TIME}; base::ScopedTempDir history_dir_; std::unique_ptr os_crypt_; @@ -1016,7 +1041,6 @@ class AddSyncedVisitTask : public history::HistoryDBTask { history::VisitID visit_id = backend->AddSyncedVisit( url_, u"Title", /*hidden=*/false, visit_, std::nullopt, std::nullopt); EXPECT_NE(visit_id, history::kInvalidVisitID); - LOG(ERROR) << "Added visit!"; return true; } @@ -1080,4 +1104,173 @@ TEST_F(HistoryEmbeddingsServiceTest, SearchGetsIfUrlIsKnownToSync) { EXPECT_EQ(result.scored_url_rows[1].is_url_known_to_sync, true); } +TEST_F(HistoryEmbeddingsServiceTest, CancelPreviousSearches) { + base::Time now = base::Time::Now(); + AddTestHistoryPage("http://test1.com"); + OnPassagesEmbeddingsComputed(UrlData(1, 1, now), + {"test passage 1", "test passage 2"}, + {Embedding(std::vector(768, 1.0f)), + Embedding(std::vector(768, 1.0f))}, + ComputeEmbeddingsStatus::kSuccess); + OverrideVisibilityScoresForTesting({ + {"test passage 1", 0.99}, + {"test passage 2", 0.99}, + }); + // Service uses the default .9 score threshold when neither the feature param + // nor the metadata thresholds are set. + SetMetadataScoreThreshold(0.01); + + base::test::TestFuture future1; + service_->Search(nullptr, "passage", {}, 3, /*skip_answering=*/true, + future1.GetRepeatingCallback()); + + base::test::TestFuture future2; + service_->Search(nullptr, "passage", {}, 3, /*skip_answering=*/true, + future2.GetRepeatingCallback()); + + base::test::TestFuture future3; + service_->Search(nullptr, "passage", {}, 3, /*skip_answering=*/true, + future3.GetRepeatingCallback()); + + base::test::TestFuture future4; + service_->Search(nullptr, "passage", {}, 3, /*skip_answering=*/true, + future4.GetRepeatingCallback()); + + // The first query is processed. + // TODO(crbug.com/390241271): The first query should NOT be processed. + SearchResult result1 = future1.Take(); + EXPECT_FALSE(result1.session_id.empty()); + EXPECT_EQ(result1.query, "passage"); + ASSERT_EQ(result1.scored_url_rows.size(), 1u); + EXPECT_EQ(result1.scored_url_rows[0].scored_url.url_id, 1); + EXPECT_EQ(result1.scored_url_rows[0].scored_url.visit_id, 1); + EXPECT_EQ(result1.scored_url_rows[0].scored_url.visit_time, now); + + // The second query is skipped. + SearchResult result2 = future2.Take(); + EXPECT_FALSE(result2.session_id.empty()); + EXPECT_EQ(result2.query, "passage"); + ASSERT_EQ(result2.scored_url_rows.size(), 0u); + + // The third query is skipped. + SearchResult result3 = future3.Take(); + EXPECT_FALSE(result3.session_id.empty()); + EXPECT_EQ(result3.query, "passage"); + ASSERT_EQ(result3.scored_url_rows.size(), 0u); + + // The last query is processed. + SearchResult result4 = future4.Take(); + EXPECT_FALSE(result4.session_id.empty()); + EXPECT_EQ(result4.query, "passage"); + ASSERT_EQ(result4.scored_url_rows.size(), 1u); + EXPECT_EQ(result4.scored_url_rows[0].scored_url.url_id, 1); + EXPECT_EQ(result4.scored_url_rows[0].scored_url.visit_id, 1); + EXPECT_EQ(result4.scored_url_rows[0].scored_url.visit_time, now); +} + +TEST_F(HistoryEmbeddingsServiceTest, UseDatabaseBeforeEmbedder) { + base::test::TestFuture store_future; + service_->SetPassagesStoredCallbackForTesting( + store_future.GetRepeatingCallback()); + + base::Time now = base::Time::Now(); + AddTestHistoryPage("http://test1.com"); + + // TODO(crbug.com/390241271): Enable erasing non-ascii characters. + + { + // TODO(crbug.com/390241271): Test with an empty passage. + base::HistogramTester histogram_tester; + service_->ComputeAndStorePassageEmbeddings( + /*url_id=*/1, + /*visit_id=*/1, + /*visit_time=*/now + base::Seconds(1), + { + "test passage 1", + "test passage ß", + "ßßß", + }); + + UrlData url_data = store_future.Take(); + ASSERT_EQ(url_data.passages.passages_size(), 3); + ASSERT_EQ(url_data.embeddings.size(), 3u); + ASSERT_EQ(url_data.passages.passages(0), "test passage 1"); + ASSERT_EQ(url_data.embeddings[0].Dimensions(), 768u); + ASSERT_EQ(url_data.passages.passages(1), "test passage ß"); + ASSERT_EQ(url_data.embeddings[1].Dimensions(), 768u); + ASSERT_EQ(url_data.passages.passages(2), "ßßß"); + ASSERT_EQ(url_data.embeddings[2].Dimensions(), 768u); + + // The cache wasn't used because there was no existing data. + histogram_tester.ExpectTotalCount( + "History.Embeddings.DatabaseCachedPassageTryCount", 1); + histogram_tester.ExpectBucketCount( + "History.Embeddings.DatabaseCachedPassageTryCount", 3, 1); + histogram_tester.ExpectTotalCount( + "History.Embeddings.DatabaseCachedPassageHitCount", 1); + histogram_tester.ExpectBucketCount( + "History.Embeddings.DatabaseCachedPassageHitCount", 0, 1); + } + { + // TODO(crbug.com/390241271): Test with an empty passage. + base::HistogramTester histogram_tester; + service_->ComputeAndStorePassageEmbeddings( + /*url_id=*/1, + /*visit_id=*/2, + /*visit_time=*/now + base::Minutes(1), + { + "test passage 1", + "test passage ßßß", + "ßßß", + }); + + UrlData url_data = store_future.Take(); + ASSERT_EQ(url_data.passages.passages_size(), 3); + ASSERT_EQ(url_data.embeddings.size(), 3u); + ASSERT_EQ(url_data.passages.passages(0), "test passage 1"); + ASSERT_EQ(url_data.embeddings[0].Dimensions(), 768u); + ASSERT_EQ(url_data.passages.passages(1), "test passage ßßß"); + ASSERT_EQ(url_data.embeddings[1].Dimensions(), 768u); + ASSERT_EQ(url_data.passages.passages(2), "ßßß"); + ASSERT_EQ(url_data.embeddings[2].Dimensions(), 768u); + + // The cache was used because there was existing data. + histogram_tester.ExpectTotalCount( + "History.Embeddings.DatabaseCachedPassageTryCount", 1); + histogram_tester.ExpectBucketCount( + "History.Embeddings.DatabaseCachedPassageTryCount", 3, 1); + histogram_tester.ExpectTotalCount( + "History.Embeddings.DatabaseCachedPassageHitCount", 1); + histogram_tester.ExpectBucketCount( + "History.Embeddings.DatabaseCachedPassageHitCount", 2, 1); + } +} + +TEST_F(HistoryEmbeddingsServiceTest, RebuildAbsentEmbeddings) { + base::HistogramTester histogram_tester; + + base::test::TestFuture store_future; + service_->SetPassagesStoredCallbackForTesting( + store_future.GetRepeatingCallback()); + + // TODO(crbug.com/390241271): Enable erasing non-ascii characters. + // TODO(crbug.com/390241271): Test with an empty passage. + + UrlData existing_url_data_1(1, 1, base::Time::Now()); + existing_url_data_1.passages.add_passages("test passage 1"); + existing_url_data_1.passages.add_passages("test passage ßßß"); + existing_url_data_1.passages.add_passages("ßßß"); + service_->RebuildAbsentEmbeddings({existing_url_data_1}); + + UrlData url_data = store_future.Take(); + ASSERT_EQ(url_data.passages.passages_size(), 3); + ASSERT_EQ(url_data.embeddings.size(), 3u); + ASSERT_EQ(url_data.passages.passages(0), "test passage 1"); + ASSERT_EQ(url_data.embeddings[0].Dimensions(), 768u); + ASSERT_EQ(url_data.passages.passages(1), "test passage ßßß"); + ASSERT_EQ(url_data.embeddings[1].Dimensions(), 768u); + ASSERT_EQ(url_data.passages.passages(2), "ßßß"); + ASSERT_EQ(url_data.embeddings[2].Dimensions(), 768u); +} + } // namespace history_embeddings diff --git a/components/history_embeddings/mock_embedder.cc b/components/history_embeddings/mock_embedder.cc index a024aed4ea623b..73b07f408692a1 100644 --- a/components/history_embeddings/mock_embedder.cc +++ b/components/history_embeddings/mock_embedder.cc @@ -22,12 +22,6 @@ Embedding ComputeEmbeddingForPassage(const std::string& passage) { return embedding; } -std::vector ComputeEmbeddingsForPassages( - const std::vector& passages) { - return std::vector(passages.size(), - ComputeEmbeddingForPassage("")); -} - } // namespace MockEmbedder::MockEmbedder() = default; @@ -49,4 +43,10 @@ void MockEmbedder::SetOnEmbedderReady(OnEmbedderReadyCallback callback) { std::move(callback).Run({kModelVersion, kOutputSize}); } +std::vector MockEmbedder::ComputeEmbeddingsForPassages( + const std::vector& passages) { + return std::vector(passages.size(), + ComputeEmbeddingForPassage("")); +} + } // namespace history_embeddings diff --git a/components/history_embeddings/mock_embedder.h b/components/history_embeddings/mock_embedder.h index 4f5b6114e49ff6..52eb2ad2073c20 100644 --- a/components/history_embeddings/mock_embedder.h +++ b/components/history_embeddings/mock_embedder.h @@ -5,6 +5,9 @@ #ifndef COMPONENTS_HISTORY_EMBEDDINGS_MOCK_EMBEDDER_H_ #define COMPONENTS_HISTORY_EMBEDDINGS_MOCK_EMBEDDER_H_ +#include +#include + #include "components/history_embeddings/embedder.h" namespace history_embeddings { @@ -21,6 +24,10 @@ class MockEmbedder : public Embedder { ComputePassagesEmbeddingsCallback callback) override; void SetOnEmbedderReady(OnEmbedderReadyCallback callback) override; + + protected: + std::vector ComputeEmbeddingsForPassages( + const std::vector& passages); }; } // namespace history_embeddings