Commit b14d5c35 authored by Donn Denman's avatar Donn Denman Committed by Commit Bot

Use dynamic thresholds for Ranker predictions.

Adds the ability to specify a threshold replacement to apply when predicting using an Assist-Ranker model.

For Contextual Search the threshold is read from a
FieldTrial variations parameter associated with the existing
ContextualSearchRankerQuery Feature.

For Translate, threshold setting is supported but unused.

BUG=899134

Change-Id: Ia0b664ce1e0949f755e7d8898fb2769fc91ae536
Reviewed-on: https://chromium-review.googlesource.com/c/1312296
Commit-Queue: Donn Denman <donnd@chromium.org>
Reviewed-by: default avatarCharles . <charleszhao@chromium.org>
Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Cr-Commit-Position: refs/heads/master@{#612324}
parent f19bb479
......@@ -124,6 +124,10 @@ GURL BasePredictor::GetModelUrl() const {
return GURL(config_.field_trial_url_param->Get());
}
float BasePredictor::GetPredictThresholdReplacement() const {
return config_.field_trial_threshold_replacement_param;
}
RankerExample BasePredictor::PreprocessExample(const RankerExample& example) {
if (ranker_model_->proto().has_metadata() &&
ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) {
......
......@@ -22,6 +22,10 @@ class UkmEntryBuilder;
namespace assist_ranker {
// Value to use for when no prediction threshold replacement should be applied.
// See |GetPredictThresholdReplacement| method.
const float kNoPredictThresholdReplacement = 0.0;
class Feature;
class RankerExample;
class RankerModel;
......@@ -29,7 +33,7 @@ class RankerModel;
// Predictors are objects that provide an interface for prediction, as well as
// encapsulate the logic for loading the model and logging. Sub-classes of
// BasePredictor implement an interface that depends on the nature of the
// suported model. Subclasses of BasePredictor will also need to implement an
// supported model. Subclasses of BasePredictor will also need to implement an
// Initialize method that will be called once the model is available, and a
// static validation function with the following signature:
//
......@@ -49,6 +53,9 @@ class BasePredictor : public base::SupportsWeakPtr<BasePredictor> {
// Returns the model URL.
GURL GetModelUrl() const;
// Returns the threshold to use for prediction, or
// kNoPredictThresholdReplacement to leave it unchanged.
float GetPredictThresholdReplacement() const;
// Returns the model name.
std::string GetModelName() const;
......
......@@ -54,37 +54,41 @@ const base::Feature kTestRankerQuery{"TestRankerQuery",
const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl};
const PredictorConfig kTestPredictorConfig = PredictorConfig{
kTestModelName, kTestLoggingName, kTestUmaPrefixName, LOG_UKM,
&kFeatureWhitelist, &kTestRankerQuery, &kTestRankerUrl};
const PredictorConfig kTestPredictorConfig =
PredictorConfig{kTestModelName, kTestLoggingName,
kTestUmaPrefixName, LOG_UKM,
&kFeatureWhitelist, &kTestRankerQuery,
&kTestRankerUrl, kNoPredictThresholdReplacement};
// Class that implements virtual functions of the base class.
class FakePredictor : public BasePredictor {
public:
static std::unique_ptr<FakePredictor> Create();
// Creates a |FakePredictor| using the default config (from this file).
static std::unique_ptr<FakePredictor> Create() {
return Create(kTestPredictorConfig);
}
// Creates a |FakePredictor| using the |PredictorConfig| passed in
// |predictor_config|.
static std::unique_ptr<FakePredictor> Create(
PredictorConfig predictor_config);
~FakePredictor() override{};
// Validation will always succeed.
static RankerModelStatus ValidateModel(const RankerModel& model);
static RankerModelStatus ValidateModel(const RankerModel& model) {
return RankerModelStatus::OK;
}
protected:
// Not implementing any inference logic.
bool Initialize() override { return true; };
private:
FakePredictor(const PredictorConfig& config);
FakePredictor(const PredictorConfig& config) : BasePredictor(config) {}
DISALLOW_COPY_AND_ASSIGN(FakePredictor);
};
FakePredictor::FakePredictor(const PredictorConfig& config)
: BasePredictor(config) {}
RankerModelStatus FakePredictor::ValidateModel(const RankerModel& model) {
return RankerModelStatus::OK;
}
std::unique_ptr<FakePredictor> FakePredictor::Create() {
std::unique_ptr<FakePredictor> predictor(
new FakePredictor(kTestPredictorConfig));
std::unique_ptr<FakePredictor> FakePredictor::Create(
PredictorConfig predictor_config) {
std::unique_ptr<FakePredictor> predictor(new FakePredictor(predictor_config));
auto ranker_model = std::make_unique<RankerModel>();
auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
base::BindRepeating(&FakePredictor::ValidateModel),
......@@ -184,4 +188,14 @@ TEST_F(BasePredictorTest, LogExampleToUkm) {
GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted));
}
TEST_F(BasePredictorTest, GetPredictThresholdReplacement) {
float altered_threshold = 0.78f; // Arbitrary value.
const PredictorConfig altered_threshold_config{
kTestModelName, kTestLoggingName, kTestUmaPrefixName,
LOG_UKM, &kFeatureWhitelist, &kTestRankerQuery,
&kTestRankerUrl, altered_threshold};
auto predictor = FakePredictor::Create(altered_threshold_config);
EXPECT_EQ(altered_threshold, predictor->GetPredictThresholdReplacement());
}
} // namespace assist_ranker
......@@ -36,6 +36,8 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
const GURL& model_url = predictor->GetModelUrl();
DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
DVLOG(1) << "Model URL: " << model_url;
DVLOG(1) << "Using predict threshold replacement: "
<< predictor->GetPredictThresholdReplacement();
auto model_loader = std::make_unique<RankerModelLoaderImpl>(
base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
......@@ -52,7 +54,13 @@ bool BinaryClassifierPredictor::Predict(const RankerExample& example,
return false;
}
*prediction = inference_module_->Predict(PreprocessExample(example));
float predict_threshold_replacement = GetPredictThresholdReplacement();
if (predict_threshold_replacement != kNoPredictThresholdReplacement) {
*prediction = inference_module_->PredictScore(PreprocessExample(example)) >=
predict_threshold_replacement;
} else {
*prediction = inference_module_->Predict(PreprocessExample(example));
}
DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction;
return true;
}
......
......@@ -31,6 +31,7 @@ class BinaryClassifierPredictorTest : public ::testing::Test {
GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();
PredictorConfig GetConfig();
PredictorConfig GetConfig(float predictor_threshold_replacement);
protected:
const std::string feature_ = "feature";
......@@ -66,9 +67,14 @@ const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, "url-param-name", "https://default.model.url"};
PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
return GetConfig(kNoPredictThresholdReplacement);
}
PredictorConfig BinaryClassifierPredictorTest::GetConfig(
float predictor_threshold_replacement) {
PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
GetEmptyWhitelist(), &kTestRankerQuery,
&kTestRankerUrl);
&kTestRankerUrl, predictor_threshold_replacement);
return config;
}
......@@ -171,4 +177,30 @@ TEST_F(BinaryClassifierPredictorTest,
EXPECT_LT(float_response, threshold_);
}
TEST_F(BinaryClassifierPredictorTest,
GenericLogisticRegressionPreprocessedModelReplacedThreshold) {
auto ranker_model = std::make_unique<RankerModel>();
auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
glr = GetSimpleLogisticRegressionModel();
glr.clear_weights();
glr.set_is_preprocessed_model(true);
(*glr.mutable_fullname_weights())[feature_] = weight_;
float high_threshold = 0.9; // Some high threshold.
auto predictor =
InitPredictor(std::move(ranker_model), GetConfig(high_threshold));
EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_FALSE(bool_response);
float float_response;
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_GT(float_response, threshold_);
EXPECT_LT(float_response, high_threshold);
}
} // namespace assist_ranker
......@@ -23,7 +23,7 @@ class GenericLogisticRegressionInference {
// Returns a boolean decision given a RankerExample. Uses the same logic as
// PredictScore, and then applies the model decision threshold.
bool Predict(const RankerExample& example);
// Returns a score between 0 and 1 give a RankerExample.
// Returns a score between 0 and 1 given a RankerExample.
float PredictScore(const RankerExample& example);
private:
......
......@@ -30,21 +30,25 @@ struct PredictorConfig {
const LogType log_type,
const base::flat_set<std::string>* feature_whitelist,
const base::Feature* field_trial,
const base::FeatureParam<std::string>* field_trial_url_param)
const base::FeatureParam<std::string>* field_trial_url_param,
float field_trial_threshold_replacement_param)
: model_name(model_name),
logging_name(logging_name),
uma_prefix(uma_prefix),
log_type(log_type),
feature_whitelist(feature_whitelist),
field_trial(field_trial),
field_trial_url_param(field_trial_url_param) {}
const char* model_name;
const char* logging_name;
const char* uma_prefix;
field_trial_url_param(field_trial_url_param),
field_trial_threshold_replacement_param(
field_trial_threshold_replacement_param) {}
const char* const model_name;
const char* const logging_name;
const char* const uma_prefix;
const LogType log_type;
const base::flat_set<std::string>* feature_whitelist;
const base::Feature* field_trial;
const base::FeatureParam<std::string>* field_trial_url_param;
const float field_trial_threshold_replacement_param;
};
} // namespace assist_ranker
......
......@@ -3,6 +3,7 @@
// found in the LICENSE file.
#include "components/assist_ranker/predictor_config_definitions.h"
#include "components/assist_ranker/base_predictor.h"
namespace assist_ranker {
......@@ -28,6 +29,15 @@ GetContextualSearchRankerUrlFeatureParam() {
return kContextualSearchRankerUrl;
}
float GetContextualSearchRankerThresholdFeatureParam() {
static auto* kContextualSearchRankerThreshold =
new base::FeatureParam<double>(
&kContextualSearchRankerQuery,
"contextual-search-ranker-predict-threshold",
kNoPredictThresholdReplacement);
return static_cast<float>(kContextualSearchRankerThreshold->Get());
}
// NOTE: This list needs to be kept in sync with tools/metrics/ukm/ukm.xml!
// Only features within this list will be logged to UKM.
// TODO(chrome-ranker-team) Deprecate the whitelist once it is available through
......@@ -77,7 +87,8 @@ const PredictorConfig GetContextualSearchPredictorConfig() {
kContextualSearchModelName, kContextualSearchLoggingName,
kContextualSearchUmaPrefixName, LOG_UKM,
GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery,
GetContextualSearchRankerUrlFeatureParam()));
GetContextualSearchRankerUrlFeatureParam(),
GetContextualSearchRankerThresholdFeatureParam()));
return kContextualSearchPredictorConfig;
}
#endif // OS_ANDROID
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment