Commit 0b0d56fd authored by Charles Zhao's avatar Charles Zhao Committed by Commit Bot

Add bucketization support for GLR model.

(1) Add generic logstic regression to support preprocessed model.

(2) Also add validation logic in binary classifer to prevent misuse
preprocessed_model with non-preprocessed_model.

Change-Id: Idf3f531c8d7bd45d31a45f6e54be0752a02334d1
Reviewed-on: https://chromium-review.googlesource.com/885541
Commit-Queue: Charles . <charleszhao@chromium.org>
Reviewed-by: default avatarPhilippe Hamel <hamelphi@chromium.org>
Cr-Commit-Position: refs/heads/master@{#532316}
parent aa8afe0f
...@@ -75,6 +75,25 @@ RankerModelStatus BinaryClassifierPredictor::ValidateModel( ...@@ -75,6 +75,25 @@ RankerModelStatus BinaryClassifierPredictor::ValidateModel(
DVLOG(0) << "Model is incompatible."; DVLOG(0) << "Model is incompatible.";
return RankerModelStatus::INCOMPATIBLE; return RankerModelStatus::INCOMPATIBLE;
} }
const GenericLogisticRegressionModel& glr =
model.proto().logistic_regression();
if (glr.is_preprocessed_model()) {
if (glr.fullname_weights().empty() || !glr.weights().empty()) {
DVLOG(0) << "Model is incompatible. Preprocessed model should use "
"fullname_weights.";
return RankerModelStatus::INCOMPATIBLE;
}
if (!glr.preprocessor_config().feature_indices().empty()) {
DVLOG(0) << "Preprocessed model doesn't need feature indices.";
return RankerModelStatus::INCOMPATIBLE;
}
} else {
if (!glr.fullname_weights().empty() || glr.weights().empty()) {
DVLOG(0) << "Model is incompatible. Non-preprocessed model should use "
"weights.";
return RankerModelStatus::INCOMPATIBLE;
}
}
return RankerModelStatus::OK; return RankerModelStatus::OK;
} }
......
...@@ -31,6 +31,7 @@ class BinaryClassifierPredictorTest : public ::testing::Test { ...@@ -31,6 +31,7 @@ class BinaryClassifierPredictorTest : public ::testing::Test {
protected: protected:
const std::string feature_ = "feature"; const std::string feature_ = "feature";
const float weight_ = 1.0;
const float threshold_ = 0.5; const float threshold_ = 0.5;
}; };
...@@ -68,7 +69,7 @@ BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() { ...@@ -68,7 +69,7 @@ BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
GenericLogisticRegressionModel lr_model; GenericLogisticRegressionModel lr_model;
lr_model.set_bias(-0.5); lr_model.set_bias(-0.5);
lr_model.set_threshold(threshold_); lr_model.set_threshold(threshold_);
(*lr_model.mutable_weights())[feature_].set_scalar(1.0); (*lr_model.mutable_weights())[feature_].set_scalar(weight_);
return lr_model; return lr_model;
} }
...@@ -132,4 +133,33 @@ TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) { ...@@ -132,4 +133,33 @@ TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) {
EXPECT_LT(float_response, threshold_); EXPECT_LT(float_response, threshold_);
} }
TEST_F(BinaryClassifierPredictorTest,
GenericLogisticRegressionPreprocessedModel) {
auto ranker_model = std::make_unique<RankerModel>();
auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
glr = GetSimpleLogisticRegressionModel();
glr.clear_weights();
glr.set_is_preprocessed_model(true);
(*glr.mutable_fullname_weights())[feature_] = weight_;
auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_TRUE(bool_response);
float float_response;
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_GT(float_response, threshold_);
features[feature_].set_bool_value(false);
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_FALSE(bool_response);
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_LT(float_response, threshold_);
}
} // namespace assist_ranker } // namespace assist_ranker
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <cmath> #include <cmath>
#include "base/logging.h" #include "base/logging.h"
#include "components/assist_ranker/example_preprocessing.h"
#include "components/assist_ranker/ranker_example_util.h" #include "components/assist_ranker/ranker_example_util.h"
namespace assist_ranker { namespace assist_ranker {
...@@ -24,54 +25,69 @@ GenericLogisticRegressionInference::GenericLogisticRegressionInference( ...@@ -24,54 +25,69 @@ GenericLogisticRegressionInference::GenericLogisticRegressionInference(
float GenericLogisticRegressionInference::PredictScore( float GenericLogisticRegressionInference::PredictScore(
const RankerExample& example) { const RankerExample& example) {
float activation = 0.0f; float activation = 0.0f;
for (const auto& weight_it : proto_.weights()) { if (!proto_.is_preprocessed_model()) {
const std::string& feature_name = weight_it.first; for (const auto& weight_it : proto_.weights()) {
const FeatureWeight& feature_weight = weight_it.second; const std::string& feature_name = weight_it.first;
switch (feature_weight.feature_type_case()) { const FeatureWeight& feature_weight = weight_it.second;
case FeatureWeight::FEATURE_TYPE_NOT_SET: { switch (feature_weight.feature_type_case()) {
DVLOG(0) << "Feature type not set for " << feature_name; case FeatureWeight::FEATURE_TYPE_NOT_SET: {
break; DVLOG(0) << "Feature type not set for " << feature_name;
} break;
case FeatureWeight::kScalar: {
float value;
if (GetFeatureValueAsFloat(feature_name, example, &value)) {
const float weight = feature_weight.scalar();
activation += value * weight;
} else {
DVLOG(1) << "Feature not in example: " << feature_name;
} }
break; case FeatureWeight::kScalar: {
} float value;
case FeatureWeight::kOneHot: { if (GetFeatureValueAsFloat(feature_name, example, &value)) {
std::string value; const float weight = feature_weight.scalar();
if (GetOneHotValue(feature_name, example, &value)) { activation += value * weight;
const auto& category_weights = feature_weight.one_hot().weights(); } else {
auto category_it = category_weights.find(value); DVLOG(1) << "Feature not in example: " << feature_name;
if (category_it != category_weights.end()) { }
activation += category_it->second; break;
}
case FeatureWeight::kOneHot: {
std::string value;
if (GetOneHotValue(feature_name, example, &value)) {
const auto& category_weights = feature_weight.one_hot().weights();
auto category_it = category_weights.find(value);
if (category_it != category_weights.end()) {
activation += category_it->second;
} else {
// If the category is not found, use the default weight.
activation += feature_weight.one_hot().default_weight();
DVLOG(1) << "Unknown feature value for " << feature_name << ": "
<< value;
}
} else { } else {
// If the category is not found, use the default weight. // If the feature is missing, use the default weight.
activation += feature_weight.one_hot().default_weight(); activation += feature_weight.one_hot().default_weight();
DVLOG(1) << "Unknown feature value for " << feature_name << ": " DVLOG(1) << "Feature not in example: " << feature_name;
<< value;
} }
} else { break;
// If the feature is missing, use the default weight. }
activation += feature_weight.one_hot().default_weight(); case FeatureWeight::kSparse: {
DVLOG(1) << "Feature not in example: " << feature_name; DVLOG(0) << "Sparse features not implemented yet.";
break;
}
case FeatureWeight::kBucketized: {
DVLOG(0) << "Bucketized features not implemented yet.";
break;
} }
break;
}
case FeatureWeight::kSparse: {
DVLOG(0) << "Sparse features not implemented yet.";
break;
} }
case FeatureWeight::kBucketized: { }
DVLOG(0) << "Bucketized features not implemented yet."; } else {
break; RankerExample processed_example = example;
ExamplePreprocessor(proto_.preprocessor_config())
.Process(&processed_example);
for (const auto& field : ExampleFloatIterator(processed_example)) {
if (field.error != ExamplePreprocessor::kSuccess)
continue;
const auto& find_weight = proto_.fullname_weights().find(field.fullname);
if (find_weight != proto_.fullname_weights().end()) {
activation += find_weight->second * field.value;
} }
} }
} }
return Sigmoid(proto_.bias() + activation); return Sigmoid(proto_.bias() + activation);
} }
......
...@@ -3,10 +3,13 @@ ...@@ -3,10 +3,13 @@
// found in the LICENSE file. // found in the LICENSE file.
#include "components/assist_ranker/generic_logistic_regression_inference.h" #include "components/assist_ranker/generic_logistic_regression_inference.h"
#include "components/assist_ranker/example_preprocessing.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
#include "third_party/protobuf/src/google/protobuf/map.h"
namespace assist_ranker { namespace assist_ranker {
using ::google::protobuf::Map;
class GenericLogisticRegressionInferenceTest : public testing::Test { class GenericLogisticRegressionInferenceTest : public testing::Test {
protected: protected:
...@@ -19,11 +22,33 @@ class GenericLogisticRegressionInferenceTest : public testing::Test { ...@@ -19,11 +22,33 @@ class GenericLogisticRegressionInferenceTest : public testing::Test {
weights[scalar1_name_].set_scalar(scalar1_weight_); weights[scalar1_name_].set_scalar(scalar1_weight_);
weights[scalar2_name_].set_scalar(scalar2_weight_); weights[scalar2_name_].set_scalar(scalar2_weight_);
weights[scalar3_name_].set_scalar(scalar3_weight_); weights[scalar3_name_].set_scalar(scalar3_weight_);
auto* one_hot_feat = weights[one_hot_name_].mutable_one_hot(); auto* one_hot_feat = weights[one_hot_name_].mutable_one_hot();
one_hot_feat->set_default_weight(one_hot_default_weight_); one_hot_feat->set_default_weight(one_hot_default_weight_);
(*one_hot_feat->mutable_weights())[one_hot_elem1_name_] = elem1_weight_; (*one_hot_feat->mutable_weights())[one_hot_elem1_name_] =
(*one_hot_feat->mutable_weights())[one_hot_elem2_name_] = elem2_weight_; one_hot_elem1_weight_;
(*one_hot_feat->mutable_weights())[one_hot_elem3_name_] = elem3_weight_; (*one_hot_feat->mutable_weights())[one_hot_elem2_name_] =
one_hot_elem2_weight_;
(*one_hot_feat->mutable_weights())[one_hot_elem3_name_] =
one_hot_elem3_weight_;
SparseWeights* sparse_feat = weights[sparse_name_].mutable_sparse();
sparse_feat->set_default_weight(sparse_default_weight_);
(*sparse_feat->mutable_weights())[sparse_elem1_name_] =
sparse_elem1_weight_;
(*sparse_feat->mutable_weights())[sparse_elem2_name_] =
sparse_elem2_weight_;
BucketizedWeights* bucketized_feat =
weights[bucketized_name_].mutable_bucketized();
bucketized_feat->set_default_weight(bucketization_default_weight_);
for (const float boundary : bucketization_boundaries_) {
bucketized_feat->add_boundaries(boundary);
}
for (const float weight : bucketization_weights_) {
bucketized_feat->add_weights(weight);
}
return proto; return proto;
} }
...@@ -31,19 +56,31 @@ class GenericLogisticRegressionInferenceTest : public testing::Test { ...@@ -31,19 +56,31 @@ class GenericLogisticRegressionInferenceTest : public testing::Test {
const std::string scalar2_name_ = "scalar_feature2"; const std::string scalar2_name_ = "scalar_feature2";
const std::string scalar3_name_ = "scalar_feature3"; const std::string scalar3_name_ = "scalar_feature3";
const std::string one_hot_name_ = "one_hot_feature"; const std::string one_hot_name_ = "one_hot_feature";
const std::string one_hot_elem1_name_ = "elem1"; const std::string one_hot_elem1_name_ = "one_hot_elem1";
const std::string one_hot_elem2_name_ = "elem2"; const std::string one_hot_elem2_name_ = "one_hot_elem2";
const std::string one_hot_elem3_name_ = "elem3"; const std::string one_hot_elem3_name_ = "one_hot_elem3";
const float bias_ = 1.5f; const float bias_ = 1.5f;
const float threshold_ = 0.6f; const float threshold_ = 0.6f;
const float scalar1_weight_ = 0.8f; const float scalar1_weight_ = 0.8f;
const float scalar2_weight_ = -2.4f; const float scalar2_weight_ = -2.4f;
const float scalar3_weight_ = 0.01f; const float scalar3_weight_ = 0.01f;
const float elem1_weight_ = -1.0f; const float one_hot_elem1_weight_ = -1.0f;
const float elem2_weight_ = 5.0f; const float one_hot_elem2_weight_ = 5.0f;
const float elem3_weight_ = -1.5f; const float one_hot_elem3_weight_ = -1.5f;
const float one_hot_default_weight_ = 10.0f; const float one_hot_default_weight_ = 10.0f;
const float epsilon_ = 0.001f; const float epsilon_ = 0.001f;
const std::string sparse_name_ = "sparse_feature";
const std::string sparse_elem1_name_ = "sparse_elem1";
const std::string sparse_elem2_name_ = "sparse_elem2";
const float sparse_elem1_weight_ = -2.2f;
const float sparse_elem2_weight_ = 3.1f;
const float sparse_default_weight_ = 4.4f;
const std::string bucketized_name_ = "bucketized_feature";
const float bucketization_boundaries_[2] = {0.3f, 0.7f};
const float bucketization_weights_[3] = {-1.0f, 1.0f, 3.0f};
const float bucketization_default_weight_ = -3.3f;
}; };
TEST_F(GenericLogisticRegressionInferenceTest, BaseTest) { TEST_F(GenericLogisticRegressionInferenceTest, BaseTest) {
...@@ -59,7 +96,7 @@ TEST_F(GenericLogisticRegressionInferenceTest, BaseTest) { ...@@ -59,7 +96,7 @@ TEST_F(GenericLogisticRegressionInferenceTest, BaseTest) {
float score = predictor.PredictScore(example); float score = predictor.PredictScore(example);
float expected_score = float expected_score =
Sigmoid(bias_ + 1.0f * scalar1_weight_ + 42.0f * scalar2_weight_ + Sigmoid(bias_ + 1.0f * scalar1_weight_ + 42.0f * scalar2_weight_ +
0.666f * scalar3_weight_ + elem1_weight_); 0.666f * scalar3_weight_ + one_hot_elem1_weight_);
EXPECT_NEAR(expected_score, score, epsilon_); EXPECT_NEAR(expected_score, score, epsilon_);
EXPECT_EQ(expected_score >= threshold_, predictor.Predict(example)); EXPECT_EQ(expected_score >= threshold_, predictor.Predict(example));
} }
...@@ -99,7 +136,7 @@ TEST_F(GenericLogisticRegressionInferenceTest, UnknownFeatures) { ...@@ -99,7 +136,7 @@ TEST_F(GenericLogisticRegressionInferenceTest, UnknownFeatures) {
auto predictor = GenericLogisticRegressionInference(GetProto()); auto predictor = GenericLogisticRegressionInference(GetProto());
float score = predictor.PredictScore(example); float score = predictor.PredictScore(example);
// Unknown features will be ignored. // Unknown features will be ignored.
float expected_score = Sigmoid(bias_ + elem2_weight_); float expected_score = Sigmoid(bias_ + one_hot_elem2_weight_);
EXPECT_NEAR(expected_score, score, epsilon_); EXPECT_NEAR(expected_score, score, epsilon_);
} }
...@@ -177,4 +214,77 @@ TEST_F(GenericLogisticRegressionInferenceTest, NoThreshold) { ...@@ -177,4 +214,77 @@ TEST_F(GenericLogisticRegressionInferenceTest, NoThreshold) {
EXPECT_FALSE(predictor.Predict(example)); EXPECT_FALSE(predictor.Predict(example));
} }
TEST_F(GenericLogisticRegressionInferenceTest, PreprossessedModel) {
GenericLogisticRegressionModel proto = GetProto();
proto.set_is_preprocessed_model(true);
// Clear the weights to make sure the inference is done by fullname_weights.
proto.clear_weights();
// Build fullname weights.
Map<std::string, float>& weights = *proto.mutable_fullname_weights();
weights[scalar1_name_] = scalar1_weight_;
weights[scalar2_name_] = scalar2_weight_;
weights[scalar3_name_] = scalar3_weight_;
weights[ExamplePreprocessor::FeatureFullname(
one_hot_name_, one_hot_elem1_name_)] = one_hot_elem1_weight_;
weights[ExamplePreprocessor::FeatureFullname(
one_hot_name_, one_hot_elem2_name_)] = one_hot_elem2_weight_;
weights[ExamplePreprocessor::FeatureFullname(
one_hot_name_, one_hot_elem3_name_)] = one_hot_elem3_weight_;
weights[ExamplePreprocessor::FeatureFullname(
sparse_name_, sparse_elem1_name_)] = sparse_elem1_weight_;
weights[ExamplePreprocessor::FeatureFullname(
sparse_name_, sparse_elem2_name_)] = sparse_elem2_weight_;
weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "0")] =
bucketization_weights_[0];
weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "1")] =
bucketization_weights_[1];
weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "2")] =
bucketization_weights_[2];
weights[ExamplePreprocessor::FeatureFullname(
ExamplePreprocessor::kMissingFeatureDefaultName, one_hot_name_)] =
one_hot_default_weight_;
weights[ExamplePreprocessor::FeatureFullname(
ExamplePreprocessor::kMissingFeatureDefaultName, sparse_name_)] =
sparse_default_weight_;
weights[ExamplePreprocessor::FeatureFullname(
ExamplePreprocessor::kMissingFeatureDefaultName, bucketized_name_)] =
bucketization_default_weight_;
// Build preprocessor_config.
ExamplePreprocessorConfig& config = *proto.mutable_preprocessor_config();
config.add_missing_features(one_hot_name_);
config.add_missing_features(sparse_name_);
config.add_missing_features(bucketized_name_);
(*config.mutable_bucketizers())[bucketized_name_].add_boundaries(
bucketization_boundaries_[0]);
(*config.mutable_bucketizers())[bucketized_name_].add_boundaries(
bucketization_boundaries_[1]);
auto predictor = GenericLogisticRegressionInference(proto);
// Build example.
RankerExample example;
Map<std::string, Feature>& features = *example.mutable_features();
features[scalar1_name_].set_bool_value(true);
features[scalar2_name_].set_int32_value(42);
features[scalar3_name_].set_float_value(0.666f);
features[one_hot_name_].set_string_value(one_hot_elem1_name_);
features[sparse_name_].mutable_string_list()->add_string_value(
sparse_elem1_name_);
features[sparse_name_].mutable_string_list()->add_string_value(
sparse_elem2_name_);
features[bucketized_name_].set_float_value(0.98f);
// Inference.
float score = predictor.PredictScore(example);
float expected_score = Sigmoid(
bias_ + 1.0f * scalar1_weight_ + 42.0f * scalar2_weight_ +
0.666f * scalar3_weight_ + one_hot_elem1_weight_ + sparse_elem1_weight_ +
sparse_elem2_weight_ + bucketization_weights_[2]);
EXPECT_NEAR(expected_score, score, epsilon_);
EXPECT_EQ(expected_score >= threshold_, predictor.Predict(example));
}
} // namespace assist_ranker } // namespace assist_ranker
...@@ -8,6 +8,8 @@ syntax = "proto2"; ...@@ -8,6 +8,8 @@ syntax = "proto2";
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
import "example_preprocessor.proto";
package assist_ranker; package assist_ranker;
message SparseWeights { message SparseWeights {
...@@ -69,4 +71,15 @@ message GenericLogisticRegressionModel { ...@@ -69,4 +71,15 @@ message GenericLogisticRegressionModel {
// Map of weights keyed by feature name. Features can be scalar, one-hot, // Map of weights keyed by feature name. Features can be scalar, one-hot,
// sparse or bucketized. // sparse or bucketized.
map<string, FeatureWeight> weights = 3; map<string, FeatureWeight> weights = 3;
}
\ No newline at end of file // If it's a preprocessed_model, then use preprocessor_config to preprocess
// the input and fullname_weights to calculate the score.
optional bool is_preprocessed_model = 4;
// Map from feature fullname to it's weights.
map<string, float> fullname_weights = 5;
// Config for preprocessor (without feature_indices; there is no need for
// vectorization, since the inference model use ExampleFloatIterator instead).
optional ExamplePreprocessorConfig preprocessor_config = 6;
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment