Commit b14d5c35 authored by Donn Denman's avatar Donn Denman Committed by Commit Bot

Use dynamic thresholds for Ranker predictions.

Adds the ability to specify a threshold replacement to apply when predicting using an Assist-Ranker model.

For Contextual Search the threshold is read from a
FieldTrial variations parameter associated with the existing
ContextualSearchRankerQuery Feature.

For Translate, threshold setting is supported but unused.

BUG=899134

Change-Id: Ia0b664ce1e0949f755e7d8898fb2769fc91ae536
Reviewed-on: https://chromium-review.googlesource.com/c/1312296
Commit-Queue: Donn Denman <donnd@chromium.org>
Reviewed-by: default avatarCharles . <charleszhao@chromium.org>
Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Cr-Commit-Position: refs/heads/master@{#612324}
parent f19bb479
...@@ -124,6 +124,10 @@ GURL BasePredictor::GetModelUrl() const { ...@@ -124,6 +124,10 @@ GURL BasePredictor::GetModelUrl() const {
return GURL(config_.field_trial_url_param->Get()); return GURL(config_.field_trial_url_param->Get());
} }
float BasePredictor::GetPredictThresholdReplacement() const {
return config_.field_trial_threshold_replacement_param;
}
RankerExample BasePredictor::PreprocessExample(const RankerExample& example) { RankerExample BasePredictor::PreprocessExample(const RankerExample& example) {
if (ranker_model_->proto().has_metadata() && if (ranker_model_->proto().has_metadata() &&
ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) { ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) {
......
...@@ -22,6 +22,10 @@ class UkmEntryBuilder; ...@@ -22,6 +22,10 @@ class UkmEntryBuilder;
namespace assist_ranker { namespace assist_ranker {
// Value to use for when no prediction threshold replacement should be applied.
// See |GetPredictThresholdReplacement| method.
const float kNoPredictThresholdReplacement = 0.0;
class Feature; class Feature;
class RankerExample; class RankerExample;
class RankerModel; class RankerModel;
...@@ -29,7 +33,7 @@ class RankerModel; ...@@ -29,7 +33,7 @@ class RankerModel;
// Predictors are objects that provide an interface for prediction, as well as // Predictors are objects that provide an interface for prediction, as well as
// encapsulate the logic for loading the model and logging. Sub-classes of // encapsulate the logic for loading the model and logging. Sub-classes of
// BasePredictor implement an interface that depends on the nature of the // BasePredictor implement an interface that depends on the nature of the
// suported model. Subclasses of BasePredictor will also need to implement an // supported model. Subclasses of BasePredictor will also need to implement an
// Initialize method that will be called once the model is available, and a // Initialize method that will be called once the model is available, and a
// static validation function with the following signature: // static validation function with the following signature:
// //
...@@ -49,6 +53,9 @@ class BasePredictor : public base::SupportsWeakPtr<BasePredictor> { ...@@ -49,6 +53,9 @@ class BasePredictor : public base::SupportsWeakPtr<BasePredictor> {
// Returns the model URL. // Returns the model URL.
GURL GetModelUrl() const; GURL GetModelUrl() const;
// Returns the threshold to use for prediction, or
// kNoPredictThresholdReplacement to leave it unchanged.
float GetPredictThresholdReplacement() const;
// Returns the model name. // Returns the model name.
std::string GetModelName() const; std::string GetModelName() const;
......
...@@ -54,37 +54,41 @@ const base::Feature kTestRankerQuery{"TestRankerQuery", ...@@ -54,37 +54,41 @@ const base::Feature kTestRankerQuery{"TestRankerQuery",
const base::FeatureParam<std::string> kTestRankerUrl{ const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl}; &kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl};
const PredictorConfig kTestPredictorConfig = PredictorConfig{ const PredictorConfig kTestPredictorConfig =
kTestModelName, kTestLoggingName, kTestUmaPrefixName, LOG_UKM, PredictorConfig{kTestModelName, kTestLoggingName,
&kFeatureWhitelist, &kTestRankerQuery, &kTestRankerUrl}; kTestUmaPrefixName, LOG_UKM,
&kFeatureWhitelist, &kTestRankerQuery,
&kTestRankerUrl, kNoPredictThresholdReplacement};
// Class that implements virtual functions of the base class. // Class that implements virtual functions of the base class.
class FakePredictor : public BasePredictor { class FakePredictor : public BasePredictor {
public: public:
static std::unique_ptr<FakePredictor> Create(); // Creates a |FakePredictor| using the default config (from this file).
static std::unique_ptr<FakePredictor> Create() {
return Create(kTestPredictorConfig);
}
// Creates a |FakePredictor| using the |PredictorConfig| passed in
// |predictor_config|.
static std::unique_ptr<FakePredictor> Create(
PredictorConfig predictor_config);
~FakePredictor() override{}; ~FakePredictor() override{};
// Validation will always succeed. // Validation will always succeed.
static RankerModelStatus ValidateModel(const RankerModel& model); static RankerModelStatus ValidateModel(const RankerModel& model) {
return RankerModelStatus::OK;
}
protected: protected:
// Not implementing any inference logic. // Not implementing any inference logic.
bool Initialize() override { return true; }; bool Initialize() override { return true; };
private: private:
FakePredictor(const PredictorConfig& config); FakePredictor(const PredictorConfig& config) : BasePredictor(config) {}
DISALLOW_COPY_AND_ASSIGN(FakePredictor); DISALLOW_COPY_AND_ASSIGN(FakePredictor);
}; };
FakePredictor::FakePredictor(const PredictorConfig& config) std::unique_ptr<FakePredictor> FakePredictor::Create(
: BasePredictor(config) {} PredictorConfig predictor_config) {
std::unique_ptr<FakePredictor> predictor(new FakePredictor(predictor_config));
RankerModelStatus FakePredictor::ValidateModel(const RankerModel& model) {
return RankerModelStatus::OK;
}
std::unique_ptr<FakePredictor> FakePredictor::Create() {
std::unique_ptr<FakePredictor> predictor(
new FakePredictor(kTestPredictorConfig));
auto ranker_model = std::make_unique<RankerModel>(); auto ranker_model = std::make_unique<RankerModel>();
auto fake_model_loader = std::make_unique<FakeRankerModelLoader>( auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
base::BindRepeating(&FakePredictor::ValidateModel), base::BindRepeating(&FakePredictor::ValidateModel),
...@@ -184,4 +188,14 @@ TEST_F(BasePredictorTest, LogExampleToUkm) { ...@@ -184,4 +188,14 @@ TEST_F(BasePredictorTest, LogExampleToUkm) {
GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted)); GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted));
} }
TEST_F(BasePredictorTest, GetPredictThresholdReplacement) {
float altered_threshold = 0.78f; // Arbitrary value.
const PredictorConfig altered_threshold_config{
kTestModelName, kTestLoggingName, kTestUmaPrefixName,
LOG_UKM, &kFeatureWhitelist, &kTestRankerQuery,
&kTestRankerUrl, altered_threshold};
auto predictor = FakePredictor::Create(altered_threshold_config);
EXPECT_EQ(altered_threshold, predictor->GetPredictThresholdReplacement());
}
} // namespace assist_ranker } // namespace assist_ranker
...@@ -36,6 +36,8 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create( ...@@ -36,6 +36,8 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
const GURL& model_url = predictor->GetModelUrl(); const GURL& model_url = predictor->GetModelUrl();
DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName(); DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
DVLOG(1) << "Model URL: " << model_url; DVLOG(1) << "Model URL: " << model_url;
DVLOG(1) << "Using predict threshold replacement: "
<< predictor->GetPredictThresholdReplacement();
auto model_loader = std::make_unique<RankerModelLoaderImpl>( auto model_loader = std::make_unique<RankerModelLoaderImpl>(
base::BindRepeating(&BinaryClassifierPredictor::ValidateModel), base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable, base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
...@@ -52,7 +54,13 @@ bool BinaryClassifierPredictor::Predict(const RankerExample& example, ...@@ -52,7 +54,13 @@ bool BinaryClassifierPredictor::Predict(const RankerExample& example,
return false; return false;
} }
*prediction = inference_module_->Predict(PreprocessExample(example)); float predict_threshold_replacement = GetPredictThresholdReplacement();
if (predict_threshold_replacement != kNoPredictThresholdReplacement) {
*prediction = inference_module_->PredictScore(PreprocessExample(example)) >=
predict_threshold_replacement;
} else {
*prediction = inference_module_->Predict(PreprocessExample(example));
}
DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction; DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction;
return true; return true;
} }
......
...@@ -31,6 +31,7 @@ class BinaryClassifierPredictorTest : public ::testing::Test { ...@@ -31,6 +31,7 @@ class BinaryClassifierPredictorTest : public ::testing::Test {
GenericLogisticRegressionModel GetSimpleLogisticRegressionModel(); GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();
PredictorConfig GetConfig(); PredictorConfig GetConfig();
PredictorConfig GetConfig(float predictor_threshold_replacement);
protected: protected:
const std::string feature_ = "feature"; const std::string feature_ = "feature";
...@@ -66,9 +67,14 @@ const base::FeatureParam<std::string> kTestRankerUrl{ ...@@ -66,9 +67,14 @@ const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, "url-param-name", "https://default.model.url"}; &kTestRankerQuery, "url-param-name", "https://default.model.url"};
PredictorConfig BinaryClassifierPredictorTest::GetConfig() { PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
return GetConfig(kNoPredictThresholdReplacement);
}
PredictorConfig BinaryClassifierPredictorTest::GetConfig(
float predictor_threshold_replacement) {
PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE, PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
GetEmptyWhitelist(), &kTestRankerQuery, GetEmptyWhitelist(), &kTestRankerQuery,
&kTestRankerUrl); &kTestRankerUrl, predictor_threshold_replacement);
return config; return config;
} }
...@@ -171,4 +177,30 @@ TEST_F(BinaryClassifierPredictorTest, ...@@ -171,4 +177,30 @@ TEST_F(BinaryClassifierPredictorTest,
EXPECT_LT(float_response, threshold_); EXPECT_LT(float_response, threshold_);
} }
TEST_F(BinaryClassifierPredictorTest,
GenericLogisticRegressionPreprocessedModelReplacedThreshold) {
auto ranker_model = std::make_unique<RankerModel>();
auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
glr = GetSimpleLogisticRegressionModel();
glr.clear_weights();
glr.set_is_preprocessed_model(true);
(*glr.mutable_fullname_weights())[feature_] = weight_;
float high_threshold = 0.9; // Some high threshold.
auto predictor =
InitPredictor(std::move(ranker_model), GetConfig(high_threshold));
EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_FALSE(bool_response);
float float_response;
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_GT(float_response, threshold_);
EXPECT_LT(float_response, high_threshold);
}
} // namespace assist_ranker } // namespace assist_ranker
...@@ -23,7 +23,7 @@ class GenericLogisticRegressionInference { ...@@ -23,7 +23,7 @@ class GenericLogisticRegressionInference {
// Returns a boolean decision given a RankerExample. Uses the same logic as // Returns a boolean decision given a RankerExample. Uses the same logic as
// PredictScore, and then applies the model decision threshold. // PredictScore, and then applies the model decision threshold.
bool Predict(const RankerExample& example); bool Predict(const RankerExample& example);
// Returns a score between 0 and 1 give a RankerExample. // Returns a score between 0 and 1 given a RankerExample.
float PredictScore(const RankerExample& example); float PredictScore(const RankerExample& example);
private: private:
......
...@@ -30,21 +30,25 @@ struct PredictorConfig { ...@@ -30,21 +30,25 @@ struct PredictorConfig {
const LogType log_type, const LogType log_type,
const base::flat_set<std::string>* feature_whitelist, const base::flat_set<std::string>* feature_whitelist,
const base::Feature* field_trial, const base::Feature* field_trial,
const base::FeatureParam<std::string>* field_trial_url_param) const base::FeatureParam<std::string>* field_trial_url_param,
float field_trial_threshold_replacement_param)
: model_name(model_name), : model_name(model_name),
logging_name(logging_name), logging_name(logging_name),
uma_prefix(uma_prefix), uma_prefix(uma_prefix),
log_type(log_type), log_type(log_type),
feature_whitelist(feature_whitelist), feature_whitelist(feature_whitelist),
field_trial(field_trial), field_trial(field_trial),
field_trial_url_param(field_trial_url_param) {} field_trial_url_param(field_trial_url_param),
const char* model_name; field_trial_threshold_replacement_param(
const char* logging_name; field_trial_threshold_replacement_param) {}
const char* uma_prefix; const char* const model_name;
const char* const logging_name;
const char* const uma_prefix;
const LogType log_type; const LogType log_type;
const base::flat_set<std::string>* feature_whitelist; const base::flat_set<std::string>* feature_whitelist;
const base::Feature* field_trial; const base::Feature* field_trial;
const base::FeatureParam<std::string>* field_trial_url_param; const base::FeatureParam<std::string>* field_trial_url_param;
const float field_trial_threshold_replacement_param;
}; };
} // namespace assist_ranker } // namespace assist_ranker
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "components/assist_ranker/predictor_config_definitions.h" #include "components/assist_ranker/predictor_config_definitions.h"
#include "components/assist_ranker/base_predictor.h"
namespace assist_ranker { namespace assist_ranker {
...@@ -28,6 +29,15 @@ GetContextualSearchRankerUrlFeatureParam() { ...@@ -28,6 +29,15 @@ GetContextualSearchRankerUrlFeatureParam() {
return kContextualSearchRankerUrl; return kContextualSearchRankerUrl;
} }
float GetContextualSearchRankerThresholdFeatureParam() {
static auto* kContextualSearchRankerThreshold =
new base::FeatureParam<double>(
&kContextualSearchRankerQuery,
"contextual-search-ranker-predict-threshold",
kNoPredictThresholdReplacement);
return static_cast<float>(kContextualSearchRankerThreshold->Get());
}
// NOTE: This list needs to be kept in sync with tools/metrics/ukm/ukm.xml! // NOTE: This list needs to be kept in sync with tools/metrics/ukm/ukm.xml!
// Only features within this list will be logged to UKM. // Only features within this list will be logged to UKM.
// TODO(chrome-ranker-team) Deprecate the whitelist once it is available through // TODO(chrome-ranker-team) Deprecate the whitelist once it is available through
...@@ -77,7 +87,8 @@ const PredictorConfig GetContextualSearchPredictorConfig() { ...@@ -77,7 +87,8 @@ const PredictorConfig GetContextualSearchPredictorConfig() {
kContextualSearchModelName, kContextualSearchLoggingName, kContextualSearchModelName, kContextualSearchLoggingName,
kContextualSearchUmaPrefixName, LOG_UKM, kContextualSearchUmaPrefixName, LOG_UKM,
GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery, GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery,
GetContextualSearchRankerUrlFeatureParam())); GetContextualSearchRankerUrlFeatureParam(),
GetContextualSearchRankerThresholdFeatureParam()));
return kContextualSearchPredictorConfig; return kContextualSearchPredictorConfig;
} }
#endif // OS_ANDROID #endif // OS_ANDROID
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment