Commit c97535bd authored by Philippe Hamel's avatar Philippe Hamel Committed by Commit Bot

Ranker owns predictor objects and uses a config for initialization.

Ranker also takes care of logging and deals with field trials internally.

Bug: 786472, 778468
Change-Id: Ie583616643ab8ad169df4739a4a18b81a950553f
Reviewed-on: https://chromium-review.googlesource.com/788331
Commit-Queue: Philippe Hamel <hamelphi@chromium.org>
Reviewed-by: default avatarRoger McFarlane <rogerm@chromium.org>
Reviewed-by: default avatarDonn Denman <donnd@chromium.org>
Reviewed-by: default avatarSteven Holte <holte@chromium.org>
Cr-Commit-Position: refs/heads/master@{#524426}
parent 11126aa0
......@@ -193,7 +193,6 @@ public abstract class ChromeFeatureList {
"ContentSuggestionsThumbnailDominantColor";
public static final String CONTEXTUAL_SEARCH_ML_TAP_SUPPRESSION =
"ContextualSearchMlTapSuppression";
public static final String CONTEXTUAL_SEARCH_RANKER_QUERY = "ContextualSearchRankerQuery";
public static final String CONTEXTUAL_SEARCH_SECOND_TAP = "ContextualSearchSecondTap";
public static final String CONTEXTUAL_SUGGESTIONS_CAROUSEL = "ContextualSuggestionsCarousel";
public static final String COPYLESS_PASTE = "CopylessPaste";
......
......@@ -54,6 +54,11 @@ public interface ContextualSearchRankerLogger {
*/
void logFeature(Feature feature, Object value);
/**
* Returns whether or not AssistRanker query is enabled.
*/
boolean isQueryEnabled();
/**
* Logs an outcome value at training time that indicates an ML label as a key/value pair.
* @param feature The feature to log.
......
......@@ -111,6 +111,12 @@ public class ContextualSearchRankerLoggerImpl implements ContextualSearchRankerL
mIsLoggingReadyForPage = true;
mBasePageWebContents = basePageWebContents;
mHasInferenceOccurred = false;
nativeSetupLoggingAndRanker(mNativePointer, basePageWebContents);
}
@Override
public boolean isQueryEnabled() {
return nativeIsQueryEnabled(mNativePointer);
}
@Override
......@@ -138,7 +144,6 @@ public class ContextualSearchRankerLoggerImpl implements ContextualSearchRankerL
mHasInferenceOccurred = true;
if (isEnabled() && mBasePageWebContents != null && mFeaturesToLog != null
&& !mFeaturesToLog.isEmpty()) {
nativeSetupLoggingAndRanker(mNativePointer, mBasePageWebContents);
for (Map.Entry<Feature, Object> entry : mFeaturesToLog.entrySet()) {
logObject(entry.getKey(), entry.getValue());
}
......@@ -260,4 +265,5 @@ public class ContextualSearchRankerLoggerImpl implements ContextualSearchRankerL
// Returns an AssistRankerPrediction integer value.
private native int nativeRunInference(long nativeContextualSearchRankerLoggerImpl);
private native void nativeWriteLogAndReset(long nativeContextualSearchRankerLoggerImpl);
private native boolean nativeIsQueryEnabled(long nativeContextualSearchRankerLoggerImpl);
}
......@@ -10,7 +10,6 @@ import android.text.TextUtils;
import org.chromium.base.Log;
import org.chromium.base.VisibleForTesting;
import org.chromium.chrome.browser.ChromeActivity;
import org.chromium.chrome.browser.ChromeFeatureList;
import org.chromium.chrome.browser.compositor.bottombar.OverlayPanel;
import org.chromium.chrome.browser.preferences.PrefServiceBridge;
import org.chromium.chrome.browser.tab.Tab;
......@@ -380,7 +379,7 @@ public class ContextualSearchSelectionController {
// Make sure Tap Suppression features are consistent.
assert !ContextualSearchFieldTrial.isContextualSearchMlTapSuppressionEnabled()
|| ChromeFeatureList.isEnabled(ChromeFeatureList.CONTEXTUAL_SEARCH_RANKER_QUERY)
|| rankerLogger.isQueryEnabled()
: "Tap Suppression requires the Ranker Query feature to be enabled!";
// If we're suppressing based on heuristics then Ranker doesn't need to know about it.
......
......@@ -85,7 +85,6 @@ const base::Feature* kFeaturesExposedToJava[] = {
&kContentSuggestionsSettings,
&kContentSuggestionsThumbnailDominantColor,
&kContextualSearchMlTapSuppression,
&kContextualSearchRankerQuery,
&kContextualSearchSecondTap,
&kContextualSuggestionsCarousel,
&kCustomContextMenu,
......@@ -251,9 +250,6 @@ const base::Feature kContentSuggestionsThumbnailDominantColor{
const base::Feature kContextualSearchMlTapSuppression{
"ContextualSearchMlTapSuppression", base::FEATURE_DISABLED_BY_DEFAULT};
const base::Feature kContextualSearchRankerQuery{
"ContextualSearchRankerQuery", base::FEATURE_DISABLED_BY_DEFAULT};
const base::Feature kContextualSearchSecondTap{
"ContextualSearchSecondTap", base::FEATURE_DISABLED_BY_DEFAULT};
......
......@@ -44,7 +44,6 @@ extern const base::Feature kContentSuggestionsScrollToLoad;
extern const base::Feature kContentSuggestionsSettings;
extern const base::Feature kContentSuggestionsThumbnailDominantColor;
extern const base::Feature kContextualSearchMlTapSuppression;
extern const base::Feature kContextualSearchRankerQuery;
extern const base::Feature kContextualSearchSecondTap;
extern const base::Feature kContextualSuggestionsCarousel;
extern const base::Feature kCustomContextMenu;
......
......@@ -7,23 +7,18 @@
#include "base/android/jni_string.h"
#include "base/android/scoped_java_ref.h"
#include "base/feature_list.h"
#include "base/metrics/field_trial_params.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/metrics_hashes.h"
#include "base/strings/stringprintf.h"
#include "chrome/browser/android/chrome_feature_list.h"
#include "chrome/browser/assist_ranker/assist_ranker_service_factory.h"
#include "chrome/browser/browser_process.h"
#include "components/assist_ranker/assist_ranker_service_impl.h"
#include "components/assist_ranker/binary_classifier_predictor.h"
#include "components/assist_ranker/predictor_config_definitions.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/ukm/content/source_url_recorder.h"
#include "content/public/browser/web_contents.h"
#include "jni/ContextualSearchRankerLoggerImpl_jni.h"
#include "services/metrics/public/cpp/ukm_entry_builder.h"
#include "services/metrics/public/cpp/ukm_recorder.h"
#include "url/gurl.h"
namespace content {
class BrowserContext;
......@@ -31,111 +26,61 @@ class BrowserContext;
namespace {
const char kContextualSearchRankerModelUrlParamName[] =
"contextual-search-ranker-model-url";
const char kContextualSearchModelFilename[] = "contextual_search_model";
const char kContextualSearchUmaPrefix[] = "Search.ContextualSearch.Ranker";
const char kContextualSearchRankerDidPredict[] = "OutcomeRankerDidPredict";
const char kContextualSearchRankerPrediction[] = "OutcomeRankerPrediction";
const base::FeatureParam<std::string> kModelUrl{
&chrome::android::kContextualSearchRankerQuery,
kContextualSearchRankerModelUrlParamName, ""};
// TODO(donnd, hamelphi): move hex-hash-string to Ranker.
std::string HexHashFeatureName(const std::string& feature_name) {
uint64_t feature_key = base::HashMetricName(feature_name);
return base::StringPrintf("%016" PRIx64, feature_key);
}
} // namespace
ContextualSearchRankerLoggerImpl::ContextualSearchRankerLoggerImpl(JNIEnv* env,
jobject obj)
: source_id_(ukm::kInvalidSourceId),
builder_(nullptr),
predictor_(nullptr),
browser_context_(nullptr),
ranker_example_(nullptr),
has_predicted_decision_(false),
java_object_(nullptr) {
java_object_.Reset(env, obj);
}
: java_object_(env, obj) {}
ContextualSearchRankerLoggerImpl::~ContextualSearchRankerLoggerImpl() {
java_object_ = nullptr;
}
ContextualSearchRankerLoggerImpl::~ContextualSearchRankerLoggerImpl() {}
void ContextualSearchRankerLoggerImpl::SetupLoggingAndRanker(
JNIEnv* env,
jobject obj,
const base::android::JavaParamRef<jobject>& java_web_contents) {
content::WebContents* web_contents =
content::WebContents::FromJavaWebContents(java_web_contents);
if (!web_contents)
web_contents_ = content::WebContents::FromJavaWebContents(java_web_contents);
if (!web_contents_)
return;
source_id_ = ukm::GetSourceIdForWebContentsDocument(web_contents);
ResetUkmEntry();
if (IsRankerQueryEnabled()) {
SetupRankerPredictor(web_contents);
// Start building example data based on features to be gathered and logged.
ranker_example_.reset(new assist_ranker::RankerExample());
} else {
// TODO(donnd): remove when behind-the-flag bug fixed (crbug.com/786589).
VLOG(1) << "SetupLoggingAndRanker got IsRankerQueryEnabled false.";
}
SetupRankerPredictor();
// Start building example data based on features to be gathered and logged.
ranker_example_ = std::make_unique<assist_ranker::RankerExample>();
}
void ContextualSearchRankerLoggerImpl::ResetUkmEntry() {
// Releasing the old entry triggers logging.
builder_ =
ukm::UkmRecorder::Get()->GetEntryBuilder(source_id_, "ContextualSearch");
}
void ContextualSearchRankerLoggerImpl::SetupRankerPredictor() {
// Create one predictor for the current BrowserContext.
if (browser_context_) {
DCHECK(browser_context_ == web_contents_->GetBrowserContext());
return;
}
browser_context_ = web_contents_->GetBrowserContext();
void ContextualSearchRankerLoggerImpl::SetupRankerPredictor(
content::WebContents* web_contents) {
// Set up the Ranker predictor.
if (IsRankerQueryEnabled()) {
// Create one predictor for the current BrowserContext.
content::BrowserContext* browser_context =
web_contents->GetBrowserContext();
if (browser_context == browser_context_)
return;
browser_context_ = browser_context;
assist_ranker::AssistRankerService* assist_ranker_service =
assist_ranker::AssistRankerServiceFactory::GetForBrowserContext(
browser_context);
DCHECK(assist_ranker_service);
std::string model_string(kModelUrl.Get());
DCHECK(model_string.size());
// TODO(donnd): remove when behind-the-flag bug fixed (crbug.com/786589).
VLOG(0) << "Model URL: " << model_string;
assist_ranker::AssistRankerService* assist_ranker_service =
assist_ranker::AssistRankerServiceFactory::GetForBrowserContext(
browser_context_);
if (assist_ranker_service) {
predictor_ = assist_ranker_service->FetchBinaryClassifierPredictor(
GURL(model_string), kContextualSearchModelFilename,
kContextualSearchUmaPrefix);
DCHECK(predictor_);
assist_ranker::GetContextualSearchPredictorConfig());
}
}
void ContextualSearchRankerLoggerImpl::LogFeature(
const std::string& feature_name,
int value) {
auto& features = *ranker_example_->mutable_features();
features[feature_name].set_int32_value(value);
}
void ContextualSearchRankerLoggerImpl::LogLong(
JNIEnv* env,
jobject obj,
const base::android::JavaParamRef<jstring>& j_feature,
jlong j_long) {
std::string feature = base::android::ConvertJavaStringToUTF8(env, j_feature);
if (builder_)
builder_->AddMetric(feature.c_str(), j_long);
// Also write to Ranker if we're logging data needed to predict a decision.
if (IsRankerQueryEnabled() && !has_predicted_decision_) {
std::string hex_feature_key(HexHashFeatureName(feature));
auto& features = *ranker_example_->mutable_features();
features[hex_feature_key].set_int32_value(j_long);
}
LogFeature(feature, j_long);
}
AssistRankerPrediction ContextualSearchRankerLoggerImpl::RunInference(
......@@ -144,19 +89,17 @@ AssistRankerPrediction ContextualSearchRankerLoggerImpl::RunInference(
has_predicted_decision_ = true;
bool prediction = false;
bool was_able_to_predict = false;
if (IsRankerQueryEnabled()) {
if (IsQueryEnabledInternal()) {
was_able_to_predict = predictor_->Predict(*ranker_example_, &prediction);
// Log to UMA whether we were able to predict or not.
base::UmaHistogramBoolean("Search.ContextualSearchRankerWasAbleToPredict",
was_able_to_predict);
// Log the Ranker decision to UKM, including whether we were able to make
// any prediction.
if (builder_) {
builder_->AddMetric(kContextualSearchRankerDidPredict,
was_able_to_predict);
if (was_able_to_predict) {
builder_->AddMetric(kContextualSearchRankerPrediction, prediction);
}
// TODO(chrome-ranker-team): this should be logged internally by Ranker.
LogFeature(kContextualSearchRankerDidPredict,
static_cast<int>(was_able_to_predict));
if (was_able_to_predict) {
LogFeature(kContextualSearchRankerPrediction,
static_cast<int>(prediction));
}
}
AssistRankerPrediction prediction_enum;
......@@ -170,21 +113,28 @@ AssistRankerPrediction ContextualSearchRankerLoggerImpl::RunInference(
prediction_enum = ASSIST_RANKER_PREDICTION_UNAVAILABLE;
}
// TODO(donnd): remove when behind-the-flag bug fixed (crbug.com/786589).
VLOG(0) << "prediction: " << prediction_enum;
DVLOG(0) << "prediction: " << prediction_enum;
return prediction_enum;
}
void ContextualSearchRankerLoggerImpl::WriteLogAndReset(JNIEnv* env,
jobject obj) {
if (predictor_ && ranker_example_) {
ukm::SourceId source_id =
ukm::GetSourceIdForWebContentsDocument(web_contents_);
predictor_->LogExampleToUkm(*ranker_example_.get(), source_id);
}
has_predicted_decision_ = false;
// Set up another builder for the next record (in case it's needed).
ResetUkmEntry();
ranker_example_.reset();
ranker_example_ = std::make_unique<assist_ranker::RankerExample>();
}
bool ContextualSearchRankerLoggerImpl::IsQueryEnabled(JNIEnv* env,
jobject obj) {
return IsQueryEnabledInternal();
}
bool ContextualSearchRankerLoggerImpl::IsRankerQueryEnabled() {
return base::FeatureList::IsEnabled(
chrome::android::kContextualSearchRankerQuery);
bool ContextualSearchRankerLoggerImpl::IsQueryEnabledInternal() {
return predictor_ && predictor_->is_query_enabled();
}
// Java wrapper boilerplate
......
......@@ -6,6 +6,7 @@
#define CHROME_BROWSER_ANDROID_CONTEXTUALSEARCH_CONTEXTUAL_SEARCH_RANKER_LOGGER_IMPL_H_
#include "base/android/jni_android.h"
#include "base/memory/weak_ptr.h"
namespace content {
class BrowserContext;
......@@ -17,10 +18,6 @@ class BinaryClassifierPredictor;
class RankerExample;
} // namespace assist_ranker
namespace ukm {
class UkmEntryBuilder;
} // namespace ukm
// A Java counterpart will be generated for this enum.
// GENERATED_JAVA_ENUM_PACKAGE: org.chromium.chrome.browser.contextualsearch
enum AssistRankerPrediction {
......@@ -65,28 +62,30 @@ class ContextualSearchRankerLoggerImpl {
// ready to start logging the next set of data.
void WriteLogAndReset(JNIEnv* env, jobject obj);
private:
// Log the current UKM entry (if any) and start a new one.
// TODO(donnd): write a test using TestAutoSetUkmRecorder.
void ResetUkmEntry();
// Returns whether or not AssistRanker query is enabled.
bool IsQueryEnabled(JNIEnv* env, jobject obj);
// Sets up the Ranker Predictor for the given |web_contents|.
void SetupRankerPredictor(content::WebContents* web_contents);
private:
// Returns whether or not AssistRanker query is enabled.
bool IsQueryEnabledInternal();
// Whether querying Ranker for model loading and prediction is enabled.
bool IsRankerQueryEnabled();
// Adds feature to the RankerExample.
void LogFeature(const std::string& feature_name, int value);
// The UKM source ID being used for this session.
int32_t source_id_;
// Sets up the Ranker Predictor for the given |web_contents|.
void SetupRankerPredictor();
// The entry builder for the current record, or nullptr if not yet configured.
std::unique_ptr<ukm::UkmEntryBuilder> builder_;
// The WebContents object used to produce the source_id for UKMs, and to get
// browser_context when fetching the predictor. The object is not owned by
// ContextualSearchRankerLoggerImpl.
content::WebContents* web_contents_ = nullptr;
// The Ranker Predictor for whether a tap gesture should be suppressed or not.
std::unique_ptr<assist_ranker::BinaryClassifierPredictor> predictor_;
base::WeakPtr<assist_ranker::BinaryClassifierPredictor> predictor_;
// The |BrowserContext| currently associated with the above predictor.
content::BrowserContext* browser_context_;
// The object not owned by ContextualSearchRankerLoggerImpl.
content::BrowserContext* browser_context_ = nullptr;
// The current RankerExample or null.
// Set of features from one example of a Tap to predict a suppression
......@@ -94,7 +93,7 @@ class ContextualSearchRankerLoggerImpl {
std::unique_ptr<assist_ranker::RankerExample> ranker_example_;
// Whether Ranker has predicted the decision yet.
bool has_predicted_decision_;
bool has_predicted_decision_ = false;
// The linked Java object.
base::android::ScopedJavaGlobalRef<jobject> java_object_;
......
......@@ -15,6 +15,10 @@ static_library("assist_ranker") {
"fake_ranker_model_loader.h",
"generic_logistic_regression_inference.cc",
"generic_logistic_regression_inference.h",
"predictor_config.cc",
"predictor_config.h",
"predictor_config_definitions.cc",
"predictor_config_definitions.h",
"ranker_example_util.cc",
"ranker_example_util.h",
"ranker_model.cc",
......@@ -32,6 +36,7 @@ static_library("assist_ranker") {
"//components/data_use_measurement/core",
"//components/keyed_service/core",
"//net",
"//services/metrics/public/cpp:metrics_cpp",
"//url",
]
}
......@@ -40,6 +45,7 @@ source_set("unit_tests") {
testonly = true
sources = [
"base_predictor_unittest.cc",
"binary_classifier_predictor_unittest.cc",
"generic_logistic_regression_inference_unittest.cc",
"ranker_example_util_unittest.cc",
......@@ -51,6 +57,7 @@ source_set("unit_tests") {
":assist_ranker",
"//base",
"//components/assist_ranker/proto",
"//components/ukm:test_support",
"//net:test_support",
"//testing/gtest",
]
......
......@@ -2,5 +2,7 @@ include_rules = [
"+components/data_use_measurement/core",
"+components/keyed_service/core",
"+components/metrics",
"+components/ukm",
"+net",
"+services/metrics/public",
]
\ No newline at end of file
......@@ -9,32 +9,25 @@
#include <string>
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "components/keyed_service/core/keyed_service.h"
class GURL;
namespace assist_ranker {
class BinaryClassifierPredictor;
struct PredictorConfig;
// TODO(crbug.com/778468) : Refactor this so that the service owns the predictor
// objects and enforce model uniqueness through internal registration in order
// to avoid cache collisions.
//
// Service that provides Predictor objects.
class AssistRankerService : public KeyedService {
public:
AssistRankerService() = default;
// Returns a binary classification model. |model_filename| is the filename of
// the cached model. It should be unique to this predictor to avoid cache
// collision. |model_url| represents a unique ID for the desired model (see
// ranker_model_loader.h for more details). |uma_prefix| is used to log
// histograms related to the loading of the model.
virtual std::unique_ptr<BinaryClassifierPredictor>
FetchBinaryClassifierPredictor(GURL model_url,
const std::string& model_filename,
const std::string& uma_prefix) = 0;
// Returns a binary classification model given a PredictorConfig.
// The predictor is instantiated the first time a predictor is fetched. The
// next calls to fetch will return a pointer to the already instantiated
// predictor.
virtual base::WeakPtr<BinaryClassifierPredictor>
FetchBinaryClassifierPredictor(const PredictorConfig& config) = 0;
private:
DISALLOW_COPY_AND_ASSIGN(AssistRankerService);
......
......@@ -3,7 +3,7 @@
// found in the LICENSE file.
#include "components/assist_ranker/assist_ranker_service_impl.h"
#include "base/memory/weak_ptr.h"
#include "components/assist_ranker/binary_classifier_predictor.h"
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h"
......@@ -19,15 +19,27 @@ AssistRankerServiceImpl::AssistRankerServiceImpl(
AssistRankerServiceImpl::~AssistRankerServiceImpl() {}
std::unique_ptr<BinaryClassifierPredictor>
base::WeakPtr<BinaryClassifierPredictor>
AssistRankerServiceImpl::FetchBinaryClassifierPredictor(
GURL model_url,
const std::string& model_filename,
const std::string& uma_prefix) {
const PredictorConfig& config) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return BinaryClassifierPredictor::Create(url_request_context_getter_.get(),
GetModelPath(model_filename),
model_url, uma_prefix);
const std::string& model_name = config.model_name;
auto predictor_it = predictor_map_.find(model_name);
if (predictor_it != predictor_map_.end()) {
DVLOG(1) << "Predictor " << model_name << " already initialized.";
return base::AsWeakPtr(
static_cast<BinaryClassifierPredictor*>(predictor_it->second.get()));
}
// The predictor does not exist yet, so we create one.
DVLOG(1) << "Initializing predictor: " << model_name;
std::unique_ptr<BinaryClassifierPredictor> predictor =
BinaryClassifierPredictor::Create(config, GetModelPath(model_name),
url_request_context_getter_.get());
base::WeakPtr<BinaryClassifierPredictor> weak_ptr =
base::AsWeakPtr(predictor.get());
predictor_map_[model_name] = std::move(predictor);
return weak_ptr;
}
base::FilePath AssistRankerServiceImpl::GetModelPath(
......
......@@ -7,13 +7,13 @@
#include <memory>
#include <string>
#include <unordered_map>
#include "base/files/file_path.h"
#include "base/memory/ref_counted.h"
#include "base/sequence_checker.h"
#include "components/assist_ranker/assist_ranker_service.h"
class GURL;
#include "components/assist_ranker/predictor_config.h"
namespace net {
class URLRequestContextGetter;
......@@ -21,6 +21,7 @@ class URLRequestContextGetter;
namespace assist_ranker {
class BasePredictor;
class BinaryClassifierPredictor;
class AssistRankerServiceImpl : public AssistRankerService {
......@@ -31,10 +32,8 @@ class AssistRankerServiceImpl : public AssistRankerService {
~AssistRankerServiceImpl() override;
// AssistRankerService...
std::unique_ptr<BinaryClassifierPredictor> FetchBinaryClassifierPredictor(
GURL model_url,
const std::string& model_filename,
const std::string& uma_prefix) override;
base::WeakPtr<BinaryClassifierPredictor> FetchBinaryClassifierPredictor(
const PredictorConfig& config) override;
private:
// Returns the full path to the model cache.
......@@ -46,6 +45,9 @@ class AssistRankerServiceImpl : public AssistRankerService {
// Base path where models are stored.
const base::FilePath base_path_;
std::unordered_map<std::string, std::unique_ptr<BasePredictor>>
predictor_map_;
SEQUENCE_CHECKER(sequence_checker_);
DISALLOW_COPY_AND_ASSIGN(AssistRankerServiceImpl);
......
......@@ -4,23 +4,40 @@
#include "components/assist_ranker/base_predictor.h"
#include "base/feature_list.h"
#include "base/memory/ptr_util.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_example_util.h"
#include "components/assist_ranker/ranker_model.h"
#include "services/metrics/public/cpp/ukm_entry_builder.h"
#include "services/metrics/public/cpp/ukm_recorder.h"
#include "url/gurl.h"
namespace assist_ranker {
BasePredictor::BasePredictor() {}
BasePredictor::BasePredictor(const PredictorConfig& config) : config_(config) {
// TODO(chrome-ranker-team): validate config.
if (config_.field_trial) {
is_query_enabled_ = base::FeatureList::IsEnabled(*config_.field_trial);
} else {
DVLOG(0) << "No field trial specified";
}
}
BasePredictor::~BasePredictor() {}
void BasePredictor::LoadModel(std::unique_ptr<RankerModelLoader> model_loader) {
if (!is_query_enabled_)
return;
if (model_loader_) {
DLOG(ERROR) << "This predictor already has a model loader.";
DVLOG(0) << "This predictor already has a model loader.";
return;
}
// Take ownership of the model loader.
model_loader_ = std::move(model_loader);
// Kick off the initial load from cache.
// Kick off the initial model load.
model_loader_->NotifyOfRankerActivity();
}
......@@ -31,10 +48,80 @@ void BasePredictor::OnModelAvailable(
}
bool BasePredictor::IsReady() {
if (!is_ready_)
if (!is_ready_ && model_loader_)
model_loader_->NotifyOfRankerActivity();
return is_ready_;
}
void BasePredictor::LogFeatureToUkm(const std::string& feature_name,
const Feature& feature,
ukm::UkmEntryBuilder* ukm_builder) {
if (!ukm_builder) {
return;
}
if (!base::ContainsKey(*config_.feature_whitelist, feature_name)) {
DVLOG(1) << "Feature not whitelisted: " << feature_name;
return;
}
int feature_int_value;
if (FeatureToInt(feature, &feature_int_value)) {
DVLOG(3) << "Logging: " << feature_name << ": " << feature_int_value;
ukm_builder->AddMetric(feature_name.c_str(), feature_int_value);
} else {
DVLOG(0) << "Could not convert feature to int: " << feature_name;
}
}
void BasePredictor::LogExampleToUkm(const RankerExample& example,
ukm::SourceId source_id) {
if (config_.log_type != LOG_UKM) {
DVLOG(0) << "Wrong log type in predictor config: " << config_.log_type;
return;
}
if (!config_.feature_whitelist) {
DVLOG(0) << "No whitelist specified.";
return;
}
if (config_.feature_whitelist->empty()) {
DVLOG(0) << "Empty whitelist, examples will not be logged.";
return;
}
// Releasing the builder will trigger logging.
std::unique_ptr<ukm::UkmEntryBuilder> builder =
ukm::UkmRecorder::Get()->GetEntryBuilder(source_id, config_.logging_name);
if (builder) {
for (const auto& feature_kv : example.features()) {
LogFeatureToUkm(feature_kv.first, feature_kv.second, builder.get());
}
} else {
DVLOG(0) << "Could not get UKM Entry Builder.";
}
}
std::string BasePredictor::GetModelName() const {
return config_.model_name;
}
GURL BasePredictor::GetModelUrl() const {
if (!config_.field_trial_url_param) {
DVLOG(1) << "No URL specified.";
return GURL();
}
return GURL(config_.field_trial_url_param->Get());
}
RankerExample BasePredictor::PreprocessExample(const RankerExample& example) {
if (ranker_model_->proto().has_metadata() &&
ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) {
return HashExampleFeatureNames(example);
}
return example;
}
} // namespace assist_ranker
......@@ -9,30 +9,53 @@
#include <string>
#include "base/files/file_path.h"
#include "base/memory/weak_ptr.h"
#include "components/assist_ranker/predictor_config.h"
#include "components/assist_ranker/ranker_model_loader.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
class GURL;
namespace ukm {
class UkmEntryBuilder;
}
namespace assist_ranker {
class Feature;
class RankerExample;
class RankerModel;
// Predictors are objects that provide an interface for prediction, as well as
// encapsulate the logic for loading the model. Sub-classes of BasePredictor
// implement an interface that depends on the nature of the suported model.
// Subclasses of BasePredictor will also need to implement an Initialize method
// that will be called once the model is available, and a static validation
// function with the following signature:
// encapsulate the logic for loading the model and logging. Sub-classes of
// BasePredictor implement an interface that depends on the nature of the
// suported model. Subclasses of BasePredictor will also need to implement an
// Initialize method that will be called once the model is available, and a
// static validation function with the following signature:
//
// static RankerModelStatus ValidateModel(const RankerModel& model);
class BasePredictor {
class BasePredictor : public base::SupportsWeakPtr<BasePredictor> {
public:
BasePredictor();
BasePredictor(const PredictorConfig& config);
virtual ~BasePredictor();
// Returns true if the predictor is ready to make predictions.
bool IsReady();
// Returns true if the base::Feature associated with this model is enabled.
bool is_query_enabled() const { return is_query_enabled_; }
// Logs the features of |example| to UKM using the given source_id.
void LogExampleToUkm(const RankerExample& example, ukm::SourceId source_id);
// Returns the model URL.
GURL GetModelUrl() const;
// Returns the model name.
std::string GetModelName() const;
protected:
// The model used for prediction.
std::unique_ptr<RankerModel> ranker_model_;
// Preprocessing applied to an example before prediction. The original
// RankerExample is not modified, so it is safe to use it later for logging.
RankerExample PreprocessExample(const RankerExample& example);
// Called when the RankerModelLoader has finished loading the model. Returns
// true only if the model was succesfully loaded and is ready to predict.
......@@ -43,9 +66,17 @@ class BasePredictor {
// Called once the model loader as succesfully loaded the model.
void OnModelAvailable(std::unique_ptr<RankerModel> model);
std::unique_ptr<RankerModelLoader> model_loader_;
// The model used for prediction.
std::unique_ptr<RankerModel> ranker_model_;
private:
void LogFeatureToUkm(const std::string& feature_name,
const Feature& feature,
ukm::UkmEntryBuilder* ukm_builder);
bool is_ready_ = false;
bool is_query_enabled_ = false;
PredictorConfig config_;
DISALLOW_COPY_AND_ASSIGN(BasePredictor);
};
......
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/base_predictor.h"
#include <memory>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/memory/ptr_util.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/scoped_task_environment.h"
#include "components/assist_ranker/fake_ranker_model_loader.h"
#include "components/assist_ranker/predictor_config.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
#include "components/assist_ranker/ranker_model.h"
#include "components/ukm/test_ukm_recorder.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"
namespace assist_ranker {
using ::assist_ranker::testing::FakeRankerModelLoader;
namespace {
// Predictor config for testing.
const char kTestModelName[] = "test_model";
const char kTestLoggingName[] = "TestLoggingName";
const char kTestUmaPrefixName[] = "Test.Ranker";
const char kTestUrlParamName[] = "ranker-model-url";
const char kTestDefaultModelUrl[] = "https://foo.bar/model.bin";
const char kBoolFeature[] = "bool_feature";
const char kIntFeature[] = "int_feature";
const char kFloatFeature[] = "float_feature";
const char kStringFeature[] = "string_feature";
const char kFeatureNotWhitelisted[] = "not_whitelisted";
const char kTestNavigationUrl[] = "https://foo.com";
const base::flat_set<std::string> kFeatureWhitelist({kBoolFeature, kIntFeature,
kFloatFeature,
kStringFeature});
const base::Feature kTestRankerQuery{"TestRankerQuery",
base::FEATURE_ENABLED_BY_DEFAULT};
const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl};
const PredictorConfig kTestPredictorConfig = PredictorConfig{
kTestModelName, kTestLoggingName, kTestUmaPrefixName, LOG_UKM,
&kFeatureWhitelist, &kTestRankerQuery, &kTestRankerUrl};
// Class that implements virtual functions of the base class.
class FakePredictor : public BasePredictor {
public:
static std::unique_ptr<FakePredictor> Create();
~FakePredictor() override{};
// Validation will always succeed.
static RankerModelStatus ValidateModel(const RankerModel& model);
protected:
// Not implementing any inference logic.
bool Initialize() override { return true; };
private:
FakePredictor(const PredictorConfig& config);
DISALLOW_COPY_AND_ASSIGN(FakePredictor);
};
FakePredictor::FakePredictor(const PredictorConfig& config)
: BasePredictor(config) {}
RankerModelStatus FakePredictor::ValidateModel(const RankerModel& model) {
return RankerModelStatus::OK;
}
std::unique_ptr<FakePredictor> FakePredictor::Create() {
std::unique_ptr<FakePredictor> predictor(
new FakePredictor(kTestPredictorConfig));
auto ranker_model = base::MakeUnique<RankerModel>();
auto fake_model_loader = base::MakeUnique<FakeRankerModelLoader>(
base::BindRepeating(&FakePredictor::ValidateModel),
base::BindRepeating(&FakePredictor::OnModelAvailable,
base::Unretained(predictor.get())),
std::move(ranker_model));
predictor->LoadModel(std::move(fake_model_loader));
return predictor;
}
} // namespace
class BasePredictorTest : public ::testing::Test {
protected:
BasePredictorTest() = default;
// Disables Query for the test predictor.
void DisableQuery();
ukm::SourceId GetSourceId();
ukm::TestUkmRecorder* GetTestUkmRecorder() { return &test_ukm_recorder_; }
private:
// Sets up the task scheduling/task-runner environment for each test.
base::test::ScopedTaskEnvironment scoped_task_environment_;
// Sets itself as the global UkmRecorder on construction.
ukm::TestAutoSetUkmRecorder test_ukm_recorder_;
// Manages the enabling/disabling of features within the scope of a test.
base::test::ScopedFeatureList scoped_feature_list_;
DISALLOW_COPY_AND_ASSIGN(BasePredictorTest);
};
ukm::SourceId BasePredictorTest::GetSourceId() {
ukm::SourceId source_id = ukm::UkmRecorder::GetNewSourceID();
test_ukm_recorder_.UpdateSourceURL(source_id, GURL(kTestNavigationUrl));
return source_id;
}
void BasePredictorTest::DisableQuery() {
scoped_feature_list_.InitWithFeatures({}, {kTestRankerQuery});
}
TEST_F(BasePredictorTest, BaseTest) {
auto predictor = FakePredictor::Create();
EXPECT_EQ(kTestModelName, predictor->GetModelName());
EXPECT_EQ(kTestDefaultModelUrl, predictor->GetModelUrl());
EXPECT_TRUE(predictor->is_query_enabled());
EXPECT_TRUE(predictor->IsReady());
}
TEST_F(BasePredictorTest, QueryDisabled) {
DisableQuery();
auto predictor = FakePredictor::Create();
EXPECT_EQ(kTestModelName, predictor->GetModelName());
EXPECT_EQ(kTestDefaultModelUrl, predictor->GetModelUrl());
EXPECT_FALSE(predictor->is_query_enabled());
EXPECT_FALSE(predictor->IsReady());
}
TEST_F(BasePredictorTest, LogExampleToUkm) {
auto predictor = FakePredictor::Create();
RankerExample example;
auto& features = *example.mutable_features();
features[kBoolFeature].set_bool_value(true);
features[kIntFeature].set_int32_value(42);
features[kFloatFeature].set_float_value(42.0f);
features[kStringFeature].set_string_value("42");
// This feature will not be logged.
features[kFeatureNotWhitelisted].set_bool_value(false);
predictor->LogExampleToUkm(example, GetSourceId());
EXPECT_EQ(1U, GetTestUkmRecorder()->sources_count());
EXPECT_EQ(1U, GetTestUkmRecorder()->entries_count());
std::vector<const ukm::mojom::UkmEntry*> entries =
GetTestUkmRecorder()->GetEntriesByName(kTestLoggingName);
EXPECT_EQ(1U, entries.size());
GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kBoolFeature, 1);
GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kIntFeature, 42);
// TODO(crbug.com/794187) Float and string features are not logged yet.
EXPECT_FALSE(GetTestUkmRecorder()->EntryHasMetric(entries[0], kFloatFeature));
EXPECT_FALSE(
GetTestUkmRecorder()->EntryHasMetric(entries[0], kStringFeature));
EXPECT_FALSE(
GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted));
}
} // namespace assist_ranker
......@@ -15,26 +15,33 @@
#include "components/assist_ranker/ranker_model.h"
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h"
namespace assist_ranker {
BinaryClassifierPredictor::BinaryClassifierPredictor(){};
BinaryClassifierPredictor::BinaryClassifierPredictor(
const PredictorConfig& config)
: BasePredictor(config){};
BinaryClassifierPredictor::~BinaryClassifierPredictor(){};
// static
std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
net::URLRequestContextGetter* request_context_getter,
const PredictorConfig& config,
const base::FilePath& model_path,
GURL model_url,
const std::string& uma_prefix) {
net::URLRequestContextGetter* request_context_getter) {
std::unique_ptr<BinaryClassifierPredictor> predictor(
new BinaryClassifierPredictor());
new BinaryClassifierPredictor(config));
if (!predictor->is_query_enabled()) {
DVLOG(1) << "Query disabled, bypassing model loading.";
return predictor;
}
const GURL& model_url = predictor->GetModelUrl();
DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
DVLOG(1) << "Model URL: " << model_url;
auto model_loader = base::MakeUnique<RankerModelLoaderImpl>(
base::Bind(&BinaryClassifierPredictor::ValidateModel),
base::Bind(&BinaryClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())),
request_context_getter, model_path, model_url, uma_prefix);
base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())),
request_context_getter, model_path, model_url, config.uma_prefix);
predictor->LoadModel(std::move(model_loader));
return predictor;
}
......@@ -42,18 +49,23 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
bool BinaryClassifierPredictor::Predict(const RankerExample& example,
bool* prediction) {
if (!IsReady()) {
DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
return false;
}
*prediction = inference_module_->Predict(example);
*prediction = inference_module_->Predict(PreprocessExample(example));
DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction;
return true;
}
bool BinaryClassifierPredictor::PredictScore(const RankerExample& example,
float* prediction) {
if (!IsReady()) {
DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
return false;
}
*prediction = inference_module_->PredictScore(example);
*prediction = inference_module_->PredictScore(PreprocessExample(example));
DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << prediction;
return true;
}
......@@ -61,17 +73,22 @@ bool BinaryClassifierPredictor::PredictScore(const RankerExample& example,
RankerModelStatus BinaryClassifierPredictor::ValidateModel(
const RankerModel& model) {
if (model.proto().model_case() != RankerModelProto::kLogisticRegression) {
DVLOG(0) << "Model is incompatible.";
return RankerModelStatus::INCOMPATIBLE;
}
return RankerModelStatus::OK;
}
bool BinaryClassifierPredictor::Initialize() {
// TODO(hamelphi): move the GLRM proto up one layer in the proto in order to
// be independent of the client feature.
inference_module_.reset(new GenericLogisticRegressionInference(
ranker_model_->proto().logistic_regression()));
return true;
if (ranker_model_->proto().model_case() ==
RankerModelProto::kLogisticRegression) {
inference_module_.reset(new GenericLogisticRegressionInference(
ranker_model_->proto().logistic_regression()));
return true;
}
DVLOG(0) << "Could not initialize inference module.";
return false;
}
} // namespace assist_ranker
......@@ -9,8 +9,6 @@
#include "components/assist_ranker/base_predictor.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
class GURL;
namespace base {
class FilePath;
}
......@@ -28,11 +26,13 @@ class BinaryClassifierPredictor : public BasePredictor {
public:
~BinaryClassifierPredictor() override;
// Returns an new predictor instance with the given |config| and initialize
// its model loader. The |request_context getter| is passed to the
// predictor's model_loader which holds it as scoped_refptr.
static std::unique_ptr<BinaryClassifierPredictor> Create(
net::URLRequestContextGetter* request_context_getter,
const PredictorConfig& config,
const base::FilePath& model_path,
GURL model_url,
const std::string& uma_prefix);
net::URLRequestContextGetter* request_context_getter) WARN_UNUSED_RESULT;
// Fills in a boolean decision given a RankerExample. Returns false if a
// prediction could not be made (e.g. the model is not loaded yet).
......@@ -53,7 +53,7 @@ class BinaryClassifierPredictor : public BasePredictor {
private:
friend class BinaryClassifierPredictorTest;
BinaryClassifierPredictor();
BinaryClassifierPredictor(const PredictorConfig& config);
// TODO(hamelphi): Use an abstract BinaryClassifierInferenceModule in order to
// generalize to other models.
......
......@@ -8,6 +8,7 @@
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/feature_list.h"
#include "base/memory/ptr_util.h"
#include "components/assist_ranker/fake_ranker_model_loader.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
......@@ -21,11 +22,14 @@ using ::assist_ranker::testing::FakeRankerModelLoader;
class BinaryClassifierPredictorTest : public ::testing::Test {
public:
std::unique_ptr<BinaryClassifierPredictor> InitPredictor(
std::unique_ptr<RankerModel> ranker_model);
std::unique_ptr<RankerModel> ranker_model,
const PredictorConfig& config);
// This model will return the value of |feature| as a prediction.
GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();
PredictorConfig GetConfig();
protected:
const std::string feature_ = "feature";
const float threshold_ = 0.5;
......@@ -33,10 +37,11 @@ class BinaryClassifierPredictorTest : public ::testing::Test {
std::unique_ptr<BinaryClassifierPredictor>
BinaryClassifierPredictorTest::InitPredictor(
std::unique_ptr<RankerModel> ranker_model) {
std::unique_ptr<RankerModel> ranker_model,
const PredictorConfig& config) {
std::unique_ptr<BinaryClassifierPredictor> predictor(
new BinaryClassifierPredictor());
auto fake_model_loader = base::MakeUnique<FakeRankerModelLoader>(
new BinaryClassifierPredictor(config));
auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
base::Bind(&BinaryClassifierPredictor::ValidateModel),
base::Bind(&BinaryClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())),
......@@ -45,6 +50,20 @@ BinaryClassifierPredictorTest::InitPredictor(
return predictor;
}
const base::Feature kTestRankerQuery{"TestRankerQuery",
base::FEATURE_ENABLED_BY_DEFAULT};
const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, "url-param-name", "https://default.model.url"};
PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
GetEmptyWhitelist(), &kTestRankerQuery,
&kTestRankerUrl);
return config;
}
GenericLogisticRegressionModel
BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
GenericLogisticRegressionModel lr_model;
......@@ -58,7 +77,7 @@ BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
TEST_F(BinaryClassifierPredictorTest, EmptyRankerModel) {
auto ranker_model = base::MakeUnique<RankerModel>();
auto predictor = InitPredictor(std::move(ranker_model));
auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_FALSE(predictor->IsReady());
RankerExample ranker_example;
......@@ -78,7 +97,7 @@ TEST_F(BinaryClassifierPredictorTest, NoInferenceModuleForModel) {
->mutable_translate()
->mutable_translate_logistic_regression_model()
->set_bias(1);
auto predictor = InitPredictor(std::move(ranker_model));
auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_FALSE(predictor->IsReady());
RankerExample ranker_example;
......@@ -94,7 +113,7 @@ TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) {
auto ranker_model = base::MakeUnique<RankerModel>();
*ranker_model->mutable_proto()->mutable_logistic_regression() =
GetSimpleLogisticRegressionModel();
auto predictor = InitPredictor(std::move(ranker_model));
auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example;
......
......@@ -29,7 +29,7 @@ float GenericLogisticRegressionInference::PredictScore(
const FeatureWeight& feature_weight = weight_it.second;
switch (feature_weight.feature_type_case()) {
case FeatureWeight::FEATURE_TYPE_NOT_SET: {
DVLOG(1) << "Feature type not set for " << feature_name;
DVLOG(0) << "Feature type not set for " << feature_name;
break;
}
case FeatureWeight::kScalar: {
......@@ -37,6 +37,8 @@ float GenericLogisticRegressionInference::PredictScore(
if (GetFeatureValueAsFloat(feature_name, example, &value)) {
const float weight = feature_weight.scalar();
activation += value * weight;
} else {
DVLOG(1) << "Feature not in example: " << feature_name;
}
break;
}
......@@ -50,19 +52,22 @@ float GenericLogisticRegressionInference::PredictScore(
} else {
// If the category is not found, use the default weight.
activation += feature_weight.one_hot().default_weight();
DVLOG(1) << "Unknown feature value for " << feature_name << ": "
<< value;
}
} else {
// If the feature is missing, use the default weight.
activation += feature_weight.one_hot().default_weight();
DVLOG(1) << "Feature not in example: " << feature_name;
}
break;
}
case FeatureWeight::kSparse: {
DVLOG(1) << "Sparse features not implemented yet.";
DVLOG(0) << "Sparse features not implemented yet.";
break;
}
case FeatureWeight::kBucketized: {
DVLOG(1) << "Bucketized features not implemented yet.";
DVLOG(0) << "Bucketized features not implemented yet.";
break;
}
}
......
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/predictor_config.h"
namespace assist_ranker {
const base::flat_set<std::string>* GetEmptyWhitelist() {
static auto* whitelist = new base::flat_set<std::string>();
return whitelist;
}
} // namespace assist_ranker
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_
#define COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_
#include <string>
#include "base/containers/flat_set.h"
#include "base/metrics/field_trial_params.h"
namespace assist_ranker {
// TODO(chrome-ranker-team): Implement other logging types.
enum LogType {
LOG_NONE = 0,
LOG_UKM = 1,
};
// Empty feature whitelist used for testing.
const base::flat_set<std::string>* GetEmptyWhitelist();
// This struct holds the config options for logging, loading and field trial
// for a predictor.
struct PredictorConfig {
PredictorConfig(const char* model_name,
const char* logging_name,
const char* uma_prefix,
const LogType log_type,
const base::flat_set<std::string>* feature_whitelist,
const base::Feature* field_trial,
const base::FeatureParam<std::string>* field_trial_url_param)
: model_name(model_name),
logging_name(logging_name),
uma_prefix(uma_prefix),
log_type(log_type),
feature_whitelist(feature_whitelist),
field_trial(field_trial),
field_trial_url_param(field_trial_url_param) {}
const char* model_name;
const char* logging_name;
const char* uma_prefix;
const LogType log_type;
const base::flat_set<std::string>* feature_whitelist;
const base::Feature* field_trial;
const base::FeatureParam<std::string>* field_trial_url_param;
};
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/predictor_config_definitions.h"
namespace assist_ranker {
#if defined(OS_ANDROID)
const base::Feature kContextualSearchRankerQuery{
"ContextualSearchRankerQuery", base::FEATURE_DISABLED_BY_DEFAULT};
namespace {
const char kContextualSearchModelName[] = "contextual_search_model";
const char kContextualSearchLoggingName[] = "ContextualSearch";
const char kContextualSearchUmaPrefixName[] = "Search.ContextualSearch.Ranker";
const char kContextualSearchDefaultModelUrl[] =
"https://www.gstatic.com/chrome/intelligence/assist/ranker/models/"
"contextual_search/test_ranker_model_20171109_short_words_v2.pb.bin";
const base::FeatureParam<std::string>*
GetContextualSearchRankerUrlFeatureParam() {
static auto* kContextualSearchRankerUrl = new base::FeatureParam<std::string>(
&kContextualSearchRankerQuery, "contextual-search-ranker-model-url",
kContextualSearchDefaultModelUrl);
return kContextualSearchRankerUrl;
}
// This list needs to be kept in sync with tools/metrics/ukm/ukm.xml.
// Only features within this list will be logged to UKM.
// TODO(chrome-ranker-team) Deprecate the whitelist once it is available through
// the UKM generated API.
const base::flat_set<std::string>* GetContextualSearchFeatureWhitelist() {
static auto* kContextualSearchFeatureWhitelist =
new base::flat_set<std::string>(
{"DidOptIn", "DurationAfterScrollMs", "IsEntity", "IsLongWord",
"IsShortWord", "IsWordEdge", "OutcomeWasCardsDataShown",
"OutcomeWasPanelOpened", "OutcomeWasQuickActionClicked",
"OutcomeWasQuickAnswerSeen", "Previous28DayCtrPercent",
"Previous28DayImpressionsCount", "PreviousWeekCtrPercent",
"PreviousWeekImpressionsCount", "ScreenTopDps", "TapDuration",
"WasScreenBottom"});
return kContextualSearchFeatureWhitelist;
}
} // namespace
const PredictorConfig GetContextualSearchPredictorConfig() {
static auto kContextualSearchPredictorConfig = *(new PredictorConfig(
kContextualSearchModelName, kContextualSearchLoggingName,
kContextualSearchUmaPrefixName, LOG_UKM,
GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery,
GetContextualSearchRankerUrlFeatureParam()));
return kContextualSearchPredictorConfig;
}
#endif // OS_ANDROID
} // namespace assist_ranker
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_
#define COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_
#include <memory>
#include <string>
#include <unordered_map>
#include "base/feature_list.h"
#include "base/metrics/field_trial_params.h"
#include "build/build_config.h"
#include "components/assist_ranker/predictor_config.h"
namespace assist_ranker {
#if defined(OS_ANDROID)
extern const base::Feature kContextualSearchRankerQuery;
const PredictorConfig GetContextualSearchPredictorConfig();
#endif // OS_ANDROID
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_
......@@ -32,5 +32,9 @@ message RankerExample {
// This field represents the ground truth that the ranker is
// expected to predict, and is typically derived from user feedback. It is
// used for training only and is not required for inference.
// NOTE: this field will not be logged. If you want to log an outcome, add it
// to the features field before calling LogExample.
// TODO(chrome-ranker-team) Add a metadata field to log metrics that are not
// used as model input.
optional Feature target = 2;
}
\ No newline at end of file
......@@ -3,7 +3,10 @@
// found in the LICENSE file.
#include "components/assist_ranker/ranker_example_util.h"
#include "base/format_macros.h"
#include "base/logging.h"
#include "base/metrics/metrics_hashes.h"
#include "base/strings/stringprintf.h"
namespace assist_ranker {
......@@ -42,6 +45,26 @@ bool GetFeatureValueAsFloat(const std::string& key,
return true;
}
bool FeatureToInt(const Feature& feature, int* int_value) {
switch (feature.feature_type_case()) {
case Feature::kBoolValue:
*int_value = static_cast<int>(feature.bool_value());
return true;
case Feature::kInt32Value:
*int_value = feature.int32_value();
return true;
case Feature::kFloatValue:
// TODO(crbug.com/794187): Implement this.
return false;
case Feature::kStringValue:
// TODO(crbug.com/794187): Implement this.
return false;
default:
NOTREACHED();
return false;
}
}
bool GetOneHotValue(const std::string& key,
const RankerExample& example,
std::string* value) {
......@@ -60,4 +83,20 @@ bool GetOneHotValue(const std::string& key,
return true;
}
// Converts string to a hex hash string.
std::string HashFeatureName(const std::string& feature_name) {
uint64_t feature_key = base::HashMetricName(feature_name);
return base::StringPrintf("%016" PRIx64, feature_key);
}
RankerExample HashExampleFeatureNames(const RankerExample& example) {
RankerExample hashed_example;
auto& output_features = *hashed_example.mutable_features();
for (const auto& feature : example.features()) {
output_features[HashFeatureName(feature.first)] = feature.second;
}
*hashed_example.mutable_target() = example.target();
return hashed_example;
}
} // namespace assist_ranker
......@@ -25,6 +25,10 @@ bool GetFeatureValueAsFloat(const std::string& key,
const RankerExample& example,
float* value) WARN_UNUSED_RESULT;
// Returns a int value for this feature. Float values are multiplied by a given
// constant that defines the precision. String features are hashed.
bool FeatureToInt(const Feature& feature, int* int_value) WARN_UNUSED_RESULT;
// Extract category from one-hot feature. Returns true and fills
// in |value| if the feature is found and is of type string_value. Returns false
// otherwise.
......@@ -32,6 +36,15 @@ bool GetOneHotValue(const std::string& key,
const RankerExample& example,
std::string* value) WARN_UNUSED_RESULT;
// Converts a string to a hex ahsh string.
std::string HashFeatureName(const std::string& feature_name);
// Hashes feature names to an hex string.
// Features logged through UKM will apply this transformation when logging
// features, so models trained on UKM data are expected to have hashed input
// feature names.
RankerExample HashExampleFeatureNames(const RankerExample& example);
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_RANKER_EXAMPLE_UTIL_H_
......@@ -103,4 +103,53 @@ TEST_F(RankerExampleUtilTest, GetOneHotValue) {
EXPECT_FALSE(GetOneHotValue("foo", example_, &value));
}
TEST_F(RankerExampleUtilTest, FeatureToInt) {
Feature feature;
int int_value;
feature.set_bool_value(true);
EXPECT_TRUE(FeatureToInt(feature, &int_value));
EXPECT_EQ(1, int_value);
feature.set_bool_value(false);
EXPECT_TRUE(FeatureToInt(feature, &int_value));
EXPECT_EQ(0, int_value);
feature.set_int32_value(42);
EXPECT_TRUE(FeatureToInt(feature, &int_value));
EXPECT_EQ(42, int_value);
feature.set_int32_value(-3);
EXPECT_TRUE(FeatureToInt(feature, &int_value));
EXPECT_EQ(-3, int_value);
// Float and string values are not implemented yet.
feature.set_float_value(12.345f);
EXPECT_FALSE(FeatureToInt(feature, &int_value));
feature.set_string_value("foo");
EXPECT_FALSE(FeatureToInt(feature, &int_value));
}
TEST_F(RankerExampleUtilTest, HashExampleFeatureNames) {
auto hashed_example = HashExampleFeatureNames(example_);
// Hashed example has the same number of features.
EXPECT_EQ(example_.features().size(), hashed_example.features().size());
// But the feature names have changed.
EXPECT_FALSE(SafeGetFeature(bool_name_, hashed_example, nullptr));
EXPECT_FALSE(SafeGetFeature(int32_name_, hashed_example, nullptr));
EXPECT_FALSE(SafeGetFeature(float_name_, hashed_example, nullptr));
EXPECT_FALSE(SafeGetFeature(one_hot_name_, hashed_example, nullptr));
EXPECT_TRUE(
SafeGetFeature(HashFeatureName(bool_name_), hashed_example, nullptr));
// Values have not changed.
float float_value;
EXPECT_TRUE(GetFeatureValueAsFloat(HashFeatureName(float_name_),
hashed_example, &float_value));
EXPECT_EQ(float_value_, float_value);
std::string string_value;
EXPECT_TRUE(GetOneHotValue(HashFeatureName(one_hot_name_), hashed_example,
&string_value));
EXPECT_EQ(one_hot_value_, string_value);
}
} // namespace assist_ranker
......@@ -17,7 +17,6 @@
#include "services/metrics/public/interfaces/ukm_interface.mojom.h"
#include "url/gurl.h"
class ContextualSearchRankerLoggerImpl;
class DocumentWritePageLoadMetricsObserver;
class FromGWSPageLoadMetricsLogger;
class PluginInfoHostImpl;
......@@ -27,6 +26,10 @@ class UkmPageLoadMetricsObserver;
class UseCounterPageLoadMetricsObserver;
class LocalNetworkRequestsPageLoadMetricsObserver;
namespace assist_ranker {
class BasePredictor;
}
namespace blink {
class AutoplayUmaHelper;
}
......@@ -85,7 +88,7 @@ class METRICS_EXPORT UkmRecorder {
virtual void UpdateSourceURL(SourceId source_id, const GURL& url) = 0;
private:
friend ContextualSearchRankerLoggerImpl;
friend assist_ranker::BasePredictor;
friend DelegatingUkmRecorder;
friend DocumentWritePageLoadMetricsObserver;
friend FromGWSPageLoadMetricsLogger;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment