Skip to content

Commit

Permalink
[history-embeddings] Tests canceling queries and computing embeddings
Browse files Browse the repository at this point in the history
- Adds tests for canceling query tasks.
- Adds tests for using DB as cache and rebuilding DB.

SchedulingEmbedder does two things which are specific to the History
Embeddings feature:

#1 Canceling stale query passages.
#2 Removing non-ascii chars before sending the passages to the embedder.

These functionalities need to be factored out before we can use the
SchedulingEmbedder in the passage_embeddings:: namespace. This CL adds
some tests around those functionalities so future changes don't break
things.

Adding these tests helped identify three potential bugs. #1 the first
query in a series of queries always gets processed, which shouldn't.
#2 enabling the feature param for removing non-ascii characters causes
a crash when the passage is entirely made up of non-ascii characters.
#3 attempting to compute embeddings for an empty passage causes a crash.

I'll attempt to address these bugs in follow-up changes.

Bug: 390241271
Change-Id: I5002b1fbaf6ea5068da405219ba153cce8378faf
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6186446
Code-Coverage: [email protected] <[email protected]>
Commit-Queue: Moe Ahmadi <[email protected]>
Reviewed-by: Orin Jaworski <[email protected]>
Cr-Commit-Position: refs/heads/main@{#1409805}
  • Loading branch information
Moe Ahmadi authored and Chromium LUCI CQ committed Jan 22, 2025
1 parent 37cd84a commit b7ad8fc
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 11 deletions.
3 changes: 1 addition & 2 deletions components/history_embeddings/history_embeddings_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,7 @@ void HistoryEmbeddingsService::OnQueryEmbeddingComputed(
<< (query_passages.empty() ? "(NONE)" : query_passages[0]) << "'";

if (!succeeded) {
// Query embedding failed. Just return no search results.
std::move(callback).Run({});
std::move(callback).Run(std::move(result));
return;
}

Expand Down
199 changes: 196 additions & 3 deletions components/history_embeddings/history_embeddings_service_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "base/run_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/task/cancelable_task_tracker.h"
#include "base/task/sequenced_task_runner.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
Expand Down Expand Up @@ -53,6 +54,27 @@ base::FilePath GetTestFilePath(const std::string& file_name) {
.AppendASCII(file_name);
}

class MockEmbedderWithDelay : public MockEmbedder {
public:
static constexpr base::TimeDelta kTimeout = base::Seconds(1);

MockEmbedderWithDelay() = default;
~MockEmbedderWithDelay() override = default;

// Embedder:
void ComputePassagesEmbeddings(
passage_embeddings::PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override {
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(std::move(callback), std::move(passages),
ComputeEmbeddingsForPassages(passages),
passage_embeddings::ComputeEmbeddingsStatus::kSuccess),
kTimeout);
}
};

} // namespace

class HistoryEmbeddingsServicePublic : public HistoryEmbeddingsService {
Expand All @@ -79,6 +101,7 @@ class HistoryEmbeddingsServicePublic : public HistoryEmbeddingsService {
using HistoryEmbeddingsService::OnPassagesEmbeddingsComputed;
using HistoryEmbeddingsService::OnSearchCompleted;
using HistoryEmbeddingsService::QueryIsFiltered;
using HistoryEmbeddingsService::RebuildAbsentEmbeddings;

using HistoryEmbeddingsService::answerer_;
using HistoryEmbeddingsService::embedder_metadata_;
Expand Down Expand Up @@ -116,7 +139,8 @@ class HistoryEmbeddingsServiceTest : public testing::Test {
os_crypt_.get(), history_service_.get(),
page_content_annotations_service_.get(),
/*optimization_guide_decider=*/nullptr,
std::make_unique<MockEmbedder>(), std::make_unique<MockAnswerer>(),
std::make_unique<MockEmbedderWithDelay>(),
std::make_unique<MockAnswerer>(),
std::make_unique<MockIntentClassifier>());

ASSERT_TRUE(listener()->filter_words_hashes().empty());
Expand Down Expand Up @@ -204,7 +228,8 @@ class HistoryEmbeddingsServiceTest : public testing::Test {
false);
}

base::test::TaskEnvironment task_environment_;
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};

base::ScopedTempDir history_dir_;
std::unique_ptr<os_crypt_async::OSCryptAsync> os_crypt_;
Expand Down Expand Up @@ -1016,7 +1041,6 @@ class AddSyncedVisitTask : public history::HistoryDBTask {
history::VisitID visit_id = backend->AddSyncedVisit(
url_, u"Title", /*hidden=*/false, visit_, std::nullopt, std::nullopt);
EXPECT_NE(visit_id, history::kInvalidVisitID);
LOG(ERROR) << "Added visit!";
return true;
}

Expand Down Expand Up @@ -1080,4 +1104,173 @@ TEST_F(HistoryEmbeddingsServiceTest, SearchGetsIfUrlIsKnownToSync) {
EXPECT_EQ(result.scored_url_rows[1].is_url_known_to_sync, true);
}

TEST_F(HistoryEmbeddingsServiceTest, CancelPreviousSearches) {
base::Time now = base::Time::Now();
AddTestHistoryPage("http://test1.com");
OnPassagesEmbeddingsComputed(UrlData(1, 1, now),
{"test passage 1", "test passage 2"},
{Embedding(std::vector<float>(768, 1.0f)),
Embedding(std::vector<float>(768, 1.0f))},
ComputeEmbeddingsStatus::kSuccess);
OverrideVisibilityScoresForTesting({
{"test passage 1", 0.99},
{"test passage 2", 0.99},
});
// Service uses the default .9 score threshold when neither the feature param
// nor the metadata thresholds are set.
SetMetadataScoreThreshold(0.01);

base::test::TestFuture<SearchResult> future1;
service_->Search(nullptr, "passage", {}, 3, /*skip_answering=*/true,
future1.GetRepeatingCallback());

base::test::TestFuture<SearchResult> future2;
service_->Search(nullptr, "passage", {}, 3, /*skip_answering=*/true,
future2.GetRepeatingCallback());

base::test::TestFuture<SearchResult> future3;
service_->Search(nullptr, "passage", {}, 3, /*skip_answering=*/true,
future3.GetRepeatingCallback());

base::test::TestFuture<SearchResult> future4;
service_->Search(nullptr, "passage", {}, 3, /*skip_answering=*/true,
future4.GetRepeatingCallback());

// The first query is processed.
// TODO(crbug.com/390241271): The first query should NOT be processed.
SearchResult result1 = future1.Take();
EXPECT_FALSE(result1.session_id.empty());
EXPECT_EQ(result1.query, "passage");
ASSERT_EQ(result1.scored_url_rows.size(), 1u);
EXPECT_EQ(result1.scored_url_rows[0].scored_url.url_id, 1);
EXPECT_EQ(result1.scored_url_rows[0].scored_url.visit_id, 1);
EXPECT_EQ(result1.scored_url_rows[0].scored_url.visit_time, now);

// The second query is skipped.
SearchResult result2 = future2.Take();
EXPECT_FALSE(result2.session_id.empty());
EXPECT_EQ(result2.query, "passage");
ASSERT_EQ(result2.scored_url_rows.size(), 0u);

// The third query is skipped.
SearchResult result3 = future3.Take();
EXPECT_FALSE(result3.session_id.empty());
EXPECT_EQ(result3.query, "passage");
ASSERT_EQ(result3.scored_url_rows.size(), 0u);

// The last query is processed.
SearchResult result4 = future4.Take();
EXPECT_FALSE(result4.session_id.empty());
EXPECT_EQ(result4.query, "passage");
ASSERT_EQ(result4.scored_url_rows.size(), 1u);
EXPECT_EQ(result4.scored_url_rows[0].scored_url.url_id, 1);
EXPECT_EQ(result4.scored_url_rows[0].scored_url.visit_id, 1);
EXPECT_EQ(result4.scored_url_rows[0].scored_url.visit_time, now);
}

TEST_F(HistoryEmbeddingsServiceTest, UseDatabaseBeforeEmbedder) {
base::test::TestFuture<UrlData> store_future;
service_->SetPassagesStoredCallbackForTesting(
store_future.GetRepeatingCallback());

base::Time now = base::Time::Now();
AddTestHistoryPage("http://test1.com");

// TODO(crbug.com/390241271): Enable erasing non-ascii characters.

{
// TODO(crbug.com/390241271): Test with an empty passage.
base::HistogramTester histogram_tester;
service_->ComputeAndStorePassageEmbeddings(
/*url_id=*/1,
/*visit_id=*/1,
/*visit_time=*/now + base::Seconds(1),
{
"test passage 1",
"test passage ß",
"ßßß",
});

UrlData url_data = store_future.Take();
ASSERT_EQ(url_data.passages.passages_size(), 3);
ASSERT_EQ(url_data.embeddings.size(), 3u);
ASSERT_EQ(url_data.passages.passages(0), "test passage 1");
ASSERT_EQ(url_data.embeddings[0].Dimensions(), 768u);
ASSERT_EQ(url_data.passages.passages(1), "test passage ß");
ASSERT_EQ(url_data.embeddings[1].Dimensions(), 768u);
ASSERT_EQ(url_data.passages.passages(2), "ßßß");
ASSERT_EQ(url_data.embeddings[2].Dimensions(), 768u);

// The cache wasn't used because there was no existing data.
histogram_tester.ExpectTotalCount(
"History.Embeddings.DatabaseCachedPassageTryCount", 1);
histogram_tester.ExpectBucketCount(
"History.Embeddings.DatabaseCachedPassageTryCount", 3, 1);
histogram_tester.ExpectTotalCount(
"History.Embeddings.DatabaseCachedPassageHitCount", 1);
histogram_tester.ExpectBucketCount(
"History.Embeddings.DatabaseCachedPassageHitCount", 0, 1);
}
{
// TODO(crbug.com/390241271): Test with an empty passage.
base::HistogramTester histogram_tester;
service_->ComputeAndStorePassageEmbeddings(
/*url_id=*/1,
/*visit_id=*/2,
/*visit_time=*/now + base::Minutes(1),
{
"test passage 1",
"test passage ßßß",
"ßßß",
});

UrlData url_data = store_future.Take();
ASSERT_EQ(url_data.passages.passages_size(), 3);
ASSERT_EQ(url_data.embeddings.size(), 3u);
ASSERT_EQ(url_data.passages.passages(0), "test passage 1");
ASSERT_EQ(url_data.embeddings[0].Dimensions(), 768u);
ASSERT_EQ(url_data.passages.passages(1), "test passage ßßß");
ASSERT_EQ(url_data.embeddings[1].Dimensions(), 768u);
ASSERT_EQ(url_data.passages.passages(2), "ßßß");
ASSERT_EQ(url_data.embeddings[2].Dimensions(), 768u);

// The cache was used because there was existing data.
histogram_tester.ExpectTotalCount(
"History.Embeddings.DatabaseCachedPassageTryCount", 1);
histogram_tester.ExpectBucketCount(
"History.Embeddings.DatabaseCachedPassageTryCount", 3, 1);
histogram_tester.ExpectTotalCount(
"History.Embeddings.DatabaseCachedPassageHitCount", 1);
histogram_tester.ExpectBucketCount(
"History.Embeddings.DatabaseCachedPassageHitCount", 2, 1);
}
}

TEST_F(HistoryEmbeddingsServiceTest, RebuildAbsentEmbeddings) {
base::HistogramTester histogram_tester;

base::test::TestFuture<UrlData> store_future;
service_->SetPassagesStoredCallbackForTesting(
store_future.GetRepeatingCallback());

// TODO(crbug.com/390241271): Enable erasing non-ascii characters.
// TODO(crbug.com/390241271): Test with an empty passage.

UrlData existing_url_data_1(1, 1, base::Time::Now());
existing_url_data_1.passages.add_passages("test passage 1");
existing_url_data_1.passages.add_passages("test passage ßßß");
existing_url_data_1.passages.add_passages("ßßß");
service_->RebuildAbsentEmbeddings({existing_url_data_1});

UrlData url_data = store_future.Take();
ASSERT_EQ(url_data.passages.passages_size(), 3);
ASSERT_EQ(url_data.embeddings.size(), 3u);
ASSERT_EQ(url_data.passages.passages(0), "test passage 1");
ASSERT_EQ(url_data.embeddings[0].Dimensions(), 768u);
ASSERT_EQ(url_data.passages.passages(1), "test passage ßßß");
ASSERT_EQ(url_data.embeddings[1].Dimensions(), 768u);
ASSERT_EQ(url_data.passages.passages(2), "ßßß");
ASSERT_EQ(url_data.embeddings[2].Dimensions(), 768u);
}

} // namespace history_embeddings
12 changes: 6 additions & 6 deletions components/history_embeddings/mock_embedder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ Embedding ComputeEmbeddingForPassage(const std::string& passage) {
return embedding;
}

std::vector<Embedding> ComputeEmbeddingsForPassages(
const std::vector<std::string>& passages) {
return std::vector<Embedding>(passages.size(),
ComputeEmbeddingForPassage(""));
}

} // namespace

MockEmbedder::MockEmbedder() = default;
Expand All @@ -49,4 +43,10 @@ void MockEmbedder::SetOnEmbedderReady(OnEmbedderReadyCallback callback) {
std::move(callback).Run({kModelVersion, kOutputSize});
}

std::vector<Embedding> MockEmbedder::ComputeEmbeddingsForPassages(
const std::vector<std::string>& passages) {
return std::vector<Embedding>(passages.size(),
ComputeEmbeddingForPassage(""));
}

} // namespace history_embeddings
7 changes: 7 additions & 0 deletions components/history_embeddings/mock_embedder.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#ifndef COMPONENTS_HISTORY_EMBEDDINGS_MOCK_EMBEDDER_H_
#define COMPONENTS_HISTORY_EMBEDDINGS_MOCK_EMBEDDER_H_

#include <string>
#include <vector>

#include "components/history_embeddings/embedder.h"

namespace history_embeddings {
Expand All @@ -21,6 +24,10 @@ class MockEmbedder : public Embedder {
ComputePassagesEmbeddingsCallback callback) override;

void SetOnEmbedderReady(OnEmbedderReadyCallback callback) override;

protected:
std::vector<Embedding> ComputeEmbeddingsForPassages(
const std::vector<std::string>& passages);
};

} // namespace history_embeddings
Expand Down

0 comments on commit b7ad8fc

Please sign in to comment.