Commit c5f5ffb7 authored by Charles Zhao's avatar Charles Zhao Committed by Commit Bot

Fix HandwritingModelLoader.

(1) calling HandwritingModelLoader(spec, receiver, callback).Load() is
a bad idea because the .Load() is async and the temporal object could be
released very soon. This Cl changes the whole class to be stateless, so
that calling HandwritingModelLoader::Load(spec, receiver, callback) won't suffer the same problem.

(2) unit tests are fixed accordingly.

Bug: 1054628
Change-Id: I56459b14e1fc4a6608d5ed264e6c650d1e076f69
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2415914
Commit-Queue: Charles . <charleszhao@chromium.org>
Reviewed-by: default avatarAndrew Moylan <amoylan@chromium.org>
Reviewed-by: default avatarCharles . <charleszhao@chromium.org>
Reviewed-by: default avatarXinglong Luan <alanlxl@chromium.org>
Cr-Commit-Position: refs/heads/master@{#810102}
parent 08c26569
......@@ -13,7 +13,12 @@ namespace chromeos {
namespace machine_learning {
namespace {
using chromeos::machine_learning::mojom::LoadHandwritingModelResult;
using ::chromeos::machine_learning::mojom::HandwritingRecognizerSpecPtr;
using ::chromeos::machine_learning::mojom::LoadHandwritingModelResult;
using HandwritingRecognizer =
mojo::PendingReceiver<mojom::HandwritingRecognizer>;
using LoadHandwritingModelCallback = ::chromeos::machine_learning::mojom::
MachineLearningService::LoadHandwritingModelCallback;
// Records CrOSActionRecorder event.
void RecordLoadHandwritingModelResult(const LoadHandwritingModelResult val) {
......@@ -22,6 +27,8 @@ void RecordLoadHandwritingModelResult(const LoadHandwritingModelResult val) {
LoadHandwritingModelResult::LOAD_MODEL_FILES_ERROR);
}
constexpr char kOndeviceHandwritingSwitch[] = "ondevice_handwriting";
constexpr char kLibHandwritingDlcId[] = "libhandwriting";
// A list of supported language code.
constexpr char kLanguageCodeEn[] = "en";
constexpr char kLanguageCodeGesture[] = "gesture_in_context";
......@@ -30,10 +37,8 @@ constexpr char kLanguageCodeGesture[] = "gesture_in_context";
// kOndeviceHandwritingSwitch.
bool HandwritingSwitchHasValue(const std::string& value) {
base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
return command_line->HasSwitch(
HandwritingModelLoader::kOndeviceHandwritingSwitch) &&
command_line->GetSwitchValueASCII(
HandwritingModelLoader::kOndeviceHandwritingSwitch) == value;
return command_line->HasSwitch(kOndeviceHandwritingSwitch) &&
command_line->GetSwitchValueASCII(kOndeviceHandwritingSwitch) == value;
}
// Returns true if switch kOndeviceHandwritingSwitch is set to use_rootfs.
......@@ -46,69 +51,45 @@ bool IsLibHandwritingDlcEnabled() {
return HandwritingSwitchHasValue("use_dlc");
}
} // namespace
constexpr char HandwritingModelLoader::kOndeviceHandwritingSwitch[];
constexpr char HandwritingModelLoader::kLibHandwritingDlcId[];
HandwritingModelLoader::HandwritingModelLoader(
mojom::HandwritingRecognizerSpecPtr spec,
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
mojom::MachineLearningService::LoadHandwritingModelCallback callback)
: dlc_client_(chromeos::DlcserviceClient::Get()),
spec_(std::move(spec)),
receiver_(std::move(receiver)),
callback_(std::move(callback)),
weak_ptr_factory_(this) {}
HandwritingModelLoader::~HandwritingModelLoader() = default;
void HandwritingModelLoader::Load() {
// Returns FEATURE_NOT_SUPPORTED_ERROR if both rootfs and dlc are not enabled.
if (!IsLibHandwritingRootfsEnabled() && !IsLibHandwritingDlcEnabled()) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
std::move(callback_).Run(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
return;
}
// Returns LANGUAGE_NOT_SUPPORTED_ERROR if the language is not supported yet.
if (spec_->language != kLanguageCodeEn &&
spec_->language != kLanguageCodeGesture) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
std::move(callback_).Run(
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
return;
}
// Load from rootfs if enabled.
if (IsLibHandwritingRootfsEnabled()) {
// Called when InstallDlc completes.
// Returns an error if the `result.error` is not dlcservice::kErrorNone.
// Calls mlservice to LoadHandwritingModel otherwise.
void OnInstallDlcComplete(
HandwritingRecognizerSpecPtr spec,
HandwritingRecognizer receiver,
LoadHandwritingModelCallback callback,
const chromeos::DlcserviceClient::InstallResult& result) {
// Call LoadHandwritingModelWithSpec if no error was found.
if (result.error == dlcservice::kErrorNone) {
ServiceConnection::GetInstance()->LoadHandwritingModel(
std::move(spec_), std::move(receiver_), std::move(callback_));
std::move(spec), std::move(receiver), std::move(callback));
return;
}
// Gets existing dlc list and based on the presence of libhandwriting
// either returns an error or installs the libhandwriting dlc.
dlc_client_->GetExistingDlcs(
base::BindOnce(&HandwritingModelLoader::OnGetExistingDlcsComplete,
weak_ptr_factory_.GetWeakPtr()));
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::DLC_INSTALL_ERROR);
std::move(callback).Run(LoadHandwritingModelResult::DLC_INSTALL_ERROR);
}
void HandwritingModelLoader::OnGetExistingDlcsComplete(
// Called when the existing-dlc-list is returned.
// Returns an error if libhandwriting is not in the existing-dlc-list.
// Calls InstallDlc otherwise.
void OnGetExistingDlcsComplete(
HandwritingRecognizerSpecPtr spec,
HandwritingRecognizer receiver,
LoadHandwritingModelCallback callback,
DlcserviceClient* const dlc_client,
const std::string& err,
const dlcservice::DlcsWithContent& dlcs_with_content) {
// Loop over dlcs_with_content, and installs libhandwriting if already exists.
// Since we don't want to trigger downloading here, we only install(mount)
// the handwriting dlc if it is already on device.
for (const auto& dlc_info : dlcs_with_content.dlc_infos()) {
if (dlc_info.id() == HandwritingModelLoader::kLibHandwritingDlcId) {
dlc_client_->Install(
if (dlc_info.id() == kLibHandwritingDlcId) {
dlc_client->Install(
kLibHandwritingDlcId,
base::BindOnce(&HandwritingModelLoader::OnInstallDlcComplete,
weak_ptr_factory_.GetWeakPtr()),
base::BindOnce(&OnInstallDlcComplete, std::move(spec),
std::move(receiver), std::move(callback)),
chromeos::DlcserviceClient::IgnoreProgress);
return;
}
......@@ -117,21 +98,46 @@ void HandwritingModelLoader::OnGetExistingDlcsComplete(
// Returns error if the handwriting dlc is not on the device.
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::DLC_DOES_NOT_EXIST);
std::move(callback_).Run(LoadHandwritingModelResult::DLC_DOES_NOT_EXIST);
std::move(callback).Run(LoadHandwritingModelResult::DLC_DOES_NOT_EXIST);
}
void HandwritingModelLoader::OnInstallDlcComplete(
const chromeos::DlcserviceClient::InstallResult& result) {
// Call LoadHandwritingModelWithSpec if no error was found.
if (result.error == dlcservice::kErrorNone) {
ServiceConnection::GetInstance()->LoadHandwritingModel(
std::move(spec_), std::move(receiver_), std::move(callback_));
} // namespace
void LoadHandwritingModelFromRootfsOrDlc(HandwritingRecognizerSpecPtr spec,
HandwritingRecognizer receiver,
LoadHandwritingModelCallback callback,
DlcserviceClient* const dlc_client) {
// Returns FEATURE_NOT_SUPPORTED_ERROR if both rootfs and dlc are not enabled.
if (!IsLibHandwritingRootfsEnabled() && !IsLibHandwritingDlcEnabled()) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
std::move(callback).Run(
LoadHandwritingModelResult::FEATURE_NOT_SUPPORTED_ERROR);
return;
}
// Returns LANGUAGE_NOT_SUPPORTED_ERROR if the language is not supported yet.
if (spec->language != kLanguageCodeEn &&
spec->language != kLanguageCodeGesture) {
RecordLoadHandwritingModelResult(
LoadHandwritingModelResult::DLC_INSTALL_ERROR);
std::move(callback_).Run(LoadHandwritingModelResult::DLC_INSTALL_ERROR);
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
std::move(callback).Run(
LoadHandwritingModelResult::LANGUAGE_NOT_SUPPORTED_ERROR);
return;
}
// Load from rootfs if enabled.
if (IsLibHandwritingRootfsEnabled()) {
ServiceConnection::GetInstance()->LoadHandwritingModel(
std::move(spec), std::move(receiver), std::move(callback));
return;
}
// Gets existing dlc list and based on the presence of libhandwriting
// either returns an error or installs the libhandwriting dlc.
dlc_client->GetExistingDlcs(
base::BindOnce(&OnGetExistingDlcsComplete, std::move(spec),
std::move(receiver), std::move(callback), dlc_client));
}
} // namespace machine_learning
......
......@@ -14,64 +14,31 @@
namespace chromeos {
namespace machine_learning {
// Class that decides either to load handwriting model from rootfs or dlc.
// Helper function decides either to load handwriting model from rootfs or dlc.
// New Handwriting clients should call this helper instead of calling
// ServiceConnection::GetInstance()->LoadHandwritingModel.
// Three typical examples of the callstack are:
// Case 1: handwriting in enabled on rootfs.
// client calls HandwritingModelLoader("en", receiver, callback).Load()
// client calls LoadHandwritingModelFromRootfsOrDlc("en", receiver, callback)
// which calls LoadHandwritingModel -> handwriting model loaded from rootfs.
// Case 2: handwriting is enabled for dlc and dlc is already on the device.
// client calls HandwritingModelLoader("en", receiver, callback).Load()
// client calls LoadHandwritingModelFromRootfsOrDlc("en", receiver, callback)
// which calls -> GetExistingDlcs -> libhandwriting dlc already exists
// -> InstallDlc -> LoadHandwritingModel
// The correct handwriting model will be loaded and bond to the receiver.
// Case 3: handwriting is enabled for dlc and dlc is not on the device yet.
// client calls HandwritingModelLoader("en", receiver, callback).Load()
// client calls LoadHandwritingModelFromRootfsOrDlc("en", receiver, callback)
// which calls -> GetExistingDlcs -> NO libhandwriting dlc exists
// -> Return error DLC_NOT_EXISTED.
// Then it will be the client's duty to install the dlc and then calls
// HandwritingModelLoader("en", receiver, callback).Load() again.
class HandwritingModelLoader {
public:
HandwritingModelLoader(
// LoadHandwritingModelFromRootfsOrDlc("en", receiver, callback) again.
//
// `dlc_client` should only be replaced with non-default value in unit tests.
void LoadHandwritingModelFromRootfsOrDlc(
mojom::HandwritingRecognizerSpecPtr spec,
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver,
mojom::MachineLearningService::LoadHandwritingModelCallback callback);
~HandwritingModelLoader();
// Load handwriting model based on the comandline flag and language.
void Load();
static constexpr char kOndeviceHandwritingSwitch[] = "ondevice_handwriting";
static constexpr char kLibHandwritingDlcId[] = "libhandwriting";
private:
friend class HandwritingModelLoaderTest;
// Called when the existing-dlc-list is returned.
// Returns an error if libhandwriting is not in the existing-dlc-list.
// Calls InstallDlc otherwise.
void OnGetExistingDlcsComplete(
const std::string& err,
const dlcservice::DlcsWithContent& dlcs_with_content);
// Called when InstallDlc completes.
// Returns an error if the `result.error` is not dlcservice::kErrorNone.
// Calls mlservice to LoadHandwritingModel otherwise.
void OnInstallDlcComplete(
const chromeos::DlcserviceClient::InstallResult& result);
DlcserviceClient* dlc_client_;
mojom::HandwritingRecognizerSpecPtr spec_;
mojo::PendingReceiver<mojom::HandwritingRecognizer> receiver_;
mojom::MachineLearningService::LoadHandwritingModelCallback callback_;
base::WeakPtrFactory<HandwritingModelLoader> weak_ptr_factory_;
DISALLOW_COPY_AND_ASSIGN(HandwritingModelLoader);
};
mojom::MachineLearningService::LoadHandwritingModelCallback callback,
DlcserviceClient* dlc_client = chromeos::DlcserviceClient::Get());
} // namespace machine_learning
} // namespace chromeos
......
......@@ -18,20 +18,16 @@ namespace chromeos {
namespace machine_learning {
using chromeos::machine_learning::mojom::LoadHandwritingModelResult;
constexpr char kOndeviceHandwritingSwitch[] = "ondevice_handwriting";
constexpr char kLibHandwritingDlcId[] = "libhandwriting";
class HandwritingModelLoaderTest : public testing::Test {
protected:
void SetUp() override {
ServiceConnection::UseFakeServiceConnectionForTesting(
&fake_service_connection_);
result_ = LoadHandwritingModelResult::DEPRECATED_MODEL_SPEC_ERROR;
loader_ = std::make_unique<HandwritingModelLoader>(
mojom::HandwritingRecognizerSpec::New("en"),
recognizer_.BindNewPipeAndPassReceiver(),
base::BindOnce(
&HandwritingModelLoaderTest::OnHandwritingModelLoaderComplete,
base::Unretained(this)));
loader_->dlc_client_ = &fake_client_;
language_ = "en";
}
// Callback that called when loader_->Load() is over to save the returned
......@@ -44,14 +40,19 @@ class HandwritingModelLoaderTest : public testing::Test {
// Runs loader_->Load() and check the returned result as expected.
void ExpectLoadHandwritingModelResult(
const LoadHandwritingModelResult expected_result) {
loader_->Load();
LoadHandwritingModelFromRootfsOrDlc(
mojom::HandwritingRecognizerSpec::New(language_),
recognizer_.BindNewPipeAndPassReceiver(),
base::BindOnce(
&HandwritingModelLoaderTest::OnHandwritingModelLoaderComplete,
base::Unretained(this)),
&fake_client_);
base::RunLoop().RunUntilIdle();
EXPECT_EQ(result_, expected_result);
}
void SetLanguage(const std::string& language) {
loader_->spec_->language = language;
}
void SetLanguage(const std::string& language) { language_ = language; }
// Creates a dlc list with one dlc inside.
void AddDlcsWithContent(const std::string& dlc_id) {
......@@ -69,7 +70,7 @@ class HandwritingModelLoaderTest : public testing::Test {
// Sets "ondevice_handwriting" value.
void SetSwitchValue(const std::string& switch_value) {
base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(
HandwritingModelLoader::kOndeviceHandwritingSwitch, switch_value);
kOndeviceHandwritingSwitch, switch_value);
}
private:
......@@ -80,8 +81,8 @@ class HandwritingModelLoaderTest : public testing::Test {
FakeDlcserviceClient fake_client_;
FakeServiceConnectionImpl fake_service_connection_;
LoadHandwritingModelResult result_;
std::string language_;
mojo::Remote<mojom::HandwritingRecognizer> recognizer_;
std::unique_ptr<HandwritingModelLoader> loader_;
};
TEST_F(HandwritingModelLoaderTest, HandwritingNotEnabled) {
......@@ -122,7 +123,7 @@ TEST_F(HandwritingModelLoaderTest, LoadingWithoutDlcOnDevice) {
TEST_F(HandwritingModelLoaderTest, DlcInstalledWithError) {
SetSwitchValue("use_dlc");
AddDlcsWithContent(HandwritingModelLoader::kLibHandwritingDlcId);
AddDlcsWithContent(kLibHandwritingDlcId);
SetInstallError("random error");
// InstallDlc error should return DLC_INSTALL_ERROR.
......@@ -133,7 +134,7 @@ TEST_F(HandwritingModelLoaderTest, DlcInstalledWithError) {
TEST_F(HandwritingModelLoaderTest, DlcInstalledWithoutError) {
SetSwitchValue("use_dlc");
AddDlcsWithContent(HandwritingModelLoader::kLibHandwritingDlcId);
AddDlcsWithContent(kLibHandwritingDlcId);
SetInstallError(dlcservice::kErrorNone);
// InstallDlc without an error should return success.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment