Commit ee42fe4b authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Add LearningTaskController::UpdateDefaultTarget

UpdateDefaultTarget can be used to update the default target value
that will be used to complete the observation if the controller
is destroyed before CompleteObservation is called.  A default value
may be added, changed, or removed.

Change-Id: I700ad31cb56e60a15e31c43367377336f4c1ad74
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1863593
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarDaniel Cheng <dcheng@chromium.org>
Reviewed-by: default avatarThomas Guilbert <tguilbert@chromium.org>
Cr-Commit-Position: refs/heads/master@{#707972}
parent c85a0c24
......@@ -34,6 +34,9 @@ class MockLearningTaskController : public LearningTaskController {
void(base::UnguessableToken id,
const ObservationCompletion& completion));
MOCK_METHOD1(CancelObservation, void(base::UnguessableToken id));
MOCK_METHOD2(UpdateDefaultTarget,
void(base::UnguessableToken id,
const base::Optional<TargetValue>& default_target));
const LearningTask& GetLearningTask() { return task_; }
......
......@@ -79,6 +79,15 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
// Notify the LearningTaskController that no completion will be sent.
virtual void CancelObservation(base::UnguessableToken id) = 0;
// Update the default target value for |id|. This can change a previously
// specified default value to something else, add one where one wasn't
// specified before, or un-set it. In the last case, the observation will be
// cancelled rather than completed if |this| is destroyed, just as if no
// default value was given.
virtual void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) = 0;
// Returns the LearningTask associated with |this|.
virtual const LearningTask& GetLearningTask() = 0;
......
......@@ -79,6 +79,15 @@ class WeakLearningTaskController : public LearningTaskController {
id);
}
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override {
if (!weak_session_)
return;
outstanding_observations_[id] = default_target;
}
const LearningTask& GetLearningTask() override { return task_; }
base::WeakPtr<LearningSessionImpl> weak_session_;
......
......@@ -61,6 +61,14 @@ class LearningSessionImplTest : public testing::Test {
cancelled_id_ = id;
}
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override {
// Should not be called, since LearningTaskControllerImpl doesn't support
// default values.
updated_id_ = id;
}
const LearningTask& GetLearningTask() override {
NOTREACHED();
return LearningTask::Empty();
......@@ -74,6 +82,9 @@ class LearningSessionImplTest : public testing::Test {
// Most recently cancelled id.
base::UnguessableToken cancelled_id_;
// Id of most recently changed default target value.
base::Optional<base::UnguessableToken> updated_id_;
};
class FakeFeatureProvider : public FeatureProvider {
......@@ -252,10 +263,58 @@ TEST_F(LearningSessionImplTest,
EXPECT_EQ(task_controllers_[0]->id_, id);
EXPECT_FALSE(task_controllers_[0]->default_target_);
// Should result in completes the observation.
// Should complete the observation.
controller.reset();
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->example_.target_value, default_target);
}
TEST_F(LearningSessionImplTest, ChangeDefaultTargetToValue) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
// Start an observation without a default, then add one.
base::UnguessableToken id = base::UnguessableToken::Create();
controller->BeginObservation(id, FeatureVector(), base::nullopt);
TargetValue default_target(123);
controller->UpdateDefaultTarget(id, default_target);
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->id_, id);
// Should complete the observation.
controller.reset();
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->example_.target_value, default_target);
// Shouldn't notify the underlying controller.
EXPECT_FALSE(task_controllers_[0]->updated_id_);
}
TEST_F(LearningSessionImplTest, ChangeDefaultTargetToNoValue) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
// Start an observation with a default, then remove it.
base::UnguessableToken id = base::UnguessableToken::Create();
TargetValue default_target(123);
controller->BeginObservation(id, FeatureVector(), default_target);
controller->UpdateDefaultTarget(id, base::nullopt);
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->id_, id);
// Should cancel the observation.
controller.reset();
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->cancelled_id_, id);
// Shouldn't notify the underlying controller.
EXPECT_FALSE(task_controllers_[0]->updated_id_);
}
} // namespace learning
......
......@@ -82,6 +82,12 @@ void LearningTaskControllerImpl::CancelObservation(base::UnguessableToken id) {
helper_->CancelObservation(id);
}
void LearningTaskControllerImpl::UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) {
NOTREACHED();
}
const LearningTask& LearningTaskControllerImpl::GetLearningTask() {
return task_;
}
......
......@@ -58,6 +58,9 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override;
void CancelObservation(base::UnguessableToken id) override;
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override;
const LearningTask& GetLearningTask() override;
private:
......
......@@ -62,5 +62,15 @@ void MojoLearningTaskControllerService::CancelObservation(
impl_->CancelObservation(id);
}
void MojoLearningTaskControllerService::UpdateDefaultTarget(
const base::UnguessableToken& id,
const base::Optional<TargetValue>& default_target) {
auto iter = in_flight_observations_.find(id);
if (iter == in_flight_observations_.end())
return;
impl_->UpdateDefaultTarget(id, default_target);
}
} // namespace learning
} // namespace media
......@@ -35,6 +35,9 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskControllerService
void CompleteObservation(const base::UnguessableToken& id,
const ObservationCompletion& completion) override;
void CancelObservation(const base::UnguessableToken& id) override;
void UpdateDefaultTarget(
const base::UnguessableToken& id,
const base::Optional<TargetValue>& default_target) override;
protected:
const LearningTask task_;
......
......@@ -39,6 +39,13 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
cancel_args_.id_ = id;
}
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override {
update_default_args_.id_ = id;
update_default_args_.default_target_ = default_target;
}
const LearningTask& GetLearningTask() override {
return LearningTask::Empty();
}
......@@ -57,6 +64,11 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
struct {
base::UnguessableToken id_;
} cancel_args_;
struct {
base::UnguessableToken id_;
base::Optional<TargetValue> default_target_;
} update_default_args_;
};
public:
......@@ -159,5 +171,27 @@ TEST_F(MojoLearningTaskControllerServiceTest, CancelWithoutBeginFails) {
EXPECT_NE(id, controller_raw_->cancel_args_.id_);
}
TEST_F(MojoLearningTaskControllerServiceTest, UpdateDefaultTargetToValue) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
service_->BeginObservation(id, features, base::nullopt);
TargetValue default_target(987);
service_->UpdateDefaultTarget(id, default_target);
EXPECT_EQ(id, controller_raw_->update_default_args_.id_);
EXPECT_EQ(default_target,
controller_raw_->update_default_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerServiceTest, UpdateDefaultTargetToNoValue) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetValue default_target(987);
service_->BeginObservation(id, features, default_target);
service_->UpdateDefaultTarget(id, base::nullopt);
EXPECT_EQ(id, controller_raw_->update_default_args_.id_);
EXPECT_EQ(base::nullopt,
controller_raw_->update_default_args_.default_target_);
}
} // namespace learning
} // namespace media
......@@ -37,6 +37,12 @@ void MojoLearningTaskController::CancelObservation(base::UnguessableToken id) {
controller_->CancelObservation(id);
}
void MojoLearningTaskController::UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) {
controller_->UpdateDefaultTarget(id, default_target);
}
const LearningTask& MojoLearningTaskController::GetLearningTask() {
return task_;
}
......
......@@ -36,6 +36,9 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override;
void CancelObservation(base::UnguessableToken id) override;
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override;
const LearningTask& GetLearningTask() override;
private:
......
......@@ -41,6 +41,13 @@ class MojoLearningTaskControllerTest : public ::testing::Test {
cancel_args_.id_ = id;
}
void UpdateDefaultTarget(
const base::UnguessableToken& id,
const base::Optional<TargetValue>& default_target) override {
update_default_args_.id_ = id;
update_default_args_.default_target_ = default_target;
}
struct {
base::UnguessableToken id_;
FeatureVector features_;
......@@ -55,6 +62,11 @@ class MojoLearningTaskControllerTest : public ::testing::Test {
struct {
base::UnguessableToken id_;
} cancel_args_;
struct {
base::UnguessableToken id_;
base::Optional<TargetValue> default_target_;
} update_default_args_;
};
public:
......@@ -108,6 +120,34 @@ TEST_F(MojoLearningTaskControllerTest, BeginWithDefaultTarget) {
fake_learning_controller_.begin_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToValue) {
// Test if we can update the default target to a non-nullopt.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
learning_controller_->BeginObservation(id, features, base::nullopt);
TargetValue default_target(987);
learning_controller_->UpdateDefaultTarget(id, default_target);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
EXPECT_EQ(default_target,
fake_learning_controller_.update_default_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToNoValue) {
// Test if we can update the default target to nullopt.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetValue default_target(987);
learning_controller_->BeginObservation(id, features, default_target);
learning_controller_->UpdateDefaultTarget(id, base::nullopt);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
EXPECT_EQ(base::nullopt,
fake_learning_controller_.update_default_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerTest, Complete) {
base::UnguessableToken id = base::UnguessableToken::Create();
ObservationCompletion completion(TargetValue(1234));
......
......@@ -36,4 +36,9 @@ interface LearningTaskController {
// Cancel observation |id|. Deleting |this| will do the same.
CancelObservation(mojo_base.mojom.UnguessableToken id);
// Update the default target for |id| to |default_target|. May also unset it,
// so that the observation will be cancelled if the controller is destroyed.
UpdateDefaultTarget(mojo_base.mojom.UnguessableToken id,
TargetValue? default_target);
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment