Commit bb47b546 authored by Jon Napper's avatar Jon Napper Committed by Commit Bot

Added quantized NN classifier for AssistRanker.

Added support for 8-bit quantized NN classifier models for AssistRanker.
The model must be dequantized prior to inferencing.

Bug: 907727
Change-Id: Ie52a55eb531eac9b3617dfd47bf27aa91b528347
Reviewed-on: https://chromium-review.googlesource.com/c/1370229Reviewed-by: default avatarCharles . <charleszhao@chromium.org>
Commit-Queue: Jon Napper <napper@chromium.org>
Cr-Commit-Position: refs/heads/master@{#615767}
parent 471e5cf6
...@@ -23,6 +23,8 @@ static_library("assist_ranker") { ...@@ -23,6 +23,8 @@ static_library("assist_ranker") {
"predictor_config.h", "predictor_config.h",
"predictor_config_definitions.cc", "predictor_config_definitions.cc",
"predictor_config_definitions.h", "predictor_config_definitions.h",
"quantized_nn_classifier.cc",
"quantized_nn_classifier.h",
"ranker_example_util.cc", "ranker_example_util.cc",
"ranker_example_util.h", "ranker_example_util.h",
"ranker_model.cc", "ranker_model.cc",
...@@ -59,6 +61,7 @@ source_set("unit_tests") { ...@@ -59,6 +61,7 @@ source_set("unit_tests") {
"nn_classifier_test_util.cc", "nn_classifier_test_util.cc",
"nn_classifier_test_util.h", "nn_classifier_test_util.h",
"nn_classifier_unittest.cc", "nn_classifier_unittest.cc",
"quantized_nn_classifier_unittest.cc",
"ranker_example_util_unittest.cc", "ranker_example_util_unittest.cc",
"ranker_model_loader_impl_unittest.cc", "ranker_model_loader_impl_unittest.cc",
"ranker_model_unittest.cc", "ranker_model_unittest.cc",
......
...@@ -9,6 +9,7 @@ proto_library("proto") { ...@@ -9,6 +9,7 @@ proto_library("proto") {
"example_preprocessor.proto", "example_preprocessor.proto",
"generic_logistic_regression_model.proto", "generic_logistic_regression_model.proto",
"nn_classifier.proto", "nn_classifier.proto",
"quantized_nn_classifier.proto",
"ranker_example.proto", "ranker_example.proto",
"ranker_model.proto", "ranker_model.proto",
"translate_ranker_model.proto", "translate_ranker_model.proto",
......
// Copyright (c) 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package assist_ranker;
// The weights and biases for a single quantized neural-network layer. All
// weight and bias values are 8-bit and can be converted to floats using
// value = x * (high - low) / 256 + low.
message QuantizedNNLayer {
// The weights for the layer.
repeated bytes weights = 1;
// The bias vectors for the layer.
optional bytes biases = 2;
// The low value used to dequantize the weights.
optional float low = 3;
// The high value used to dequantize the weights.
optional float high = 4;
}
// Defines the model weights and biases for a single layer neural network.
message QuantizedNNClassifierModel {
optional QuantizedNNLayer hidden_layer = 1;
optional QuantizedNNLayer logits_layer = 2;
}
// Copyright (c) 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/quantized_nn_classifier.h"
#include "base/logging.h"
#include "components/assist_ranker/nn_classifier.h"
namespace assist_ranker {
namespace quantized_nn_classifier {
namespace {
// Dequantized a set of unsigned 8-bit weights using the specified scaling
// factor and base value.
void DequantizeVector(const std::string& s,
float scale,
float low,
FloatVector* v) {
for (const unsigned char ch : s) {
v->mutable_values()->Add(scale * ch + low);
}
}
// Dequantizes a quantized NN layer.
void DequantizeLayer(const QuantizedNNLayer& quantized, NNLayer* layer) {
const float low = quantized.low();
const float scale = (quantized.high() - low) / 256;
DequantizeVector(quantized.biases(), scale, low, layer->mutable_biases());
for (const std::string& s : quantized.weights()) {
auto* p = layer->mutable_weights()->Add();
DequantizeVector(s, scale, low, p);
}
}
bool ValidateLayer(const QuantizedNNLayer& layer) {
// The quantization low value must always be less than the high value.
return layer.low() < layer.high();
}
} // namespace
NNClassifierModel Dequantize(const QuantizedNNClassifierModel& quantized) {
NNClassifierModel model;
DequantizeLayer(quantized.hidden_layer(), model.mutable_hidden_layer());
DequantizeLayer(quantized.logits_layer(), model.mutable_logits_layer());
return model;
}
bool Validate(const QuantizedNNClassifierModel& quantized) {
if (!ValidateLayer(quantized.hidden_layer()) ||
!ValidateLayer(quantized.logits_layer())) {
return false;
}
return nn_classifier::Validate(Dequantize(quantized));
}
} // namespace quantized_nn_classifier
} // namespace assist_ranker
// Copyright (c) 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_ASSIST_RANKER_QUANTIZED_NN_CLASSIFIER_H_
#define COMPONENTS_ASSIST_RANKER_QUANTIZED_NN_CLASSIFIER_H_
#include "components/assist_ranker/proto/nn_classifier.pb.h"
#include "components/assist_ranker/proto/quantized_nn_classifier.pb.h"
namespace assist_ranker {
namespace quantized_nn_classifier {
// Verifies that the dimensions and quantization high / low values are valid.
// Returns true if value, false otherwise.
bool Validate(const QuantizedNNClassifierModel& quantized);
// Dequantizes the weights and biases in a quantized NN classifier model. This
// must be done before inferencing.
NNClassifierModel Dequantize(const QuantizedNNClassifierModel& quantized);
} // namespace quantized_nn_classifier
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_QUANTIZED_NN_CLASSIFIER_H_
// Copyright (c) 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/quantized_nn_classifier.h"
#include "base/logging.h"
#include "components/assist_ranker/nn_classifier.h"
#include "components/assist_ranker/nn_classifier_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace assist_ranker {
namespace quantized_nn_classifier {
namespace {
using ::google::protobuf::RepeatedFieldBackInserter;
using ::google::protobuf::RepeatedPtrField;
using ::std::copy;
using ::std::vector;
void CreateLayer(const vector<int>& biases,
const vector<vector<int>>& weights,
float low,
float high,
QuantizedNNLayer* layer) {
layer->set_biases(std::string(biases.begin(), biases.end()));
for (const auto& i : weights) {
layer->mutable_weights()->Add(std::string(i.begin(), i.end()));
}
layer->set_low(low);
layer->set_high(high);
}
// Creates a QuantizedDNNClassifierModel proto using a trained set of biases and
// weights.
QuantizedNNClassifierModel CreateModel(
const vector<int>& hidden_biases,
const vector<vector<int>>& hidden_weights,
const vector<int>& logits_biases,
const vector<vector<int>>& logits_weights,
float low,
float high) {
QuantizedNNClassifierModel model;
CreateLayer(hidden_biases, hidden_weights, low, high,
model.mutable_hidden_layer());
CreateLayer(logits_biases, logits_weights, low, high,
model.mutable_logits_layer());
return model;
}
TEST(QuantizedNNClassifierTest, Dequantize) {
const QuantizedNNClassifierModel quantized = CreateModel(
// Hidden biases.
{{8, 16, 32}},
// Hidden weights.
{{2, 4, 6}, {10, 4, 8}},
// Logits biases.
{2},
// Logits weights.
{{4}, {2}, {6}},
// Low.
0,
// High.
128);
ASSERT_TRUE(Validate(quantized));
const NNClassifierModel model = Dequantize(quantized);
const NNClassifierModel expected = nn_classifier::CreateModel(
// Hidden biases.
{{4, 8, 16}},
// Hidden weights.
{{1, 2, 3}, {5, 2, 4}},
// Logits biases.
{1},
// Logits weights.
{{2}, {1}, {3}});
EXPECT_EQ(model.SerializeAsString(), expected.SerializeAsString());
}
TEST(QuantizedNNClassifierTest, XorTest) {
// Creates a NN with a single hidden layer of 5 units that solves XOR.
// Creates a QuantizedDNNClassifier model containing the trained biases and
// weights.
const QuantizedNNClassifierModel quantized = CreateModel(
// Hidden biases.
{{110, 139, 175, 55, 106}},
// Hidden weights.
{{228, 127, 97, 217, 158}, {55, 219, 80, 199, 152}},
// Logits biases.
{74},
// Logits weights.
{{255}, {211}, {53}, {0}, {86}},
// Low.
-2.96390629,
// High.
2.8636384);
ASSERT_TRUE(Validate(quantized));
const NNClassifierModel model = Dequantize(quantized);
ASSERT_TRUE(nn_classifier::Validate(model));
EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 0}, {-2.7032}));
EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 1}, {2.80681}));
EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 0}, {2.64435}));
EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 1}, {-3.17825}));
}
TEST(QuantizedNNClassifierTest, ValidateQuantizedNNClassifierModel) {
// Empty model.
QuantizedNNClassifierModel model;
EXPECT_FALSE(Validate(model));
// Valid model.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
0, 1);
EXPECT_TRUE(Validate(model));
// Hidden bias incorrect size.
model =
CreateModel({0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
// Hidden weight vector incorrect size.
model =
CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
// Logits weights incorrect size.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
// Empty logits bias.
model =
CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {}, {{0}, {0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
// Low / high incorrect.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
1, 0);
EXPECT_FALSE(Validate(model));
}
} // namespace
} // namespace quantized_nn_classifier
} // namespace assist_ranker
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment