Commit 0b0d56fd authored by Charles Zhao's avatar Charles Zhao Committed by Commit Bot

Add bucketization support for GLR model.

(1) Add generic logstic regression to support preprocessed model.

(2) Also add validation logic in binary classifer to prevent misuse
preprocessed_model with non-preprocessed_model.

Change-Id: Idf3f531c8d7bd45d31a45f6e54be0752a02334d1
Reviewed-on: https://chromium-review.googlesource.com/885541
Commit-Queue: Charles . <charleszhao@chromium.org>
Reviewed-by: default avatarPhilippe Hamel <hamelphi@chromium.org>
Cr-Commit-Position: refs/heads/master@{#532316}
parent aa8afe0f
......@@ -75,6 +75,25 @@ RankerModelStatus BinaryClassifierPredictor::ValidateModel(
DVLOG(0) << "Model is incompatible.";
return RankerModelStatus::INCOMPATIBLE;
}
const GenericLogisticRegressionModel& glr =
model.proto().logistic_regression();
if (glr.is_preprocessed_model()) {
if (glr.fullname_weights().empty() || !glr.weights().empty()) {
DVLOG(0) << "Model is incompatible. Preprocessed model should use "
"fullname_weights.";
return RankerModelStatus::INCOMPATIBLE;
}
if (!glr.preprocessor_config().feature_indices().empty()) {
DVLOG(0) << "Preprocessed model doesn't need feature indices.";
return RankerModelStatus::INCOMPATIBLE;
}
} else {
if (!glr.fullname_weights().empty() || glr.weights().empty()) {
DVLOG(0) << "Model is incompatible. Non-preprocessed model should use "
"weights.";
return RankerModelStatus::INCOMPATIBLE;
}
}
return RankerModelStatus::OK;
}
......
......@@ -31,6 +31,7 @@ class BinaryClassifierPredictorTest : public ::testing::Test {
protected:
const std::string feature_ = "feature";
const float weight_ = 1.0;
const float threshold_ = 0.5;
};
......@@ -68,7 +69,7 @@ BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
GenericLogisticRegressionModel lr_model;
lr_model.set_bias(-0.5);
lr_model.set_threshold(threshold_);
(*lr_model.mutable_weights())[feature_].set_scalar(1.0);
(*lr_model.mutable_weights())[feature_].set_scalar(weight_);
return lr_model;
}
......@@ -132,4 +133,33 @@ TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) {
EXPECT_LT(float_response, threshold_);
}
TEST_F(BinaryClassifierPredictorTest,
GenericLogisticRegressionPreprocessedModel) {
auto ranker_model = std::make_unique<RankerModel>();
auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
glr = GetSimpleLogisticRegressionModel();
glr.clear_weights();
glr.set_is_preprocessed_model(true);
(*glr.mutable_fullname_weights())[feature_] = weight_;
auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_TRUE(bool_response);
float float_response;
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_GT(float_response, threshold_);
features[feature_].set_bool_value(false);
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_FALSE(bool_response);
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_LT(float_response, threshold_);
}
} // namespace assist_ranker
......@@ -7,6 +7,7 @@
#include <cmath>
#include "base/logging.h"
#include "components/assist_ranker/example_preprocessing.h"
#include "components/assist_ranker/ranker_example_util.h"
namespace assist_ranker {
......@@ -24,6 +25,7 @@ GenericLogisticRegressionInference::GenericLogisticRegressionInference(
float GenericLogisticRegressionInference::PredictScore(
const RankerExample& example) {
float activation = 0.0f;
if (!proto_.is_preprocessed_model()) {
for (const auto& weight_it : proto_.weights()) {
const std::string& feature_name = weight_it.first;
const FeatureWeight& feature_weight = weight_it.second;
......@@ -72,6 +74,20 @@ float GenericLogisticRegressionInference::PredictScore(
}
}
}
} else {
RankerExample processed_example = example;
ExamplePreprocessor(proto_.preprocessor_config())
.Process(&processed_example);
for (const auto& field : ExampleFloatIterator(processed_example)) {
if (field.error != ExamplePreprocessor::kSuccess)
continue;
const auto& find_weight = proto_.fullname_weights().find(field.fullname);
if (find_weight != proto_.fullname_weights().end()) {
activation += find_weight->second * field.value;
}
}
}
return Sigmoid(proto_.bias() + activation);
}
......
......@@ -3,10 +3,13 @@
// found in the LICENSE file.
#include "components/assist_ranker/generic_logistic_regression_inference.h"
#include "components/assist_ranker/example_preprocessing.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/protobuf/src/google/protobuf/map.h"
namespace assist_ranker {
using ::google::protobuf::Map;
class GenericLogisticRegressionInferenceTest : public testing::Test {
protected:
......@@ -19,11 +22,33 @@ class GenericLogisticRegressionInferenceTest : public testing::Test {
weights[scalar1_name_].set_scalar(scalar1_weight_);
weights[scalar2_name_].set_scalar(scalar2_weight_);
weights[scalar3_name_].set_scalar(scalar3_weight_);
auto* one_hot_feat = weights[one_hot_name_].mutable_one_hot();
one_hot_feat->set_default_weight(one_hot_default_weight_);
(*one_hot_feat->mutable_weights())[one_hot_elem1_name_] = elem1_weight_;
(*one_hot_feat->mutable_weights())[one_hot_elem2_name_] = elem2_weight_;
(*one_hot_feat->mutable_weights())[one_hot_elem3_name_] = elem3_weight_;
(*one_hot_feat->mutable_weights())[one_hot_elem1_name_] =
one_hot_elem1_weight_;
(*one_hot_feat->mutable_weights())[one_hot_elem2_name_] =
one_hot_elem2_weight_;
(*one_hot_feat->mutable_weights())[one_hot_elem3_name_] =
one_hot_elem3_weight_;
SparseWeights* sparse_feat = weights[sparse_name_].mutable_sparse();
sparse_feat->set_default_weight(sparse_default_weight_);
(*sparse_feat->mutable_weights())[sparse_elem1_name_] =
sparse_elem1_weight_;
(*sparse_feat->mutable_weights())[sparse_elem2_name_] =
sparse_elem2_weight_;
BucketizedWeights* bucketized_feat =
weights[bucketized_name_].mutable_bucketized();
bucketized_feat->set_default_weight(bucketization_default_weight_);
for (const float boundary : bucketization_boundaries_) {
bucketized_feat->add_boundaries(boundary);
}
for (const float weight : bucketization_weights_) {
bucketized_feat->add_weights(weight);
}
return proto;
}
......@@ -31,19 +56,31 @@ class GenericLogisticRegressionInferenceTest : public testing::Test {
const std::string scalar2_name_ = "scalar_feature2";
const std::string scalar3_name_ = "scalar_feature3";
const std::string one_hot_name_ = "one_hot_feature";
const std::string one_hot_elem1_name_ = "elem1";
const std::string one_hot_elem2_name_ = "elem2";
const std::string one_hot_elem3_name_ = "elem3";
const std::string one_hot_elem1_name_ = "one_hot_elem1";
const std::string one_hot_elem2_name_ = "one_hot_elem2";
const std::string one_hot_elem3_name_ = "one_hot_elem3";
const float bias_ = 1.5f;
const float threshold_ = 0.6f;
const float scalar1_weight_ = 0.8f;
const float scalar2_weight_ = -2.4f;
const float scalar3_weight_ = 0.01f;
const float elem1_weight_ = -1.0f;
const float elem2_weight_ = 5.0f;
const float elem3_weight_ = -1.5f;
const float one_hot_elem1_weight_ = -1.0f;
const float one_hot_elem2_weight_ = 5.0f;
const float one_hot_elem3_weight_ = -1.5f;
const float one_hot_default_weight_ = 10.0f;
const float epsilon_ = 0.001f;
const std::string sparse_name_ = "sparse_feature";
const std::string sparse_elem1_name_ = "sparse_elem1";
const std::string sparse_elem2_name_ = "sparse_elem2";
const float sparse_elem1_weight_ = -2.2f;
const float sparse_elem2_weight_ = 3.1f;
const float sparse_default_weight_ = 4.4f;
const std::string bucketized_name_ = "bucketized_feature";
const float bucketization_boundaries_[2] = {0.3f, 0.7f};
const float bucketization_weights_[3] = {-1.0f, 1.0f, 3.0f};
const float bucketization_default_weight_ = -3.3f;
};
TEST_F(GenericLogisticRegressionInferenceTest, BaseTest) {
......@@ -59,7 +96,7 @@ TEST_F(GenericLogisticRegressionInferenceTest, BaseTest) {
float score = predictor.PredictScore(example);
float expected_score =
Sigmoid(bias_ + 1.0f * scalar1_weight_ + 42.0f * scalar2_weight_ +
0.666f * scalar3_weight_ + elem1_weight_);
0.666f * scalar3_weight_ + one_hot_elem1_weight_);
EXPECT_NEAR(expected_score, score, epsilon_);
EXPECT_EQ(expected_score >= threshold_, predictor.Predict(example));
}
......@@ -99,7 +136,7 @@ TEST_F(GenericLogisticRegressionInferenceTest, UnknownFeatures) {
auto predictor = GenericLogisticRegressionInference(GetProto());
float score = predictor.PredictScore(example);
// Unknown features will be ignored.
float expected_score = Sigmoid(bias_ + elem2_weight_);
float expected_score = Sigmoid(bias_ + one_hot_elem2_weight_);
EXPECT_NEAR(expected_score, score, epsilon_);
}
......@@ -177,4 +214,77 @@ TEST_F(GenericLogisticRegressionInferenceTest, NoThreshold) {
EXPECT_FALSE(predictor.Predict(example));
}
TEST_F(GenericLogisticRegressionInferenceTest, PreprossessedModel) {
GenericLogisticRegressionModel proto = GetProto();
proto.set_is_preprocessed_model(true);
// Clear the weights to make sure the inference is done by fullname_weights.
proto.clear_weights();
// Build fullname weights.
Map<std::string, float>& weights = *proto.mutable_fullname_weights();
weights[scalar1_name_] = scalar1_weight_;
weights[scalar2_name_] = scalar2_weight_;
weights[scalar3_name_] = scalar3_weight_;
weights[ExamplePreprocessor::FeatureFullname(
one_hot_name_, one_hot_elem1_name_)] = one_hot_elem1_weight_;
weights[ExamplePreprocessor::FeatureFullname(
one_hot_name_, one_hot_elem2_name_)] = one_hot_elem2_weight_;
weights[ExamplePreprocessor::FeatureFullname(
one_hot_name_, one_hot_elem3_name_)] = one_hot_elem3_weight_;
weights[ExamplePreprocessor::FeatureFullname(
sparse_name_, sparse_elem1_name_)] = sparse_elem1_weight_;
weights[ExamplePreprocessor::FeatureFullname(
sparse_name_, sparse_elem2_name_)] = sparse_elem2_weight_;
weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "0")] =
bucketization_weights_[0];
weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "1")] =
bucketization_weights_[1];
weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "2")] =
bucketization_weights_[2];
weights[ExamplePreprocessor::FeatureFullname(
ExamplePreprocessor::kMissingFeatureDefaultName, one_hot_name_)] =
one_hot_default_weight_;
weights[ExamplePreprocessor::FeatureFullname(
ExamplePreprocessor::kMissingFeatureDefaultName, sparse_name_)] =
sparse_default_weight_;
weights[ExamplePreprocessor::FeatureFullname(
ExamplePreprocessor::kMissingFeatureDefaultName, bucketized_name_)] =
bucketization_default_weight_;
// Build preprocessor_config.
ExamplePreprocessorConfig& config = *proto.mutable_preprocessor_config();
config.add_missing_features(one_hot_name_);
config.add_missing_features(sparse_name_);
config.add_missing_features(bucketized_name_);
(*config.mutable_bucketizers())[bucketized_name_].add_boundaries(
bucketization_boundaries_[0]);
(*config.mutable_bucketizers())[bucketized_name_].add_boundaries(
bucketization_boundaries_[1]);
auto predictor = GenericLogisticRegressionInference(proto);
// Build example.
RankerExample example;
Map<std::string, Feature>& features = *example.mutable_features();
features[scalar1_name_].set_bool_value(true);
features[scalar2_name_].set_int32_value(42);
features[scalar3_name_].set_float_value(0.666f);
features[one_hot_name_].set_string_value(one_hot_elem1_name_);
features[sparse_name_].mutable_string_list()->add_string_value(
sparse_elem1_name_);
features[sparse_name_].mutable_string_list()->add_string_value(
sparse_elem2_name_);
features[bucketized_name_].set_float_value(0.98f);
// Inference.
float score = predictor.PredictScore(example);
float expected_score = Sigmoid(
bias_ + 1.0f * scalar1_weight_ + 42.0f * scalar2_weight_ +
0.666f * scalar3_weight_ + one_hot_elem1_weight_ + sparse_elem1_weight_ +
sparse_elem2_weight_ + bucketization_weights_[2]);
EXPECT_NEAR(expected_score, score, epsilon_);
EXPECT_EQ(expected_score >= threshold_, predictor.Predict(example));
}
} // namespace assist_ranker
......@@ -8,6 +8,8 @@ syntax = "proto2";
option optimize_for = LITE_RUNTIME;
import "example_preprocessor.proto";
package assist_ranker;
message SparseWeights {
......@@ -69,4 +71,15 @@ message GenericLogisticRegressionModel {
// Map of weights keyed by feature name. Features can be scalar, one-hot,
// sparse or bucketized.
map<string, FeatureWeight> weights = 3;
// If it's a preprocessed_model, then use preprocessor_config to preprocess
// the input and fullname_weights to calculate the score.
optional bool is_preprocessed_model = 4;
// Map from feature fullname to it's weights.
map<string, float> fullname_weights = 5;
// Config for preprocessor (without feature_indices; there is no need for
// vectorization, since the inference model use ExampleFloatIterator instead).
optional ExamplePreprocessorConfig preprocessor_config = 6;
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment