Commit 477234d9 authored by Mehrdad Hessar's avatar Mehrdad Hessar Committed by Commit Bot

This CL renames the TFLitePredictor class to InProcessTFLitePredictor.

Bug: 1116626
Change-Id: I3b0542f0b9a670c797e6e61329864f2fb9a8abff
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2358059Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Commit-Queue: Mehrdad Hessar <mehrdadh@google.com>
Cr-Commit-Position: refs/heads/master@{#798877}
parent dc08a996
...@@ -14,7 +14,7 @@ TFLiteExperimentKeyedService::TFLiteExperimentKeyedService( ...@@ -14,7 +14,7 @@ TFLiteExperimentKeyedService::TFLiteExperimentKeyedService(
if (!model_path) if (!model_path)
return; return;
predictor_ = std::make_unique<machine_learning::TFLitePredictor>( predictor_ = std::make_unique<machine_learning::InProcessTFLitePredictor>(
model_path.value(), model_path.value(),
tflite_experiment::switches::GetTFLitePredictorNumThreads()); tflite_experiment::switches::GetTFLitePredictorNumThreads());
predictor_->Initialize(); predictor_->Initialize();
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#ifndef CHROME_BROWSER_TFLITE_EXPERIMENT_TFLITE_EXPERIMENT_KEYED_SERVICE_H_ #ifndef CHROME_BROWSER_TFLITE_EXPERIMENT_TFLITE_EXPERIMENT_KEYED_SERVICE_H_
#define CHROME_BROWSER_TFLITE_EXPERIMENT_TFLITE_EXPERIMENT_KEYED_SERVICE_H_ #define CHROME_BROWSER_TFLITE_EXPERIMENT_TFLITE_EXPERIMENT_KEYED_SERVICE_H_
#include "chrome/services/machine_learning/machine_learning_tflite_predictor.h" #include "chrome/services/machine_learning/in_process_tflite_predictor.h"
#include "components/keyed_service/core/keyed_service.h" #include "components/keyed_service/core/keyed_service.h"
namespace content { namespace content {
...@@ -20,14 +20,14 @@ class TFLiteExperimentKeyedService : public KeyedService { ...@@ -20,14 +20,14 @@ class TFLiteExperimentKeyedService : public KeyedService {
content::BrowserContext* browser_context); content::BrowserContext* browser_context);
~TFLiteExperimentKeyedService() override; ~TFLiteExperimentKeyedService() override;
machine_learning::TFLitePredictor* tflite_predictor() { machine_learning::InProcessTFLitePredictor* tflite_predictor() {
return predictor_.get(); return predictor_.get();
} }
private: private:
// The predictor owned by this keyed service capable of // The predictor owned by this keyed service capable of
// running a TFLite model. // running a TFLite model.
std::unique_ptr<machine_learning::TFLitePredictor> predictor_; std::unique_ptr<machine_learning::InProcessTFLitePredictor> predictor_;
}; };
#endif // CHROME_BROWSER_TFLITE_EXPERIMENT_TFLITE_EXPERIMENT_KEYED_SERVICE_H_ #endif // CHROME_BROWSER_TFLITE_EXPERIMENT_TFLITE_EXPERIMENT_KEYED_SERVICE_H_
...@@ -22,8 +22,8 @@ constexpr int32_t kTFLitePredictorEvaluationLoop = 10; ...@@ -22,8 +22,8 @@ constexpr int32_t kTFLitePredictorEvaluationLoop = 10;
namespace { namespace {
// Returns the TFLitePredictor. // Returns the InProcessTFLitePredictor.
machine_learning::TFLitePredictor* GetTFLitePredictorFromWebContents( machine_learning::InProcessTFLitePredictor* GetTFLitePredictorFromWebContents(
content::WebContents* web_contents) { content::WebContents* web_contents) {
if (!web_contents) if (!web_contents)
return nullptr; return nullptr;
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "base/macros.h" #include "base/macros.h"
#include "base/timer/timer.h" #include "base/timer/timer.h"
#include "chrome/services/machine_learning/machine_learning_tflite_predictor.h" #include "chrome/services/machine_learning/in_process_tflite_predictor.h"
#include "content/public/browser/web_contents_observer.h" #include "content/public/browser/web_contents_observer.h"
#include "content/public/browser/web_contents_user_data.h" #include "content/public/browser/web_contents_user_data.h"
...@@ -49,7 +49,7 @@ class TFLiteExperimentObserver ...@@ -49,7 +49,7 @@ class TFLiteExperimentObserver
const std::string&); const std::string&);
// The predictor is capable of running a TFLite model. // The predictor is capable of running a TFLite model.
machine_learning::TFLitePredictor* tflite_predictor_ = nullptr; machine_learning::InProcessTFLitePredictor* tflite_predictor_ = nullptr;
// True when |tflite_predictor_| ran model evaluation. It forces // True when |tflite_predictor_| ran model evaluation. It forces
// the observer to run tflite prediction only once. // the observer to run tflite prediction only once.
......
...@@ -18,8 +18,8 @@ source_set("machine_learning") { ...@@ -18,8 +18,8 @@ source_set("machine_learning") {
if (build_with_tflite_lib) { if (build_with_tflite_lib) {
sources += [ sources += [
"machine_learning_tflite_predictor.cc", "in_process_tflite_predictor.cc",
"machine_learning_tflite_predictor.h", "in_process_tflite_predictor.h",
] ]
deps += [ deps += [
...@@ -52,7 +52,7 @@ source_set("unit_tests") { ...@@ -52,7 +52,7 @@ source_set("unit_tests") {
] ]
if (build_with_tflite_lib) { if (build_with_tflite_lib) {
sources += [ "machine_learning_tflite_predictor_unittest.cc" ] sources += [ "in_process_tflite_predictor_unittest.cc" ]
} }
deps = [ deps = [
......
...@@ -2,18 +2,19 @@ ...@@ -2,18 +2,19 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#include "chrome/services/machine_learning/machine_learning_tflite_predictor.h" #include "chrome/services/machine_learning/in_process_tflite_predictor.h"
#include "base/check.h" #include "base/check.h"
namespace machine_learning { namespace machine_learning {
TFLitePredictor::TFLitePredictor(std::string filename, int32_t num_threads) InProcessTFLitePredictor::InProcessTFLitePredictor(std::string filename,
int32_t num_threads)
: model_file_name_(filename), num_threads_(num_threads) {} : model_file_name_(filename), num_threads_(num_threads) {}
TFLitePredictor::~TFLitePredictor() = default; InProcessTFLitePredictor::~InProcessTFLitePredictor() = default;
TfLiteStatus TFLitePredictor::Initialize() { TfLiteStatus InProcessTFLitePredictor::Initialize() {
if (!LoadModel()) if (!LoadModel())
return kTfLiteError; return kTfLiteError;
if (!BuildInterpreter()) if (!BuildInterpreter())
...@@ -24,11 +25,11 @@ TfLiteStatus TFLitePredictor::Initialize() { ...@@ -24,11 +25,11 @@ TfLiteStatus TFLitePredictor::Initialize() {
return status; return status;
} }
TfLiteStatus TFLitePredictor::Evaluate() { TfLiteStatus InProcessTFLitePredictor::Evaluate() {
return TfLiteInterpreterInvoke(interpreter_.get()); return TfLiteInterpreterInvoke(interpreter_.get());
} }
bool TFLitePredictor::LoadModel() { bool InProcessTFLitePredictor::LoadModel() {
if (model_file_name_.empty()) if (model_file_name_.empty())
return false; return false;
...@@ -42,7 +43,7 @@ bool TFLitePredictor::LoadModel() { ...@@ -42,7 +43,7 @@ bool TFLitePredictor::LoadModel() {
return true; return true;
} }
bool TFLitePredictor::BuildInterpreter() { bool InProcessTFLitePredictor::BuildInterpreter() {
// We create the pointer using this approach since |TfLiteInterpreterOptions| // We create the pointer using this approach since |TfLiteInterpreterOptions|
// is a structure without the delete operator. // is a structure without the delete operator.
options_ = std::unique_ptr<TfLiteInterpreterOptions, options_ = std::unique_ptr<TfLiteInterpreterOptions,
...@@ -65,73 +66,77 @@ bool TFLitePredictor::BuildInterpreter() { ...@@ -65,73 +66,77 @@ bool TFLitePredictor::BuildInterpreter() {
return true; return true;
} }
TfLiteStatus TFLitePredictor::AllocateTensors() { TfLiteStatus InProcessTFLitePredictor::AllocateTensors() {
TfLiteStatus status = TfLiteInterpreterAllocateTensors(interpreter_.get()); TfLiteStatus status = TfLiteInterpreterAllocateTensors(interpreter_.get());
DCHECK(status == kTfLiteOk); DCHECK(status == kTfLiteOk);
return status; return status;
} }
int32_t TFLitePredictor::GetInputTensorCount() const { int32_t InProcessTFLitePredictor::GetInputTensorCount() const {
if (interpreter_ == nullptr) if (interpreter_ == nullptr)
return 0; return 0;
return TfLiteInterpreterGetInputTensorCount(interpreter_.get()); return TfLiteInterpreterGetInputTensorCount(interpreter_.get());
} }
int32_t TFLitePredictor::GetOutputTensorCount() const { int32_t InProcessTFLitePredictor::GetOutputTensorCount() const {
if (interpreter_ == nullptr) if (interpreter_ == nullptr)
return 0; return 0;
return TfLiteInterpreterGetOutputTensorCount(interpreter_.get()); return TfLiteInterpreterGetOutputTensorCount(interpreter_.get());
} }
TfLiteTensor* TFLitePredictor::GetInputTensor(int32_t index) const { TfLiteTensor* InProcessTFLitePredictor::GetInputTensor(int32_t index) const {
if (interpreter_ == nullptr) if (interpreter_ == nullptr)
return nullptr; return nullptr;
return TfLiteInterpreterGetInputTensor(interpreter_.get(), index); return TfLiteInterpreterGetInputTensor(interpreter_.get(), index);
} }
const TfLiteTensor* TFLitePredictor::GetOutputTensor(int32_t index) const { const TfLiteTensor* InProcessTFLitePredictor::GetOutputTensor(
int32_t index) const {
if (interpreter_ == nullptr) if (interpreter_ == nullptr)
return nullptr; return nullptr;
return TfLiteInterpreterGetOutputTensor(interpreter_.get(), index); return TfLiteInterpreterGetOutputTensor(interpreter_.get(), index);
} }
bool TFLitePredictor::IsInitialized() const { bool InProcessTFLitePredictor::IsInitialized() const {
return initialized_; return initialized_;
} }
int32_t TFLitePredictor::GetInputTensorNumDims(int32_t tensor_index) const { int32_t InProcessTFLitePredictor::GetInputTensorNumDims(
int32_t tensor_index) const {
TfLiteTensor* tensor = GetInputTensor(tensor_index); TfLiteTensor* tensor = GetInputTensor(tensor_index);
return TfLiteTensorNumDims(tensor); return TfLiteTensorNumDims(tensor);
} }
int32_t TFLitePredictor::GetInputTensorDim(int32_t tensor_index, int32_t InProcessTFLitePredictor::GetInputTensorDim(int32_t tensor_index,
int32_t dim_index) const { int32_t dim_index) const {
TfLiteTensor* tensor = GetInputTensor(tensor_index); TfLiteTensor* tensor = GetInputTensor(tensor_index);
return TfLiteTensorDim(tensor, dim_index); return TfLiteTensorDim(tensor, dim_index);
} }
void* TFLitePredictor::GetInputTensorData(int32_t tensor_index) const { void* InProcessTFLitePredictor::GetInputTensorData(int32_t tensor_index) const {
TfLiteTensor* tensor = GetInputTensor(tensor_index); TfLiteTensor* tensor = GetInputTensor(tensor_index);
return TfLiteTensorData(tensor); return TfLiteTensorData(tensor);
} }
int32_t TFLitePredictor::GetOutputTensorNumDims(int32_t tensor_index) const { int32_t InProcessTFLitePredictor::GetOutputTensorNumDims(
int32_t tensor_index) const {
const TfLiteTensor* tensor = GetOutputTensor(tensor_index); const TfLiteTensor* tensor = GetOutputTensor(tensor_index);
return TfLiteTensorNumDims(tensor); return TfLiteTensorNumDims(tensor);
} }
int32_t TFLitePredictor::GetOutputTensorDim(int32_t tensor_index, int32_t InProcessTFLitePredictor::GetOutputTensorDim(int32_t tensor_index,
int32_t dim_index) const { int32_t dim_index) const {
const TfLiteTensor* tensor = GetOutputTensor(tensor_index); const TfLiteTensor* tensor = GetOutputTensor(tensor_index);
return TfLiteTensorDim(tensor, dim_index); return TfLiteTensorDim(tensor, dim_index);
} }
void* TFLitePredictor::GetOutputTensorData(int32_t tensor_index) const { void* InProcessTFLitePredictor::GetOutputTensorData(
int32_t tensor_index) const {
const TfLiteTensor* tensor = GetInputTensor(tensor_index); const TfLiteTensor* tensor = GetInputTensor(tensor_index);
return TfLiteTensorData(tensor); return TfLiteTensorData(tensor);
} }
int32_t TFLitePredictor::GetTFLiteNumThreads() const { int32_t InProcessTFLitePredictor::GetTFLiteNumThreads() const {
return num_threads_; return num_threads_;
} }
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#ifndef CHROME_SERVICES_MACHINE_LEARNING_MACHINE_LEARNING_TFLITE_PREDICTOR_H_ #ifndef CHROME_SERVICES_MACHINE_LEARNING_IN_PROCESS_TFLITE_PREDICTOR_H_
#define CHROME_SERVICES_MACHINE_LEARNING_MACHINE_LEARNING_TFLITE_PREDICTOR_H_ #define CHROME_SERVICES_MACHINE_LEARNING_IN_PROCESS_TFLITE_PREDICTOR_H_
#include <functional> #include <functional>
#include <string> #include <string>
...@@ -20,10 +20,10 @@ ...@@ -20,10 +20,10 @@
namespace machine_learning { namespace machine_learning {
// TFLite predictor class around TFLite C API for TFLite model evaluation. // TFLite predictor class around TFLite C API for TFLite model evaluation.
class TFLitePredictor { class InProcessTFLitePredictor {
public: public:
TFLitePredictor(std::string filename, int32_t num_threads); InProcessTFLitePredictor(std::string filename, int32_t num_threads);
~TFLitePredictor(); ~InProcessTFLitePredictor();
// Loads model, build the TFLite interpreter and allocates tensors. // Loads model, build the TFLite interpreter and allocates tensors.
TfLiteStatus Initialize(); TfLiteStatus Initialize();
...@@ -94,4 +94,4 @@ class TFLitePredictor { ...@@ -94,4 +94,4 @@ class TFLitePredictor {
} // namespace machine_learning } // namespace machine_learning
#endif // CHROME_SERVICES_MACHINE_LEARNING_MACHINE_LEARNING_TFLITE_PREDICTOR_H_ #endif // CHROME_SERVICES_MACHINE_LEARNING_IN_PROCESS_TFLITE_PREDICTOR_H_
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#include "chrome/services/machine_learning/machine_learning_tflite_predictor.h" #include "chrome/services/machine_learning/in_process_tflite_predictor.h"
#include <string> #include <string>
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
namespace machine_learning { namespace machine_learning {
class TFLitePredictorTest : public ::testing::Test { class InProcessTFLitePredictorTest : public ::testing::Test {
public: public:
const int32_t kTFLiteNumThreads = 4; const int32_t kTFLiteNumThreads = 4;
const int32_t kInputTensorNum = 1; const int32_t kInputTensorNum = 1;
...@@ -32,8 +32,8 @@ class TFLitePredictorTest : public ::testing::Test { ...@@ -32,8 +32,8 @@ class TFLitePredictorTest : public ::testing::Test {
const int32_t kOutputTensorDim0 = 1; const int32_t kOutputTensorDim0 = 1;
const int32_t kOutputTensorDim1 = 10; const int32_t kOutputTensorDim1 = 10;
TFLitePredictorTest() = default; InProcessTFLitePredictorTest() = default;
~TFLitePredictorTest() override = default; ~InProcessTFLitePredictorTest() override = default;
// Returns TFLite test model path // Returns TFLite test model path
std::string GetTFLiteTestPath() { std::string GetTFLiteTestPath() {
...@@ -50,18 +50,18 @@ class TFLitePredictorTest : public ::testing::Test { ...@@ -50,18 +50,18 @@ class TFLitePredictorTest : public ::testing::Test {
} }
}; };
TEST_F(TFLitePredictorTest, TFLiteInitializationTest) { TEST_F(InProcessTFLitePredictorTest, TFLiteInitializationTest) {
// Initialize the model // Initialize the model
std::string model_path = GetTFLiteTestPath(); std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads); InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
TfLiteStatus status = predictor.Initialize(); TfLiteStatus status = predictor.Initialize();
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
} }
TEST_F(TFLitePredictorTest, TFLiteTensorsCountTest) { TEST_F(InProcessTFLitePredictorTest, TFLiteTensorsCountTest) {
// Initialize the model // Initialize the model
std::string model_path = GetTFLiteTestPath(); std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads); InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
TfLiteStatus status = predictor.Initialize(); TfLiteStatus status = predictor.Initialize();
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
...@@ -69,10 +69,10 @@ TEST_F(TFLitePredictorTest, TFLiteTensorsCountTest) { ...@@ -69,10 +69,10 @@ TEST_F(TFLitePredictorTest, TFLiteTensorsCountTest) {
EXPECT_EQ(predictor.GetOutputTensorCount(), kOutputTensorNum); EXPECT_EQ(predictor.GetOutputTensorCount(), kOutputTensorNum);
} }
TEST_F(TFLitePredictorTest, TFLiteTensorsTest) { TEST_F(InProcessTFLitePredictorTest, TFLiteTensorsTest) {
// Initialize the model // Initialize the model
std::string model_path = GetTFLiteTestPath(); std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads); InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
TfLiteStatus status = predictor.Initialize(); TfLiteStatus status = predictor.Initialize();
EXPECT_EQ(status, kTfLiteOk); EXPECT_EQ(status, kTfLiteOk);
...@@ -91,7 +91,7 @@ TEST_F(TFLitePredictorTest, TFLiteTensorsTest) { ...@@ -91,7 +91,7 @@ TEST_F(TFLitePredictorTest, TFLiteTensorsTest) {
EXPECT_EQ(TfLiteTensorDim(outputTensor, 1), kOutputTensorDim1); EXPECT_EQ(TfLiteTensorDim(outputTensor, 1), kOutputTensorDim1);
} }
TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) { TEST_F(InProcessTFLitePredictorTest, TFLiteEvaluationTest) {
int const kOutpuSize = 10; int const kOutpuSize = 10;
float expectedOutput[kOutpuSize] = { float expectedOutput[kOutpuSize] = {
-0.4936581, -0.32497078, -0.1705023, -0.38193324, 0.36136785, -0.4936581, -0.32497078, -0.1705023, -0.38193324, 0.36136785,
...@@ -99,7 +99,7 @@ TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) { ...@@ -99,7 +99,7 @@ TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) {
// Initialize the model // Initialize the model
std::string model_path = GetTFLiteTestPath(); std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads); InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
predictor.Initialize(); predictor.Initialize();
// Initialize model input tensor // Initialize model input tensor
...@@ -125,10 +125,10 @@ TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) { ...@@ -125,10 +125,10 @@ TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) {
EXPECT_NEAR(expectedOutput[i], outputData[i], 1e-5); EXPECT_NEAR(expectedOutput[i], outputData[i], 1e-5);
} }
TEST_F(TFLitePredictorTest, TFLiteInterpreterThreadsSet) { TEST_F(InProcessTFLitePredictorTest, TFLiteInterpreterThreadsSet) {
// Initialize the model // Initialize the model
std::string model_path = GetTFLiteTestPath(); std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads); InProcessTFLitePredictor predictor(model_path, kTFLiteNumThreads);
EXPECT_EQ(kTFLiteNumThreads, predictor.GetTFLiteNumThreads()); EXPECT_EQ(kTFLiteNumThreads, predictor.GetTFLiteNumThreads());
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment