Commit ee42fe4b authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Add LearningTaskController::UpdateDefaultTarget

UpdateDefaultTarget can be used to update the default target value
that will be used to complete the observation if the controller
is destroyed before CompleteObservation is called.  A default value
may be added, changed, or removed.

Change-Id: I700ad31cb56e60a15e31c43367377336f4c1ad74
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1863593
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarDaniel Cheng <dcheng@chromium.org>
Reviewed-by: default avatarThomas Guilbert <tguilbert@chromium.org>
Cr-Commit-Position: refs/heads/master@{#707972}
parent c85a0c24
...@@ -34,6 +34,9 @@ class MockLearningTaskController : public LearningTaskController { ...@@ -34,6 +34,9 @@ class MockLearningTaskController : public LearningTaskController {
void(base::UnguessableToken id, void(base::UnguessableToken id,
const ObservationCompletion& completion)); const ObservationCompletion& completion));
MOCK_METHOD1(CancelObservation, void(base::UnguessableToken id)); MOCK_METHOD1(CancelObservation, void(base::UnguessableToken id));
MOCK_METHOD2(UpdateDefaultTarget,
void(base::UnguessableToken id,
const base::Optional<TargetValue>& default_target));
const LearningTask& GetLearningTask() { return task_; } const LearningTask& GetLearningTask() { return task_; }
......
...@@ -79,6 +79,15 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController { ...@@ -79,6 +79,15 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
// Notify the LearningTaskController that no completion will be sent. // Notify the LearningTaskController that no completion will be sent.
virtual void CancelObservation(base::UnguessableToken id) = 0; virtual void CancelObservation(base::UnguessableToken id) = 0;
// Update the default target value for |id|. This can change a previously
// specified default value to something else, add one where one wasn't
// specified before, or un-set it. In the last case, the observation will be
// cancelled rather than completed if |this| is destroyed, just as if no
// default value was given.
virtual void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) = 0;
// Returns the LearningTask associated with |this|. // Returns the LearningTask associated with |this|.
virtual const LearningTask& GetLearningTask() = 0; virtual const LearningTask& GetLearningTask() = 0;
......
...@@ -79,6 +79,15 @@ class WeakLearningTaskController : public LearningTaskController { ...@@ -79,6 +79,15 @@ class WeakLearningTaskController : public LearningTaskController {
id); id);
} }
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override {
if (!weak_session_)
return;
outstanding_observations_[id] = default_target;
}
const LearningTask& GetLearningTask() override { return task_; } const LearningTask& GetLearningTask() override { return task_; }
base::WeakPtr<LearningSessionImpl> weak_session_; base::WeakPtr<LearningSessionImpl> weak_session_;
......
...@@ -61,6 +61,14 @@ class LearningSessionImplTest : public testing::Test { ...@@ -61,6 +61,14 @@ class LearningSessionImplTest : public testing::Test {
cancelled_id_ = id; cancelled_id_ = id;
} }
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override {
// Should not be called, since LearningTaskControllerImpl doesn't support
// default values.
updated_id_ = id;
}
const LearningTask& GetLearningTask() override { const LearningTask& GetLearningTask() override {
NOTREACHED(); NOTREACHED();
return LearningTask::Empty(); return LearningTask::Empty();
...@@ -74,6 +82,9 @@ class LearningSessionImplTest : public testing::Test { ...@@ -74,6 +82,9 @@ class LearningSessionImplTest : public testing::Test {
// Most recently cancelled id. // Most recently cancelled id.
base::UnguessableToken cancelled_id_; base::UnguessableToken cancelled_id_;
// Id of most recently changed default target value.
base::Optional<base::UnguessableToken> updated_id_;
}; };
class FakeFeatureProvider : public FeatureProvider { class FakeFeatureProvider : public FeatureProvider {
...@@ -252,10 +263,58 @@ TEST_F(LearningSessionImplTest, ...@@ -252,10 +263,58 @@ TEST_F(LearningSessionImplTest,
EXPECT_EQ(task_controllers_[0]->id_, id); EXPECT_EQ(task_controllers_[0]->id_, id);
EXPECT_FALSE(task_controllers_[0]->default_target_); EXPECT_FALSE(task_controllers_[0]->default_target_);
// Should result in completes the observation. // Should complete the observation.
controller.reset();
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->example_.target_value, default_target);
}
TEST_F(LearningSessionImplTest, ChangeDefaultTargetToValue) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
// Start an observation without a default, then add one.
base::UnguessableToken id = base::UnguessableToken::Create();
controller->BeginObservation(id, FeatureVector(), base::nullopt);
TargetValue default_target(123);
controller->UpdateDefaultTarget(id, default_target);
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->id_, id);
// Should complete the observation.
controller.reset(); controller.reset();
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->example_.target_value, default_target); EXPECT_EQ(task_controllers_[0]->example_.target_value, default_target);
// Shouldn't notify the underlying controller.
EXPECT_FALSE(task_controllers_[0]->updated_id_);
}
TEST_F(LearningSessionImplTest, ChangeDefaultTargetToNoValue) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
// Start an observation with a default, then remove it.
base::UnguessableToken id = base::UnguessableToken::Create();
TargetValue default_target(123);
controller->BeginObservation(id, FeatureVector(), default_target);
controller->UpdateDefaultTarget(id, base::nullopt);
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->id_, id);
// Should cancel the observation.
controller.reset();
task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->cancelled_id_, id);
// Shouldn't notify the underlying controller.
EXPECT_FALSE(task_controllers_[0]->updated_id_);
} }
} // namespace learning } // namespace learning
......
...@@ -82,6 +82,12 @@ void LearningTaskControllerImpl::CancelObservation(base::UnguessableToken id) { ...@@ -82,6 +82,12 @@ void LearningTaskControllerImpl::CancelObservation(base::UnguessableToken id) {
helper_->CancelObservation(id); helper_->CancelObservation(id);
} }
void LearningTaskControllerImpl::UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) {
NOTREACHED();
}
const LearningTask& LearningTaskControllerImpl::GetLearningTask() { const LearningTask& LearningTaskControllerImpl::GetLearningTask() {
return task_; return task_;
} }
......
...@@ -58,6 +58,9 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl ...@@ -58,6 +58,9 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
void CompleteObservation(base::UnguessableToken id, void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override; const ObservationCompletion& completion) override;
void CancelObservation(base::UnguessableToken id) override; void CancelObservation(base::UnguessableToken id) override;
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override;
const LearningTask& GetLearningTask() override; const LearningTask& GetLearningTask() override;
private: private:
......
...@@ -62,5 +62,15 @@ void MojoLearningTaskControllerService::CancelObservation( ...@@ -62,5 +62,15 @@ void MojoLearningTaskControllerService::CancelObservation(
impl_->CancelObservation(id); impl_->CancelObservation(id);
} }
void MojoLearningTaskControllerService::UpdateDefaultTarget(
const base::UnguessableToken& id,
const base::Optional<TargetValue>& default_target) {
auto iter = in_flight_observations_.find(id);
if (iter == in_flight_observations_.end())
return;
impl_->UpdateDefaultTarget(id, default_target);
}
} // namespace learning } // namespace learning
} // namespace media } // namespace media
...@@ -35,6 +35,9 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskControllerService ...@@ -35,6 +35,9 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskControllerService
void CompleteObservation(const base::UnguessableToken& id, void CompleteObservation(const base::UnguessableToken& id,
const ObservationCompletion& completion) override; const ObservationCompletion& completion) override;
void CancelObservation(const base::UnguessableToken& id) override; void CancelObservation(const base::UnguessableToken& id) override;
void UpdateDefaultTarget(
const base::UnguessableToken& id,
const base::Optional<TargetValue>& default_target) override;
protected: protected:
const LearningTask task_; const LearningTask task_;
......
...@@ -39,6 +39,13 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test { ...@@ -39,6 +39,13 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
cancel_args_.id_ = id; cancel_args_.id_ = id;
} }
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override {
update_default_args_.id_ = id;
update_default_args_.default_target_ = default_target;
}
const LearningTask& GetLearningTask() override { const LearningTask& GetLearningTask() override {
return LearningTask::Empty(); return LearningTask::Empty();
} }
...@@ -57,6 +64,11 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test { ...@@ -57,6 +64,11 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
struct { struct {
base::UnguessableToken id_; base::UnguessableToken id_;
} cancel_args_; } cancel_args_;
struct {
base::UnguessableToken id_;
base::Optional<TargetValue> default_target_;
} update_default_args_;
}; };
public: public:
...@@ -159,5 +171,27 @@ TEST_F(MojoLearningTaskControllerServiceTest, CancelWithoutBeginFails) { ...@@ -159,5 +171,27 @@ TEST_F(MojoLearningTaskControllerServiceTest, CancelWithoutBeginFails) {
EXPECT_NE(id, controller_raw_->cancel_args_.id_); EXPECT_NE(id, controller_raw_->cancel_args_.id_);
} }
TEST_F(MojoLearningTaskControllerServiceTest, UpdateDefaultTargetToValue) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
service_->BeginObservation(id, features, base::nullopt);
TargetValue default_target(987);
service_->UpdateDefaultTarget(id, default_target);
EXPECT_EQ(id, controller_raw_->update_default_args_.id_);
EXPECT_EQ(default_target,
controller_raw_->update_default_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerServiceTest, UpdateDefaultTargetToNoValue) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetValue default_target(987);
service_->BeginObservation(id, features, default_target);
service_->UpdateDefaultTarget(id, base::nullopt);
EXPECT_EQ(id, controller_raw_->update_default_args_.id_);
EXPECT_EQ(base::nullopt,
controller_raw_->update_default_args_.default_target_);
}
} // namespace learning } // namespace learning
} // namespace media } // namespace media
...@@ -37,6 +37,12 @@ void MojoLearningTaskController::CancelObservation(base::UnguessableToken id) { ...@@ -37,6 +37,12 @@ void MojoLearningTaskController::CancelObservation(base::UnguessableToken id) {
controller_->CancelObservation(id); controller_->CancelObservation(id);
} }
void MojoLearningTaskController::UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) {
controller_->UpdateDefaultTarget(id, default_target);
}
const LearningTask& MojoLearningTaskController::GetLearningTask() { const LearningTask& MojoLearningTaskController::GetLearningTask() {
return task_; return task_;
} }
......
...@@ -36,6 +36,9 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController ...@@ -36,6 +36,9 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController
void CompleteObservation(base::UnguessableToken id, void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override; const ObservationCompletion& completion) override;
void CancelObservation(base::UnguessableToken id) override; void CancelObservation(base::UnguessableToken id) override;
void UpdateDefaultTarget(
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override;
const LearningTask& GetLearningTask() override; const LearningTask& GetLearningTask() override;
private: private:
......
...@@ -41,6 +41,13 @@ class MojoLearningTaskControllerTest : public ::testing::Test { ...@@ -41,6 +41,13 @@ class MojoLearningTaskControllerTest : public ::testing::Test {
cancel_args_.id_ = id; cancel_args_.id_ = id;
} }
void UpdateDefaultTarget(
const base::UnguessableToken& id,
const base::Optional<TargetValue>& default_target) override {
update_default_args_.id_ = id;
update_default_args_.default_target_ = default_target;
}
struct { struct {
base::UnguessableToken id_; base::UnguessableToken id_;
FeatureVector features_; FeatureVector features_;
...@@ -55,6 +62,11 @@ class MojoLearningTaskControllerTest : public ::testing::Test { ...@@ -55,6 +62,11 @@ class MojoLearningTaskControllerTest : public ::testing::Test {
struct { struct {
base::UnguessableToken id_; base::UnguessableToken id_;
} cancel_args_; } cancel_args_;
struct {
base::UnguessableToken id_;
base::Optional<TargetValue> default_target_;
} update_default_args_;
}; };
public: public:
...@@ -108,6 +120,34 @@ TEST_F(MojoLearningTaskControllerTest, BeginWithDefaultTarget) { ...@@ -108,6 +120,34 @@ TEST_F(MojoLearningTaskControllerTest, BeginWithDefaultTarget) {
fake_learning_controller_.begin_args_.default_target_); fake_learning_controller_.begin_args_.default_target_);
} }
TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToValue) {
// Test if we can update the default target to a non-nullopt.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
learning_controller_->BeginObservation(id, features, base::nullopt);
TargetValue default_target(987);
learning_controller_->UpdateDefaultTarget(id, default_target);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
EXPECT_EQ(default_target,
fake_learning_controller_.update_default_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerTest, UpdateDefaultTargetToNoValue) {
// Test if we can update the default target to nullopt.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetValue default_target(987);
learning_controller_->BeginObservation(id, features, default_target);
learning_controller_->UpdateDefaultTarget(id, base::nullopt);
task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.update_default_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
EXPECT_EQ(base::nullopt,
fake_learning_controller_.update_default_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerTest, Complete) { TEST_F(MojoLearningTaskControllerTest, Complete) {
base::UnguessableToken id = base::UnguessableToken::Create(); base::UnguessableToken id = base::UnguessableToken::Create();
ObservationCompletion completion(TargetValue(1234)); ObservationCompletion completion(TargetValue(1234));
......
...@@ -36,4 +36,9 @@ interface LearningTaskController { ...@@ -36,4 +36,9 @@ interface LearningTaskController {
// Cancel observation |id|. Deleting |this| will do the same. // Cancel observation |id|. Deleting |this| will do the same.
CancelObservation(mojo_base.mojom.UnguessableToken id); CancelObservation(mojo_base.mojom.UnguessableToken id);
// Update the default target for |id| to |default_target|. May also unset it,
// so that the observation will be cancelled if the controller is destroyed.
UpdateDefaultTarget(mojo_base.mojom.UnguessableToken id,
TargetValue? default_target);
}; };
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment