Commit c97535bd authored by Philippe Hamel's avatar Philippe Hamel Committed by Commit Bot

Ranker owns predictor objects and uses a config for initialization.

Ranker also takes care of logging and deals with field trials internally.

Bug: 786472, 778468
Change-Id: Ie583616643ab8ad169df4739a4a18b81a950553f
Reviewed-on: https://chromium-review.googlesource.com/788331
Commit-Queue: Philippe Hamel <hamelphi@chromium.org>
Reviewed-by: default avatarRoger McFarlane <rogerm@chromium.org>
Reviewed-by: default avatarDonn Denman <donnd@chromium.org>
Reviewed-by: default avatarSteven Holte <holte@chromium.org>
Cr-Commit-Position: refs/heads/master@{#524426}
parent 11126aa0
...@@ -193,7 +193,6 @@ public abstract class ChromeFeatureList { ...@@ -193,7 +193,6 @@ public abstract class ChromeFeatureList {
"ContentSuggestionsThumbnailDominantColor"; "ContentSuggestionsThumbnailDominantColor";
public static final String CONTEXTUAL_SEARCH_ML_TAP_SUPPRESSION = public static final String CONTEXTUAL_SEARCH_ML_TAP_SUPPRESSION =
"ContextualSearchMlTapSuppression"; "ContextualSearchMlTapSuppression";
public static final String CONTEXTUAL_SEARCH_RANKER_QUERY = "ContextualSearchRankerQuery";
public static final String CONTEXTUAL_SEARCH_SECOND_TAP = "ContextualSearchSecondTap"; public static final String CONTEXTUAL_SEARCH_SECOND_TAP = "ContextualSearchSecondTap";
public static final String CONTEXTUAL_SUGGESTIONS_CAROUSEL = "ContextualSuggestionsCarousel"; public static final String CONTEXTUAL_SUGGESTIONS_CAROUSEL = "ContextualSuggestionsCarousel";
public static final String COPYLESS_PASTE = "CopylessPaste"; public static final String COPYLESS_PASTE = "CopylessPaste";
......
...@@ -54,6 +54,11 @@ public interface ContextualSearchRankerLogger { ...@@ -54,6 +54,11 @@ public interface ContextualSearchRankerLogger {
*/ */
void logFeature(Feature feature, Object value); void logFeature(Feature feature, Object value);
/**
* Returns whether or not AssistRanker query is enabled.
*/
boolean isQueryEnabled();
/** /**
* Logs an outcome value at training time that indicates an ML label as a key/value pair. * Logs an outcome value at training time that indicates an ML label as a key/value pair.
* @param feature The feature to log. * @param feature The feature to log.
......
...@@ -111,6 +111,12 @@ public class ContextualSearchRankerLoggerImpl implements ContextualSearchRankerL ...@@ -111,6 +111,12 @@ public class ContextualSearchRankerLoggerImpl implements ContextualSearchRankerL
mIsLoggingReadyForPage = true; mIsLoggingReadyForPage = true;
mBasePageWebContents = basePageWebContents; mBasePageWebContents = basePageWebContents;
mHasInferenceOccurred = false; mHasInferenceOccurred = false;
nativeSetupLoggingAndRanker(mNativePointer, basePageWebContents);
}
@Override
public boolean isQueryEnabled() {
return nativeIsQueryEnabled(mNativePointer);
} }
@Override @Override
...@@ -138,7 +144,6 @@ public class ContextualSearchRankerLoggerImpl implements ContextualSearchRankerL ...@@ -138,7 +144,6 @@ public class ContextualSearchRankerLoggerImpl implements ContextualSearchRankerL
mHasInferenceOccurred = true; mHasInferenceOccurred = true;
if (isEnabled() && mBasePageWebContents != null && mFeaturesToLog != null if (isEnabled() && mBasePageWebContents != null && mFeaturesToLog != null
&& !mFeaturesToLog.isEmpty()) { && !mFeaturesToLog.isEmpty()) {
nativeSetupLoggingAndRanker(mNativePointer, mBasePageWebContents);
for (Map.Entry<Feature, Object> entry : mFeaturesToLog.entrySet()) { for (Map.Entry<Feature, Object> entry : mFeaturesToLog.entrySet()) {
logObject(entry.getKey(), entry.getValue()); logObject(entry.getKey(), entry.getValue());
} }
...@@ -260,4 +265,5 @@ public class ContextualSearchRankerLoggerImpl implements ContextualSearchRankerL ...@@ -260,4 +265,5 @@ public class ContextualSearchRankerLoggerImpl implements ContextualSearchRankerL
// Returns an AssistRankerPrediction integer value. // Returns an AssistRankerPrediction integer value.
private native int nativeRunInference(long nativeContextualSearchRankerLoggerImpl); private native int nativeRunInference(long nativeContextualSearchRankerLoggerImpl);
private native void nativeWriteLogAndReset(long nativeContextualSearchRankerLoggerImpl); private native void nativeWriteLogAndReset(long nativeContextualSearchRankerLoggerImpl);
private native boolean nativeIsQueryEnabled(long nativeContextualSearchRankerLoggerImpl);
} }
...@@ -10,7 +10,6 @@ import android.text.TextUtils; ...@@ -10,7 +10,6 @@ import android.text.TextUtils;
import org.chromium.base.Log; import org.chromium.base.Log;
import org.chromium.base.VisibleForTesting; import org.chromium.base.VisibleForTesting;
import org.chromium.chrome.browser.ChromeActivity; import org.chromium.chrome.browser.ChromeActivity;
import org.chromium.chrome.browser.ChromeFeatureList;
import org.chromium.chrome.browser.compositor.bottombar.OverlayPanel; import org.chromium.chrome.browser.compositor.bottombar.OverlayPanel;
import org.chromium.chrome.browser.preferences.PrefServiceBridge; import org.chromium.chrome.browser.preferences.PrefServiceBridge;
import org.chromium.chrome.browser.tab.Tab; import org.chromium.chrome.browser.tab.Tab;
...@@ -380,7 +379,7 @@ public class ContextualSearchSelectionController { ...@@ -380,7 +379,7 @@ public class ContextualSearchSelectionController {
// Make sure Tap Suppression features are consistent. // Make sure Tap Suppression features are consistent.
assert !ContextualSearchFieldTrial.isContextualSearchMlTapSuppressionEnabled() assert !ContextualSearchFieldTrial.isContextualSearchMlTapSuppressionEnabled()
|| ChromeFeatureList.isEnabled(ChromeFeatureList.CONTEXTUAL_SEARCH_RANKER_QUERY) || rankerLogger.isQueryEnabled()
: "Tap Suppression requires the Ranker Query feature to be enabled!"; : "Tap Suppression requires the Ranker Query feature to be enabled!";
// If we're suppressing based on heuristics then Ranker doesn't need to know about it. // If we're suppressing based on heuristics then Ranker doesn't need to know about it.
......
...@@ -85,7 +85,6 @@ const base::Feature* kFeaturesExposedToJava[] = { ...@@ -85,7 +85,6 @@ const base::Feature* kFeaturesExposedToJava[] = {
&kContentSuggestionsSettings, &kContentSuggestionsSettings,
&kContentSuggestionsThumbnailDominantColor, &kContentSuggestionsThumbnailDominantColor,
&kContextualSearchMlTapSuppression, &kContextualSearchMlTapSuppression,
&kContextualSearchRankerQuery,
&kContextualSearchSecondTap, &kContextualSearchSecondTap,
&kContextualSuggestionsCarousel, &kContextualSuggestionsCarousel,
&kCustomContextMenu, &kCustomContextMenu,
...@@ -251,9 +250,6 @@ const base::Feature kContentSuggestionsThumbnailDominantColor{ ...@@ -251,9 +250,6 @@ const base::Feature kContentSuggestionsThumbnailDominantColor{
const base::Feature kContextualSearchMlTapSuppression{ const base::Feature kContextualSearchMlTapSuppression{
"ContextualSearchMlTapSuppression", base::FEATURE_DISABLED_BY_DEFAULT}; "ContextualSearchMlTapSuppression", base::FEATURE_DISABLED_BY_DEFAULT};
const base::Feature kContextualSearchRankerQuery{
"ContextualSearchRankerQuery", base::FEATURE_DISABLED_BY_DEFAULT};
const base::Feature kContextualSearchSecondTap{ const base::Feature kContextualSearchSecondTap{
"ContextualSearchSecondTap", base::FEATURE_DISABLED_BY_DEFAULT}; "ContextualSearchSecondTap", base::FEATURE_DISABLED_BY_DEFAULT};
......
...@@ -44,7 +44,6 @@ extern const base::Feature kContentSuggestionsScrollToLoad; ...@@ -44,7 +44,6 @@ extern const base::Feature kContentSuggestionsScrollToLoad;
extern const base::Feature kContentSuggestionsSettings; extern const base::Feature kContentSuggestionsSettings;
extern const base::Feature kContentSuggestionsThumbnailDominantColor; extern const base::Feature kContentSuggestionsThumbnailDominantColor;
extern const base::Feature kContextualSearchMlTapSuppression; extern const base::Feature kContextualSearchMlTapSuppression;
extern const base::Feature kContextualSearchRankerQuery;
extern const base::Feature kContextualSearchSecondTap; extern const base::Feature kContextualSearchSecondTap;
extern const base::Feature kContextualSuggestionsCarousel; extern const base::Feature kContextualSuggestionsCarousel;
extern const base::Feature kCustomContextMenu; extern const base::Feature kCustomContextMenu;
......
...@@ -7,23 +7,18 @@ ...@@ -7,23 +7,18 @@
#include "base/android/jni_string.h" #include "base/android/jni_string.h"
#include "base/android/scoped_java_ref.h" #include "base/android/scoped_java_ref.h"
#include "base/feature_list.h" #include "base/feature_list.h"
#include "base/metrics/field_trial_params.h"
#include "base/metrics/histogram_functions.h" #include "base/metrics/histogram_functions.h"
#include "base/metrics/metrics_hashes.h"
#include "base/strings/stringprintf.h"
#include "chrome/browser/android/chrome_feature_list.h" #include "chrome/browser/android/chrome_feature_list.h"
#include "chrome/browser/assist_ranker/assist_ranker_service_factory.h" #include "chrome/browser/assist_ranker/assist_ranker_service_factory.h"
#include "chrome/browser/browser_process.h" #include "chrome/browser/browser_process.h"
#include "components/assist_ranker/assist_ranker_service_impl.h" #include "components/assist_ranker/assist_ranker_service_impl.h"
#include "components/assist_ranker/binary_classifier_predictor.h" #include "components/assist_ranker/binary_classifier_predictor.h"
#include "components/assist_ranker/predictor_config_definitions.h"
#include "components/assist_ranker/proto/ranker_example.pb.h" #include "components/assist_ranker/proto/ranker_example.pb.h"
#include "components/keyed_service/core/keyed_service.h" #include "components/keyed_service/core/keyed_service.h"
#include "components/ukm/content/source_url_recorder.h" #include "components/ukm/content/source_url_recorder.h"
#include "content/public/browser/web_contents.h" #include "content/public/browser/web_contents.h"
#include "jni/ContextualSearchRankerLoggerImpl_jni.h" #include "jni/ContextualSearchRankerLoggerImpl_jni.h"
#include "services/metrics/public/cpp/ukm_entry_builder.h"
#include "services/metrics/public/cpp/ukm_recorder.h"
#include "url/gurl.h"
namespace content { namespace content {
class BrowserContext; class BrowserContext;
...@@ -31,111 +26,61 @@ class BrowserContext; ...@@ -31,111 +26,61 @@ class BrowserContext;
namespace { namespace {
const char kContextualSearchRankerModelUrlParamName[] =
"contextual-search-ranker-model-url";
const char kContextualSearchModelFilename[] = "contextual_search_model";
const char kContextualSearchUmaPrefix[] = "Search.ContextualSearch.Ranker";
const char kContextualSearchRankerDidPredict[] = "OutcomeRankerDidPredict"; const char kContextualSearchRankerDidPredict[] = "OutcomeRankerDidPredict";
const char kContextualSearchRankerPrediction[] = "OutcomeRankerPrediction"; const char kContextualSearchRankerPrediction[] = "OutcomeRankerPrediction";
const base::FeatureParam<std::string> kModelUrl{
&chrome::android::kContextualSearchRankerQuery,
kContextualSearchRankerModelUrlParamName, ""};
// TODO(donnd, hamelphi): move hex-hash-string to Ranker.
std::string HexHashFeatureName(const std::string& feature_name) {
uint64_t feature_key = base::HashMetricName(feature_name);
return base::StringPrintf("%016" PRIx64, feature_key);
}
} // namespace } // namespace
ContextualSearchRankerLoggerImpl::ContextualSearchRankerLoggerImpl(JNIEnv* env, ContextualSearchRankerLoggerImpl::ContextualSearchRankerLoggerImpl(JNIEnv* env,
jobject obj) jobject obj)
: source_id_(ukm::kInvalidSourceId), : java_object_(env, obj) {}
builder_(nullptr),
predictor_(nullptr),
browser_context_(nullptr),
ranker_example_(nullptr),
has_predicted_decision_(false),
java_object_(nullptr) {
java_object_.Reset(env, obj);
}
ContextualSearchRankerLoggerImpl::~ContextualSearchRankerLoggerImpl() { ContextualSearchRankerLoggerImpl::~ContextualSearchRankerLoggerImpl() {}
java_object_ = nullptr;
}
void ContextualSearchRankerLoggerImpl::SetupLoggingAndRanker( void ContextualSearchRankerLoggerImpl::SetupLoggingAndRanker(
JNIEnv* env, JNIEnv* env,
jobject obj, jobject obj,
const base::android::JavaParamRef<jobject>& java_web_contents) { const base::android::JavaParamRef<jobject>& java_web_contents) {
content::WebContents* web_contents = web_contents_ = content::WebContents::FromJavaWebContents(java_web_contents);
content::WebContents::FromJavaWebContents(java_web_contents); if (!web_contents_)
if (!web_contents)
return; return;
source_id_ = ukm::GetSourceIdForWebContentsDocument(web_contents); SetupRankerPredictor();
ResetUkmEntry();
if (IsRankerQueryEnabled()) {
SetupRankerPredictor(web_contents);
// Start building example data based on features to be gathered and logged. // Start building example data based on features to be gathered and logged.
ranker_example_.reset(new assist_ranker::RankerExample()); ranker_example_ = std::make_unique<assist_ranker::RankerExample>();
} else {
// TODO(donnd): remove when behind-the-flag bug fixed (crbug.com/786589).
VLOG(1) << "SetupLoggingAndRanker got IsRankerQueryEnabled false.";
}
} }
void ContextualSearchRankerLoggerImpl::ResetUkmEntry() { void ContextualSearchRankerLoggerImpl::SetupRankerPredictor() {
// Releasing the old entry triggers logging.
builder_ =
ukm::UkmRecorder::Get()->GetEntryBuilder(source_id_, "ContextualSearch");
}
void ContextualSearchRankerLoggerImpl::SetupRankerPredictor(
content::WebContents* web_contents) {
// Set up the Ranker predictor.
if (IsRankerQueryEnabled()) {
// Create one predictor for the current BrowserContext. // Create one predictor for the current BrowserContext.
content::BrowserContext* browser_context = if (browser_context_) {
web_contents->GetBrowserContext(); DCHECK(browser_context_ == web_contents_->GetBrowserContext());
if (browser_context == browser_context_)
return; return;
}
browser_context_ = web_contents_->GetBrowserContext();
browser_context_ = browser_context;
assist_ranker::AssistRankerService* assist_ranker_service = assist_ranker::AssistRankerService* assist_ranker_service =
assist_ranker::AssistRankerServiceFactory::GetForBrowserContext( assist_ranker::AssistRankerServiceFactory::GetForBrowserContext(
browser_context); browser_context_);
DCHECK(assist_ranker_service); if (assist_ranker_service) {
std::string model_string(kModelUrl.Get());
DCHECK(model_string.size());
// TODO(donnd): remove when behind-the-flag bug fixed (crbug.com/786589).
VLOG(0) << "Model URL: " << model_string;
predictor_ = assist_ranker_service->FetchBinaryClassifierPredictor( predictor_ = assist_ranker_service->FetchBinaryClassifierPredictor(
GURL(model_string), kContextualSearchModelFilename, assist_ranker::GetContextualSearchPredictorConfig());
kContextualSearchUmaPrefix);
DCHECK(predictor_);
} }
} }
void ContextualSearchRankerLoggerImpl::LogFeature(
const std::string& feature_name,
int value) {
auto& features = *ranker_example_->mutable_features();
features[feature_name].set_int32_value(value);
}
void ContextualSearchRankerLoggerImpl::LogLong( void ContextualSearchRankerLoggerImpl::LogLong(
JNIEnv* env, JNIEnv* env,
jobject obj, jobject obj,
const base::android::JavaParamRef<jstring>& j_feature, const base::android::JavaParamRef<jstring>& j_feature,
jlong j_long) { jlong j_long) {
std::string feature = base::android::ConvertJavaStringToUTF8(env, j_feature); std::string feature = base::android::ConvertJavaStringToUTF8(env, j_feature);
if (builder_) LogFeature(feature, j_long);
builder_->AddMetric(feature.c_str(), j_long);
// Also write to Ranker if we're logging data needed to predict a decision.
if (IsRankerQueryEnabled() && !has_predicted_decision_) {
std::string hex_feature_key(HexHashFeatureName(feature));
auto& features = *ranker_example_->mutable_features();
features[hex_feature_key].set_int32_value(j_long);
}
} }
AssistRankerPrediction ContextualSearchRankerLoggerImpl::RunInference( AssistRankerPrediction ContextualSearchRankerLoggerImpl::RunInference(
...@@ -144,19 +89,17 @@ AssistRankerPrediction ContextualSearchRankerLoggerImpl::RunInference( ...@@ -144,19 +89,17 @@ AssistRankerPrediction ContextualSearchRankerLoggerImpl::RunInference(
has_predicted_decision_ = true; has_predicted_decision_ = true;
bool prediction = false; bool prediction = false;
bool was_able_to_predict = false; bool was_able_to_predict = false;
if (IsRankerQueryEnabled()) { if (IsQueryEnabledInternal()) {
was_able_to_predict = predictor_->Predict(*ranker_example_, &prediction); was_able_to_predict = predictor_->Predict(*ranker_example_, &prediction);
// Log to UMA whether we were able to predict or not. // Log to UMA whether we were able to predict or not.
base::UmaHistogramBoolean("Search.ContextualSearchRankerWasAbleToPredict", base::UmaHistogramBoolean("Search.ContextualSearchRankerWasAbleToPredict",
was_able_to_predict); was_able_to_predict);
// Log the Ranker decision to UKM, including whether we were able to make // TODO(chrome-ranker-team): this should be logged internally by Ranker.
// any prediction. LogFeature(kContextualSearchRankerDidPredict,
if (builder_) { static_cast<int>(was_able_to_predict));
builder_->AddMetric(kContextualSearchRankerDidPredict,
was_able_to_predict);
if (was_able_to_predict) { if (was_able_to_predict) {
builder_->AddMetric(kContextualSearchRankerPrediction, prediction); LogFeature(kContextualSearchRankerPrediction,
} static_cast<int>(prediction));
} }
} }
AssistRankerPrediction prediction_enum; AssistRankerPrediction prediction_enum;
...@@ -170,21 +113,28 @@ AssistRankerPrediction ContextualSearchRankerLoggerImpl::RunInference( ...@@ -170,21 +113,28 @@ AssistRankerPrediction ContextualSearchRankerLoggerImpl::RunInference(
prediction_enum = ASSIST_RANKER_PREDICTION_UNAVAILABLE; prediction_enum = ASSIST_RANKER_PREDICTION_UNAVAILABLE;
} }
// TODO(donnd): remove when behind-the-flag bug fixed (crbug.com/786589). // TODO(donnd): remove when behind-the-flag bug fixed (crbug.com/786589).
VLOG(0) << "prediction: " << prediction_enum; DVLOG(0) << "prediction: " << prediction_enum;
return prediction_enum; return prediction_enum;
} }
void ContextualSearchRankerLoggerImpl::WriteLogAndReset(JNIEnv* env, void ContextualSearchRankerLoggerImpl::WriteLogAndReset(JNIEnv* env,
jobject obj) { jobject obj) {
if (predictor_ && ranker_example_) {
ukm::SourceId source_id =
ukm::GetSourceIdForWebContentsDocument(web_contents_);
predictor_->LogExampleToUkm(*ranker_example_.get(), source_id);
}
has_predicted_decision_ = false; has_predicted_decision_ = false;
// Set up another builder for the next record (in case it's needed). ranker_example_ = std::make_unique<assist_ranker::RankerExample>();
ResetUkmEntry(); }
ranker_example_.reset();
bool ContextualSearchRankerLoggerImpl::IsQueryEnabled(JNIEnv* env,
jobject obj) {
return IsQueryEnabledInternal();
} }
bool ContextualSearchRankerLoggerImpl::IsRankerQueryEnabled() { bool ContextualSearchRankerLoggerImpl::IsQueryEnabledInternal() {
return base::FeatureList::IsEnabled( return predictor_ && predictor_->is_query_enabled();
chrome::android::kContextualSearchRankerQuery);
} }
// Java wrapper boilerplate // Java wrapper boilerplate
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#define CHROME_BROWSER_ANDROID_CONTEXTUALSEARCH_CONTEXTUAL_SEARCH_RANKER_LOGGER_IMPL_H_ #define CHROME_BROWSER_ANDROID_CONTEXTUALSEARCH_CONTEXTUAL_SEARCH_RANKER_LOGGER_IMPL_H_
#include "base/android/jni_android.h" #include "base/android/jni_android.h"
#include "base/memory/weak_ptr.h"
namespace content { namespace content {
class BrowserContext; class BrowserContext;
...@@ -17,10 +18,6 @@ class BinaryClassifierPredictor; ...@@ -17,10 +18,6 @@ class BinaryClassifierPredictor;
class RankerExample; class RankerExample;
} // namespace assist_ranker } // namespace assist_ranker
namespace ukm {
class UkmEntryBuilder;
} // namespace ukm
// A Java counterpart will be generated for this enum. // A Java counterpart will be generated for this enum.
// GENERATED_JAVA_ENUM_PACKAGE: org.chromium.chrome.browser.contextualsearch // GENERATED_JAVA_ENUM_PACKAGE: org.chromium.chrome.browser.contextualsearch
enum AssistRankerPrediction { enum AssistRankerPrediction {
...@@ -65,28 +62,30 @@ class ContextualSearchRankerLoggerImpl { ...@@ -65,28 +62,30 @@ class ContextualSearchRankerLoggerImpl {
// ready to start logging the next set of data. // ready to start logging the next set of data.
void WriteLogAndReset(JNIEnv* env, jobject obj); void WriteLogAndReset(JNIEnv* env, jobject obj);
private: // Returns whether or not AssistRanker query is enabled.
// Log the current UKM entry (if any) and start a new one. bool IsQueryEnabled(JNIEnv* env, jobject obj);
// TODO(donnd): write a test using TestAutoSetUkmRecorder.
void ResetUkmEntry();
// Sets up the Ranker Predictor for the given |web_contents|. private:
void SetupRankerPredictor(content::WebContents* web_contents); // Returns whether or not AssistRanker query is enabled.
bool IsQueryEnabledInternal();
// Whether querying Ranker for model loading and prediction is enabled. // Adds feature to the RankerExample.
bool IsRankerQueryEnabled(); void LogFeature(const std::string& feature_name, int value);
// The UKM source ID being used for this session. // Sets up the Ranker Predictor for the given |web_contents|.
int32_t source_id_; void SetupRankerPredictor();
// The entry builder for the current record, or nullptr if not yet configured. // The WebContents object used to produce the source_id for UKMs, and to get
std::unique_ptr<ukm::UkmEntryBuilder> builder_; // browser_context when fetching the predictor. The object is not owned by
// ContextualSearchRankerLoggerImpl.
content::WebContents* web_contents_ = nullptr;
// The Ranker Predictor for whether a tap gesture should be suppressed or not. // The Ranker Predictor for whether a tap gesture should be suppressed or not.
std::unique_ptr<assist_ranker::BinaryClassifierPredictor> predictor_; base::WeakPtr<assist_ranker::BinaryClassifierPredictor> predictor_;
// The |BrowserContext| currently associated with the above predictor. // The |BrowserContext| currently associated with the above predictor.
content::BrowserContext* browser_context_; // The object not owned by ContextualSearchRankerLoggerImpl.
content::BrowserContext* browser_context_ = nullptr;
// The current RankerExample or null. // The current RankerExample or null.
// Set of features from one example of a Tap to predict a suppression // Set of features from one example of a Tap to predict a suppression
...@@ -94,7 +93,7 @@ class ContextualSearchRankerLoggerImpl { ...@@ -94,7 +93,7 @@ class ContextualSearchRankerLoggerImpl {
std::unique_ptr<assist_ranker::RankerExample> ranker_example_; std::unique_ptr<assist_ranker::RankerExample> ranker_example_;
// Whether Ranker has predicted the decision yet. // Whether Ranker has predicted the decision yet.
bool has_predicted_decision_; bool has_predicted_decision_ = false;
// The linked Java object. // The linked Java object.
base::android::ScopedJavaGlobalRef<jobject> java_object_; base::android::ScopedJavaGlobalRef<jobject> java_object_;
......
...@@ -15,6 +15,10 @@ static_library("assist_ranker") { ...@@ -15,6 +15,10 @@ static_library("assist_ranker") {
"fake_ranker_model_loader.h", "fake_ranker_model_loader.h",
"generic_logistic_regression_inference.cc", "generic_logistic_regression_inference.cc",
"generic_logistic_regression_inference.h", "generic_logistic_regression_inference.h",
"predictor_config.cc",
"predictor_config.h",
"predictor_config_definitions.cc",
"predictor_config_definitions.h",
"ranker_example_util.cc", "ranker_example_util.cc",
"ranker_example_util.h", "ranker_example_util.h",
"ranker_model.cc", "ranker_model.cc",
...@@ -32,6 +36,7 @@ static_library("assist_ranker") { ...@@ -32,6 +36,7 @@ static_library("assist_ranker") {
"//components/data_use_measurement/core", "//components/data_use_measurement/core",
"//components/keyed_service/core", "//components/keyed_service/core",
"//net", "//net",
"//services/metrics/public/cpp:metrics_cpp",
"//url", "//url",
] ]
} }
...@@ -40,6 +45,7 @@ source_set("unit_tests") { ...@@ -40,6 +45,7 @@ source_set("unit_tests") {
testonly = true testonly = true
sources = [ sources = [
"base_predictor_unittest.cc",
"binary_classifier_predictor_unittest.cc", "binary_classifier_predictor_unittest.cc",
"generic_logistic_regression_inference_unittest.cc", "generic_logistic_regression_inference_unittest.cc",
"ranker_example_util_unittest.cc", "ranker_example_util_unittest.cc",
...@@ -51,6 +57,7 @@ source_set("unit_tests") { ...@@ -51,6 +57,7 @@ source_set("unit_tests") {
":assist_ranker", ":assist_ranker",
"//base", "//base",
"//components/assist_ranker/proto", "//components/assist_ranker/proto",
"//components/ukm:test_support",
"//net:test_support", "//net:test_support",
"//testing/gtest", "//testing/gtest",
] ]
......
...@@ -2,5 +2,7 @@ include_rules = [ ...@@ -2,5 +2,7 @@ include_rules = [
"+components/data_use_measurement/core", "+components/data_use_measurement/core",
"+components/keyed_service/core", "+components/keyed_service/core",
"+components/metrics", "+components/metrics",
"+components/ukm",
"+net", "+net",
"+services/metrics/public",
] ]
\ No newline at end of file
...@@ -9,32 +9,25 @@ ...@@ -9,32 +9,25 @@
#include <string> #include <string>
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "components/keyed_service/core/keyed_service.h" #include "components/keyed_service/core/keyed_service.h"
class GURL;
namespace assist_ranker { namespace assist_ranker {
class BinaryClassifierPredictor; class BinaryClassifierPredictor;
struct PredictorConfig;
// TODO(crbug.com/778468) : Refactor this so that the service owns the predictor
// objects and enforce model uniqueness through internal registration in order
// to avoid cache collisions.
//
// Service that provides Predictor objects. // Service that provides Predictor objects.
class AssistRankerService : public KeyedService { class AssistRankerService : public KeyedService {
public: public:
AssistRankerService() = default; AssistRankerService() = default;
// Returns a binary classification model. |model_filename| is the filename of // Returns a binary classification model given a PredictorConfig.
// the cached model. It should be unique to this predictor to avoid cache // The predictor is instantiated the first time a predictor is fetched. The
// collision. |model_url| represents a unique ID for the desired model (see // next calls to fetch will return a pointer to the already instantiated
// ranker_model_loader.h for more details). |uma_prefix| is used to log // predictor.
// histograms related to the loading of the model. virtual base::WeakPtr<BinaryClassifierPredictor>
virtual std::unique_ptr<BinaryClassifierPredictor> FetchBinaryClassifierPredictor(const PredictorConfig& config) = 0;
FetchBinaryClassifierPredictor(GURL model_url,
const std::string& model_filename,
const std::string& uma_prefix) = 0;
private: private:
DISALLOW_COPY_AND_ASSIGN(AssistRankerService); DISALLOW_COPY_AND_ASSIGN(AssistRankerService);
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "components/assist_ranker/assist_ranker_service_impl.h" #include "components/assist_ranker/assist_ranker_service_impl.h"
#include "base/memory/weak_ptr.h"
#include "components/assist_ranker/binary_classifier_predictor.h" #include "components/assist_ranker/binary_classifier_predictor.h"
#include "components/assist_ranker/ranker_model_loader_impl.h" #include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h" #include "net/url_request/url_request_context_getter.h"
...@@ -19,15 +19,27 @@ AssistRankerServiceImpl::AssistRankerServiceImpl( ...@@ -19,15 +19,27 @@ AssistRankerServiceImpl::AssistRankerServiceImpl(
AssistRankerServiceImpl::~AssistRankerServiceImpl() {} AssistRankerServiceImpl::~AssistRankerServiceImpl() {}
std::unique_ptr<BinaryClassifierPredictor> base::WeakPtr<BinaryClassifierPredictor>
AssistRankerServiceImpl::FetchBinaryClassifierPredictor( AssistRankerServiceImpl::FetchBinaryClassifierPredictor(
GURL model_url, const PredictorConfig& config) {
const std::string& model_filename,
const std::string& uma_prefix) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return BinaryClassifierPredictor::Create(url_request_context_getter_.get(), const std::string& model_name = config.model_name;
GetModelPath(model_filename), auto predictor_it = predictor_map_.find(model_name);
model_url, uma_prefix); if (predictor_it != predictor_map_.end()) {
DVLOG(1) << "Predictor " << model_name << " already initialized.";
return base::AsWeakPtr(
static_cast<BinaryClassifierPredictor*>(predictor_it->second.get()));
}
// The predictor does not exist yet, so we create one.
DVLOG(1) << "Initializing predictor: " << model_name;
std::unique_ptr<BinaryClassifierPredictor> predictor =
BinaryClassifierPredictor::Create(config, GetModelPath(model_name),
url_request_context_getter_.get());
base::WeakPtr<BinaryClassifierPredictor> weak_ptr =
base::AsWeakPtr(predictor.get());
predictor_map_[model_name] = std::move(predictor);
return weak_ptr;
} }
base::FilePath AssistRankerServiceImpl::GetModelPath( base::FilePath AssistRankerServiceImpl::GetModelPath(
......
...@@ -7,13 +7,13 @@ ...@@ -7,13 +7,13 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/sequence_checker.h" #include "base/sequence_checker.h"
#include "components/assist_ranker/assist_ranker_service.h" #include "components/assist_ranker/assist_ranker_service.h"
#include "components/assist_ranker/predictor_config.h"
class GURL;
namespace net { namespace net {
class URLRequestContextGetter; class URLRequestContextGetter;
...@@ -21,6 +21,7 @@ class URLRequestContextGetter; ...@@ -21,6 +21,7 @@ class URLRequestContextGetter;
namespace assist_ranker { namespace assist_ranker {
class BasePredictor;
class BinaryClassifierPredictor; class BinaryClassifierPredictor;
class AssistRankerServiceImpl : public AssistRankerService { class AssistRankerServiceImpl : public AssistRankerService {
...@@ -31,10 +32,8 @@ class AssistRankerServiceImpl : public AssistRankerService { ...@@ -31,10 +32,8 @@ class AssistRankerServiceImpl : public AssistRankerService {
~AssistRankerServiceImpl() override; ~AssistRankerServiceImpl() override;
// AssistRankerService... // AssistRankerService...
std::unique_ptr<BinaryClassifierPredictor> FetchBinaryClassifierPredictor( base::WeakPtr<BinaryClassifierPredictor> FetchBinaryClassifierPredictor(
GURL model_url, const PredictorConfig& config) override;
const std::string& model_filename,
const std::string& uma_prefix) override;
private: private:
// Returns the full path to the model cache. // Returns the full path to the model cache.
...@@ -46,6 +45,9 @@ class AssistRankerServiceImpl : public AssistRankerService { ...@@ -46,6 +45,9 @@ class AssistRankerServiceImpl : public AssistRankerService {
// Base path where models are stored. // Base path where models are stored.
const base::FilePath base_path_; const base::FilePath base_path_;
std::unordered_map<std::string, std::unique_ptr<BasePredictor>>
predictor_map_;
SEQUENCE_CHECKER(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_);
DISALLOW_COPY_AND_ASSIGN(AssistRankerServiceImpl); DISALLOW_COPY_AND_ASSIGN(AssistRankerServiceImpl);
......
...@@ -4,23 +4,40 @@ ...@@ -4,23 +4,40 @@
#include "components/assist_ranker/base_predictor.h" #include "components/assist_ranker/base_predictor.h"
#include "base/feature_list.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
#include "components/assist_ranker/proto/ranker_model.pb.h" #include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_example_util.h"
#include "components/assist_ranker/ranker_model.h" #include "components/assist_ranker/ranker_model.h"
#include "services/metrics/public/cpp/ukm_entry_builder.h"
#include "services/metrics/public/cpp/ukm_recorder.h"
#include "url/gurl.h"
namespace assist_ranker { namespace assist_ranker {
BasePredictor::BasePredictor() {} BasePredictor::BasePredictor(const PredictorConfig& config) : config_(config) {
// TODO(chrome-ranker-team): validate config.
if (config_.field_trial) {
is_query_enabled_ = base::FeatureList::IsEnabled(*config_.field_trial);
} else {
DVLOG(0) << "No field trial specified";
}
}
BasePredictor::~BasePredictor() {} BasePredictor::~BasePredictor() {}
void BasePredictor::LoadModel(std::unique_ptr<RankerModelLoader> model_loader) { void BasePredictor::LoadModel(std::unique_ptr<RankerModelLoader> model_loader) {
if (!is_query_enabled_)
return;
if (model_loader_) { if (model_loader_) {
DLOG(ERROR) << "This predictor already has a model loader."; DVLOG(0) << "This predictor already has a model loader.";
return; return;
} }
// Take ownership of the model loader. // Take ownership of the model loader.
model_loader_ = std::move(model_loader); model_loader_ = std::move(model_loader);
// Kick off the initial load from cache. // Kick off the initial model load.
model_loader_->NotifyOfRankerActivity(); model_loader_->NotifyOfRankerActivity();
} }
...@@ -31,10 +48,80 @@ void BasePredictor::OnModelAvailable( ...@@ -31,10 +48,80 @@ void BasePredictor::OnModelAvailable(
} }
bool BasePredictor::IsReady() { bool BasePredictor::IsReady() {
if (!is_ready_) if (!is_ready_ && model_loader_)
model_loader_->NotifyOfRankerActivity(); model_loader_->NotifyOfRankerActivity();
return is_ready_; return is_ready_;
} }
void BasePredictor::LogFeatureToUkm(const std::string& feature_name,
const Feature& feature,
ukm::UkmEntryBuilder* ukm_builder) {
if (!ukm_builder) {
return;
}
if (!base::ContainsKey(*config_.feature_whitelist, feature_name)) {
DVLOG(1) << "Feature not whitelisted: " << feature_name;
return;
}
int feature_int_value;
if (FeatureToInt(feature, &feature_int_value)) {
DVLOG(3) << "Logging: " << feature_name << ": " << feature_int_value;
ukm_builder->AddMetric(feature_name.c_str(), feature_int_value);
} else {
DVLOG(0) << "Could not convert feature to int: " << feature_name;
}
}
void BasePredictor::LogExampleToUkm(const RankerExample& example,
ukm::SourceId source_id) {
if (config_.log_type != LOG_UKM) {
DVLOG(0) << "Wrong log type in predictor config: " << config_.log_type;
return;
}
if (!config_.feature_whitelist) {
DVLOG(0) << "No whitelist specified.";
return;
}
if (config_.feature_whitelist->empty()) {
DVLOG(0) << "Empty whitelist, examples will not be logged.";
return;
}
// Releasing the builder will trigger logging.
std::unique_ptr<ukm::UkmEntryBuilder> builder =
ukm::UkmRecorder::Get()->GetEntryBuilder(source_id, config_.logging_name);
if (builder) {
for (const auto& feature_kv : example.features()) {
LogFeatureToUkm(feature_kv.first, feature_kv.second, builder.get());
}
} else {
DVLOG(0) << "Could not get UKM Entry Builder.";
}
}
std::string BasePredictor::GetModelName() const {
return config_.model_name;
}
GURL BasePredictor::GetModelUrl() const {
if (!config_.field_trial_url_param) {
DVLOG(1) << "No URL specified.";
return GURL();
}
return GURL(config_.field_trial_url_param->Get());
}
RankerExample BasePredictor::PreprocessExample(const RankerExample& example) {
if (ranker_model_->proto().has_metadata() &&
ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) {
return HashExampleFeatureNames(example);
}
return example;
}
} // namespace assist_ranker } // namespace assist_ranker
...@@ -9,30 +9,53 @@ ...@@ -9,30 +9,53 @@
#include <string> #include <string>
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/memory/weak_ptr.h"
#include "components/assist_ranker/predictor_config.h"
#include "components/assist_ranker/ranker_model_loader.h" #include "components/assist_ranker/ranker_model_loader.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
class GURL;
namespace ukm {
class UkmEntryBuilder;
}
namespace assist_ranker { namespace assist_ranker {
class Feature;
class RankerExample;
class RankerModel; class RankerModel;
// Predictors are objects that provide an interface for prediction, as well as // Predictors are objects that provide an interface for prediction, as well as
// encapsulate the logic for loading the model. Sub-classes of BasePredictor // encapsulate the logic for loading the model and logging. Sub-classes of
// implement an interface that depends on the nature of the suported model. // BasePredictor implement an interface that depends on the nature of the
// Subclasses of BasePredictor will also need to implement an Initialize method // suported model. Subclasses of BasePredictor will also need to implement an
// that will be called once the model is available, and a static validation // Initialize method that will be called once the model is available, and a
// function with the following signature: // static validation function with the following signature:
// //
// static RankerModelStatus ValidateModel(const RankerModel& model); // static RankerModelStatus ValidateModel(const RankerModel& model);
class BasePredictor { class BasePredictor : public base::SupportsWeakPtr<BasePredictor> {
public: public:
BasePredictor(); BasePredictor(const PredictorConfig& config);
virtual ~BasePredictor(); virtual ~BasePredictor();
// Returns true if the predictor is ready to make predictions.
bool IsReady(); bool IsReady();
// Returns true if the base::Feature associated with this model is enabled.
bool is_query_enabled() const { return is_query_enabled_; }
// Logs the features of |example| to UKM using the given source_id.
void LogExampleToUkm(const RankerExample& example, ukm::SourceId source_id);
// Returns the model URL.
GURL GetModelUrl() const;
// Returns the model name.
std::string GetModelName() const;
protected: protected:
// The model used for prediction. // Preprocessing applied to an example before prediction. The original
std::unique_ptr<RankerModel> ranker_model_; // RankerExample is not modified, so it is safe to use it later for logging.
RankerExample PreprocessExample(const RankerExample& example);
// Called when the RankerModelLoader has finished loading the model. Returns // Called when the RankerModelLoader has finished loading the model. Returns
// true only if the model was succesfully loaded and is ready to predict. // true only if the model was succesfully loaded and is ready to predict.
...@@ -43,9 +66,17 @@ class BasePredictor { ...@@ -43,9 +66,17 @@ class BasePredictor {
// Called once the model loader as succesfully loaded the model. // Called once the model loader as succesfully loaded the model.
void OnModelAvailable(std::unique_ptr<RankerModel> model); void OnModelAvailable(std::unique_ptr<RankerModel> model);
std::unique_ptr<RankerModelLoader> model_loader_; std::unique_ptr<RankerModelLoader> model_loader_;
// The model used for prediction.
std::unique_ptr<RankerModel> ranker_model_;
private: private:
void LogFeatureToUkm(const std::string& feature_name,
const Feature& feature,
ukm::UkmEntryBuilder* ukm_builder);
bool is_ready_ = false; bool is_ready_ = false;
bool is_query_enabled_ = false;
PredictorConfig config_;
DISALLOW_COPY_AND_ASSIGN(BasePredictor); DISALLOW_COPY_AND_ASSIGN(BasePredictor);
}; };
......
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/base_predictor.h"
#include <memory>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/memory/ptr_util.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/scoped_task_environment.h"
#include "components/assist_ranker/fake_ranker_model_loader.h"
#include "components/assist_ranker/predictor_config.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
#include "components/assist_ranker/ranker_model.h"
#include "components/ukm/test_ukm_recorder.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"
namespace assist_ranker {
using ::assist_ranker::testing::FakeRankerModelLoader;
namespace {
// Predictor config for testing.
const char kTestModelName[] = "test_model";
const char kTestLoggingName[] = "TestLoggingName";
const char kTestUmaPrefixName[] = "Test.Ranker";
const char kTestUrlParamName[] = "ranker-model-url";
const char kTestDefaultModelUrl[] = "https://foo.bar/model.bin";
const char kBoolFeature[] = "bool_feature";
const char kIntFeature[] = "int_feature";
const char kFloatFeature[] = "float_feature";
const char kStringFeature[] = "string_feature";
const char kFeatureNotWhitelisted[] = "not_whitelisted";
const char kTestNavigationUrl[] = "https://foo.com";
const base::flat_set<std::string> kFeatureWhitelist({kBoolFeature, kIntFeature,
kFloatFeature,
kStringFeature});
const base::Feature kTestRankerQuery{"TestRankerQuery",
base::FEATURE_ENABLED_BY_DEFAULT};
const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl};
const PredictorConfig kTestPredictorConfig = PredictorConfig{
kTestModelName, kTestLoggingName, kTestUmaPrefixName, LOG_UKM,
&kFeatureWhitelist, &kTestRankerQuery, &kTestRankerUrl};
// Class that implements virtual functions of the base class.
class FakePredictor : public BasePredictor {
public:
static std::unique_ptr<FakePredictor> Create();
~FakePredictor() override{};
// Validation will always succeed.
static RankerModelStatus ValidateModel(const RankerModel& model);
protected:
// Not implementing any inference logic.
bool Initialize() override { return true; };
private:
FakePredictor(const PredictorConfig& config);
DISALLOW_COPY_AND_ASSIGN(FakePredictor);
};
FakePredictor::FakePredictor(const PredictorConfig& config)
: BasePredictor(config) {}
RankerModelStatus FakePredictor::ValidateModel(const RankerModel& model) {
return RankerModelStatus::OK;
}
std::unique_ptr<FakePredictor> FakePredictor::Create() {
std::unique_ptr<FakePredictor> predictor(
new FakePredictor(kTestPredictorConfig));
auto ranker_model = base::MakeUnique<RankerModel>();
auto fake_model_loader = base::MakeUnique<FakeRankerModelLoader>(
base::BindRepeating(&FakePredictor::ValidateModel),
base::BindRepeating(&FakePredictor::OnModelAvailable,
base::Unretained(predictor.get())),
std::move(ranker_model));
predictor->LoadModel(std::move(fake_model_loader));
return predictor;
}
} // namespace
class BasePredictorTest : public ::testing::Test {
protected:
BasePredictorTest() = default;
// Disables Query for the test predictor.
void DisableQuery();
ukm::SourceId GetSourceId();
ukm::TestUkmRecorder* GetTestUkmRecorder() { return &test_ukm_recorder_; }
private:
// Sets up the task scheduling/task-runner environment for each test.
base::test::ScopedTaskEnvironment scoped_task_environment_;
// Sets itself as the global UkmRecorder on construction.
ukm::TestAutoSetUkmRecorder test_ukm_recorder_;
// Manages the enabling/disabling of features within the scope of a test.
base::test::ScopedFeatureList scoped_feature_list_;
DISALLOW_COPY_AND_ASSIGN(BasePredictorTest);
};
ukm::SourceId BasePredictorTest::GetSourceId() {
ukm::SourceId source_id = ukm::UkmRecorder::GetNewSourceID();
test_ukm_recorder_.UpdateSourceURL(source_id, GURL(kTestNavigationUrl));
return source_id;
}
void BasePredictorTest::DisableQuery() {
scoped_feature_list_.InitWithFeatures({}, {kTestRankerQuery});
}
TEST_F(BasePredictorTest, BaseTest) {
auto predictor = FakePredictor::Create();
EXPECT_EQ(kTestModelName, predictor->GetModelName());
EXPECT_EQ(kTestDefaultModelUrl, predictor->GetModelUrl());
EXPECT_TRUE(predictor->is_query_enabled());
EXPECT_TRUE(predictor->IsReady());
}
TEST_F(BasePredictorTest, QueryDisabled) {
DisableQuery();
auto predictor = FakePredictor::Create();
EXPECT_EQ(kTestModelName, predictor->GetModelName());
EXPECT_EQ(kTestDefaultModelUrl, predictor->GetModelUrl());
EXPECT_FALSE(predictor->is_query_enabled());
EXPECT_FALSE(predictor->IsReady());
}
TEST_F(BasePredictorTest, LogExampleToUkm) {
auto predictor = FakePredictor::Create();
RankerExample example;
auto& features = *example.mutable_features();
features[kBoolFeature].set_bool_value(true);
features[kIntFeature].set_int32_value(42);
features[kFloatFeature].set_float_value(42.0f);
features[kStringFeature].set_string_value("42");
// This feature will not be logged.
features[kFeatureNotWhitelisted].set_bool_value(false);
predictor->LogExampleToUkm(example, GetSourceId());
EXPECT_EQ(1U, GetTestUkmRecorder()->sources_count());
EXPECT_EQ(1U, GetTestUkmRecorder()->entries_count());
std::vector<const ukm::mojom::UkmEntry*> entries =
GetTestUkmRecorder()->GetEntriesByName(kTestLoggingName);
EXPECT_EQ(1U, entries.size());
GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kBoolFeature, 1);
GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kIntFeature, 42);
// TODO(crbug.com/794187) Float and string features are not logged yet.
EXPECT_FALSE(GetTestUkmRecorder()->EntryHasMetric(entries[0], kFloatFeature));
EXPECT_FALSE(
GetTestUkmRecorder()->EntryHasMetric(entries[0], kStringFeature));
EXPECT_FALSE(
GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted));
}
} // namespace assist_ranker
...@@ -15,26 +15,33 @@ ...@@ -15,26 +15,33 @@
#include "components/assist_ranker/ranker_model.h" #include "components/assist_ranker/ranker_model.h"
#include "components/assist_ranker/ranker_model_loader_impl.h" #include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h" #include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h"
namespace assist_ranker { namespace assist_ranker {
BinaryClassifierPredictor::BinaryClassifierPredictor(){}; BinaryClassifierPredictor::BinaryClassifierPredictor(
const PredictorConfig& config)
: BasePredictor(config){};
BinaryClassifierPredictor::~BinaryClassifierPredictor(){}; BinaryClassifierPredictor::~BinaryClassifierPredictor(){};
// static // static
std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create( std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
net::URLRequestContextGetter* request_context_getter, const PredictorConfig& config,
const base::FilePath& model_path, const base::FilePath& model_path,
GURL model_url, net::URLRequestContextGetter* request_context_getter) {
const std::string& uma_prefix) {
std::unique_ptr<BinaryClassifierPredictor> predictor( std::unique_ptr<BinaryClassifierPredictor> predictor(
new BinaryClassifierPredictor()); new BinaryClassifierPredictor(config));
if (!predictor->is_query_enabled()) {
DVLOG(1) << "Query disabled, bypassing model loading.";
return predictor;
}
const GURL& model_url = predictor->GetModelUrl();
DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
DVLOG(1) << "Model URL: " << model_url;
auto model_loader = base::MakeUnique<RankerModelLoaderImpl>( auto model_loader = base::MakeUnique<RankerModelLoaderImpl>(
base::Bind(&BinaryClassifierPredictor::ValidateModel), base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
base::Bind(&BinaryClassifierPredictor::OnModelAvailable, base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())), base::Unretained(predictor.get())),
request_context_getter, model_path, model_url, uma_prefix); request_context_getter, model_path, model_url, config.uma_prefix);
predictor->LoadModel(std::move(model_loader)); predictor->LoadModel(std::move(model_loader));
return predictor; return predictor;
} }
...@@ -42,18 +49,23 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create( ...@@ -42,18 +49,23 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
bool BinaryClassifierPredictor::Predict(const RankerExample& example, bool BinaryClassifierPredictor::Predict(const RankerExample& example,
bool* prediction) { bool* prediction) {
if (!IsReady()) { if (!IsReady()) {
DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
return false; return false;
} }
*prediction = inference_module_->Predict(example);
*prediction = inference_module_->Predict(PreprocessExample(example));
DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction;
return true; return true;
} }
bool BinaryClassifierPredictor::PredictScore(const RankerExample& example, bool BinaryClassifierPredictor::PredictScore(const RankerExample& example,
float* prediction) { float* prediction) {
if (!IsReady()) { if (!IsReady()) {
DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
return false; return false;
} }
*prediction = inference_module_->PredictScore(example); *prediction = inference_module_->PredictScore(PreprocessExample(example));
DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << prediction;
return true; return true;
} }
...@@ -61,17 +73,22 @@ bool BinaryClassifierPredictor::PredictScore(const RankerExample& example, ...@@ -61,17 +73,22 @@ bool BinaryClassifierPredictor::PredictScore(const RankerExample& example,
RankerModelStatus BinaryClassifierPredictor::ValidateModel( RankerModelStatus BinaryClassifierPredictor::ValidateModel(
const RankerModel& model) { const RankerModel& model) {
if (model.proto().model_case() != RankerModelProto::kLogisticRegression) { if (model.proto().model_case() != RankerModelProto::kLogisticRegression) {
DVLOG(0) << "Model is incompatible.";
return RankerModelStatus::INCOMPATIBLE; return RankerModelStatus::INCOMPATIBLE;
} }
return RankerModelStatus::OK; return RankerModelStatus::OK;
} }
bool BinaryClassifierPredictor::Initialize() { bool BinaryClassifierPredictor::Initialize() {
// TODO(hamelphi): move the GLRM proto up one layer in the proto in order to if (ranker_model_->proto().model_case() ==
// be independent of the client feature. RankerModelProto::kLogisticRegression) {
inference_module_.reset(new GenericLogisticRegressionInference( inference_module_.reset(new GenericLogisticRegressionInference(
ranker_model_->proto().logistic_regression())); ranker_model_->proto().logistic_regression()));
return true; return true;
}
DVLOG(0) << "Could not initialize inference module.";
return false;
} }
} // namespace assist_ranker } // namespace assist_ranker
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include "components/assist_ranker/base_predictor.h" #include "components/assist_ranker/base_predictor.h"
#include "components/assist_ranker/proto/ranker_example.pb.h" #include "components/assist_ranker/proto/ranker_example.pb.h"
class GURL;
namespace base { namespace base {
class FilePath; class FilePath;
} }
...@@ -28,11 +26,13 @@ class BinaryClassifierPredictor : public BasePredictor { ...@@ -28,11 +26,13 @@ class BinaryClassifierPredictor : public BasePredictor {
public: public:
~BinaryClassifierPredictor() override; ~BinaryClassifierPredictor() override;
// Returns an new predictor instance with the given |config| and initialize
// its model loader. The |request_context getter| is passed to the
// predictor's model_loader which holds it as scoped_refptr.
static std::unique_ptr<BinaryClassifierPredictor> Create( static std::unique_ptr<BinaryClassifierPredictor> Create(
net::URLRequestContextGetter* request_context_getter, const PredictorConfig& config,
const base::FilePath& model_path, const base::FilePath& model_path,
GURL model_url, net::URLRequestContextGetter* request_context_getter) WARN_UNUSED_RESULT;
const std::string& uma_prefix);
// Fills in a boolean decision given a RankerExample. Returns false if a // Fills in a boolean decision given a RankerExample. Returns false if a
// prediction could not be made (e.g. the model is not loaded yet). // prediction could not be made (e.g. the model is not loaded yet).
...@@ -53,7 +53,7 @@ class BinaryClassifierPredictor : public BasePredictor { ...@@ -53,7 +53,7 @@ class BinaryClassifierPredictor : public BasePredictor {
private: private:
friend class BinaryClassifierPredictorTest; friend class BinaryClassifierPredictorTest;
BinaryClassifierPredictor(); BinaryClassifierPredictor(const PredictorConfig& config);
// TODO(hamelphi): Use an abstract BinaryClassifierInferenceModule in order to // TODO(hamelphi): Use an abstract BinaryClassifierInferenceModule in order to
// generalize to other models. // generalize to other models.
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/bind_helpers.h" #include "base/bind_helpers.h"
#include "base/feature_list.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "components/assist_ranker/fake_ranker_model_loader.h" #include "components/assist_ranker/fake_ranker_model_loader.h"
#include "components/assist_ranker/proto/ranker_model.pb.h" #include "components/assist_ranker/proto/ranker_model.pb.h"
...@@ -21,11 +22,14 @@ using ::assist_ranker::testing::FakeRankerModelLoader; ...@@ -21,11 +22,14 @@ using ::assist_ranker::testing::FakeRankerModelLoader;
class BinaryClassifierPredictorTest : public ::testing::Test { class BinaryClassifierPredictorTest : public ::testing::Test {
public: public:
std::unique_ptr<BinaryClassifierPredictor> InitPredictor( std::unique_ptr<BinaryClassifierPredictor> InitPredictor(
std::unique_ptr<RankerModel> ranker_model); std::unique_ptr<RankerModel> ranker_model,
const PredictorConfig& config);
// This model will return the value of |feature| as a prediction. // This model will return the value of |feature| as a prediction.
GenericLogisticRegressionModel GetSimpleLogisticRegressionModel(); GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();
PredictorConfig GetConfig();
protected: protected:
const std::string feature_ = "feature"; const std::string feature_ = "feature";
const float threshold_ = 0.5; const float threshold_ = 0.5;
...@@ -33,10 +37,11 @@ class BinaryClassifierPredictorTest : public ::testing::Test { ...@@ -33,10 +37,11 @@ class BinaryClassifierPredictorTest : public ::testing::Test {
std::unique_ptr<BinaryClassifierPredictor> std::unique_ptr<BinaryClassifierPredictor>
BinaryClassifierPredictorTest::InitPredictor( BinaryClassifierPredictorTest::InitPredictor(
std::unique_ptr<RankerModel> ranker_model) { std::unique_ptr<RankerModel> ranker_model,
const PredictorConfig& config) {
std::unique_ptr<BinaryClassifierPredictor> predictor( std::unique_ptr<BinaryClassifierPredictor> predictor(
new BinaryClassifierPredictor()); new BinaryClassifierPredictor(config));
auto fake_model_loader = base::MakeUnique<FakeRankerModelLoader>( auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
base::Bind(&BinaryClassifierPredictor::ValidateModel), base::Bind(&BinaryClassifierPredictor::ValidateModel),
base::Bind(&BinaryClassifierPredictor::OnModelAvailable, base::Bind(&BinaryClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())), base::Unretained(predictor.get())),
...@@ -45,6 +50,20 @@ BinaryClassifierPredictorTest::InitPredictor( ...@@ -45,6 +50,20 @@ BinaryClassifierPredictorTest::InitPredictor(
return predictor; return predictor;
} }
const base::Feature kTestRankerQuery{"TestRankerQuery",
base::FEATURE_ENABLED_BY_DEFAULT};
const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, "url-param-name", "https://default.model.url"};
PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
GetEmptyWhitelist(), &kTestRankerQuery,
&kTestRankerUrl);
return config;
}
GenericLogisticRegressionModel GenericLogisticRegressionModel
BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() { BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
GenericLogisticRegressionModel lr_model; GenericLogisticRegressionModel lr_model;
...@@ -58,7 +77,7 @@ BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() { ...@@ -58,7 +77,7 @@ BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
TEST_F(BinaryClassifierPredictorTest, EmptyRankerModel) { TEST_F(BinaryClassifierPredictorTest, EmptyRankerModel) {
auto ranker_model = base::MakeUnique<RankerModel>(); auto ranker_model = base::MakeUnique<RankerModel>();
auto predictor = InitPredictor(std::move(ranker_model)); auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_FALSE(predictor->IsReady()); EXPECT_FALSE(predictor->IsReady());
RankerExample ranker_example; RankerExample ranker_example;
...@@ -78,7 +97,7 @@ TEST_F(BinaryClassifierPredictorTest, NoInferenceModuleForModel) { ...@@ -78,7 +97,7 @@ TEST_F(BinaryClassifierPredictorTest, NoInferenceModuleForModel) {
->mutable_translate() ->mutable_translate()
->mutable_translate_logistic_regression_model() ->mutable_translate_logistic_regression_model()
->set_bias(1); ->set_bias(1);
auto predictor = InitPredictor(std::move(ranker_model)); auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_FALSE(predictor->IsReady()); EXPECT_FALSE(predictor->IsReady());
RankerExample ranker_example; RankerExample ranker_example;
...@@ -94,7 +113,7 @@ TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) { ...@@ -94,7 +113,7 @@ TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) {
auto ranker_model = base::MakeUnique<RankerModel>(); auto ranker_model = base::MakeUnique<RankerModel>();
*ranker_model->mutable_proto()->mutable_logistic_regression() = *ranker_model->mutable_proto()->mutable_logistic_regression() =
GetSimpleLogisticRegressionModel(); GetSimpleLogisticRegressionModel();
auto predictor = InitPredictor(std::move(ranker_model)); auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_TRUE(predictor->IsReady()); EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example; RankerExample ranker_example;
......
...@@ -29,7 +29,7 @@ float GenericLogisticRegressionInference::PredictScore( ...@@ -29,7 +29,7 @@ float GenericLogisticRegressionInference::PredictScore(
const FeatureWeight& feature_weight = weight_it.second; const FeatureWeight& feature_weight = weight_it.second;
switch (feature_weight.feature_type_case()) { switch (feature_weight.feature_type_case()) {
case FeatureWeight::FEATURE_TYPE_NOT_SET: { case FeatureWeight::FEATURE_TYPE_NOT_SET: {
DVLOG(1) << "Feature type not set for " << feature_name; DVLOG(0) << "Feature type not set for " << feature_name;
break; break;
} }
case FeatureWeight::kScalar: { case FeatureWeight::kScalar: {
...@@ -37,6 +37,8 @@ float GenericLogisticRegressionInference::PredictScore( ...@@ -37,6 +37,8 @@ float GenericLogisticRegressionInference::PredictScore(
if (GetFeatureValueAsFloat(feature_name, example, &value)) { if (GetFeatureValueAsFloat(feature_name, example, &value)) {
const float weight = feature_weight.scalar(); const float weight = feature_weight.scalar();
activation += value * weight; activation += value * weight;
} else {
DVLOG(1) << "Feature not in example: " << feature_name;
} }
break; break;
} }
...@@ -50,19 +52,22 @@ float GenericLogisticRegressionInference::PredictScore( ...@@ -50,19 +52,22 @@ float GenericLogisticRegressionInference::PredictScore(
} else { } else {
// If the category is not found, use the default weight. // If the category is not found, use the default weight.
activation += feature_weight.one_hot().default_weight(); activation += feature_weight.one_hot().default_weight();
DVLOG(1) << "Unknown feature value for " << feature_name << ": "
<< value;
} }
} else { } else {
// If the feature is missing, use the default weight. // If the feature is missing, use the default weight.
activation += feature_weight.one_hot().default_weight(); activation += feature_weight.one_hot().default_weight();
DVLOG(1) << "Feature not in example: " << feature_name;
} }
break; break;
} }
case FeatureWeight::kSparse: { case FeatureWeight::kSparse: {
DVLOG(1) << "Sparse features not implemented yet."; DVLOG(0) << "Sparse features not implemented yet.";
break; break;
} }
case FeatureWeight::kBucketized: { case FeatureWeight::kBucketized: {
DVLOG(1) << "Bucketized features not implemented yet."; DVLOG(0) << "Bucketized features not implemented yet.";
break; break;
} }
} }
......
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/predictor_config.h"
namespace assist_ranker {
const base::flat_set<std::string>* GetEmptyWhitelist() {
static auto* whitelist = new base::flat_set<std::string>();
return whitelist;
}
} // namespace assist_ranker
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_
#define COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_
#include <string>
#include "base/containers/flat_set.h"
#include "base/metrics/field_trial_params.h"
namespace assist_ranker {
// TODO(chrome-ranker-team): Implement other logging types.
enum LogType {
LOG_NONE = 0,
LOG_UKM = 1,
};
// Empty feature whitelist used for testing.
const base::flat_set<std::string>* GetEmptyWhitelist();
// This struct holds the config options for logging, loading and field trial
// for a predictor.
struct PredictorConfig {
PredictorConfig(const char* model_name,
const char* logging_name,
const char* uma_prefix,
const LogType log_type,
const base::flat_set<std::string>* feature_whitelist,
const base::Feature* field_trial,
const base::FeatureParam<std::string>* field_trial_url_param)
: model_name(model_name),
logging_name(logging_name),
uma_prefix(uma_prefix),
log_type(log_type),
feature_whitelist(feature_whitelist),
field_trial(field_trial),
field_trial_url_param(field_trial_url_param) {}
const char* model_name;
const char* logging_name;
const char* uma_prefix;
const LogType log_type;
const base::flat_set<std::string>* feature_whitelist;
const base::Feature* field_trial;
const base::FeatureParam<std::string>* field_trial_url_param;
};
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/predictor_config_definitions.h"
namespace assist_ranker {
#if defined(OS_ANDROID)
const base::Feature kContextualSearchRankerQuery{
"ContextualSearchRankerQuery", base::FEATURE_DISABLED_BY_DEFAULT};
namespace {
const char kContextualSearchModelName[] = "contextual_search_model";
const char kContextualSearchLoggingName[] = "ContextualSearch";
const char kContextualSearchUmaPrefixName[] = "Search.ContextualSearch.Ranker";
const char kContextualSearchDefaultModelUrl[] =
"https://www.gstatic.com/chrome/intelligence/assist/ranker/models/"
"contextual_search/test_ranker_model_20171109_short_words_v2.pb.bin";
const base::FeatureParam<std::string>*
GetContextualSearchRankerUrlFeatureParam() {
static auto* kContextualSearchRankerUrl = new base::FeatureParam<std::string>(
&kContextualSearchRankerQuery, "contextual-search-ranker-model-url",
kContextualSearchDefaultModelUrl);
return kContextualSearchRankerUrl;
}
// This list needs to be kept in sync with tools/metrics/ukm/ukm.xml.
// Only features within this list will be logged to UKM.
// TODO(chrome-ranker-team) Deprecate the whitelist once it is available through
// the UKM generated API.
const base::flat_set<std::string>* GetContextualSearchFeatureWhitelist() {
static auto* kContextualSearchFeatureWhitelist =
new base::flat_set<std::string>(
{"DidOptIn", "DurationAfterScrollMs", "IsEntity", "IsLongWord",
"IsShortWord", "IsWordEdge", "OutcomeWasCardsDataShown",
"OutcomeWasPanelOpened", "OutcomeWasQuickActionClicked",
"OutcomeWasQuickAnswerSeen", "Previous28DayCtrPercent",
"Previous28DayImpressionsCount", "PreviousWeekCtrPercent",
"PreviousWeekImpressionsCount", "ScreenTopDps", "TapDuration",
"WasScreenBottom"});
return kContextualSearchFeatureWhitelist;
}
} // namespace
const PredictorConfig GetContextualSearchPredictorConfig() {
static auto kContextualSearchPredictorConfig = *(new PredictorConfig(
kContextualSearchModelName, kContextualSearchLoggingName,
kContextualSearchUmaPrefixName, LOG_UKM,
GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery,
GetContextualSearchRankerUrlFeatureParam()));
return kContextualSearchPredictorConfig;
}
#endif // OS_ANDROID
} // namespace assist_ranker
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_
#define COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_
#include <memory>
#include <string>
#include <unordered_map>
#include "base/feature_list.h"
#include "base/metrics/field_trial_params.h"
#include "build/build_config.h"
#include "components/assist_ranker/predictor_config.h"
namespace assist_ranker {
#if defined(OS_ANDROID)
extern const base::Feature kContextualSearchRankerQuery;
const PredictorConfig GetContextualSearchPredictorConfig();
#endif // OS_ANDROID
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_
...@@ -32,5 +32,9 @@ message RankerExample { ...@@ -32,5 +32,9 @@ message RankerExample {
// This field represents the ground truth that the ranker is // This field represents the ground truth that the ranker is
// expected to predict, and is typically derived from user feedback. It is // expected to predict, and is typically derived from user feedback. It is
// used for training only and is not required for inference. // used for training only and is not required for inference.
// NOTE: this field will not be logged. If you want to log an outcome, add it
// to the features field before calling LogExample.
// TODO(chrome-ranker-team) Add a metadata field to log metrics that are not
// used as model input.
optional Feature target = 2; optional Feature target = 2;
} }
\ No newline at end of file
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "components/assist_ranker/ranker_example_util.h" #include "components/assist_ranker/ranker_example_util.h"
#include "base/format_macros.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/metrics/metrics_hashes.h"
#include "base/strings/stringprintf.h"
namespace assist_ranker { namespace assist_ranker {
...@@ -42,6 +45,26 @@ bool GetFeatureValueAsFloat(const std::string& key, ...@@ -42,6 +45,26 @@ bool GetFeatureValueAsFloat(const std::string& key,
return true; return true;
} }
bool FeatureToInt(const Feature& feature, int* int_value) {
switch (feature.feature_type_case()) {
case Feature::kBoolValue:
*int_value = static_cast<int>(feature.bool_value());
return true;
case Feature::kInt32Value:
*int_value = feature.int32_value();
return true;
case Feature::kFloatValue:
// TODO(crbug.com/794187): Implement this.
return false;
case Feature::kStringValue:
// TODO(crbug.com/794187): Implement this.
return false;
default:
NOTREACHED();
return false;
}
}
bool GetOneHotValue(const std::string& key, bool GetOneHotValue(const std::string& key,
const RankerExample& example, const RankerExample& example,
std::string* value) { std::string* value) {
...@@ -60,4 +83,20 @@ bool GetOneHotValue(const std::string& key, ...@@ -60,4 +83,20 @@ bool GetOneHotValue(const std::string& key,
return true; return true;
} }
// Converts string to a hex hash string.
std::string HashFeatureName(const std::string& feature_name) {
uint64_t feature_key = base::HashMetricName(feature_name);
return base::StringPrintf("%016" PRIx64, feature_key);
}
RankerExample HashExampleFeatureNames(const RankerExample& example) {
RankerExample hashed_example;
auto& output_features = *hashed_example.mutable_features();
for (const auto& feature : example.features()) {
output_features[HashFeatureName(feature.first)] = feature.second;
}
*hashed_example.mutable_target() = example.target();
return hashed_example;
}
} // namespace assist_ranker } // namespace assist_ranker
...@@ -25,6 +25,10 @@ bool GetFeatureValueAsFloat(const std::string& key, ...@@ -25,6 +25,10 @@ bool GetFeatureValueAsFloat(const std::string& key,
const RankerExample& example, const RankerExample& example,
float* value) WARN_UNUSED_RESULT; float* value) WARN_UNUSED_RESULT;
// Returns a int value for this feature. Float values are multiplied by a given
// constant that defines the precision. String features are hashed.
bool FeatureToInt(const Feature& feature, int* int_value) WARN_UNUSED_RESULT;
// Extract category from one-hot feature. Returns true and fills // Extract category from one-hot feature. Returns true and fills
// in |value| if the feature is found and is of type string_value. Returns false // in |value| if the feature is found and is of type string_value. Returns false
// otherwise. // otherwise.
...@@ -32,6 +36,15 @@ bool GetOneHotValue(const std::string& key, ...@@ -32,6 +36,15 @@ bool GetOneHotValue(const std::string& key,
const RankerExample& example, const RankerExample& example,
std::string* value) WARN_UNUSED_RESULT; std::string* value) WARN_UNUSED_RESULT;
// Converts a string to a hex ahsh string.
std::string HashFeatureName(const std::string& feature_name);
// Hashes feature names to an hex string.
// Features logged through UKM will apply this transformation when logging
// features, so models trained on UKM data are expected to have hashed input
// feature names.
RankerExample HashExampleFeatureNames(const RankerExample& example);
} // namespace assist_ranker } // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_RANKER_EXAMPLE_UTIL_H_ #endif // COMPONENTS_ASSIST_RANKER_RANKER_EXAMPLE_UTIL_H_
...@@ -103,4 +103,53 @@ TEST_F(RankerExampleUtilTest, GetOneHotValue) { ...@@ -103,4 +103,53 @@ TEST_F(RankerExampleUtilTest, GetOneHotValue) {
EXPECT_FALSE(GetOneHotValue("foo", example_, &value)); EXPECT_FALSE(GetOneHotValue("foo", example_, &value));
} }
TEST_F(RankerExampleUtilTest, FeatureToInt) {
Feature feature;
int int_value;
feature.set_bool_value(true);
EXPECT_TRUE(FeatureToInt(feature, &int_value));
EXPECT_EQ(1, int_value);
feature.set_bool_value(false);
EXPECT_TRUE(FeatureToInt(feature, &int_value));
EXPECT_EQ(0, int_value);
feature.set_int32_value(42);
EXPECT_TRUE(FeatureToInt(feature, &int_value));
EXPECT_EQ(42, int_value);
feature.set_int32_value(-3);
EXPECT_TRUE(FeatureToInt(feature, &int_value));
EXPECT_EQ(-3, int_value);
// Float and string values are not implemented yet.
feature.set_float_value(12.345f);
EXPECT_FALSE(FeatureToInt(feature, &int_value));
feature.set_string_value("foo");
EXPECT_FALSE(FeatureToInt(feature, &int_value));
}
TEST_F(RankerExampleUtilTest, HashExampleFeatureNames) {
auto hashed_example = HashExampleFeatureNames(example_);
// Hashed example has the same number of features.
EXPECT_EQ(example_.features().size(), hashed_example.features().size());
// But the feature names have changed.
EXPECT_FALSE(SafeGetFeature(bool_name_, hashed_example, nullptr));
EXPECT_FALSE(SafeGetFeature(int32_name_, hashed_example, nullptr));
EXPECT_FALSE(SafeGetFeature(float_name_, hashed_example, nullptr));
EXPECT_FALSE(SafeGetFeature(one_hot_name_, hashed_example, nullptr));
EXPECT_TRUE(
SafeGetFeature(HashFeatureName(bool_name_), hashed_example, nullptr));
// Values have not changed.
float float_value;
EXPECT_TRUE(GetFeatureValueAsFloat(HashFeatureName(float_name_),
hashed_example, &float_value));
EXPECT_EQ(float_value_, float_value);
std::string string_value;
EXPECT_TRUE(GetOneHotValue(HashFeatureName(one_hot_name_), hashed_example,
&string_value));
EXPECT_EQ(one_hot_value_, string_value);
}
} // namespace assist_ranker } // namespace assist_ranker
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include "services/metrics/public/interfaces/ukm_interface.mojom.h" #include "services/metrics/public/interfaces/ukm_interface.mojom.h"
#include "url/gurl.h" #include "url/gurl.h"
class ContextualSearchRankerLoggerImpl;
class DocumentWritePageLoadMetricsObserver; class DocumentWritePageLoadMetricsObserver;
class FromGWSPageLoadMetricsLogger; class FromGWSPageLoadMetricsLogger;
class PluginInfoHostImpl; class PluginInfoHostImpl;
...@@ -27,6 +26,10 @@ class UkmPageLoadMetricsObserver; ...@@ -27,6 +26,10 @@ class UkmPageLoadMetricsObserver;
class UseCounterPageLoadMetricsObserver; class UseCounterPageLoadMetricsObserver;
class LocalNetworkRequestsPageLoadMetricsObserver; class LocalNetworkRequestsPageLoadMetricsObserver;
namespace assist_ranker {
class BasePredictor;
}
namespace blink { namespace blink {
class AutoplayUmaHelper; class AutoplayUmaHelper;
} }
...@@ -85,7 +88,7 @@ class METRICS_EXPORT UkmRecorder { ...@@ -85,7 +88,7 @@ class METRICS_EXPORT UkmRecorder {
virtual void UpdateSourceURL(SourceId source_id, const GURL& url) = 0; virtual void UpdateSourceURL(SourceId source_id, const GURL& url) = 0;
private: private:
friend ContextualSearchRankerLoggerImpl; friend assist_ranker::BasePredictor;
friend DelegatingUkmRecorder; friend DelegatingUkmRecorder;
friend DocumentWritePageLoadMetricsObserver; friend DocumentWritePageLoadMetricsObserver;
friend FromGWSPageLoadMetricsLogger; friend FromGWSPageLoadMetricsLogger;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment