Commit 8c1305f3 authored by Mehrdad Hessar's avatar Mehrdad Hessar Committed by Commit Bot

This CL adds a switch for TFLite interpreter number of threads.

This changes enables us to set number of threads used by TFLite
predictor and measure performance at various scenarios.

Bug: 1115689
Change-Id: Ib5b0a1315fb426851e0b8d54f6e43128e3303877
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2352213Reviewed-by: default avatarDavid Roger <droger@chromium.org>
Reviewed-by: default avatarMichael Crouse <mcrouse@chromium.org>
Reviewed-by: default avatarRyan Sturm <ryansturm@chromium.org>
Commit-Queue: Mehrdad Hessar <mehrdadh@google.com>
Cr-Commit-Position: refs/heads/master@{#797855}
parent bd181788
file://components/optimization_guide/OWNERS
# COMPONENT: Internals>OptimizationGuide
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
#include "base/optional.h" #include "base/optional.h"
#include "chrome/browser/tflite_experiment/tflite_experiment_switches.h" #include "chrome/browser/tflite_experiment/tflite_experiment_switches.h"
constexpr int32_t kTFLiteNumThreads = 4;
TFLiteExperimentKeyedService::TFLiteExperimentKeyedService( TFLiteExperimentKeyedService::TFLiteExperimentKeyedService(
content::BrowserContext* browser_context) { content::BrowserContext* browser_context) {
base::Optional<std::string> model_path = base::Optional<std::string> model_path =
...@@ -17,7 +15,8 @@ TFLiteExperimentKeyedService::TFLiteExperimentKeyedService( ...@@ -17,7 +15,8 @@ TFLiteExperimentKeyedService::TFLiteExperimentKeyedService(
return; return;
predictor_ = std::make_unique<machine_learning::TFLitePredictor>( predictor_ = std::make_unique<machine_learning::TFLitePredictor>(
model_path.value(), kTFLiteNumThreads); model_path.value(),
tflite_experiment::switches::GetTFLitePredictorNumThreads());
predictor_->Initialize(); predictor_->Initialize();
} }
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include "chrome/browser/tflite_experiment/tflite_experiment_switches.h" #include "chrome/browser/tflite_experiment/tflite_experiment_switches.h"
#include "base/command_line.h" #include "base/command_line.h"
#include "base/strings/string_number_conversions.h"
constexpr int32_t kTFLitePredictorDefaultNumThreads = 4;
namespace tflite_experiment { namespace tflite_experiment {
namespace switches { namespace switches {
...@@ -15,6 +18,9 @@ const char kTFLiteModelPath[] = "tflite-model-path"; ...@@ -15,6 +18,9 @@ const char kTFLiteModelPath[] = "tflite-model-path";
// Specifies the TFLite experiment log file path. // Specifies the TFLite experiment log file path.
const char kTFLiteExperimentLogPath[] = "tflite-experiment-log-path"; const char kTFLiteExperimentLogPath[] = "tflite-experiment-log-path";
// Specifies number of threads used by TFLite predictor.
const char kTFLitePredictorNumThreads[] = "tflite-predictor-num-threads";
base::Optional<std::string> GetTFLiteModelPath() { base::Optional<std::string> GetTFLiteModelPath() {
base::CommandLine* command_line = base::CommandLine::ForCurrentProcess(); base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
if (command_line->HasSwitch(tflite_experiment::switches::kTFLiteModelPath)) { if (command_line->HasSwitch(tflite_experiment::switches::kTFLiteModelPath)) {
...@@ -34,5 +40,23 @@ base::Optional<std::string> GetTFLiteExperimentLogPath() { ...@@ -34,5 +40,23 @@ base::Optional<std::string> GetTFLiteExperimentLogPath() {
return base::nullopt; return base::nullopt;
} }
int32_t GetTFLitePredictorNumThreads() {
base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
if (!command_line->HasSwitch(
tflite_experiment::switches::kTFLitePredictorNumThreads)) {
return kTFLitePredictorDefaultNumThreads;
}
int threads_num;
if (!base::StringToInt(
command_line->GetSwitchValueASCII(
tflite_experiment::switches::kTFLitePredictorNumThreads),
&threads_num)) {
return kTFLitePredictorDefaultNumThreads;
}
return static_cast<int32_t>(threads_num);
}
} // namespace switches } // namespace switches
} // namespace tflite_experiment } // namespace tflite_experiment
...@@ -14,6 +14,7 @@ namespace switches { ...@@ -14,6 +14,7 @@ namespace switches {
extern const char kTFLiteModelPath[]; extern const char kTFLiteModelPath[];
extern const char kTFLiteExperimentLogPath[]; extern const char kTFLiteExperimentLogPath[];
extern const char kTFLitePredictorNumThreads[];
// Returns TFLite model path. // Returns TFLite model path.
base::Optional<std::string> GetTFLiteModelPath(); base::Optional<std::string> GetTFLiteModelPath();
...@@ -21,6 +22,9 @@ base::Optional<std::string> GetTFLiteModelPath(); ...@@ -21,6 +22,9 @@ base::Optional<std::string> GetTFLiteModelPath();
// Returns TFLite experiment log file path. // Returns TFLite experiment log file path.
base::Optional<std::string> GetTFLiteExperimentLogPath(); base::Optional<std::string> GetTFLiteExperimentLogPath();
// Returns TFLite predictor number of threads.
int32_t GetTFLitePredictorNumThreads();
} // namespace switches } // namespace switches
} // namespace tflite_experiment } // namespace tflite_experiment
......
...@@ -131,4 +131,8 @@ void* TFLitePredictor::GetOutputTensorData(int32_t tensor_index) const { ...@@ -131,4 +131,8 @@ void* TFLitePredictor::GetOutputTensorData(int32_t tensor_index) const {
return TfLiteTensorData(tensor); return TfLiteTensorData(tensor);
} }
int32_t TFLitePredictor::GetTFLiteNumThreads() const {
return num_threads_;
}
} // namespace machine_learning } // namespace machine_learning
...@@ -64,6 +64,9 @@ class TFLitePredictor { ...@@ -64,6 +64,9 @@ class TFLitePredictor {
// Returns data pointer to output tensor with index |tensor_index|. // Returns data pointer to output tensor with index |tensor_index|.
void* GetOutputTensorData(int32_t tensor_index) const; void* GetOutputTensorData(int32_t tensor_index) const;
// Returns TFLite interpreter number of threads.
int32_t GetTFLiteNumThreads() const;
private: private:
// Loads TFLite model. // Loads TFLite model.
bool LoadModel(); bool LoadModel();
...@@ -77,7 +80,7 @@ class TFLitePredictor { ...@@ -77,7 +80,7 @@ class TFLitePredictor {
std::string model_file_name_; std::string model_file_name_;
// Number of threads used by |interpreter_| for evaluating |model_|. // Number of threads used by |interpreter_| for evaluating |model_|.
int32_t num_threads_ = 1; int32_t num_threads_;
std::unique_ptr<TfLiteModel, std::function<void(TfLiteModel*)>> model_; std::unique_ptr<TfLiteModel, std::function<void(TfLiteModel*)>> model_;
std::unique_ptr<TfLiteInterpreterOptions, std::unique_ptr<TfLiteInterpreterOptions,
std::function<void(TfLiteInterpreterOptions*)>> std::function<void(TfLiteInterpreterOptions*)>>
......
...@@ -125,4 +125,11 @@ TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) { ...@@ -125,4 +125,11 @@ TEST_F(TFLitePredictorTest, TFLiteEvaluationTest) {
EXPECT_NEAR(expectedOutput[i], outputData[i], 1e-5); EXPECT_NEAR(expectedOutput[i], outputData[i], 1e-5);
} }
TEST_F(TFLitePredictorTest, TFLiteInterpreterThreadsSet) {
// Initialize the model
std::string model_path = GetTFLiteTestPath();
TFLitePredictor predictor(model_path, kTFLiteNumThreads);
EXPECT_EQ(kTFLiteNumThreads, predictor.GetTFLiteNumThreads());
}
} // namespace machine_learning } // namespace machine_learning
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment