Commit 242ea367 authored by Chris Cunningham's avatar Chris Cunningham Committed by Commit Bot

Add PredictDistribution() to LearningTaskController

Plumbs through various implementations including mojo.
Adds typemaps for the new mojo types.

Necessary for making predections in the renderer. A follow up CL will
call the API to make smoothness predictions for MediaCapabilities.

Change-Id: Ib6e46b2ac7f375bd343e1b9cda389ecf82eb880f
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2009843
Commit-Queue: Chrome Cunningham <chcunningham@chromium.org>
Auto-Submit: Chrome Cunningham <chcunningham@chromium.org>
Reviewed-by: default avatarRobert Sesek <rsesek@chromium.org>
Reviewed-by: default avatarKen Rockot <rockot@google.com>
Reviewed-by: default avatarFrank Liberato <liberato@chromium.org>
Cr-Commit-Position: refs/heads/master@{#734184}
parent a7a4dd07
...@@ -37,8 +37,10 @@ class MockLearningTaskController : public LearningTaskController { ...@@ -37,8 +37,10 @@ class MockLearningTaskController : public LearningTaskController {
MOCK_METHOD2(UpdateDefaultTarget, MOCK_METHOD2(UpdateDefaultTarget,
void(base::UnguessableToken id, void(base::UnguessableToken id,
const base::Optional<TargetValue>& default_target)); const base::Optional<TargetValue>& default_target));
MOCK_METHOD2(PredictDistribution,
void(const FeatureVector& features, PredictionCB callback));
const LearningTask& GetLearningTask() { return task_; } const LearningTask& GetLearningTask() override { return task_; }
private: private:
LearningTask task_; LearningTask task_;
......
...@@ -61,6 +61,8 @@ class SmoothnessHelperTest : public testing::Test { ...@@ -61,6 +61,8 @@ class SmoothnessHelperTest : public testing::Test {
const base::Optional<TargetValue>& default_target)); const base::Optional<TargetValue>& default_target));
MOCK_METHOD0(GetLearningTask, const LearningTask&()); MOCK_METHOD0(GetLearningTask, const LearningTask&());
MOCK_METHOD2(PredictDistribution,
void(const FeatureVector& features, PredictionCB callback));
}; };
class MockClient : public SmoothnessHelper::Client { class MockClient : public SmoothnessHelper::Client {
......
...@@ -37,6 +37,8 @@ component("common") { ...@@ -37,6 +37,8 @@ component("common") {
"learning_task_controller.h", "learning_task_controller.h",
"media_learning_tasks.cc", "media_learning_tasks.cc",
"media_learning_tasks.h", "media_learning_tasks.h",
"target_histogram.cc",
"target_histogram.h",
"value.cc", "value.cc",
"value.h", "value.h",
] ]
...@@ -53,6 +55,7 @@ source_set("unit_tests") { ...@@ -53,6 +55,7 @@ source_set("unit_tests") {
"feature_dictionary_unittest.cc", "feature_dictionary_unittest.cc",
"labelled_example_unittest.cc", "labelled_example_unittest.cc",
"media_learning_tasks_unittest.cc", "media_learning_tasks_unittest.cc",
"target_histogram_unittest.cc",
"value_unittest.cc", "value_unittest.cc",
] ]
......
include_rules = [ include_rules = [
"+services/metrics", "+services/metrics",
# Needed for typemap to befriend TargetHistogram.
"+mojo/public",
] ]
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "base/unguessable_token.h" #include "base/unguessable_token.h"
#include "media/learning/common/labelled_example.h" #include "media/learning/common/labelled_example.h"
#include "media/learning/common/learning_task.h" #include "media/learning/common/learning_task.h"
#include "media/learning/common/target_histogram.h"
#include "services/metrics/public/cpp/ukm_source_id.h" #include "services/metrics/public/cpp/ukm_source_id.h"
namespace media { namespace media {
...@@ -49,6 +50,9 @@ struct ObservationCompletion { ...@@ -49,6 +50,9 @@ struct ObservationCompletion {
// observed to do that. // observed to do that.
class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController { class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
public: public:
using PredictionCB = base::OnceCallback<void(
const base::Optional<TargetHistogram>& predicted)>;
LearningTaskController() = default; LearningTaskController() = default;
virtual ~LearningTaskController() = default; virtual ~LearningTaskController() = default;
...@@ -91,6 +95,12 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController { ...@@ -91,6 +95,12 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
// Returns the LearningTask associated with |this|. // Returns the LearningTask associated with |this|.
virtual const LearningTask& GetLearningTask() = 0; virtual const LearningTask& GetLearningTask() = 0;
// Asynchronously predicts distribution for given |features|. |callback| will
// receive a base::nullopt prediction when model is not available. |callback|
// may be called immediately without posting.
virtual void PredictDistribution(const FeatureVector& features,
PredictionCB callback) = 0;
private: private:
DISALLOW_COPY_AND_ASSIGN(LearningTaskController); DISALLOW_COPY_AND_ASSIGN(LearningTaskController);
}; };
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#include "media/learning/impl/target_histogram.h" #include "media/learning/common/target_histogram.h"
#include <sstream> #include <sstream>
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_ #ifndef MEDIA_LEARNING_COMMON_TARGET_HISTOGRAM_H_
#define MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_ #define MEDIA_LEARNING_COMMON_TARGET_HISTOGRAM_H_
#include <ostream> #include <ostream>
#include <string> #include <string>
...@@ -14,17 +14,22 @@ ...@@ -14,17 +14,22 @@
#include "media/learning/common/labelled_example.h" #include "media/learning/common/labelled_example.h"
#include "media/learning/common/value.h" #include "media/learning/common/value.h"
#include "mojo/public/cpp/bindings/struct_traits.h" // nogncheck
namespace media { namespace media {
namespace learning { namespace learning {
namespace mojom {
class TargetHistogramDataView;
}
// Histogram of target values that allows fractional counts. // Histogram of target values that allows fractional counts.
class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram { class COMPONENT_EXPORT(LEARNING_COMMON) TargetHistogram {
private: public:
// We use a flat_map since this will often have only one or two TargetValues, // We use a flat_map since this will often have only one or two TargetValues,
// such as "true" or "false". // such as "true" or "false".
using CountMap = base::flat_map<TargetValue, double>; using CountMap = base::flat_map<TargetValue, double>;
public:
TargetHistogram(); TargetHistogram();
TargetHistogram(const TargetHistogram& rhs); TargetHistogram(const TargetHistogram& rhs);
TargetHistogram(TargetHistogram&& rhs); TargetHistogram(TargetHistogram&& rhs);
...@@ -81,6 +86,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram { ...@@ -81,6 +86,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram {
std::string ToString() const; std::string ToString() const;
private: private:
friend struct mojo::StructTraits<
media::learning::mojom::TargetHistogramDataView,
media::learning::TargetHistogram>;
const CountMap& counts() const { return counts_; } const CountMap& counts() const { return counts_; }
// [value] == counts // [value] == counts
...@@ -89,10 +98,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram { ...@@ -89,10 +98,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram {
// Allow copy and assign. // Allow copy and assign.
}; };
COMPONENT_EXPORT(LEARNING_IMPL) COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const TargetHistogram& dist); std::ostream& operator<<(std::ostream& out, const TargetHistogram& dist);
} // namespace learning } // namespace learning
} // namespace media } // namespace media
#endif // MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_ #endif // MEDIA_LEARNING_COMMON_TARGET_HISTOGRAM_H_
// Copyright 2018 The Chromium Authors. All rights reserved. // Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#include "media/learning/impl/target_histogram.h" #include "media/learning/common/target_histogram.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
......
...@@ -38,8 +38,6 @@ component("impl") { ...@@ -38,8 +38,6 @@ component("impl") {
"random_number_generator.h", "random_number_generator.h",
"random_tree_trainer.cc", "random_tree_trainer.cc",
"random_tree_trainer.h", "random_tree_trainer.h",
"target_histogram.cc",
"target_histogram.h",
"training_algorithm.h", "training_algorithm.h",
"voting_ensemble.cc", "voting_ensemble.cc",
"voting_ensemble.h", "voting_ensemble.h",
...@@ -71,7 +69,6 @@ source_set("unit_tests") { ...@@ -71,7 +69,6 @@ source_set("unit_tests") {
"one_hot_unittest.cc", "one_hot_unittest.cc",
"random_number_generator_unittest.cc", "random_number_generator_unittest.cc",
"random_tree_trainer_unittest.cc", "random_tree_trainer_unittest.cc",
"target_histogram_unittest.cc",
"test_random_number_generator.cc", "test_random_number_generator.cc",
"test_random_number_generator.h", "test_random_number_generator.h",
] ]
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/optional.h" #include "base/optional.h"
#include "media/learning/common/learning_task.h" #include "media/learning/common/learning_task.h"
#include "media/learning/common/target_histogram.h"
#include "media/learning/impl/model.h" #include "media/learning/impl/model.h"
#include "media/learning/impl/target_histogram.h"
#include "services/metrics/public/cpp/ukm_source_id.h" #include "services/metrics/public/cpp/ukm_source_id.h"
namespace media { namespace media {
......
...@@ -90,6 +90,12 @@ class WeakLearningTaskController : public LearningTaskController { ...@@ -90,6 +90,12 @@ class WeakLearningTaskController : public LearningTaskController {
const LearningTask& GetLearningTask() override { return task_; } const LearningTask& GetLearningTask() override { return task_; }
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override {
controller_->Post(FROM_HERE, &LearningTaskController::PredictDistribution,
features, std::move(callback));
}
base::WeakPtr<LearningSessionImpl> weak_session_; base::WeakPtr<LearningSessionImpl> weak_session_;
base::SequenceBound<LearningTaskController>* controller_; base::SequenceBound<LearningTaskController>* controller_;
LearningTask task_; LearningTask task_;
......
...@@ -45,14 +45,14 @@ class LearningSessionImplTest : public testing::Test { ...@@ -45,14 +45,14 @@ class LearningSessionImplTest : public testing::Test {
const FeatureVector& features, const FeatureVector& features,
const base::Optional<TargetValue>& default_target) override { const base::Optional<TargetValue>& default_target) override {
id_ = id; id_ = id;
features_ = features; observation_features_ = features;
default_target_ = default_target; default_target_ = default_target;
} }
void CompleteObservation(base::UnguessableToken id, void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override { const ObservationCompletion& completion) override {
EXPECT_EQ(id_, id); EXPECT_EQ(id_, id);
example_.features = std::move(features_); example_.features = std::move(observation_features_);
example_.target_value = completion.target_value; example_.target_value = completion.target_value;
example_.weight = completion.weight; example_.weight = completion.weight;
} }
...@@ -74,9 +74,17 @@ class LearningSessionImplTest : public testing::Test { ...@@ -74,9 +74,17 @@ class LearningSessionImplTest : public testing::Test {
return LearningTask::Empty(); return LearningTask::Empty();
} }
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override {
predict_features_ = features;
predict_cb_ = std::move(callback);
}
SequenceBoundFeatureProvider feature_provider_; SequenceBoundFeatureProvider feature_provider_;
base::UnguessableToken id_; base::UnguessableToken id_;
FeatureVector features_; FeatureVector observation_features_;
FeatureVector predict_features_;
PredictionCB predict_cb_;
base::Optional<TargetValue> default_target_; base::Optional<TargetValue> default_target_;
LabelledExample example_; LabelledExample example_;
...@@ -317,5 +325,34 @@ TEST_F(LearningSessionImplTest, ChangeDefaultTargetToNoValue) { ...@@ -317,5 +325,34 @@ TEST_F(LearningSessionImplTest, ChangeDefaultTargetToNoValue) {
EXPECT_FALSE(task_controllers_[0]->updated_id_); EXPECT_FALSE(task_controllers_[0]->updated_id_);
} }
TEST_F(LearningSessionImplTest, PredictDistribution) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetHistogram observed_prediction;
controller->PredictDistribution(
features, base::BindOnce(
[](TargetHistogram* test_storage,
const base::Optional<TargetHistogram>& predicted) {
*test_storage = *predicted;
},
&observed_prediction));
task_environment_.RunUntilIdle();
EXPECT_EQ(features, task_controllers_[0]->predict_features_);
EXPECT_FALSE(task_controllers_[0]->predict_cb_.is_null());
TargetHistogram expected_prediction;
expected_prediction[TargetValue(1)] = 1.0;
expected_prediction[TargetValue(2)] = 2.0;
expected_prediction[TargetValue(3)] = 3.0;
std::move(task_controllers_[0]->predict_cb_).Run(expected_prediction);
task_environment_.RunUntilIdle();
EXPECT_EQ(expected_prediction, observed_prediction);
}
} // namespace learning } // namespace learning
} // namespace media } // namespace media
...@@ -92,6 +92,15 @@ const LearningTask& LearningTaskControllerImpl::GetLearningTask() { ...@@ -92,6 +92,15 @@ const LearningTask& LearningTaskControllerImpl::GetLearningTask() {
return task_; return task_;
} }
void LearningTaskControllerImpl::PredictDistribution(
const FeatureVector& features,
PredictionCB callback) {
if (model_)
std::move(callback).Run(model_->PredictDistribution(features));
else
std::move(callback).Run(base::nullopt);
}
void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example, void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example,
ukm::SourceId source_id) { ukm::SourceId source_id) {
// Verify that we have a trainer and that we got the right number of features. // Verify that we have a trainer and that we got the right number of features.
......
...@@ -62,6 +62,8 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl ...@@ -62,6 +62,8 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
base::UnguessableToken id, base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override; const base::Optional<TargetValue>& default_target) override;
const LearningTask& GetLearningTask() override; const LearningTask& GetLearningTask() override;
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override;
private: private:
// Add |example| to the training data, and process it. // Add |example| to the training data, and process it.
......
...@@ -137,6 +137,20 @@ class LearningTaskControllerImplTest : public testing::Test { ...@@ -137,6 +137,20 @@ class LearningTaskControllerImplTest : public testing::Test {
id, ObservationCompletion(example.target_value, example.weight)); id, ObservationCompletion(example.target_value, example.weight));
} }
void VerifyPrediction(const FeatureVector& features,
base::Optional<TargetHistogram> expectation) {
base::Optional<TargetHistogram> observed_prediction;
controller_->PredictDistribution(
features, base::BindOnce(
[](base::Optional<TargetHistogram>* test_storage,
const base::Optional<TargetHistogram>& predicted) {
*test_storage = predicted;
},
&observed_prediction));
task_environment_.RunUntilIdle();
EXPECT_EQ(observed_prediction, expectation);
}
base::test::TaskEnvironment task_environment_; base::test::TaskEnvironment task_environment_;
// Number of models that we trained. // Number of models that we trained.
...@@ -258,5 +272,18 @@ TEST_F(LearningTaskControllerImplTest, FeatureSubsetsWork) { ...@@ -258,5 +272,18 @@ TEST_F(LearningTaskControllerImplTest, FeatureSubsetsWork) {
EXPECT_EQ(trainer_raw_->training_data()[0].features, expected_features); EXPECT_EQ(trainer_raw_->training_data()[0].features, expected_features);
} }
TEST_F(LearningTaskControllerImplTest, PredictDistribution) {
CreateController();
// Predictions should be base::nullopt until we have a model.
LabelledExample example;
VerifyPrediction(example.features, base::nullopt);
AddExample(example);
TargetHistogram expected_histogram;
expected_histogram += predicted_target_;
VerifyPrediction(example.features, expected_histogram);
}
} // namespace learning } // namespace learning
} // namespace media } // namespace media
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "base/component_export.h" #include "base/component_export.h"
#include "media/learning/common/labelled_example.h" #include "media/learning/common/labelled_example.h"
#include "media/learning/common/target_histogram.h"
#include "media/learning/impl/model.h" #include "media/learning/impl/model.h"
#include "media/learning/impl/target_histogram.h"
namespace media { namespace media {
namespace learning { namespace learning {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <utility> #include <utility>
#include "base/bind.h"
#include "media/learning/common/learning_task_controller.h" #include "media/learning/common/learning_task_controller.h"
namespace media { namespace media {
...@@ -72,5 +73,11 @@ void MojoLearningTaskControllerService::UpdateDefaultTarget( ...@@ -72,5 +73,11 @@ void MojoLearningTaskControllerService::UpdateDefaultTarget(
impl_->UpdateDefaultTarget(id, default_target); impl_->UpdateDefaultTarget(id, default_target);
} }
void MojoLearningTaskControllerService::PredictDistribution(
const FeatureVector& features,
PredictDistributionCallback callback) {
impl_->PredictDistribution(features, std::move(callback));
}
} // namespace learning } // namespace learning
} // namespace media } // namespace media
...@@ -38,6 +38,8 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskControllerService ...@@ -38,6 +38,8 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskControllerService
void UpdateDefaultTarget( void UpdateDefaultTarget(
const base::UnguessableToken& id, const base::UnguessableToken& id,
const base::Optional<TargetValue>& default_target) override; const base::Optional<TargetValue>& default_target) override;
void PredictDistribution(const FeatureVector& features,
PredictDistributionCallback callback) override;
protected: protected:
const LearningTask task_; const LearningTask task_;
......
...@@ -50,6 +50,12 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test { ...@@ -50,6 +50,12 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
return LearningTask::Empty(); return LearningTask::Empty();
} }
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override {
predict_distribution_args_.features_ = features;
predict_distribution_args_.callback_ = std::move(callback);
}
struct { struct {
base::UnguessableToken id_; base::UnguessableToken id_;
FeatureVector features_; FeatureVector features_;
...@@ -69,6 +75,11 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test { ...@@ -69,6 +75,11 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
base::UnguessableToken id_; base::UnguessableToken id_;
base::Optional<TargetValue> default_target_; base::Optional<TargetValue> default_target_;
} update_default_args_; } update_default_args_;
struct {
FeatureVector features_;
PredictionCB callback_;
} predict_distribution_args_;
}; };
public: public:
...@@ -193,5 +204,27 @@ TEST_F(MojoLearningTaskControllerServiceTest, UpdateDefaultTargetToNoValue) { ...@@ -193,5 +204,27 @@ TEST_F(MojoLearningTaskControllerServiceTest, UpdateDefaultTargetToNoValue) {
controller_raw_->update_default_args_.default_target_); controller_raw_->update_default_args_.default_target_);
} }
TEST_F(MojoLearningTaskControllerServiceTest, PredictDistribution) {
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetHistogram observed_prediction;
service_->PredictDistribution(
features, base::BindOnce(
[](TargetHistogram* test_storage,
const base::Optional<TargetHistogram>& predicted) {
*test_storage = *predicted;
},
&observed_prediction));
EXPECT_EQ(features, controller_raw_->predict_distribution_args_.features_);
EXPECT_FALSE(controller_raw_->predict_distribution_args_.callback_.is_null());
TargetHistogram expected_prediction;
expected_prediction[TargetValue(1)] = 1.0;
expected_prediction[TargetValue(2)] = 2.0;
expected_prediction[TargetValue(3)] = 3.0;
std::move(controller_raw_->predict_distribution_args_.callback_)
.Run(expected_prediction);
EXPECT_EQ(expected_prediction, observed_prediction);
}
} // namespace learning } // namespace learning
} // namespace media } // namespace media
...@@ -49,4 +49,16 @@ bool StructTraits<media::learning::mojom::ObservationCompletionDataView, ...@@ -49,4 +49,16 @@ bool StructTraits<media::learning::mojom::ObservationCompletionDataView,
out_observation_completion->weight = data.weight(); out_observation_completion->weight = data.weight();
return true; return true;
} }
// static
bool StructTraits<media::learning::mojom::TargetHistogramDataView,
media::learning::TargetHistogram>::
Read(media::learning::mojom::TargetHistogramDataView data,
media::learning::TargetHistogram* out_target_histogram) {
if (!data.ReadCounts(&out_target_histogram->counts_))
return false;
return true;
}
} // namespace mojo } // namespace mojo
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
namespace mojo { namespace mojo {
template <> template <>
class StructTraits<media::learning::mojom::LabelledExampleDataView, struct StructTraits<media::learning::mojom::LabelledExampleDataView,
media::learning::LabelledExample> { media::learning::LabelledExample> {
public:
static const std::vector<media::learning::FeatureValue>& features( static const std::vector<media::learning::FeatureValue>& features(
const media::learning::LabelledExample& e) { const media::learning::LabelledExample& e) {
return e.features; return e.features;
...@@ -32,9 +31,8 @@ class StructTraits<media::learning::mojom::LabelledExampleDataView, ...@@ -32,9 +31,8 @@ class StructTraits<media::learning::mojom::LabelledExampleDataView,
}; };
template <> template <>
class StructTraits<media::learning::mojom::FeatureValueDataView, struct StructTraits<media::learning::mojom::FeatureValueDataView,
media::learning::FeatureValue> { media::learning::FeatureValue> {
public:
static int64_t value(const media::learning::FeatureValue& e) { static int64_t value(const media::learning::FeatureValue& e) {
return e.value(); return e.value();
} }
...@@ -43,9 +41,8 @@ class StructTraits<media::learning::mojom::FeatureValueDataView, ...@@ -43,9 +41,8 @@ class StructTraits<media::learning::mojom::FeatureValueDataView,
}; };
template <> template <>
class StructTraits<media::learning::mojom::TargetValueDataView, struct StructTraits<media::learning::mojom::TargetValueDataView,
media::learning::TargetValue> { media::learning::TargetValue> {
public:
static int64_t value(const media::learning::TargetValue& e) { static int64_t value(const media::learning::TargetValue& e) {
return e.value(); return e.value();
} }
...@@ -54,9 +51,8 @@ class StructTraits<media::learning::mojom::TargetValueDataView, ...@@ -54,9 +51,8 @@ class StructTraits<media::learning::mojom::TargetValueDataView,
}; };
template <> template <>
class StructTraits<media::learning::mojom::ObservationCompletionDataView, struct StructTraits<media::learning::mojom::ObservationCompletionDataView,
media::learning::ObservationCompletion> { media::learning::ObservationCompletion> {
public:
static media::learning::TargetValue target_value( static media::learning::TargetValue target_value(
const media::learning::ObservationCompletion& e) { const media::learning::ObservationCompletion& e) {
return e.target_value; return e.target_value;
...@@ -70,6 +66,18 @@ class StructTraits<media::learning::mojom::ObservationCompletionDataView, ...@@ -70,6 +66,18 @@ class StructTraits<media::learning::mojom::ObservationCompletionDataView,
media::learning::ObservationCompletion* out_observation_completion); media::learning::ObservationCompletion* out_observation_completion);
}; };
template <>
struct StructTraits<media::learning::mojom::TargetHistogramDataView,
media::learning::TargetHistogram> {
static media::learning::TargetHistogram::CountMap counts(
const media::learning::TargetHistogram& e) {
return e.counts();
}
static bool Read(media::learning::mojom::TargetHistogramDataView data,
media::learning::TargetHistogram* out_target_histogram);
};
} // namespace mojo } // namespace mojo
#endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_LEARNING_MOJOM_TRAITS_H_ #endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_LEARNING_MOJOM_TRAITS_H_
...@@ -45,5 +45,11 @@ const LearningTask& MojoLearningTaskController::GetLearningTask() { ...@@ -45,5 +45,11 @@ const LearningTask& MojoLearningTaskController::GetLearningTask() {
return task_; return task_;
} }
void MojoLearningTaskController::PredictDistribution(
const FeatureVector& features,
PredictionCB callback) {
controller_->PredictDistribution(features, std::move(callback));
}
} // namespace learning } // namespace learning
} // namespace media } // namespace media
...@@ -40,6 +40,8 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController ...@@ -40,6 +40,8 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController
base::UnguessableToken id, base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override; const base::Optional<TargetValue>& default_target) override;
const LearningTask& GetLearningTask() override; const LearningTask& GetLearningTask() override;
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override;
private: private:
LearningTask task_; LearningTask task_;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <utility> #include <utility>
#include "base/bind.h" #include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/ptr_util.h" #include "base/memory/ptr_util.h"
#include "base/test/task_environment.h" #include "base/test/task_environment.h"
...@@ -48,6 +49,12 @@ class MojoLearningTaskControllerTest : public ::testing::Test { ...@@ -48,6 +49,12 @@ class MojoLearningTaskControllerTest : public ::testing::Test {
update_default_args_.default_target_ = default_target; update_default_args_.default_target_ = default_target;
} }
void PredictDistribution(const FeatureVector& features,
PredictDistributionCallback callback) override {
predict_args_.features_ = features;
predict_args_.callback_ = std::move(callback);
}
struct { struct {
base::UnguessableToken id_; base::UnguessableToken id_;
FeatureVector features_; FeatureVector features_;
...@@ -67,6 +74,11 @@ class MojoLearningTaskControllerTest : public ::testing::Test { ...@@ -67,6 +74,11 @@ class MojoLearningTaskControllerTest : public ::testing::Test {
base::UnguessableToken id_; base::UnguessableToken id_;
base::Optional<TargetValue> default_target_; base::Optional<TargetValue> default_target_;
} update_default_args_; } update_default_args_;
struct {
FeatureVector features_;
PredictDistributionCallback callback_;
} predict_args_;
}; };
public: public:
...@@ -165,5 +177,30 @@ TEST_F(MojoLearningTaskControllerTest, Cancel) { ...@@ -165,5 +177,30 @@ TEST_F(MojoLearningTaskControllerTest, Cancel) {
EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_); EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_);
} }
TEST_F(MojoLearningTaskControllerTest, PredictDistribution) {
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetHistogram observed_prediction;
learning_controller_->PredictDistribution(
features, base::BindOnce(
[](TargetHistogram* test_storage,
const base::Optional<TargetHistogram>& predicted) {
*test_storage = *predicted;
},
&observed_prediction));
task_environment_.RunUntilIdle();
EXPECT_EQ(features, fake_learning_controller_.predict_args_.features_);
EXPECT_FALSE(fake_learning_controller_.predict_args_.callback_.is_null());
TargetHistogram expected_prediction;
expected_prediction[TargetValue(1)] = 1.0;
expected_prediction[TargetValue(2)] = 2.0;
expected_prediction[TargetValue(3)] = 3.0;
std::move(fake_learning_controller_.predict_args_.callback_)
.Run(expected_prediction);
task_environment_.RunUntilIdle();
EXPECT_EQ(observed_prediction, expected_prediction);
}
} // namespace learning } // namespace learning
} // namespace media } // namespace media
...@@ -41,4 +41,9 @@ interface LearningTaskController { ...@@ -41,4 +41,9 @@ interface LearningTaskController {
// so that the observation will be cancelled if the controller is destroyed. // so that the observation will be cancelled if the controller is destroyed.
UpdateDefaultTarget(mojo_base.mojom.UnguessableToken id, UpdateDefaultTarget(mojo_base.mojom.UnguessableToken id,
TargetValue? default_target); TargetValue? default_target);
// Asynchronously predicts distribution for given |features|. |callback| will
// receive a base::nullopt prediction when model is not available.
PredictDistribution(array<FeatureValue> features)
=> (TargetHistogram? predicted);
}; };
...@@ -25,3 +25,8 @@ struct ObservationCompletion { ...@@ -25,3 +25,8 @@ struct ObservationCompletion {
TargetValue target_value; TargetValue target_value;
uint64 weight = 1; uint64 weight = 1;
}; };
// learning::TargetHistogram (common/target_histogram.h)
struct TargetHistogram {
map<TargetValue, double> counts;
};
...@@ -9,12 +9,11 @@ sources = [ ...@@ -9,12 +9,11 @@ sources = [
"//media/learning/mojo/public/cpp/learning_mojom_traits.cc", "//media/learning/mojo/public/cpp/learning_mojom_traits.cc",
"//media/learning/mojo/public/cpp/learning_mojom_traits.h", "//media/learning/mojo/public/cpp/learning_mojom_traits.h",
] ]
public_deps = [ public_deps = [ "//media/learning/common" ]
"//media/learning/common",
]
type_mappings = [ type_mappings = [
"media.learning.mojom.LabelledExample=::media::learning::LabelledExample", "media.learning.mojom.LabelledExample=::media::learning::LabelledExample",
"media.learning.mojom.FeatureValue=::media::learning::FeatureValue", "media.learning.mojom.FeatureValue=::media::learning::FeatureValue",
"media.learning.mojom.TargetValue=::media::learning::TargetValue", "media.learning.mojom.TargetValue=::media::learning::TargetValue",
"media.learning.mojom.ObservationCompletion=::media::learning::ObservationCompletion", "media.learning.mojom.ObservationCompletion=::media::learning::ObservationCompletion",
"media.learning.mojom.TargetHistogram=::media::learning::TargetHistogram",
] ]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment