Commit 620219f7 authored by Michael Crouse's avatar Michael Crouse Committed by Chromium LUCI CQ

[TFLite] Fix test for windows and create custom resolver.

This fixes some path and test data issues that cause tests to fail
when enabling TFLite in the build.

This also creates a TFLite resolver for Chrome with the set of ops
to be supported initially.

Bug: 1165517
Change-Id: Ib0acdcb380578e94f4b7172b5ebe27ab6faafc25
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2631186Reviewed-by: default avatarColin Blundell <blundell@chromium.org>
Reviewed-by: default avatarSophie Chang <sophiechang@chromium.org>
Commit-Queue: Michael Crouse <mcrouse@chromium.org>
Cr-Commit-Position: refs/heads/master@{#844164}
parent 840abe8b
...@@ -2468,7 +2468,7 @@ static_library("browser") { ...@@ -2468,7 +2468,7 @@ static_library("browser") {
"tflite_experiment/tflite_experiment_switches.h", "tflite_experiment/tflite_experiment_switches.h",
] ]
deps += [ "//chrome/services/machine_learning" ] public_deps += [ "//chrome/services/machine_learning" ]
} }
# Platforms that have a network diagnostics dialog. All others fall through # Platforms that have a network diagnostics dialog. All others fall through
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "chrome/browser/tflite_experiment/tflite_experiment_keyed_service.h" #include "chrome/browser/tflite_experiment/tflite_experiment_keyed_service.h"
#include "base/command_line.h" #include "base/command_line.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h" #include "base/files/file_util.h"
#include "base/json/json_reader.h" #include "base/json/json_reader.h"
#include "base/path_service.h" #include "base/path_service.h"
...@@ -25,9 +26,7 @@ ...@@ -25,9 +26,7 @@
#include "content/public/test/browser_test_utils.h" #include "content/public/test/browser_test_utils.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
constexpr char kTFLiteModelName[] = "simple_test.tflite";
constexpr char kNavigationURL[] = "https://google.com"; constexpr char kNavigationURL[] = "https://google.com";
constexpr char kTFLiteExperimentLogName[] = "tflite_experiment.log";
namespace { namespace {
// Fetch and calculate the total number of samples from all the bins for // Fetch and calculate the total number of samples from all the bins for
...@@ -90,16 +89,6 @@ IN_PROC_BROWSER_TEST_F(TFLiteExperimentKeyedServiceDisabledBrowserTest, ...@@ -90,16 +89,6 @@ IN_PROC_BROWSER_TEST_F(TFLiteExperimentKeyedServiceDisabledBrowserTest,
EXPECT_FALSE(tflite_experiment_keyed_service->tflite_predictor()); EXPECT_FALSE(tflite_experiment_keyed_service->tflite_predictor());
} }
IN_PROC_BROWSER_TEST_F(
TFLiteExperimentKeyedServiceDisabledBrowserTest,
TFLiteExperimentEnabledButTFLitePredictorDisabledOnNavigation) {
GURL navigation_url(kNavigationURL);
ui_test_utils::NavigateToURL(browser(), navigation_url);
WaitForTFLiteObserverToCallNullTFLitePredictor();
histogram_tester()->ExpectUniqueSample(
"TFLiteExperiment.Observer.TFLitePredictor.Null", true, 1);
}
class TFLiteExperimentKeyedServiceBrowserTest : public InProcessBrowserTest { class TFLiteExperimentKeyedServiceBrowserTest : public InProcessBrowserTest {
public: public:
TFLiteExperimentKeyedServiceBrowserTest() = default; TFLiteExperimentKeyedServiceBrowserTest() = default;
...@@ -112,21 +101,21 @@ class TFLiteExperimentKeyedServiceBrowserTest : public InProcessBrowserTest { ...@@ -112,21 +101,21 @@ class TFLiteExperimentKeyedServiceBrowserTest : public InProcessBrowserTest {
// Set TFLite model path. // Set TFLite model path.
base::PathService::Get(chrome::DIR_TEST_DATA, &g_test_data_directory); base::PathService::Get(chrome::DIR_TEST_DATA, &g_test_data_directory);
g_test_data_directory = g_test_data_directory =
g_test_data_directory.Append(FILE_PATH_LITERAL(kTFLiteModelName)); g_test_data_directory.Append(FILE_PATH_LITERAL("simple_test.tflite"));
cmd->AppendSwitchASCII(tflite_experiment::switches::kTFLiteModelPath, cmd->AppendSwitchASCII(tflite_experiment::switches::kTFLiteModelPath,
g_test_data_directory.value()); g_test_data_directory.MaybeAsASCII());
// Set TFLite experiment log path. // Set TFLite experiment log path.
cmd->AppendSwitchASCII( cmd->AppendSwitchASCII(
tflite_experiment::switches::kTFLiteExperimentLogPath, tflite_experiment::switches::kTFLiteExperimentLogPath,
GetTFLiteExperimentLogPath().value()); GetTFLiteExperimentLogPath().MaybeAsASCII());
} }
base::FilePath GetTFLiteExperimentLogPath() { base::FilePath GetTFLiteExperimentLogPath() {
base::FilePath g_test_data_directory; base::FilePath g_test_data_directory;
base::PathService::Get(chrome::DIR_TEST_DATA, &g_test_data_directory); base::PathService::Get(chrome::DIR_TEST_DATA, &g_test_data_directory);
g_test_data_directory = g_test_data_directory = g_test_data_directory.Append(
g_test_data_directory.Append(kTFLiteExperimentLogName); FILE_PATH_LITERAL("tflite_experiment.log"));
return g_test_data_directory; return g_test_data_directory;
} }
......
...@@ -27,6 +27,8 @@ source_set("machine_learning") { ...@@ -27,6 +27,8 @@ source_set("machine_learning") {
if (build_with_tflite_lib) { if (build_with_tflite_lib) {
sources += [ sources += [
"chrome_tflite_op_resolver.cc",
"chrome_tflite_op_resolver.h",
"in_process_tflite_predictor.cc", "in_process_tflite_predictor.cc",
"in_process_tflite_predictor.h", "in_process_tflite_predictor.h",
] ]
...@@ -65,6 +67,8 @@ source_set("unit_tests") { ...@@ -65,6 +67,8 @@ source_set("unit_tests") {
sources += [ "in_process_tflite_predictor_unittest.cc" ] sources += [ "in_process_tflite_predictor_unittest.cc" ]
} }
data = [ "//chrome/test/data/" ]
deps = [ deps = [
":machine_learning", ":machine_learning",
":metrics", ":metrics",
......
This diff is collapsed.
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_SERVICES_MACHINE_LEARNING_CHROME_TFLITE_OP_RESOLVER_H_
#define CHROME_SERVICES_MACHINE_LEARNING_CHROME_TFLITE_OP_RESOLVER_H_
#include "third_party/tflite/src/tensorflow/lite/model.h"
#include "third_party/tflite/src/tensorflow/lite/mutable_op_resolver.h"
namespace machine_learning {
// This class maintains all the currently supported TFLite
// operations for Chrome and registers them for use.
class ChromeTFLiteOpResolver : public tflite::MutableOpResolver {
public:
ChromeTFLiteOpResolver();
};
} // namespace machine_learning
#endif // CHROME_SERVICES_MACHINE_LEARNING_CHROME_TFLITE_OP_RESOLVER_H_
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#include "chrome/services/machine_learning/in_process_tflite_predictor.h" #include "chrome/services/machine_learning/in_process_tflite_predictor.h"
#include "base/check.h" #include "base/check.h"
#include "chrome/services/machine_learning/chrome_tflite_op_resolver.h"
#include "third_party/tflite/src/tensorflow/lite/interpreter.h" #include "third_party/tflite/src/tensorflow/lite/interpreter.h"
#include "third_party/tflite/src/tensorflow/lite/kernels/register.h"
#include "third_party/tflite/src/tensorflow/lite/model.h" #include "third_party/tflite/src/tensorflow/lite/model.h"
namespace machine_learning { namespace machine_learning {
...@@ -46,7 +46,7 @@ bool InProcessTFLitePredictor::LoadModel() { ...@@ -46,7 +46,7 @@ bool InProcessTFLitePredictor::LoadModel() {
} }
bool InProcessTFLitePredictor::BuildInterpreter() { bool InProcessTFLitePredictor::BuildInterpreter() {
tflite::ops::builtin::BuiltinOpResolver resolver; ChromeTFLiteOpResolver resolver;
tflite::InterpreterBuilder builder(*model_, resolver); tflite::InterpreterBuilder builder(*model_, resolver);
if (builder(&interpreter_, num_threads_) != kTfLiteOk || !interpreter_) if (builder(&interpreter_, num_threads_) != kTfLiteOk || !interpreter_)
......
...@@ -51,7 +51,7 @@ class InProcessTFLitePredictorTest : public ::testing::Test { ...@@ -51,7 +51,7 @@ class InProcessTFLitePredictorTest : public ::testing::Test {
.Append(FILE_PATH_LITERAL("data")) .Append(FILE_PATH_LITERAL("data"))
.Append(FILE_PATH_LITERAL("simple_test.tflite")); .Append(FILE_PATH_LITERAL("simple_test.tflite"));
EXPECT_TRUE(base::PathExists(model_file_path)); EXPECT_TRUE(base::PathExists(model_file_path));
return model_file_path.value().c_str(); return model_file_path.AsUTF8Unsafe();
} }
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment