Commit ddffebdb authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Mojo bindings for LearningTaskController

This CL adds unused mojo bindings for LearningTaskController.  The
service side owns a LearningTaskController to which it forwards
calls.  It doesn't have to worry about tracking in-flight
observations since it just drops the LearningTaskController when the
connection to the client is dropped.

The client side just forwards requests to the service.  Both are
more or less just type adapters from ::mojom to the real types.

Change-Id: I707987c8ee2c8bff1915e58d0ae60b0752074674
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1513204Reviewed-by: default avatarDaniel Cheng <dcheng@chromium.org>
Reviewed-by: default avatarDan Sanders <sandersd@chromium.org>
Commit-Queue: Frank Liberato <liberato@chromium.org>
Cr-Commit-Position: refs/heads/master@{#642347}
parent 771f50ea
...@@ -8,7 +8,8 @@ import("//testing/test.gni") ...@@ -8,7 +8,8 @@ import("//testing/test.gni")
component("impl") { component("impl") {
output_name = "media_learning_mojo_impl" output_name = "media_learning_mojo_impl"
sources = [ sources = [
"dummy.cc", "mojo_learning_task_controller_service.cc",
"mojo_learning_task_controller_service.h",
] ]
defines = [ "IS_MEDIA_LEARNING_MOJO_IMPL" ] defines = [ "IS_MEDIA_LEARNING_MOJO_IMPL" ]
...@@ -33,7 +34,9 @@ component("impl") { ...@@ -33,7 +34,9 @@ component("impl") {
source_set("unit_tests") { source_set("unit_tests") {
testonly = true testonly = true
sources = [] sources = [
"mojo_learning_task_controller_service_unittest.cc",
]
deps = [ deps = [
":impl", ":impl",
......
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Because mac requires a non-empty .a, and removing the component requires
// quite a bit more fiddling with other gn files. Since there will be new
// things in this directory very shortly, it's much easier just to add this.
void media_learning_mojo_do_nothing() {}
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/mojo/mojo_learning_task_controller_service.h"
#include <utility>
#include "media/learning/common/learning_task_controller.h"
namespace media {
namespace learning {
// Somewhat arbitrary upper limit on the number of in-flight observations that
// we'll allow a client to have.
static const size_t kMaxInFlightObservations = 16;
MojoLearningTaskControllerService::MojoLearningTaskControllerService(
const LearningTask& task,
std::unique_ptr<::media::learning::LearningTaskController> impl)
: task_(task), impl_(std::move(impl)) {}
MojoLearningTaskControllerService::~MojoLearningTaskControllerService() =
default;
void MojoLearningTaskControllerService::BeginObservation(
const base::UnguessableToken& id,
const FeatureVector& features) {
// Drop the observation if it doesn't match the feature description size.
if (features.size() != task_.feature_descriptions.size())
return;
// Don't allow the client to send too many in-flight observations.
if (in_flight_observations_.size() >= kMaxInFlightObservations)
return;
in_flight_observations_.insert(id);
// Since we own |impl_|, we don't need to keep track of in-flight
// observations. We'll release |impl_| on destruction, which cancels them.
impl_->BeginObservation(id, features);
}
void MojoLearningTaskControllerService::CompleteObservation(
const base::UnguessableToken& id,
const ObservationCompletion& completion) {
auto iter = in_flight_observations_.find(id);
if (iter == in_flight_observations_.end())
return;
in_flight_observations_.erase(iter);
impl_->CompleteObservation(id, completion);
}
void MojoLearningTaskControllerService::CancelObservation(
const base::UnguessableToken& id) {
auto iter = in_flight_observations_.find(id);
if (iter == in_flight_observations_.end())
return;
in_flight_observations_.erase(iter);
impl_->CancelObservation(id);
}
} // namespace learning
} // namespace media
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_MOJO_MOJO_LEARNING_TASK_CONTROLLER_SERVICE_H_
#define MEDIA_LEARNING_MOJO_MOJO_LEARNING_TASK_CONTROLLER_SERVICE_H_
#include <memory>
#include <set>
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/mojo/public/mojom/learning_task_controller.mojom.h"
namespace media {
namespace learning {
class LearningTaskController;
// Mojo service that talks to a local LearningTaskController.
class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskControllerService
: public mojom::LearningTaskController {
public:
// |impl| is the underlying controller that we'll send requests to.
explicit MojoLearningTaskControllerService(
const LearningTask& task,
std::unique_ptr<::media::learning::LearningTaskController> impl);
~MojoLearningTaskControllerService() override;
// mojom::LearningTaskController
void BeginObservation(const base::UnguessableToken& id,
const FeatureVector& features) override;
void CompleteObservation(const base::UnguessableToken& id,
const ObservationCompletion& completion) override;
void CancelObservation(const base::UnguessableToken& id) override;
protected:
const LearningTask task_;
// Underlying controller to which we proxy calls.
std::unique_ptr<::media::learning::LearningTaskController> impl_;
std::set<base::UnguessableToken> in_flight_observations_;
DISALLOW_COPY_AND_ASSIGN(MojoLearningTaskControllerService);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_MOJO_MOJO_LEARNING_TASK_CONTROLLER_SERVICE_H_
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <utility>
#include "base/bind.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "base/test/scoped_task_environment.h"
#include "base/threading/thread.h"
#include "media/learning/mojo/mojo_learning_task_controller_service.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class MojoLearningTaskControllerServiceTest : public ::testing::Test {
public:
class FakeLearningTaskController : public LearningTaskController {
public:
void BeginObservation(base::UnguessableToken id,
const FeatureVector& features) override {
begin_args_.id_ = id;
begin_args_.features_ = features;
}
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override {
complete_args_.id_ = id;
complete_args_.completion_ = completion;
}
void CancelObservation(base::UnguessableToken id) override {
cancel_args_.id_ = id;
}
struct {
base::UnguessableToken id_;
FeatureVector features_;
} begin_args_;
struct {
base::UnguessableToken id_;
ObservationCompletion completion_;
} complete_args_;
struct {
base::UnguessableToken id_;
} cancel_args_;
};
public:
MojoLearningTaskControllerServiceTest() = default;
~MojoLearningTaskControllerServiceTest() override = default;
void SetUp() override {
std::unique_ptr<FakeLearningTaskController> controller =
std::make_unique<FakeLearningTaskController>();
controller_raw_ = controller.get();
// Add two features.
task_.feature_descriptions.push_back({});
task_.feature_descriptions.push_back({});
// Tell |learning_controller_| to forward to the fake learner impl.
service_ = std::make_unique<MojoLearningTaskControllerService>(
task_, std::move(controller));
}
LearningTask task_;
// Mojo stuff.
base::test::ScopedTaskEnvironment scoped_task_environment_;
FakeLearningTaskController* controller_raw_ = nullptr;
// The learner under test.
std::unique_ptr<MojoLearningTaskControllerService> service_;
};
TEST_F(MojoLearningTaskControllerServiceTest, BeginComplete) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
service_->BeginObservation(id, features);
EXPECT_EQ(id, controller_raw_->begin_args_.id_);
EXPECT_EQ(features, controller_raw_->begin_args_.features_);
ObservationCompletion completion(TargetValue(1234));
service_->CompleteObservation(id, completion);
EXPECT_EQ(id, controller_raw_->complete_args_.id_);
EXPECT_EQ(completion.target_value,
controller_raw_->complete_args_.completion_.target_value);
}
TEST_F(MojoLearningTaskControllerServiceTest, BeginCancel) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
service_->BeginObservation(id, features);
EXPECT_EQ(id, controller_raw_->begin_args_.id_);
EXPECT_EQ(features, controller_raw_->begin_args_.features_);
service_->CancelObservation(id);
EXPECT_EQ(id, controller_raw_->cancel_args_.id_);
}
TEST_F(MojoLearningTaskControllerServiceTest, TooFewFeaturesIsIgnored) {
// A FeatureVector with too few elements should be ignored.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector short_features = {FeatureValue(123)};
service_->BeginObservation(id, short_features);
EXPECT_NE(id, controller_raw_->begin_args_.id_);
EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u);
}
TEST_F(MojoLearningTaskControllerServiceTest, TooManyFeaturesIsIgnored) {
// A FeatureVector with too many elements should be ignored.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector long_features = {FeatureValue(123), FeatureValue(456),
FeatureValue(789)};
service_->BeginObservation(id, long_features);
EXPECT_NE(id, controller_raw_->begin_args_.id_);
EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u);
}
TEST_F(MojoLearningTaskControllerServiceTest, CompleteWithoutBeginFails) {
base::UnguessableToken id = base::UnguessableToken::Create();
ObservationCompletion completion(TargetValue(1234));
service_->CompleteObservation(id, completion);
EXPECT_NE(id, controller_raw_->complete_args_.id_);
}
TEST_F(MojoLearningTaskControllerServiceTest, CancelWithoutBeginFails) {
base::UnguessableToken id = base::UnguessableToken::Create();
service_->CancelObservation(id);
EXPECT_NE(id, controller_raw_->cancel_args_.id_);
}
} // namespace learning
} // namespace media
...@@ -8,6 +8,11 @@ source_set("cpp") { ...@@ -8,6 +8,11 @@ source_set("cpp") {
"//media/learning/mojo/public/cpp:unit_tests", "//media/learning/mojo/public/cpp:unit_tests",
] ]
sources = [
"mojo_learning_task_controller.cc",
"mojo_learning_task_controller.h",
]
defines = [ "IS_MEDIA_LEARNING_MOJO_IMPL" ] defines = [ "IS_MEDIA_LEARNING_MOJO_IMPL" ]
deps = [ deps = [
...@@ -20,6 +25,10 @@ source_set("cpp") { ...@@ -20,6 +25,10 @@ source_set("cpp") {
source_set("unit_tests") { source_set("unit_tests") {
testonly = true testonly = true
sources = [
"mojo_learning_task_controller_unittest.cc",
]
deps = [ deps = [
"//base", "//base",
"//base/test:test_support", "//base/test:test_support",
......
...@@ -39,4 +39,14 @@ bool StructTraits<media::learning::mojom::TargetValueDataView, ...@@ -39,4 +39,14 @@ bool StructTraits<media::learning::mojom::TargetValueDataView,
return true; return true;
} }
// static
bool StructTraits<media::learning::mojom::ObservationCompletionDataView,
media::learning::ObservationCompletion>::
Read(media::learning::mojom::ObservationCompletionDataView data,
media::learning::ObservationCompletion* out_observation_completion) {
if (!data.ReadTargetValue(&out_observation_completion->target_value))
return false;
out_observation_completion->weight = data.weight();
return true;
}
} // namespace mojo } // namespace mojo
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <vector> #include <vector>
#include "media/learning/common/learning_task_controller.h"
#include "media/learning/common/value.h" #include "media/learning/common/value.h"
#include "media/learning/mojo/public/mojom/learning_types.mojom.h" #include "media/learning/mojo/public/mojom/learning_types.mojom.h"
#include "mojo/public/cpp/bindings/struct_traits.h" #include "mojo/public/cpp/bindings/struct_traits.h"
...@@ -52,6 +53,23 @@ class StructTraits<media::learning::mojom::TargetValueDataView, ...@@ -52,6 +53,23 @@ class StructTraits<media::learning::mojom::TargetValueDataView,
media::learning::TargetValue* out_target_value); media::learning::TargetValue* out_target_value);
}; };
template <>
class StructTraits<media::learning::mojom::ObservationCompletionDataView,
media::learning::ObservationCompletion> {
public:
static media::learning::TargetValue target_value(
const media::learning::ObservationCompletion& e) {
return e.target_value;
}
static media::learning::WeightType weight(
const media::learning::ObservationCompletion& e) {
return e.weight;
}
static bool Read(
media::learning::mojom::ObservationCompletionDataView data,
media::learning::ObservationCompletion* out_observation_completion);
};
} // namespace mojo } // namespace mojo
#endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_LEARNING_MOJOM_TRAITS_H_ #endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_LEARNING_MOJOM_TRAITS_H_
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
#include <utility>
#include "mojo/public/cpp/bindings/binding.h"
namespace media {
namespace learning {
MojoLearningTaskController::MojoLearningTaskController(
mojom::LearningTaskControllerPtr controller_ptr)
: controller_ptr_(std::move(controller_ptr)) {}
MojoLearningTaskController::~MojoLearningTaskController() = default;
void MojoLearningTaskController::BeginObservation(
base::UnguessableToken id,
const FeatureVector& features) {
// We don't need to keep track of in-flight observations, since the service
// side handles it for us.
controller_ptr_->BeginObservation(id, features);
}
void MojoLearningTaskController::CompleteObservation(
base::UnguessableToken id,
const ObservationCompletion& completion) {
controller_ptr_->CompleteObservation(id, completion);
}
void MojoLearningTaskController::CancelObservation(base::UnguessableToken id) {
controller_ptr_->CancelObservation(id);
}
} // namespace learning
} // namespace media
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_TASK_CONTROLLER_H_
#define MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_TASK_CONTROLLER_H_
#include <utility>
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/common/learning_task_controller.h"
#include "media/learning/mojo/public/mojom/learning_task_controller.mojom.h"
namespace media {
namespace learning {
// LearningTaskController implementation to forward to a remote impl via mojo.
class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController
: public LearningTaskController {
public:
explicit MojoLearningTaskController(
mojom::LearningTaskControllerPtr controller_ptr);
~MojoLearningTaskController() override;
// LearningTaskController
void BeginObservation(base::UnguessableToken id,
const FeatureVector& features) override;
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override;
void CancelObservation(base::UnguessableToken id) override;
private:
mojom::LearningTaskControllerPtr controller_ptr_;
DISALLOW_COPY_AND_ASSIGN(MojoLearningTaskController);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_TASK_CONTROLLER_H_
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <utility>
#include "base/bind.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "base/test/scoped_task_environment.h"
#include "base/threading/thread.h"
#include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
#include "mojo/public/cpp/bindings/binding.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class MojoLearningTaskControllerTest : public ::testing::Test {
public:
// Impl of a mojom::LearningTaskController that remembers call arguments.
class FakeMojoLearningTaskController : public mojom::LearningTaskController {
public:
void BeginObservation(const base::UnguessableToken& id,
const FeatureVector& features) override {
begin_args_.id_ = id;
begin_args_.features_ = features;
}
void CompleteObservation(const base::UnguessableToken& id,
const ObservationCompletion& completion) override {
complete_args_.id_ = id;
complete_args_.completion_ = completion;
}
void CancelObservation(const base::UnguessableToken& id) override {
cancel_args_.id_ = id;
}
struct {
base::UnguessableToken id_;
FeatureVector features_;
} begin_args_;
struct {
base::UnguessableToken id_;
ObservationCompletion completion_;
} complete_args_;
struct {
base::UnguessableToken id_;
} cancel_args_;
};
public:
MojoLearningTaskControllerTest()
: learning_controller_binding_(&fake_learning_controller_) {}
~MojoLearningTaskControllerTest() override = default;
void SetUp() override {
// Create a fake learner provider mojo impl.
mojom::LearningTaskControllerPtr learning_controller_ptr;
learning_controller_binding_.Bind(
mojo::MakeRequest(&learning_controller_ptr));
// Tell |learning_controller_| to forward to the fake learner impl.
learning_controller_ = std::make_unique<MojoLearningTaskController>(
std::move(learning_controller_ptr));
}
// Mojo stuff.
base::test::ScopedTaskEnvironment scoped_task_environment_;
FakeMojoLearningTaskController fake_learning_controller_;
mojo::Binding<mojom::LearningTaskController> learning_controller_binding_;
// The learner under test.
std::unique_ptr<MojoLearningTaskController> learning_controller_;
};
TEST_F(MojoLearningTaskControllerTest, Begin) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
learning_controller_->BeginObservation(id, features);
scoped_task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
}
TEST_F(MojoLearningTaskControllerTest, Complete) {
base::UnguessableToken id = base::UnguessableToken::Create();
ObservationCompletion completion(TargetValue(1234));
learning_controller_->CompleteObservation(id, completion);
scoped_task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.complete_args_.id_);
EXPECT_EQ(completion.target_value,
fake_learning_controller_.complete_args_.completion_.target_value);
}
TEST_F(MojoLearningTaskControllerTest, Cancel) {
base::UnguessableToken id = base::UnguessableToken::Create();
learning_controller_->CancelObservation(id);
scoped_task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_);
}
} // namespace learning
} // namespace media
...@@ -7,7 +7,7 @@ import("//mojo/public/tools/bindings/mojom.gni") ...@@ -7,7 +7,7 @@ import("//mojo/public/tools/bindings/mojom.gni")
mojom("mojom") { mojom("mojom") {
sources = [ sources = [
"learning_session.mojom", "learning_task_controller.mojom",
"learning_types.mojom", "learning_types.mojom",
] ]
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
module media.learning.mojom;
import "media/learning/mojo/public/mojom/learning_types.mojom";
// Learning tasks, to prevent sending the task name string in AcquireLearner.
enum LearningTaskType {
// There are no tasks yet.
kPlaceHolderTask,
};
// media/learning/public/learning_session.h
interface LearningSession {
// Add |example| to |task_type|.
AddExample(LearningTaskType task_type, LabelledExample example);
};
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
module media.learning.mojom;
import "mojo/public/mojom/base/unguessable_token.mojom";
import "media/learning/mojo/public/mojom/learning_types.mojom";
// Client for a single learning task. Intended to be the primary API for client
// code that generates FeatureVectors / requests predictions for a single task.
// The API supports sending in an observed FeatureVector without a target value,
// so that framework-provided features (FeatureProvider) can be snapshotted at
// the right time. One doesn't generally want to wait until the TargetValue is
// observed to do that.
//
// Typically, this interface will allow non-browser processes to communicate
// with the learning framework in the browser.
interface LearningTaskController {
// Start a new observation. Call this at the time one would try to predict
// the TargetValue. This lets the framework snapshot any framework-provided
// feature values at prediction time. Later, if you want to turn these
// features into an example for training a model, then call
// CompleteObservation with the same id and an ObservationCompletion.
// Otherwise, call CancelObservation with |id|. It's also okay to destroy the
// controller with outstanding observations; these will be cancelled.
BeginObservation(mojo_base.mojom.UnguessableToken id,
array<FeatureValue> features);
// Complete observation |id| by providing |completion|.
CompleteObservation(mojo_base.mojom.UnguessableToken id,
ObservationCompletion completion);
// Cancel observation |id|. Deleting |this| will do the same.
CancelObservation(mojo_base.mojom.UnguessableToken id);
};
...@@ -19,3 +19,9 @@ struct LabelledExample { ...@@ -19,3 +19,9 @@ struct LabelledExample {
array<FeatureValue> features; array<FeatureValue> features;
TargetValue target_value; TargetValue target_value;
}; };
// learning::ObservationCompletion (common/learning_task_controller.h)
struct ObservationCompletion {
TargetValue target_value;
uint64 weight = 1;
};
mojom = "//media/learning/mojo/public/mojom/learning_types.mojom" mojom = "//media/learning/mojo/public/mojom/learning_types.mojom"
public_headers = [ public_headers = [
"//media/learning/common/labelled_example.h", "//media/learning/common/labelled_example.h",
"//media/learning/common/learning_task_controller.h",
"//media/learning/common/value.h", "//media/learning/common/value.h",
] ]
traits_headers = [ "//media/learning/mojo/public/cpp/learning_mojom_traits.h" ] traits_headers = [ "//media/learning/mojo/public/cpp/learning_mojom_traits.h" ]
...@@ -15,4 +16,5 @@ type_mappings = [ ...@@ -15,4 +16,5 @@ type_mappings = [
"media.learning.mojom.LabelledExample=media::learning::LabelledExample", "media.learning.mojom.LabelledExample=media::learning::LabelledExample",
"media.learning.mojom.FeatureValue=media::learning::FeatureValue", "media.learning.mojom.FeatureValue=media::learning::FeatureValue",
"media.learning.mojom.TargetValue=media::learning::TargetValue", "media.learning.mojom.TargetValue=media::learning::TargetValue",
"media.learning.mojom.ObservationCompletion=media::learning::ObservationCompletion",
] ]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment