Commit c5f5ffb7 authored by Charles Zhao's avatar Charles Zhao Committed by Commit Bot

Fix HandwritingModelLoader.

(1) calling HandwritingModelLoader(spec, receiver, callback).Load() is
a bad idea because the .Load() is async and the temporal object could be
released very soon. This Cl changes the whole class to be stateless, so
that calling HandwritingModelLoader::Load(spec, receiver, callback) won't suffer the same problem.

(2) unit tests are fixed accordingly.

Bug: 1054628
Change-Id: I56459b14e1fc4a6608d5ed264e6c650d1e076f69
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2415914
Commit-Queue: Charles . <charleszhao@chromium.org>
Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Reviewed-by: default avatarCharles . <charleszhao@chromium.org>
Reviewed-by: default avatarXinglong Luan <alanlxl@chromium.org>
Cr-Commit-Position: refs/heads/master@{#810102}
parent 08c26569
...@@ -13,7 +13,12 @@ namespace chromeos { ...@@ -13,7 +13,12 @@ namespace chromeos {
namespace machine_learning { namespace machine_learning {
namespace { namespace {
using chromeos::machine_learning::mojom::LoadHandwritingModelResult; using ::chromeos::machine_learning::mojom::HandwritingRecognizerSpecPtr;
using ::chromeos::machine_learning::mojom::LoadHandwritingModelResult;
using HandwritingRecognizer =
mojo::PendingReceiver<mojom::HandwritingRecognizer>;
using LoadHandwritingModelCallback = ::chromeos::machine_learning::mojom::
MachineLearningService::LoadHandwritingModelCallback;
// Records CrOSActionRecorder event. // Records CrOSActionRecorder event.
void RecordLoadHandwritingModelResult(const LoadHandwritingModelResult val) { void RecordLoadHandwritingModelResult(const LoadHandwritingModelResult val) {
...@@ -22,6 +27,8 @@ void RecordLoadHandwritingModelResult(const LoadHandwritingModelResult val) { ...@@ -22,6 +27,8 @@ void RecordLoadHandwritingModelResult(const LoadHandwritingModelResult val) {
LoadHandwritingModelResult::LOAD_MODEL_FILES_ERROR); LoadHandwritingModelResult::LOAD_MODEL_FILES_ERROR);
} }
constexpr char kOndeviceHandwritingSwitch[] = "ondevice_handwriting";
constexpr char kLibHandwritingDlcId[] = "libhandwriting";
// A list of supported language code. // A list of supported language code.
constexpr char kLanguageCodeEn[] = "en"; constexpr char kLanguageCodeEn[] = "en";
constexpr char kLanguageCodeGesture[] = "gesture_in_context"; constexpr char kLanguageCodeGesture[] = "gesture_in_context";
...@@ -30,10 +37,8 @@ constexpr char kLanguageCodeGesture[] = "gesture_in_context"; ...@@ -30,10 +37,8 @@ constexpr char kLanguageCodeGesture[] = "gesture_in_context";
// kOndeviceHandwritingSwitch. // kOndeviceHandwritingSwitch.
bool HandwritingSwitchHasValue(const std::string& value) { bool HandwritingSwitchHasValue(const std::string& value) {
base::CommandLine* command_line = base::CommandLine::ForCurrentProcess(); base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
return command_line->HasSwitch( return command_line->HasSwitch(kOndeviceHandwritingSwitch) &&
HandwritingModelLoader::kOndeviceHandwritingSwitch) && command_line->GetSwitchValueASCII(kOndeviceHandwritingSwitch) == value;
command_line->GetSwitchValueASCII(
HandwritingModelLoader::kOndeviceHandwritingSwitch) == value;
} }
// Returns true if switch kOndeviceHandwritingSwitch is set to use_rootfs. // Returns true if switch kOndeviceHandwritingSwitch is set to use_rootfs.
...@@ -46,69 +51,45 @@ bool IsLibHandwritingDlcEnabled() { ...@@ -46,69 +51,45 @@ bool IsLibHandwritingDlcEnabled() {
return HandwritingSwitchHasValue("use_dlc"); return HandwritingSwitchHasValue("use_dlc");
} }
} // namespace // Called when InstallDlc completes.
// Returns an error if the `result.error` is not dlcservice::kErrorNone.
constexpr char HandwritingModelLoader::kOndeviceHandwritingSwitch[]; // Calls mlservice to LoadHandwritingModel otherwise.
constexpr char HandwritingModelLoader::kLibHandwritingDlcId[]; void OnInstallDlcComplete(
HandwritingRecognizerSpecPtr spec,
HandwritingModelLoader::HandwritingModelLoader( HandwritingRecognizer receiver,
mojom::HandwritingRecognizerSpecPtr spec, LoadHandwritingModelCallback callback,
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver, const chromeos::DlcserviceClient::InstallResult& result) {
mojom::MachineLearningService::LoadHandwritingModelCallback callback) // Call LoadHandwritingModelWithSpec if no error was found.
: dlc_client_(chromeos::DlcserviceClient::Get()), if (result.error == dlcservice::kErrorNone) {
spec_(std::move(spec)),
receiver_(std::move(receiver)),
callback_(std::move(callback)),
weak_ptr_factory_(this) {}
HandwritingModelLoader::~HandwritingModelLoader() = default;
void HandwritingModelLoader::Load() {
// Returns FEATURE_NOT_SUPPORTED_ERROR if both rootfs and dlc are not enabled.
if (!IsLibHandwritingRootfsEnabled() && !IsLibHandwritingDlcEnabled()) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
std::move(callback_).Run(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
return;
}
// Returns LANGUAGE_NOT_SUPPORTED_ERROR if the language is not supported yet.
if (spec_->language != kLanguageCodeEn &&
spec_->language != kLanguageCodeGesture) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
std::move(callback_).Run(
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
return;
}
// Load from rootfs if enabled.
if (IsLibHandwritingRootfsEnabled()) {
ServiceConnection::GetInstance()->LoadHandwritingModel( ServiceConnection::GetInstance()->LoadHandwritingModel(
std::move(spec_), std::move(receiver_), std::move(callback_)); std::move(spec), std::move(receiver), std::move(callback));
return; return;
} }
// Gets existing dlc list and based on the presence of libhandwriting RecordLoadHandwritingModelResult(
// either returns an error or installs the libhandwriting dlc. LoadHandwritingModelResult::DLC_INSTALL_ERROR);
dlc_client_->GetExistingDlcs( std::move(callback).Run(LoadHandwritingModelResult::DLC_INSTALL_ERROR);
base::BindOnce(&HandwritingModelLoader::OnGetExistingDlcsComplete,
weak_ptr_factory_.GetWeakPtr()));
} }
void HandwritingModelLoader::OnGetExistingDlcsComplete( // Called when the existing-dlc-list is returned.
// Returns an error if libhandwriting is not in the existing-dlc-list.
// Calls InstallDlc otherwise.
void OnGetExistingDlcsComplete(
HandwritingRecognizerSpecPtr spec,
HandwritingRecognizer receiver,
LoadHandwritingModelCallback callback,
DlcserviceClient* const dlc_client,
const std::string& err, const std::string& err,
const dlcservice::DlcsWithContent& dlcs_with_content) { const dlcservice::DlcsWithContent& dlcs_with_content) {
// Loop over dlcs_with_content, and installs libhandwriting if already exists. // Loop over dlcs_with_content, and installs libhandwriting if already exists.
// Since we don't want to trigger downloading here, we only install(mount) // Since we don't want to trigger downloading here, we only install(mount)
// the handwriting dlc if it is already on device. // the handwriting dlc if it is already on device.
for (const auto& dlc_info : dlcs_with_content.dlc_infos()) { for (const auto& dlc_info : dlcs_with_content.dlc_infos()) {
if (dlc_info.id() == HandwritingModelLoader::kLibHandwritingDlcId) { if (dlc_info.id() == kLibHandwritingDlcId) {
dlc_client_->Install( dlc_client->Install(
kLibHandwritingDlcId, kLibHandwritingDlcId,
base::BindOnce(&HandwritingModelLoader::OnInstallDlcComplete, base::BindOnce(&OnInstallDlcComplete, std::move(spec),
weak_ptr_factory_.GetWeakPtr()), std::move(receiver), std::move(callback)),
chromeos::DlcserviceClient::IgnoreProgress); chromeos::DlcserviceClient::IgnoreProgress);
return; return;
} }
...@@ -117,21 +98,46 @@ void HandwritingModelLoader::OnGetExistingDlcsComplete( ...@@ -117,21 +98,46 @@ void HandwritingModelLoader::OnGetExistingDlcsComplete(
// Returns error if the handwriting dlc is not on the device. // Returns error if the handwriting dlc is not on the device.
RecordLoadHandwritingModelResult( RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::DLC_DOES_NOT_EXIST); LoadHandwritingModelResult::DLC_DOES_NOT_EXIST);
std::move(callback_).Run(LoadHandwritingModelResult::DLC_DOES_NOT_EXIST); std::move(callback).Run(LoadHandwritingModelResult::DLC_DOES_NOT_EXIST);
} }
void HandwritingModelLoader::OnInstallDlcComplete( } // namespace
const chromeos::DlcserviceClient::InstallResult& result) {
// Call LoadHandwritingModelWithSpec if no error was found. void LoadHandwritingModelFromRootfsOrDlc(HandwritingRecognizerSpecPtr spec,
if (result.error == dlcservice::kErrorNone) { HandwritingRecognizer receiver,
LoadHandwritingModelCallback callback,
DlcserviceClient* const dlc_client) {
// Returns FEATURE_NOT_SUPPORTED_ERROR if both rootfs and dlc are not enabled.
if (!IsLibHandwritingRootfsEnabled() && !IsLibHandwritingDlcEnabled()) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
std::move(callback).Run(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
return;
}
// Returns LANGUAGE_NOT_SUPPORTED_ERROR if the language is not supported yet.
if (spec->language != kLanguageCodeEn &&
spec->language != kLanguageCodeGesture) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
std::move(callback).Run(
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
return;
}
// Load from rootfs if enabled.
if (IsLibHandwritingRootfsEnabled()) {
ServiceConnection::GetInstance()->LoadHandwritingModel( ServiceConnection::GetInstance()->LoadHandwritingModel(
std::move(spec_), std::move(receiver_), std::move(callback_)); std::move(spec), std::move(receiver), std::move(callback));
return; return;
} }
RecordLoadHandwritingModelResult( // Gets existing dlc list and based on the presence of libhandwriting
LoadHandwritingModelResult::DLC_INSTALL_ERROR); // either returns an error or installs the libhandwriting dlc.
std::move(callback_).Run(LoadHandwritingModelResult::DLC_INSTALL_ERROR); dlc_client->GetExistingDlcs(
base::BindOnce(&OnGetExistingDlcsComplete, std::move(spec),
std::move(receiver), std::move(callback), dlc_client));
} }
} // namespace machine_learning } // namespace machine_learning
......
...@@ -14,64 +14,31 @@ ...@@ -14,64 +14,31 @@
namespace chromeos { namespace chromeos {
namespace machine_learning { namespace machine_learning {
// Class that decides either to load handwriting model from rootfs or dlc. // Helper function decides either to load handwriting model from rootfs or dlc.
// New Handwriting clients should call this helper instead of calling // New Handwriting clients should call this helper instead of calling
// ServiceConnection::GetInstance()->LoadHandwritingModel. // ServiceConnection::GetInstance()->LoadHandwritingModel.
// Three typical examples of the callstack are: // Three typical examples of the callstack are:
// Case 1: handwriting in enabled on rootfs. // Case 1: handwriting in enabled on rootfs.
// client calls HandwritingModelLoader("en", receiver, callback).Load() // client calls LoadHandwritingModelFromRootfsOrDlc("en", receiver, callback)
// which calls LoadHandwritingModel -> handwriting model loaded from rootfs. // which calls LoadHandwritingModel -> handwriting model loaded from rootfs.
// Case 2: handwriting is enabled for dlc and dlc is already on the device. // Case 2: handwriting is enabled for dlc and dlc is already on the device.
// client calls HandwritingModelLoader("en", receiver, callback).Load() // client calls LoadHandwritingModelFromRootfsOrDlc("en", receiver, callback)
// which calls -> GetExistingDlcs -> libhandwriting dlc already exists // which calls -> GetExistingDlcs -> libhandwriting dlc already exists
// -> InstallDlc -> LoadHandwritingModel // -> InstallDlc -> LoadHandwritingModel
// The correct handwriting model will be loaded and bond to the receiver. // The correct handwriting model will be loaded and bond to the receiver.
// Case 3: handwriting is enabled for dlc and dlc is not on the device yet. // Case 3: handwriting is enabled for dlc and dlc is not on the device yet.
// client calls HandwritingModelLoader("en", receiver, callback).Load() // client calls LoadHandwritingModelFromRootfsOrDlc("en", receiver, callback)
// which calls -> GetExistingDlcs -> NO libhandwriting dlc exists // which calls -> GetExistingDlcs -> NO libhandwriting dlc exists
// -> Return error DLC_NOT_EXISTED. // -> Return error DLC_NOT_EXISTED.
// Then it will be the client's duty to install the dlc and then calls // Then it will be the client's duty to install the dlc and then calls
// HandwritingModelLoader("en", receiver, callback).Load() again. // LoadHandwritingModelFromRootfsOrDlc("en", receiver, callback) again.
class HandwritingModelLoader { //
public: // `dlc_client` should only be replaced with non-default value in unit tests.
HandwritingModelLoader( void LoadHandwritingModelFromRootfsOrDlc(
mojom::HandwritingRecognizerSpecPtr spec, mojom::HandwritingRecognizerSpecPtr spec,
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver, mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
mojom::MachineLearningService::LoadHandwritingModelCallback callback); mojom::MachineLearningService::LoadHandwritingModelCallback callback,
DlcserviceClient* dlc_client = chromeos::DlcserviceClient::Get());
~HandwritingModelLoader();
// Load handwriting model based on the comandline flag and language.
void Load();
static constexpr char kOndeviceHandwritingSwitch[] = "ondevice_handwriting";
static constexpr char kLibHandwritingDlcId[] = "libhandwriting";
private:
friend class HandwritingModelLoaderTest;
// Called when the existing-dlc-list is returned.
// Returns an error if libhandwriting is not in the existing-dlc-list.
// Calls InstallDlc otherwise.
void OnGetExistingDlcsComplete(
const std::string& err,
const dlcservice::DlcsWithContent& dlcs_with_content);
// Called when InstallDlc completes.
// Returns an error if the `result.error` is not dlcservice::kErrorNone.
// Calls mlservice to LoadHandwritingModel otherwise.
void OnInstallDlcComplete(
const chromeos::DlcserviceClient::InstallResult& result);
DlcserviceClient* dlc_client_;
mojom::HandwritingRecognizerSpecPtr spec_;
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver_;
mojom::MachineLearningService::LoadHandwritingModelCallback callback_;
base::WeakPtrFactory<HandwritingModelLoader> weak_ptr_factory_;
DISALLOW_COPY_AND_ASSIGN(HandwritingModelLoader);
};
} // namespace machine_learning } // namespace machine_learning
} // namespace chromeos } // namespace chromeos
......
...@@ -18,20 +18,16 @@ namespace chromeos { ...@@ -18,20 +18,16 @@ namespace chromeos {
namespace machine_learning { namespace machine_learning {
using chromeos::machine_learning::mojom::LoadHandwritingModelResult; using chromeos::machine_learning::mojom::LoadHandwritingModelResult;
constexpr char kOndeviceHandwritingSwitch[] = "ondevice_handwriting";
constexpr char kLibHandwritingDlcId[] = "libhandwriting";
class HandwritingModelLoaderTest : public testing::Test { class HandwritingModelLoaderTest : public testing::Test {
protected: protected:
void SetUp() override { void SetUp() override {
ServiceConnection::UseFakeServiceConnectionForTesting( ServiceConnection::UseFakeServiceConnectionForTesting(
&fake_service_connection_); &fake_service_connection_);
result_ = LoadHandwritingModelResult::DEPRECATED_MODEL_SPEC_ERROR; result_ = LoadHandwritingModelResult::DEPRECATED_MODEL_SPEC_ERROR;
language_ = "en";
loader_ = std::make_unique<HandwritingModelLoader>(
mojom::HandwritingRecognizerSpec::New("en"),
recognizer_.BindNewPipeAndPassReceiver(),
base::BindOnce(
&HandwritingModelLoaderTest::OnHandwritingModelLoaderComplete,
base::Unretained(this)));
loader_->dlc_client_ = &fake_client_;
} }
// Callback that called when loader_->Load() is over to save the returned // Callback that called when loader_->Load() is over to save the returned
...@@ -44,14 +40,19 @@ class HandwritingModelLoaderTest : public testing::Test { ...@@ -44,14 +40,19 @@ class HandwritingModelLoaderTest : public testing::Test {
// Runs loader_->Load() and check the returned result as expected. // Runs loader_->Load() and check the returned result as expected.
void ExpectLoadHandwritingModelResult( void ExpectLoadHandwritingModelResult(
const LoadHandwritingModelResult expected_result) { const LoadHandwritingModelResult expected_result) {
loader_->Load(); LoadHandwritingModelFromRootfsOrDlc(
mojom::HandwritingRecognizerSpec::New(language_),
recognizer_.BindNewPipeAndPassReceiver(),
base::BindOnce(
&HandwritingModelLoaderTest::OnHandwritingModelLoaderComplete,
base::Unretained(this)),
&fake_client_);
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
EXPECT_EQ(result_, expected_result); EXPECT_EQ(result_, expected_result);
} }
void SetLanguage(const std::string& language) { void SetLanguage(const std::string& language) { language_ = language; }
loader_->spec_->language = language;
}
// Creates a dlc list with one dlc inside. // Creates a dlc list with one dlc inside.
void AddDlcsWithContent(const std::string& dlc_id) { void AddDlcsWithContent(const std::string& dlc_id) {
...@@ -69,7 +70,7 @@ class HandwritingModelLoaderTest : public testing::Test { ...@@ -69,7 +70,7 @@ class HandwritingModelLoaderTest : public testing::Test {
// Sets "ondevice_handwriting" value. // Sets "ondevice_handwriting" value.
void SetSwitchValue(const std::string& switch_value) { void SetSwitchValue(const std::string& switch_value) {
base::CommandLine::ForCurrentProcess()->AppendSwitchASCII( base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(
HandwritingModelLoader::kOndeviceHandwritingSwitch, switch_value); kOndeviceHandwritingSwitch, switch_value);
} }
private: private:
...@@ -80,8 +81,8 @@ class HandwritingModelLoaderTest : public testing::Test { ...@@ -80,8 +81,8 @@ class HandwritingModelLoaderTest : public testing::Test {
FakeDlcserviceClient fake_client_; FakeDlcserviceClient fake_client_;
FakeServiceConnectionImpl fake_service_connection_; FakeServiceConnectionImpl fake_service_connection_;
LoadHandwritingModelResult result_; LoadHandwritingModelResult result_;
std::string language_;
mojo::Remote<mojom::HandwritingRecognizer> recognizer_; mojo::Remote<mojom::HandwritingRecognizer> recognizer_;
std::unique_ptr<HandwritingModelLoader> loader_;
}; };
TEST_F(HandwritingModelLoaderTest, HandwritingNotEnabled) { TEST_F(HandwritingModelLoaderTest, HandwritingNotEnabled) {
...@@ -122,7 +123,7 @@ TEST_F(HandwritingModelLoaderTest, LoadingWithoutDlcOnDevice) { ...@@ -122,7 +123,7 @@ TEST_F(HandwritingModelLoaderTest, LoadingWithoutDlcOnDevice) {
TEST_F(HandwritingModelLoaderTest, DlcInstalledWithError) { TEST_F(HandwritingModelLoaderTest, DlcInstalledWithError) {
SetSwitchValue("use_dlc"); SetSwitchValue("use_dlc");
AddDlcsWithContent(HandwritingModelLoader::kLibHandwritingDlcId); AddDlcsWithContent(kLibHandwritingDlcId);
SetInstallError("random error"); SetInstallError("random error");
// InstallDlc error should return DLC_INSTALL_ERROR. // InstallDlc error should return DLC_INSTALL_ERROR.
...@@ -133,7 +134,7 @@ TEST_F(HandwritingModelLoaderTest, DlcInstalledWithError) { ...@@ -133,7 +134,7 @@ TEST_F(HandwritingModelLoaderTest, DlcInstalledWithError) {
TEST_F(HandwritingModelLoaderTest, DlcInstalledWithoutError) { TEST_F(HandwritingModelLoaderTest, DlcInstalledWithoutError) {
SetSwitchValue("use_dlc"); SetSwitchValue("use_dlc");
AddDlcsWithContent(HandwritingModelLoader::kLibHandwritingDlcId); AddDlcsWithContent(kLibHandwritingDlcId);
SetInstallError(dlcservice::kErrorNone); SetInstallError(dlcservice::kErrorNone);
// InstallDlc without an error should return success. // InstallDlc without an error should return success.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment