Commit 477234d9 authored by Mehrdad Hessar's avatar Mehrdad Hessar Committed by Commit Bot

This CL renames the TFLitePredictor class to InProcessTFLitePredictor.

Bug: 1116626
Change-Id: I3b0542f0b9a670c797e6e61329864f2fb9a8abff
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2358059Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Commit-Queue: Mehrdad Hessar <mehrdadh@google.com>
Cr-Commit-Position: refs/heads/master@{#798877}
parent dc08a996
......@@ -14,7 +14,7 @@ TFLiteExperimentKeyedService::TFLiteExperimentKeyedService(
if (!model_path)
return;
predictor_ = std::make_unique<machine_learning::TFLitePredictor>(
predictor_ = std::make_unique<machine_learning::InProcessTFLitePredictor>(
model_path.value(),
tflite_experiment::switches::GetTFLitePredictorNumThreads());
predictor_->Initialize();
......
......@@ -5,7 +5,7 @@
#ifndef CHROME_BROWSER_TFLITE_EXPERIMENT_TFLITE_EXPERIMENT_KEYED_SERVICE_H_
#define CHROME_BROWSER_TFLITE_EXPERIMENT_TFLITE_EXPERIMENT_KEYED_SERVICE_H_
#include "chrome/services/machine_learning/machine_learning_tflite_predictor.h"
#include "chrome/services/machine_learning/in_process_tflite_predictor.h"
#include "components/keyed_service/core/keyed_service.h"
namespace content {
......@@ -20,14 +20,14 @@ class TFLiteExperimentKeyedService : public KeyedService {
content::BrowserContext* browser_context);
~TFLiteExperimentKeyedService() override;
machine_learning::TFLitePredictor* tflite_predictor() {
machine_learning::InProcessTFLitePredictor* tflite_predictor() {
return predictor_.get();
}
private:
// The predictor owned by this keyed service capable of
// running a TFLite model.
std::unique_ptr<machine_learning::TFLitePredictor> predictor_;
std::unique_ptr<machine_learning::InProcessTFLitePredictor> predictor_;
};
#endif // CHROME_BROWSER_TFLITE_EXPERIMENT_TFLITE_EXPERIMENT_KEYED_SERVICE_H_
......@@ -22,8 +22,8 @@ constexpr int32_t kTFLitePredictorEvaluationLoop = 10;
namespace {
// Returns the TFLitePredictor.
machine_learning::TFLitePredictor* GetTFLitePredictorFromWebContents(
// Returns the InProcessTFLitePredictor.
machine_learning::InProcessTFLitePredictor* GetTFLitePredictorFromWebContents(
content::WebContents* web_contents) {
if (!web_contents)
return nullptr;
......
......@@ -9,7 +9,7 @@
#include "base/macros.h"
#include "base/timer/timer.h"
#include "chrome/services/machine_learning/machine_learning_tflite_predictor.h"
#include "chrome/services/machine_learning/in_process_tflite_predictor.h"
#include "content/public/browser/web_contents_observer.h"
#include "content/public/browser/web_contents_user_data.h"
......@@ -49,7 +49,7 @@ class TFLiteExperimentObserver
const std::string&);
// The predictor is capable of running a TFLite model.
machine_learning::TFLitePredictor* tflite_predictor_ = nullptr;
machine_learning::InProcessTFLitePredictor* tflite_predictor_ = nullptr;
// True when |tflite_predictor_| ran model evaluation. It forces
// the observer to run tflite prediction only once.
......
......@@ -18,8 +18,8 @@ source_set("machine_learning") {
if (build_with_tflite_lib) {
sources += [
"machine_learning_tflite_predictor.cc",
"machine_learning_tflite_predictor.h",
"in_process_tflite_predictor.cc",
"in_process_tflite_predictor.h",
]
deps += [
......@@ -52,7 +52,7 @@ source_set("unit_tests") {
]
if (build_with_tflite_lib) {
sources += [ "machine_learning_tflite_predictor_unittest.cc" ]
sources += [ "in_process_tflite_predictor_unittest.cc" ]
}
deps = [
......
......@@ -2,18 +2,19 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/services/machine_learning/machine_learning_tflite_predictor.h"
#include "chrome/services/machine_learning/in_process_tflite_predictor.h"
#include "base/check.h"
namespace machine_learning {
TFLitePredictor::TFLitePredictor(std::string filename, int32_t num_threads)
InProcessTFLitePredictor::InProcessTFLitePredictor(std::string filename,
int32_t num_threads)
: model_file_name_(filename), num_threads_(num_threads) {}
TFLitePredictor::~TFLitePredictor() = default;
InProcessTFLitePredictor::~InProcessTFLitePredictor() = default;
TfLiteStatus TFLitePredictor::Initialize() {
TfLiteStatus InProcessTFLitePredictor::Initialize() {
if (!LoadModel())
return kTfLiteError;
if (!BuildInterpreter())
......@@ -24,11 +25,11 @@ TfLiteStatus TFLitePredictor::Initialize() {
return status;
}
TfLiteStatus TFLitePredictor::Evaluate() {
TfLiteStatus InProcessTFLitePredictor::Evaluate() {
return TfLiteInterpreterInvoke(interpreter_.get());
}
bool TFLitePredictor::LoadModel() {
bool InProcessTFLitePredictor::LoadModel() {
if (model_file_name_.empty())
return false;
......@@ -42,7 +43,7 @@ bool TFLitePredictor::LoadModel() {
return true;
}
bool TFLitePredictor::BuildInterpreter() {
bool InProcessTFLitePredictor::BuildInterpreter() {
// We create the pointer using this approach since |TfLiteInterpreterOptions|
// is a structure without the delete operator.
options_ = std::unique_ptr<TfLiteInterpreterOptions,
......@@ -65,73 +66,77 @@ bool TFLitePredictor::BuildInterpreter() {
return true;
}
TfLiteStatus TFLitePredictor::AllocateTensors() {
TfLiteStatus InProcessTFLitePredictor::AllocateTensors() {
TfLiteStatus status = TfLiteInterpreterAllocateTensors(interpreter_.get());
DCHECK(status == kTfLiteOk);
return status;
}
int32_t TFLitePredictor::GetInputTensorCount() const {
int32_t InProcessTFLitePredictor::GetInputTensorCount() const {
if (interpreter_ == nullptr)
return 0;
return TfLiteInterpreterGetInputTensorCount(interpreter_.get());
}
int32_t TFLitePredictor::GetOutputTensorCount() const {
int32_t InProcessTFLitePredictor::GetOutputTensorCount() const {
if (interpreter_ == nullptr)
return 0;
return TfLiteInterpreterGetOutputTensorCount(interpreter_.get());
}
TfLiteTensor* TFLitePredictor::GetInputTensor(int32_t index) const {
TfLiteTensor* InProcessTFLitePredictor::GetInputTensor(int32_t index) const {
if (interpreter_ == nullptr)
return nullptr;
return TfLiteInterpreterGetInputTensor(interpreter_.get(), index);
}
const TfLiteTensor* TFLitePredictor::GetOutputTensor(int32_t index) const {
const TfLiteTensor* InProcessTFLitePredictor::GetOutputTensor(
int32_t index) const {
if (interpreter_ == nullptr)
return nullptr;
return TfLiteInterpreterGetOutputTensor(interpreter_.get(), index);
}
bool TFLitePredictor::IsInitialized() const {
bool InProcessTFLitePredictor::IsInitialized() const {
return initialized_;
}
int32_t TFLitePredictor::GetInputTensorNumDims(int32_t tensor_index) const {
int32_t InProcessTFLitePredictor::GetInputTensorNumDims(
int32_t tensor_index) const {
TfLiteTensor* tensor = GetInputTensor(tensor_index);
return TfLiteTensorNumDims(tensor);
}
int32_t TFLitePredictor::GetInputTensorDim(int32_t tensor_index,
int32_t dim_index) const {
int32_t InProcessTFLitePredictor::GetInputTensorDim(int32_t tensor_index,
int32_t dim_index) const {
TfLiteTensor* tensor = GetInputTensor(tensor_index);
return TfLiteTensorDim(tensor, dim_index);
}
void* TFLitePredictor::GetInputTensorData(int32_t tensor_index) const {
void* InProcessTFLitePredictor::GetInputTensorData(int32_t tensor_index) const {
TfLiteTensor* tensor = GetInputTensor(tensor_index);
return TfLiteTensorData(tensor);
}
int32_t TFLitePredictor::GetOutputTensorNumDims(int32_t tensor_index) const {
int32_t InProcessTFLitePredictor::GetOutputTensorNumDims(
int32_t tensor_index) const {
const TfLiteTensor* tensor = GetOutputTensor(tensor_index);
return TfLiteTensorNumDims(tensor);
}
int32_t TFLitePredictor::GetOutputTensorDim(int32_t tensor_index,
int32_t dim_index) const {
int32_t InProcessTFLitePredictor::GetOutputTensorDim(int32_t tensor_index,
int32_t dim_index) const {
const TfLiteTensor* tensor = GetOutputTensor(tensor_index);
return TfLiteTensorDim(tensor, dim_index);
}
void* TFLitePredictor::GetOutputTensorData(int32_t tensor_index) const {
void* InProcessTFLitePredictor::GetOutputTensorData(
int32_t tensor_index) const {
const TfLiteTensor* tensor = GetInputTensor(tensor_index);
return TfLiteTensorData(tensor);
}
int32_t TFLitePredictor::GetTFLiteNumThreads() const {
int32_t InProcessTFLitePredictor::GetTFLiteNumThreads() const {
return num_threads_;
}
......
......@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_SERVICES_MACHINE_LEARNING_MACHINE_LEARNING_TFLITE_PREDICTOR_H_
#define CHROME_SERVICES_MACHINE_LEARNING_MACHINE_LEARNING_TFLITE_PREDICTOR_H_
#ifndef CHROME_SERVICES_MACHINE_LEARNING_IN_PROCESS_TFLITE_PREDICTOR_H_
#define CHROME_SERVICES_MACHINE_LEARNING_IN_PROCESS_TFLITE_PREDICTOR_H_
#include <functional>
#include <string>
......@@ -20,10 +20,10 @@
namespace machine_learning {
// TFLite predictor class around TFLite C API for TFLite model evaluation.
class TFLitePredictor {
class InProcessTFLitePredictor {
public:
TFLitePredictor(std::string filename, int32_t num_threads);
~TFLitePredictor();
InProcessTFLitePredictor(std::string filename, int32_t num_threads);
~InProcessTFLitePredictor();
// Loads model, build the TFLite interpreter and allocates tensors.
TfLiteStatus Initialize();
......@@ -94,4 +94,4 @@ class TFLitePredictor {
} // namespace machine_learning
#endif // CHROME_SERVICES_MACHINE_LEARNING_MACHINE_LEARNING_TFLITE_PREDICTOR_H_
#endif // CHROME_SERVICES_MACHINE_LEARNING_IN_PROCESS_TFLITE_PREDICTOR_H_
......@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/services/machine_learning/machine_learning_tflite_predictor.h"
#include "chrome/services/machine_learning/in_process_tflite_predictor.h"
#include <string>
......@@ -15,7 +15,7 @@
namespace machine_learning {
class TFLitePredictorTest : public ::testing::Test {
class InProcessTFLitePredictorTest : public ::testing::Test {
public:
const int32_t kTFLiteNumThreads = 4;
const int32_t kInputTensorNum = 1;
......@@ -32,8 +32,8 @@ class TFLitePredictorTest : public ::testing::Test {
const int32_t kOutputTensorDim0 = 1;
const int32_t kOutputTensorDim1 = 10;
TFLitePredictorTest() = default;
~TFLitePredictorTest() override = default;
InProcessTFLitePredictorTest() = default;
~InProcessTFLitePredictorTest() override = default;
// Returns TFLite test model path
std::string GetTFLiteTestPath() {
......@@ -50,18 +50,18 @@ class TFLitePredictorTest : public ::testing::Test {
}
};
TEST_F(TFLitePredictorTest, TFLiteInitializationTest) {
TEST_F(InProcessTFLitePredictorTest, TFLiteInitializationTest) {
// Initialize the model
std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads);
InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
TfLiteStatus status = predictor.Initialize();
EXPECT_EQ(status, kTfLiteOk);
}
TEST_F(TFLitePredictorTest, TFLiteTensorsCountTest) {
TEST_F(InProcessTFLitePredictorTest, TFLiteTensorsCountTest) {
// Initialize the model
std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads);
InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
TfLiteStatus status = predictor.Initialize();
EXPECT_EQ(status, kTfLiteOk);
......@@ -69,10 +69,10 @@ TEST_F(TFLitePredictorTest, TFLiteTensorsCountTest) {
EXPECT_EQ(predictor.GetOutputTensorCount(), kOutputTensorNum);
}
TEST_F(TFLitePredictorTest, TFLiteTensorsTest) {
TEST_F(InProcessTFLitePredictorTest, TFLiteTensorsTest) {
// Initialize the model
std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads);
InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
TfLiteStatus status = predictor.Initialize();
EXPECT_EQ(status, kTfLiteOk);
......@@ -91,7 +91,7 @@ TEST_F(TFLitePredictorTest, TFLiteTensorsTest) {
EXPECT_EQ(TfLiteTensorDim(outputTensor, 1), kOutputTensorDim1);
}
TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) {
TEST_F(InProcessTFLitePredictorTest, TFLiteEvaluationTest) {
int const kOutpuSize = 10;
float expectedOutput[kOutpuSize] = {
-0.4936581, -0.32497078, -0.1705023, -0.38193324, 0.36136785,
......@@ -99,7 +99,7 @@ TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) {
// Initialize the model
std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads);
InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
predictor.Initialize();
// Initialize model input tensor
......@@ -125,10 +125,10 @@ TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) {
EXPECT_NEAR(expectedOutput[i], outputData[i], 1e-5);
}
TEST_F(TFLitePredictorTest, TFLiteInterpreterThreadsSet) {
TEST_F(InProcessTFLitePredictorTest, TFLiteInterpreterThreadsSet) {
// Initialize the model
std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads);
InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
EXPECT_EQ(kTFLiteNumThreads, predictor.GetTFLiteNumThreads());
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment