Commit fa21cf28 authored by Ce Chen's avatar Ce Chen Committed by Commit Bot

[omnibox] Move private static functions/classes inside OnDeviceHeadModel

header into anonymous namespace.

For simplicity I did not make any changes other than moving
declarations, adjusting orders for these private functions/nested
classes.

Bug: 925072
Change-Id: Iafeaa2d71b18328061844721187e061bedeb32a8
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2444674Reviewed-by: default avatarTommy Li <tommycli@chromium.org>
Commit-Queue: Ce Chen <cch@chromium.org>
Cr-Commit-Position: refs/heads/master@{#814118}
parent 8fbb21d4
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <fstream>
#include <list>
#include <memory>
#include "base/logging.h" #include "base/logging.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
...@@ -17,6 +20,62 @@ namespace { ...@@ -17,6 +20,62 @@ namespace {
// specify the size (num of bytes) of the address and the score in each node. // specify the size (num of bytes) of the address and the score in each node.
const int kRootNodeOffset = 2; const int kRootNodeOffset = 2;
// A useful data structure to keep track of the tree nodes should be and have
// been visited during tree traversal.
struct MatchCandidate {
// The sequences of characters from the start node to current node.
std::string text;
// Whether the text above can be returned as a suggestion; if false it is the
// prefix of some other complete suggestion.
bool is_complete_suggestion;
// If is_complete_suggestion is true, this is the score for the suggestion;
// Otherwise it will be set as max_score_as_root of the node.
uint32_t score;
// The address of the node in the model file. It is not required if
// is_complete_suggestion is true.
uint32_t address;
};
// Doubly linked list structure, which will be sorted based on candidates'
// scores (from low to high), to track nodes during tree search. We use two of
// this list to keep max_num_matches_to_return_ nodes in total with highest
// score during the search, and prune children and branches with low score.
// In theory, using RBTree might give a better search performance
// (i.e. log(n)) compared with linear from linked list here when inserting new
// candidates with high score into the struct, but since n is usually small,
// using linked list shall be okay.
using CandidateQueue = std::list<MatchCandidate>;
// A mini class holds all parameters needed to access the model on disk.
class OnDeviceModelParams {
public:
static std::unique_ptr<OnDeviceModelParams> Create(
const std::string& model_filename,
const uint32_t max_num_matches_to_return);
std::ifstream* GetModelFileStream() { return &model_filestream_; }
uint32_t score_size() const { return score_size_; }
uint32_t address_size() const { return address_size_; }
uint32_t max_num_matches_to_return() const {
return max_num_matches_to_return_;
}
~OnDeviceModelParams();
OnDeviceModelParams(const OnDeviceModelParams&) = delete;
OnDeviceModelParams& operator=(const OnDeviceModelParams&) = delete;
private:
OnDeviceModelParams() = default;
std::ifstream model_filestream_;
uint32_t score_size_;
uint32_t address_size_;
uint32_t max_num_matches_to_return_;
};
uint32_t ConvertByteArrayToInt(char byte_array[], uint32_t num_bytes) { uint32_t ConvertByteArrayToInt(char byte_array[], uint32_t num_bytes) {
uint32_t result = 0; uint32_t result = 0;
for (uint32_t i = 0; i < num_bytes; ++i) { for (uint32_t i = 0; i < num_bytes; ++i) {
...@@ -25,50 +84,76 @@ uint32_t ConvertByteArrayToInt(char byte_array[], uint32_t num_bytes) { ...@@ -25,50 +84,76 @@ uint32_t ConvertByteArrayToInt(char byte_array[], uint32_t num_bytes) {
return result; return result;
} }
} // namespace bool OpenModelFileStream(OnDeviceModelParams* params,
// static
std::unique_ptr<OnDeviceHeadModel::OnDeviceModelParams>
OnDeviceHeadModel::OnDeviceModelParams::Create(
const std::string& model_filename, const std::string& model_filename,
const uint32_t max_num_matches_to_return) { const uint32_t start_address) {
std::unique_ptr<OnDeviceModelParams> params(new OnDeviceModelParams()); if (model_filename.empty()) {
DVLOG(1) << "Model filename is empty";
return false;
}
// TODO(crbug.com/925072): Add DCHECK and code to report failures to UMA // First close the file if it's still open.
// histogram. if (params->GetModelFileStream()->is_open()) {
if (!OpenModelFileStream(params.get(), model_filename, 0)) { DVLOG(1) << "Previous file is still open";
DVLOG(1) << "On Device Head Params: cannot access on device head " params->GetModelFileStream()->close();
<< "params instance because model file cannot be opened";
return nullptr;
} }
char sizes[2]; params->GetModelFileStream()->open(model_filename,
if (!ReadNextNumBytes(params.get(), 2, sizes)) { std::ios::in | std::ios::binary);
DVLOG(1) << "On Device Head Params: failed to read size information " if (!params->GetModelFileStream()->is_open()) {
<< "in the first 2 bytes of the model file: " << model_filename; DVLOG(1) << "Failed to open model file from [" << model_filename << "]";
return nullptr; return false;
} }
params->address_size_ = sizes[0]; if (start_address > 0) {
params->score_size_ = sizes[1]; params->GetModelFileStream()->seekg(start_address);
if (!AreSizesValid(params.get())) {
return nullptr;
} }
return true;
}
params->max_num_matches_to_return_ = max_num_matches_to_return; void MaybeCloseModelFileStream(OnDeviceModelParams* params) {
return params; if (params->GetModelFileStream()->is_open()) {
params->GetModelFileStream()->close();
}
} }
OnDeviceHeadModel::OnDeviceModelParams::~OnDeviceModelParams() { // Reads next num_bytes from the file stream.
if (model_filestream_.is_open()) { bool ReadNextNumBytes(OnDeviceModelParams* params,
model_filestream_.close(); uint32_t num_bytes,
char* buf) {
uint32_t address = params->GetModelFileStream()->tellg();
params->GetModelFileStream()->read(buf, num_bytes);
if (params->GetModelFileStream()->fail()) {
DVLOG(1) << "On Device Head model: ifstream read error at address ["
<< address << "], when trying to read [" << num_bytes << "] bytes";
return false;
} }
return true;
} }
OnDeviceHeadModel::OnDeviceModelParams::OnDeviceModelParams() = default; // Reads next num_bytes from the file stream but returns as an integer.
uint32_t ReadNextNumBytesAsInt(OnDeviceModelParams* params,
uint32_t num_bytes,
bool* is_successful) {
char* buf = new char[num_bytes];
*is_successful = ReadNextNumBytes(params, num_bytes, buf);
if (!*is_successful) {
delete[] buf;
return 0;
}
// static uint32_t result = ConvertByteArrayToInt(buf, num_bytes);
bool OnDeviceHeadModel::AreSizesValid(OnDeviceModelParams* params) { delete[] buf;
return result;
}
// Checks if size of score and size of address read from the model file are
// valid.
// For score, we use size of 2 bytes (15 bits), 3 bytes (23 bits) or 4 bytes
// (31 bits); For address, we use size of 3 bytes (23 bits) or 4 bytes
// (31 bits).
bool AreSizesValid(OnDeviceModelParams* params) {
bool is_score_size_valid = bool is_score_size_valid =
(params->score_size() >= 2 && params->score_size() <= 4); (params->score_size() >= 2 && params->score_size() <= 4);
bool is_address_size_valid = bool is_address_size_valid =
...@@ -84,84 +169,7 @@ bool OnDeviceHeadModel::AreSizesValid(OnDeviceModelParams* params) { ...@@ -84,84 +169,7 @@ bool OnDeviceHeadModel::AreSizesValid(OnDeviceModelParams* params) {
return is_score_size_valid && is_address_size_valid; return is_score_size_valid && is_address_size_valid;
} }
// static void InsertCandidateToQueue(const MatchCandidate& candidate,
std::vector<std::pair<std::string, uint32_t>>
OnDeviceHeadModel::GetSuggestionsForPrefix(const std::string& model_filename,
uint32_t max_num_matches_to_return,
const std::string& prefix) {
std::vector<std::pair<std::string, uint32_t>> suggestions;
if (prefix.empty() || max_num_matches_to_return < 1) {
return suggestions;
}
std::unique_ptr<OnDeviceModelParams> params =
OnDeviceModelParams::Create(model_filename, max_num_matches_to_return);
if (params && params->GetModelFileStream()->is_open()) {
params->GetModelFileStream()->seekg(kRootNodeOffset);
MatchCandidate start_match;
if (FindStartNode(params.get(), prefix, &start_match)) {
suggestions = DoSearch(params.get(), start_match);
}
MaybeCloseModelFileStream(params.get());
}
return suggestions;
}
// static
std::vector<std::pair<std::string, uint32_t>> OnDeviceHeadModel::DoSearch(
OnDeviceModelParams* params,
const MatchCandidate& start_match) {
std::vector<std::pair<std::string, uint32_t>> suggestions;
CandidateQueue leaf_queue, non_leaf_queue;
uint32_t min_score_in_queues = start_match.score;
InsertCandidateToQueue(start_match, &leaf_queue, &non_leaf_queue);
// Do the search until there is no non leaf candidates in the queue.
while (!non_leaf_queue.empty()) {
// Always fetch the intermediate node with highest score at the back of the
// queue.
auto next_candidates = ReadTreeNode(params, non_leaf_queue.back());
non_leaf_queue.pop_back();
min_score_in_queues =
GetMinScoreFromQueues(params, leaf_queue, non_leaf_queue);
for (const auto& candidate : next_candidates) {
if (candidate.score > min_score_in_queues ||
(leaf_queue.size() + non_leaf_queue.size() <
params->max_num_matches_to_return())) {
InsertCandidateToQueue(candidate, &leaf_queue, &non_leaf_queue);
}
// If there are too many candidates in the queues, remove the one with
// lowest score since it will never be shown to users.
if (leaf_queue.size() + non_leaf_queue.size() >
params->max_num_matches_to_return()) {
if (leaf_queue.empty() ||
(!non_leaf_queue.empty() &&
leaf_queue.front().score > non_leaf_queue.front().score)) {
non_leaf_queue.pop_front();
} else {
leaf_queue.pop_front();
}
}
min_score_in_queues =
GetMinScoreFromQueues(params, leaf_queue, non_leaf_queue);
}
}
while (!leaf_queue.empty()) {
suggestions.push_back(
std::make_pair(leaf_queue.back().text, leaf_queue.back().score));
leaf_queue.pop_back();
}
return suggestions;
}
// static
void OnDeviceHeadModel::InsertCandidateToQueue(const MatchCandidate& candidate,
CandidateQueue* leaf_queue, CandidateQueue* leaf_queue,
CandidateQueue* non_leaf_queue) { CandidateQueue* non_leaf_queue) {
CandidateQueue* queue_ptr = CandidateQueue* queue_ptr =
...@@ -177,9 +185,7 @@ void OnDeviceHeadModel::InsertCandidateToQueue(const MatchCandidate& candidate, ...@@ -177,9 +185,7 @@ void OnDeviceHeadModel::InsertCandidateToQueue(const MatchCandidate& candidate,
} }
} }
// static uint32_t GetMinScoreFromQueues(OnDeviceModelParams* params,
uint32_t OnDeviceHeadModel::GetMinScoreFromQueues(
OnDeviceModelParams* params,
const CandidateQueue& queue_1, const CandidateQueue& queue_1,
const CandidateQueue& queue_2) { const CandidateQueue& queue_2) {
uint32_t min_score = 0x1 << (params->score_size() * 8 - 1); uint32_t min_score = 0x1 << (params->score_size() * 8 - 1);
...@@ -192,58 +198,16 @@ uint32_t OnDeviceHeadModel::GetMinScoreFromQueues( ...@@ -192,58 +198,16 @@ uint32_t OnDeviceHeadModel::GetMinScoreFromQueues(
return min_score; return min_score;
} }
// static // Reads block max_score_as_root at the beginning of the node from the given
bool OnDeviceHeadModel::FindStartNode(OnDeviceModelParams* params, // address. If there is a leaf score at the end of the block, return the leaf
const std::string& prefix, // score using param leaf_candidate;
MatchCandidate* start_match) { uint32_t ReadMaxScoreAsRoot(OnDeviceModelParams* params,
if (start_match == nullptr) {
return false;
}
start_match->text = "";
start_match->score = 0;
start_match->address = kRootNodeOffset;
start_match->is_complete_suggestion = false;
while (start_match->text.size() < prefix.size()) {
auto children = ReadTreeNode(params, *start_match);
bool has_match = false;
for (auto const& child : children) {
// The way we build the model ensures that there will be only one child
// matching the given prefix at each node.
if (!child.text.empty() &&
(base::StartsWith(child.text, prefix, base::CompareCase::SENSITIVE) ||
base::StartsWith(prefix, child.text,
base::CompareCase::SENSITIVE))) {
// A leaf only partially matching the given prefix cannot be the right
// start node.
if (child.is_complete_suggestion && child.text.size() < prefix.size()) {
continue;
}
start_match->text = child.text;
start_match->is_complete_suggestion = child.is_complete_suggestion;
start_match->score = child.score;
start_match->address = child.address;
has_match = true;
break;
}
}
if (!has_match) {
return false;
}
}
return start_match->text.size() >= prefix.size();
}
// static
uint32_t OnDeviceHeadModel::ReadMaxScoreAsRoot(OnDeviceModelParams* params,
uint32_t address, uint32_t address,
MatchCandidate* leaf_candidate, MatchCandidate* leaf_candidate,
bool* is_successful) { bool* is_successful) {
if (is_successful == nullptr) { if (is_successful == nullptr) {
DVLOG(1) << "On Device Head model: a boolean var is_successful " DVLOG(1) << "On Device Head model: a boolean var is_successful is required "
<< "is required when calling function ReadMaxScoreAsRoot"; << "when calling function ReadMaxScoreAsRoot";
return 0; return 0;
} }
...@@ -271,9 +235,9 @@ uint32_t OnDeviceHeadModel::ReadMaxScoreAsRoot(OnDeviceModelParams* params, ...@@ -271,9 +235,9 @@ uint32_t OnDeviceHeadModel::ReadMaxScoreAsRoot(OnDeviceModelParams* params,
return max_score; return max_score;
} }
// static // Reads a child block and move ifstream cursor to next child; returns false
bool OnDeviceHeadModel::ReadNextChild(OnDeviceModelParams* params, // when reaching the end of the node or ifstream read error happens.
MatchCandidate* candidate) { bool ReadNextChild(OnDeviceModelParams* params, MatchCandidate* candidate) {
if (candidate == nullptr) { if (candidate == nullptr) {
return false; return false;
} }
...@@ -341,9 +305,9 @@ bool OnDeviceHeadModel::ReadNextChild(OnDeviceModelParams* params, ...@@ -341,9 +305,9 @@ bool OnDeviceHeadModel::ReadNextChild(OnDeviceModelParams* params,
return is_successful; return is_successful;
} }
// static // Reads tree node from given match candidate, convert all possible suggestions
std::vector<OnDeviceHeadModel::MatchCandidate> OnDeviceHeadModel::ReadTreeNode( // and children of this node into structure MatchCandidate.
OnDeviceModelParams* params, std::vector<MatchCandidate> ReadTreeNode(OnDeviceModelParams* params,
const MatchCandidate& current) { const MatchCandidate& current) {
std::vector<MatchCandidate> candidates; std::vector<MatchCandidate> candidates;
// The current candidate passed in is a leaf node and we shall stop here. // The current candidate passed in is a leaf node and we shall stop here.
...@@ -383,68 +347,160 @@ std::vector<OnDeviceHeadModel::MatchCandidate> OnDeviceHeadModel::ReadTreeNode( ...@@ -383,68 +347,160 @@ std::vector<OnDeviceHeadModel::MatchCandidate> OnDeviceHeadModel::ReadTreeNode(
return candidates; return candidates;
} }
// static // Finds start node which matches given prefix, returns true if found and the
bool OnDeviceHeadModel::ReadNextNumBytes(OnDeviceModelParams* params, // start node using param match_candidate.
uint32_t num_bytes, bool FindStartNode(OnDeviceModelParams* params,
char* buf) { const std::string& prefix,
uint32_t address = params->GetModelFileStream()->tellg(); MatchCandidate* start_match) {
params->GetModelFileStream()->read(buf, num_bytes); if (start_match == nullptr) {
if (params->GetModelFileStream()->fail()) {
DVLOG(1) << "On Device Head model: ifstream read error at address ["
<< address << "], when trying to read [" << num_bytes << "] bytes";
return false; return false;
} }
return true;
start_match->text = "";
start_match->score = 0;
start_match->address = kRootNodeOffset;
start_match->is_complete_suggestion = false;
while (start_match->text.size() < prefix.size()) {
auto children = ReadTreeNode(params, *start_match);
bool has_match = false;
for (auto const& child : children) {
// The way we build the model ensures that there will be only one child
// matching the given prefix at each node.
if (!child.text.empty() &&
(base::StartsWith(child.text, prefix, base::CompareCase::SENSITIVE) ||
base::StartsWith(prefix, child.text,
base::CompareCase::SENSITIVE))) {
// A leaf only partially matching the given prefix cannot be the right
// start node.
if (child.is_complete_suggestion && child.text.size() < prefix.size()) {
continue;
}
start_match->text = child.text;
start_match->is_complete_suggestion = child.is_complete_suggestion;
start_match->score = child.score;
start_match->address = child.address;
has_match = true;
break;
}
}
if (!has_match) {
return false;
}
}
return start_match->text.size() >= prefix.size();
} }
// static std::vector<std::pair<std::string, uint32_t>> DoSearch(
uint32_t OnDeviceHeadModel::ReadNextNumBytesAsInt(OnDeviceModelParams* params, OnDeviceModelParams* params,
uint32_t num_bytes, const MatchCandidate& start_match) {
bool* is_successful) { std::vector<std::pair<std::string, uint32_t>> suggestions;
char* buf = new char[num_bytes];
*is_successful = ReadNextNumBytes(params, num_bytes, buf); CandidateQueue leaf_queue, non_leaf_queue;
if (!*is_successful) { uint32_t min_score_in_queues = start_match.score;
delete[] buf; InsertCandidateToQueue(start_match, &leaf_queue, &non_leaf_queue);
return 0;
// Do the search until there is no non leaf candidates in the queue.
while (!non_leaf_queue.empty()) {
// Always fetch the intermediate node with highest score at the back of the
// queue.
auto next_candidates = ReadTreeNode(params, non_leaf_queue.back());
non_leaf_queue.pop_back();
min_score_in_queues =
GetMinScoreFromQueues(params, leaf_queue, non_leaf_queue);
for (const auto& candidate : next_candidates) {
if (candidate.score > min_score_in_queues ||
(leaf_queue.size() + non_leaf_queue.size() <
params->max_num_matches_to_return())) {
InsertCandidateToQueue(candidate, &leaf_queue, &non_leaf_queue);
} }
uint32_t result = ConvertByteArrayToInt(buf, num_bytes); // If there are too many candidates in the queues, remove the one with
delete[] buf; // lowest score since it will never be shown to users.
if (leaf_queue.size() + non_leaf_queue.size() >
params->max_num_matches_to_return()) {
if (leaf_queue.empty() ||
(!non_leaf_queue.empty() &&
leaf_queue.front().score > non_leaf_queue.front().score)) {
non_leaf_queue.pop_front();
} else {
leaf_queue.pop_front();
}
}
min_score_in_queues =
GetMinScoreFromQueues(params, leaf_queue, non_leaf_queue);
}
}
return result; while (!leaf_queue.empty()) {
suggestions.emplace_back(leaf_queue.back().text, leaf_queue.back().score);
leaf_queue.pop_back();
}
return suggestions;
} }
} // namespace
// static // static
bool OnDeviceHeadModel::OpenModelFileStream(OnDeviceModelParams* params, std::unique_ptr<OnDeviceModelParams> OnDeviceModelParams::Create(
const std::string& model_filename, const std::string& model_filename,
const uint32_t start_address) { const uint32_t max_num_matches_to_return) {
if (model_filename.empty()) { std::unique_ptr<OnDeviceModelParams> params(new OnDeviceModelParams());
DVLOG(1) << "Model filename is empty";
return false; // TODO(crbug.com/925072): Add DCHECK and code to report failures to UMA
// histogram.
if (!OpenModelFileStream(params.get(), model_filename, 0)) {
DVLOG(1) << "On Device Head Params: cannot access on device head params "
<< "instance because model file cannot be opened";
return nullptr;
} }
// First close the file if it's still open. char sizes[2];
if (params->GetModelFileStream()->is_open()) { if (!ReadNextNumBytes(params.get(), 2, sizes)) {
DVLOG(1) << "Previous file is still open"; DVLOG(1) << "On Device Head Params: failed to read size information in the "
params->GetModelFileStream()->close(); << "first 2 bytes of the model file: " << model_filename;
return nullptr;
} }
params->GetModelFileStream()->open(model_filename, params->address_size_ = sizes[0];
std::ios::in | std::ios::binary); params->score_size_ = sizes[1];
if (!params->GetModelFileStream()->is_open()) { if (!AreSizesValid(params.get())) {
DVLOG(1) << "Failed to open model file from [" << model_filename << "]"; return nullptr;
return false;
} }
if (start_address > 0) { params->max_num_matches_to_return_ = max_num_matches_to_return;
params->GetModelFileStream()->seekg(start_address); return params;
}
OnDeviceModelParams::~OnDeviceModelParams() {
if (model_filestream_.is_open()) {
model_filestream_.close();
} }
return true;
} }
// static // static
void OnDeviceHeadModel::MaybeCloseModelFileStream(OnDeviceModelParams* params) { std::vector<std::pair<std::string, uint32_t>>
if (params->GetModelFileStream()->is_open()) { OnDeviceHeadModel::GetSuggestionsForPrefix(const std::string& model_filename,
params->GetModelFileStream()->close(); uint32_t max_num_matches_to_return,
const std::string& prefix) {
std::vector<std::pair<std::string, uint32_t>> suggestions;
if (prefix.empty() || max_num_matches_to_return < 1) {
return suggestions;
}
std::unique_ptr<OnDeviceModelParams> params =
OnDeviceModelParams::Create(model_filename, max_num_matches_to_return);
if (params && params->GetModelFileStream()->is_open()) {
params->GetModelFileStream()->seekg(kRootNodeOffset);
MatchCandidate start_match;
if (FindStartNode(params.get(), prefix, &start_match)) {
suggestions = DoSearch(params.get(), start_match);
} }
MaybeCloseModelFileStream(params.get());
}
return suggestions;
} }
\ No newline at end of file
...@@ -5,12 +5,8 @@ ...@@ -5,12 +5,8 @@
#ifndef COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_MODEL_H_ #ifndef COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_MODEL_H_
#define COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_MODEL_H_ #define COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_MODEL_H_
#include <fstream>
#include <list>
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
// On device head suggest feature uses an on device model which encodes some // On device head suggest feature uses an on device model which encodes some
// top queries into a radix tree (https://en.wikipedia.org/wiki/Radix_tree), to // top queries into a radix tree (https://en.wikipedia.org/wiki/Radix_tree), to
...@@ -65,9 +61,6 @@ ...@@ -65,9 +61,6 @@
// //
// The size of score and address will be given in the first two bytes of the // The size of score and address will be given in the first two bytes of the
// model file. // model file.
// TODO(crbug.com/925072): make some cleanups after converting this class into
// a static class, e.g. move private class functions into anonymous namespace.
class OnDeviceHeadModel { class OnDeviceHeadModel {
public: public:
// Gets top "max_num_matches_to_return" suggestions and their scores which // Gets top "max_num_matches_to_return" suggestions and their scores which
...@@ -76,124 +69,6 @@ class OnDeviceHeadModel { ...@@ -76,124 +69,6 @@ class OnDeviceHeadModel {
const std::string& model_filename, const std::string& model_filename,
const uint32_t max_num_matches_to_return, const uint32_t max_num_matches_to_return,
const std::string& prefix); const std::string& prefix);
private:
// A useful data structure to keep track of the tree nodes should be and have
// been visited during tree traversal.
struct MatchCandidate {
// The sequences of characters from the start node to current node.
std::string text;
// Whether the text above can be returned as a suggestion; if false it is
// the prefix of some other complete suggestion.
bool is_complete_suggestion;
// If is_complete_suggestion is true, this is the score for the suggestion;
// Otherwise it will be set as max_score_as_root of the node.
uint32_t score;
// The address of the node in the model file. It is not required if
// is_complete_suggestion is true.
uint32_t address;
};
// Doubly linked list structure, which will be sorted based on candidates'
// scores (from low to high), to track nodes during tree search. We use two of
// this list to keep max_num_matches_to_return_ nodes in total with
// highest score during the search, and prune children and branches with low
// score.
// In theory, using RBTree might give a better search performance
// (i.e. log(n)) compared with linear from linked list here when inserting
// new candidates with high score into the struct, but since n is usually
// small, using linked list shall be okay.
using CandidateQueue = std::list<MatchCandidate>;
// A mini class holds all parameters needed to access the model on disk.
class OnDeviceModelParams {
public:
static std::unique_ptr<OnDeviceModelParams> Create(
const std::string& model_filename,
const uint32_t max_num_matches_to_return);
std::ifstream* GetModelFileStream() { return &model_filestream_; }
uint32_t score_size() const { return score_size_; }
uint32_t address_size() const { return address_size_; }
uint32_t max_num_matches_to_return() const {
return max_num_matches_to_return_;
}
~OnDeviceModelParams();
private:
OnDeviceModelParams();
OnDeviceModelParams(const OnDeviceModelParams&) = delete;
OnDeviceModelParams& operator=(const OnDeviceModelParams&) = delete;
std::ifstream model_filestream_;
uint32_t score_size_;
uint32_t address_size_;
uint32_t max_num_matches_to_return_;
};
static void InsertCandidateToQueue(const MatchCandidate& candidate,
CandidateQueue* leaf_queue,
CandidateQueue* non_leaf_queue);
static uint32_t GetMinScoreFromQueues(OnDeviceModelParams* params,
const CandidateQueue& queue_1,
const CandidateQueue& queue_2);
// Finds start node which matches given prefix, returns true if found and
// the start node using param match_candidate.
static bool FindStartNode(OnDeviceModelParams* params,
const std::string& prefix,
MatchCandidate* match_candidate);
// Reads tree node from given match candidate, convert all possible
// suggestions and children of this node into structure MatchCandidate.
static std::vector<MatchCandidate> ReadTreeNode(
OnDeviceModelParams* params,
const MatchCandidate& current);
// Reads block max_score_as_root at the beginning of the node from the given
// address. If there is a leaf score at the end of the block, return the leaf
// score using param leaf_candidate;
static uint32_t ReadMaxScoreAsRoot(OnDeviceModelParams* params,
uint32_t address,
MatchCandidate* leaf_candidate,
bool* is_successful);
// Reads a child block and move ifstream cursor to next child; returns false
// when reaching the end of the node or ifstream read error happens.
static bool ReadNextChild(OnDeviceModelParams* params,
MatchCandidate* candidate);
// Performs a search starting from the address specified by start_match and
// returns max_num_matches_to_return_ number of complete suggestions with
// highest scores.
static std::vector<std::pair<std::string, uint32_t>> DoSearch(
OnDeviceModelParams* params,
const MatchCandidate& start_match);
// Reads next num_bytes from the file stream.
static bool ReadNextNumBytes(OnDeviceModelParams* params,
uint32_t num_bytes,
char* buf);
static uint32_t ReadNextNumBytesAsInt(OnDeviceModelParams* params,
uint32_t num_bytes,
bool* is_successful);
// Checks if size of score and size of address read from the model file are
// valid.
// For score, we use size of 2 bytes (15 bits), 3 bytes (23 bits) or 4 bytes
// (31 bits); For address, we use size of 3 bytes (23 bits) or 4 bytes
// (31 bits).
static bool AreSizesValid(OnDeviceModelParams* params);
static bool OpenModelFileStream(OnDeviceModelParams* params,
const std::string& model_filename,
const uint32_t start_address);
static void MaybeCloseModelFileStream(OnDeviceModelParams* params);
}; };
#endif // COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_MODEL_H_ #endif // COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_MODEL_H_
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment