Commit 3835d088 authored by Thanh Nguyen's avatar Thanh Nguyen Committed by Commit Bot

[local-search-service] Add a function to support update tf-idf cache

Bug: 1080427
Change-Id: Id1e6470e1823016ac8a4bce97c0be8bc4ff68027
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2208311
Commit-Queue: Thanh Nguyen <thanhdng@chromium.org>
Reviewed-by: default avatarJia Meng <jiameng@chromium.org>
Cr-Commit-Position: refs/heads/master@{#772993}
parent 70ef2d45
...@@ -42,6 +42,7 @@ void InvertedIndex::AddDocument(const std::string& document_id, ...@@ -42,6 +42,7 @@ void InvertedIndex::AddDocument(const std::string& document_id,
for (const auto& token : tokens) { for (const auto& token : tokens) {
dictionary_[token.content][document_id] = token.positions; dictionary_[token.content][document_id] = token.positions;
doc_length_[document_id] += token.positions.size(); doc_length_[document_id] += token.positions.size();
terms_to_be_updated_.insert(token.content);
} }
} }
...@@ -49,7 +50,10 @@ void InvertedIndex::RemoveDocument(const std::string& document_id) { ...@@ -49,7 +50,10 @@ void InvertedIndex::RemoveDocument(const std::string& document_id) {
doc_length_.erase(document_id); doc_length_.erase(document_id);
for (auto it = dictionary_.begin(); it != dictionary_.end();) { for (auto it = dictionary_.begin(); it != dictionary_.end();) {
it->second.erase(document_id); if (it->second.find(document_id) != it->second.end()) {
terms_to_be_updated_.insert(it->first);
it->second.erase(document_id);
}
// Removes term from the dictionary if its posting list is empty. // Removes term from the dictionary if its posting list is empty.
if (it->second.empty()) { if (it->second.empty()) {
...@@ -60,18 +64,36 @@ void InvertedIndex::RemoveDocument(const std::string& document_id) { ...@@ -60,18 +64,36 @@ void InvertedIndex::RemoveDocument(const std::string& document_id) {
} }
} }
std::vector<TfidfResult> InvertedIndex::GetTfidf(const base::string16& term) { std::vector<TfidfResult> InvertedIndex::GetTfidf(
if (tfidf_cache_.find(term) != tfidf_cache_.end()) const base::string16& term) const {
if (tfidf_cache_.find(term) != tfidf_cache_.end()) {
return tfidf_cache_.at(term); return tfidf_cache_.at(term);
}
return {}; return {};
} }
void InvertedIndex::PopulateTfidfCache() { void InvertedIndex::BuildInvertedIndex() {
tfidf_cache_.clear(); // If number of documents doesn't change from the last time index was built,
for (const auto& item : dictionary_) { // we only need to update terms in |terms_to_be_updated_|. Otherwise we need
tfidf_cache_[item.first] = CalculateTfidf(item.first); // to rebuild the index.
if (num_docs_from_last_update_ == doc_length_.size()) {
for (const auto& term : terms_to_be_updated_) {
if (dictionary_.find(term) != dictionary_.end()) {
tfidf_cache_[term] = CalculateTfidf(term);
} else {
tfidf_cache_.erase(term);
}
}
} else {
tfidf_cache_.clear();
for (const auto& item : dictionary_) {
tfidf_cache_[item.first] = CalculateTfidf(item.first);
}
} }
terms_to_be_updated_.clear();
num_docs_from_last_update_ = doc_length_.size();
} }
std::vector<TfidfResult> InvertedIndex::CalculateTfidf( std::vector<TfidfResult> InvertedIndex::CalculateTfidf(
...@@ -79,6 +101,7 @@ std::vector<TfidfResult> InvertedIndex::CalculateTfidf( ...@@ -79,6 +101,7 @@ std::vector<TfidfResult> InvertedIndex::CalculateTfidf(
std::vector<TfidfResult> results; std::vector<TfidfResult> results;
const float idf = const float idf =
1.0 + log((1.0 + doc_length_.size()) / (1.0 + dictionary_[term].size())); 1.0 + log((1.0 + doc_length_.size()) / (1.0 + dictionary_[term].size()));
for (const auto& item : dictionary_[term]) { for (const auto& item : dictionary_[term]) {
const float tf = const float tf =
static_cast<float>(item.second.size()) / doc_length_[item.first]; static_cast<float>(item.second.size()) / doc_length_[item.first];
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "base/gtest_prod_util.h" #include "base/gtest_prod_util.h"
...@@ -48,8 +49,10 @@ using PostingList = std::unordered_map<std::string, Posting>; ...@@ -48,8 +49,10 @@ using PostingList = std::unordered_map<std::string, Posting>;
// score. // score.
using TfidfResult = std::tuple<std::string, Posting, float>; using TfidfResult = std::tuple<std::string, Posting, float>;
// InvertedIndex stores the inverted index for local search and provides the // InvertedIndex stores the inverted index for local search. It provides the
// abilities to add/remove documents, find term, etc. // abilities to add/remove documents, find term, etc. Before this class can be
// used to return tf-idf scores of a term, the client should build the index
// first (using BuildInvertedIndex).
class InvertedIndex { class InvertedIndex {
public: public:
InvertedIndex(); InvertedIndex();
...@@ -62,22 +65,29 @@ class InvertedIndex { ...@@ -62,22 +65,29 @@ class InvertedIndex {
// Adds a new document to the inverted index. If the document ID is already in // Adds a new document to the inverted index. If the document ID is already in
// the index, remove the existing and add the new one. All tokens must be // the index, remove the existing and add the new one. All tokens must be
// unique (have unique content). // unique (have unique content). This function doesn't modify any cache. It
// only adds documents and tokens to the index.
void AddDocument(const std::string& document_id, void AddDocument(const std::string& document_id,
const std::vector<Token>& tokens); const std::vector<Token>& tokens);
// Removes a document from the inverted index. Do nothing if document_id is // Removes a document from the inverted index. Do nothing if document_id is
// not in the index. // not in the index. This function doesn't modify any cache. It only removes
// documents and tokens from the index.
void RemoveDocument(const std::string& document_id); void RemoveDocument(const std::string& document_id);
// Gets TF-IDF scores for a term. The result is pre-computed from // Gets TF-IDF scores for a term. This function returns the TF-IDF score from
// |tfidf_cache_|. // the cache.
std::vector<TfidfResult> GetTfidf(const base::string16& term); // Note: client of this function should call BuildInvertedIndex before using
// this function to have up-to-date score.
std::vector<TfidfResult> GetTfidf(const base::string16& term) const;
// Populates the TF-IDF score cache so that TF-IDF scores can be obtained // Builds the inverted index.
// faster. This function should be called after the inverted index is updated void BuildInvertedIndex();
// (after adding/removing documents).
void PopulateTfidfCache(); // Checks if the inverted index has been built: returns |true| if the inverted
// index is up to date, returns |false| if there are some modified document
// since the last time the index has been built.
bool IsInvertedIndexBuilt() const { return terms_to_be_updated_.empty(); }
private: private:
friend class InvertedIndexTest; friend class InvertedIndexTest;
...@@ -85,6 +95,8 @@ class InvertedIndex { ...@@ -85,6 +95,8 @@ class InvertedIndex {
// Calculates TF-IDF scores for a term. // Calculates TF-IDF scores for a term.
std::vector<TfidfResult> CalculateTfidf(const base::string16& term); std::vector<TfidfResult> CalculateTfidf(const base::string16& term);
// Set of the terms that are needed to be update in |tfidf_cache_|.
std::unordered_set<base::string16> terms_to_be_updated_;
// Contains the length of the document (the number of terms in the document). // Contains the length of the document (the number of terms in the document).
// The size of this map will always equal to the number of documents in the // The size of this map will always equal to the number of documents in the
// index. // index.
...@@ -93,6 +105,8 @@ class InvertedIndex { ...@@ -93,6 +105,8 @@ class InvertedIndex {
std::unordered_map<base::string16, PostingList> dictionary_; std::unordered_map<base::string16, PostingList> dictionary_;
// Contains the TF-IDF scores for all the term in the index. // Contains the TF-IDF scores for all the term in the index.
std::unordered_map<base::string16, std::vector<TfidfResult>> tfidf_cache_; std::unordered_map<base::string16, std::vector<TfidfResult>> tfidf_cache_;
// Number of documents when the index was built.
uint32_t num_docs_from_last_update_ = 0;
}; };
} // namespace local_search_service } // namespace local_search_service
......
...@@ -16,6 +16,15 @@ ...@@ -16,6 +16,15 @@
namespace local_search_service { namespace local_search_service {
std::vector<float> GetScoresFromTfidfResult(
const std::vector<TfidfResult>& results) {
std::vector<float> scores;
for (const auto& item : results) {
scores.push_back(std::roundf(std::get<2>(item) * 100) / 100.0);
}
return scores;
}
class InvertedIndexTest : public ::testing::Test { class InvertedIndexTest : public ::testing::Test {
public: public:
InvertedIndexTest() = default; InvertedIndexTest() = default;
...@@ -41,7 +50,9 @@ class InvertedIndexTest : public ::testing::Test { ...@@ -41,7 +50,9 @@ class InvertedIndexTest : public ::testing::Test {
TokenPosition("body", 3, 1), TokenPosition("body", 3, 1),
TokenPosition("header", 5, 1), TokenPosition("header", 5, 1),
TokenPosition("body", 7, 1)})}}); TokenPosition("body", 7, 1)})}});
PopulateTfidfCache(); index_.terms_to_be_updated_.insert(base::UTF8ToUTF16("A"));
index_.terms_to_be_updated_.insert(base::UTF8ToUTF16("B"));
index_.terms_to_be_updated_.insert(base::UTF8ToUTF16("C"));
} }
PostingList FindTerm(const base::string16& term) { PostingList FindTerm(const base::string16& term) {
...@@ -61,7 +72,9 @@ class InvertedIndexTest : public ::testing::Test { ...@@ -61,7 +72,9 @@ class InvertedIndexTest : public ::testing::Test {
return index_.GetTfidf(term); return index_.GetTfidf(term);
} }
void PopulateTfidfCache() { index_.PopulateTfidfCache(); } void BuildInvertedIndex() { index_.BuildInvertedIndex(); }
bool IsInvertedIndexBuilt() { return index_.IsInvertedIndexBuilt(); }
std::unordered_map<base::string16, PostingList> GetDictionary() { std::unordered_map<base::string16, PostingList> GetDictionary() {
return index_.dictionary_; return index_.dictionary_;
...@@ -81,14 +94,14 @@ class InvertedIndexTest : public ::testing::Test { ...@@ -81,14 +94,14 @@ class InvertedIndexTest : public ::testing::Test {
TEST_F(InvertedIndexTest, FindTermTest) { TEST_F(InvertedIndexTest, FindTermTest) {
PostingList result = FindTerm(base::UTF8ToUTF16("A")); PostingList result = FindTerm(base::UTF8ToUTF16("A"));
ASSERT_EQ(result.size(), static_cast<unsigned long>(2)); ASSERT_EQ(result.size(), 2u);
EXPECT_EQ(result["doc1"][0].start, static_cast<uint32_t>(1)); EXPECT_EQ(result["doc1"][0].start, 1u);
EXPECT_EQ(result["doc1"][1].start, static_cast<uint32_t>(3)); EXPECT_EQ(result["doc1"][1].start, 3u);
EXPECT_EQ(result["doc1"][2].start, static_cast<uint32_t>(5)); EXPECT_EQ(result["doc1"][2].start, 5u);
EXPECT_EQ(result["doc1"][3].start, static_cast<uint32_t>(7)); EXPECT_EQ(result["doc1"][3].start, 7u);
EXPECT_EQ(result["doc2"][0].start, static_cast<uint32_t>(2)); EXPECT_EQ(result["doc2"][0].start, 2u);
EXPECT_EQ(result["doc2"][1].start, static_cast<uint32_t>(4)); EXPECT_EQ(result["doc2"][1].start, 4u);
} }
TEST_F(InvertedIndexTest, AddNewDocumentTest) { TEST_F(InvertedIndexTest, AddNewDocumentTest) {
...@@ -103,16 +116,16 @@ TEST_F(InvertedIndexTest, AddNewDocumentTest) { ...@@ -103,16 +116,16 @@ TEST_F(InvertedIndexTest, AddNewDocumentTest) {
// Find "A" // Find "A"
PostingList result = FindTerm(a_utf16); PostingList result = FindTerm(a_utf16);
ASSERT_EQ(result.size(), static_cast<unsigned long>(3)); ASSERT_EQ(result.size(), 3u);
EXPECT_EQ(result["doc3"][0].start, static_cast<uint32_t>(1)); EXPECT_EQ(result["doc3"][0].start, 1u);
EXPECT_EQ(result["doc3"][1].start, static_cast<uint32_t>(2)); EXPECT_EQ(result["doc3"][1].start, 2u);
EXPECT_EQ(result["doc3"][2].start, static_cast<uint32_t>(4)); EXPECT_EQ(result["doc3"][2].start, 4u);
// Find "D" // Find "D"
result = FindTerm(d_utf16); result = FindTerm(d_utf16);
ASSERT_EQ(result.size(), static_cast<unsigned long>(1)); ASSERT_EQ(result.size(), 1u);
EXPECT_EQ(result["doc3"][0].start, static_cast<uint32_t>(3)); EXPECT_EQ(result["doc3"][0].start, 3u);
EXPECT_EQ(result["doc3"][1].start, static_cast<uint32_t>(5)); EXPECT_EQ(result["doc3"][1].start, 5u);
} }
TEST_F(InvertedIndexTest, ReplaceDocumentTest) { TEST_F(InvertedIndexTest, ReplaceDocumentTest) {
...@@ -128,95 +141,104 @@ TEST_F(InvertedIndexTest, ReplaceDocumentTest) { ...@@ -128,95 +141,104 @@ TEST_F(InvertedIndexTest, ReplaceDocumentTest) {
// Find "A" // Find "A"
PostingList result = FindTerm(a_utf16); PostingList result = FindTerm(a_utf16);
ASSERT_EQ(result.size(), static_cast<unsigned long>(2)); ASSERT_EQ(result.size(), 2u);
EXPECT_EQ(result["doc1"][0].start, static_cast<uint32_t>(1)); EXPECT_EQ(result["doc1"][0].start, 1u);
EXPECT_EQ(result["doc1"][1].start, static_cast<uint32_t>(2)); EXPECT_EQ(result["doc1"][1].start, 2u);
EXPECT_EQ(result["doc1"][2].start, static_cast<uint32_t>(4)); EXPECT_EQ(result["doc1"][2].start, 4u);
// Find "B" // Find "B"
result = FindTerm(base::UTF8ToUTF16("B")); result = FindTerm(base::UTF8ToUTF16("B"));
ASSERT_EQ(result.size(), static_cast<unsigned long>(0)); ASSERT_EQ(result.size(), 0u);
// Find "D" // Find "D"
result = FindTerm(d_utf16); result = FindTerm(d_utf16);
ASSERT_EQ(result.size(), static_cast<unsigned long>(1)); ASSERT_EQ(result.size(), 1u);
EXPECT_EQ(result["doc1"][0].start, static_cast<uint32_t>(3)); EXPECT_EQ(result["doc1"][0].start, 3u);
EXPECT_EQ(result["doc1"][1].start, static_cast<uint32_t>(5)); EXPECT_EQ(result["doc1"][1].start, 5u);
} }
TEST_F(InvertedIndexTest, RemoveDocumentTest) { TEST_F(InvertedIndexTest, RemoveDocumentTest) {
EXPECT_EQ(GetDictionary().size(), static_cast<unsigned long>(3)); EXPECT_EQ(GetDictionary().size(), 3u);
EXPECT_EQ(GetDocLength().size(), static_cast<unsigned long>(2)); EXPECT_EQ(GetDocLength().size(), 2u);
RemoveDocument("doc1"); RemoveDocument("doc1");
EXPECT_EQ(GetDictionary().size(), static_cast<unsigned long>(2)); EXPECT_EQ(GetDictionary().size(), 2u);
EXPECT_EQ(GetDocLength().size(), static_cast<unsigned long>(1)); EXPECT_EQ(GetDocLength().size(), 1u);
EXPECT_EQ(GetDocLength()["doc2"], 6); EXPECT_EQ(GetDocLength()["doc2"], 6);
// Find "A" // Find "A"
PostingList result = FindTerm(base::UTF8ToUTF16("A")); PostingList result = FindTerm(base::UTF8ToUTF16("A"));
ASSERT_EQ(result.size(), static_cast<unsigned long>(1)); ASSERT_EQ(result.size(), 1u);
EXPECT_EQ(result["doc2"][0].start, static_cast<uint32_t>(2)); EXPECT_EQ(result["doc2"][0].start, 2u);
EXPECT_EQ(result["doc2"][1].start, static_cast<uint32_t>(4)); EXPECT_EQ(result["doc2"][1].start, 4u);
// Find "B" // Find "B"
result = FindTerm(base::UTF8ToUTF16("B")); result = FindTerm(base::UTF8ToUTF16("B"));
ASSERT_EQ(result.size(), static_cast<unsigned long>(0)); ASSERT_EQ(result.size(), 0u);
// Find "C" // Find "C"
result = FindTerm(base::UTF8ToUTF16("C")); result = FindTerm(base::UTF8ToUTF16("C"));
ASSERT_EQ(result.size(), static_cast<unsigned long>(1)); ASSERT_EQ(result.size(), 1u);
EXPECT_EQ(result["doc2"][0].start, static_cast<uint32_t>(1)); EXPECT_EQ(result["doc2"][0].start, 1u);
EXPECT_EQ(result["doc2"][1].start, static_cast<uint32_t>(3)); EXPECT_EQ(result["doc2"][1].start, 3u);
EXPECT_EQ(result["doc2"][2].start, static_cast<uint32_t>(5)); EXPECT_EQ(result["doc2"][2].start, 5u);
EXPECT_EQ(result["doc2"][3].start, static_cast<uint32_t>(7)); EXPECT_EQ(result["doc2"][3].start, 7u);
} }
TEST_F(InvertedIndexTest, TfidfTest) { TEST_F(InvertedIndexTest, TfidfFromZeroTest) {
EXPECT_EQ(GetTfidfCache().size(), 0u);
EXPECT_FALSE(IsInvertedIndexBuilt());
BuildInvertedIndex();
std::vector<TfidfResult> results = GetTfidf(base::UTF8ToUTF16("A")); std::vector<TfidfResult> results = GetTfidf(base::UTF8ToUTF16("A"));
EXPECT_EQ(results.size(), static_cast<unsigned long>(2)); EXPECT_THAT(GetScoresFromTfidfResult(results),
const std::vector<float> idf_scores = { testing::UnorderedElementsAre(0.5, 0.33));
std::roundf(std::get<2>(results[0]) * 100) / 100.0,
std::roundf(std::get<2>(results[1]) * 100) / 100.0};
EXPECT_THAT(idf_scores, testing::UnorderedElementsAre(0.5, 0.33));
results = GetTfidf(base::UTF8ToUTF16("B")); results = GetTfidf(base::UTF8ToUTF16("B"));
EXPECT_EQ(results.size(), static_cast<unsigned long>(1)); EXPECT_EQ(results.size(), 1u);
EXPECT_NEAR(std::get<2>(results[0]), 0.70, 0.01); EXPECT_THAT(GetScoresFromTfidfResult(results),
testing::UnorderedElementsAre(0.7));
results = GetTfidf(base::UTF8ToUTF16("C")); results = GetTfidf(base::UTF8ToUTF16("C"));
EXPECT_EQ(results.size(), static_cast<unsigned long>(1)); EXPECT_EQ(results.size(), 1u);
EXPECT_NEAR(std::get<2>(results[0]), 0.94, 0.01); EXPECT_THAT(GetScoresFromTfidfResult(results),
testing::UnorderedElementsAre(0.94));
results = GetTfidf(base::UTF8ToUTF16("D")); results = GetTfidf(base::UTF8ToUTF16("D"));
EXPECT_EQ(results.size(), static_cast<unsigned long>(0)); EXPECT_EQ(results.size(), 0u);
} }
TEST_F(InvertedIndexTest, PopulateTfidfCacheTest) { TEST_F(InvertedIndexTest, UpdateIndexTest) {
EXPECT_EQ(GetTfidfCache().size(), 0u);
BuildInvertedIndex();
EXPECT_TRUE(IsInvertedIndexBuilt());
EXPECT_EQ(GetTfidfCache().size(), 3u);
// Replaces "doc1" // Replaces "doc1"
AddDocument("doc1", AddDocument("doc1",
{{base::UTF8ToUTF16("A"), {{base::UTF8ToUTF16("A"),
{{"header", 1, 1}, {"body", 2, 1}, {"header", 4, 1}}}, {{"header", 1, 1}, {"body", 2, 1}, {"header", 4, 1}}},
{base::UTF8ToUTF16("D"), {{"header", 3, 1}, {"body", 5, 1}}}}); {base::UTF8ToUTF16("D"), {{"header", 3, 1}, {"body", 5, 1}}}});
PopulateTfidfCache(); EXPECT_FALSE(IsInvertedIndexBuilt());
BuildInvertedIndex();
EXPECT_EQ(GetTfidfCache().size(), 3u);
std::vector<TfidfResult> results = GetTfidf(base::UTF8ToUTF16("A")); std::vector<TfidfResult> results = GetTfidf(base::UTF8ToUTF16("A"));
EXPECT_EQ(results.size(), static_cast<unsigned long>(2)); EXPECT_THAT(GetScoresFromTfidfResult(results),
const std::vector<float> idf_scores = { testing::UnorderedElementsAre(0.6, 0.33));
std::roundf(std::get<2>(results[0]) * 100) / 100.0,
std::roundf(std::get<2>(results[1]) * 100) / 100.0};
EXPECT_THAT(idf_scores, testing::UnorderedElementsAre(0.6, 0.33));
results = GetTfidf(base::UTF8ToUTF16("B")); results = GetTfidf(base::UTF8ToUTF16("B"));
EXPECT_EQ(results.size(), static_cast<unsigned long>(0)); EXPECT_THAT(GetScoresFromTfidfResult(results),
testing::UnorderedElementsAre());
results = GetTfidf(base::UTF8ToUTF16("C")); results = GetTfidf(base::UTF8ToUTF16("C"));
EXPECT_EQ(results.size(), static_cast<unsigned long>(1)); EXPECT_THAT(GetScoresFromTfidfResult(results),
EXPECT_NEAR(std::get<2>(results[0]), 0.94, 0.01); testing::UnorderedElementsAre(0.94));
results = GetTfidf(base::UTF8ToUTF16("D")); results = GetTfidf(base::UTF8ToUTF16("D"));
EXPECT_EQ(results.size(), static_cast<unsigned long>(1)); EXPECT_THAT(GetScoresFromTfidfResult(results),
EXPECT_NEAR(std::get<2>(results[0]), 0.56, 0.01); testing::UnorderedElementsAre(0.56));
} }
} // namespace local_search_service } // namespace local_search_service
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment