Commit 242ea367 authored by Chris Cunningham's avatar Chris Cunningham Committed by Commit Bot

Add PredictDistribution() to LearningTaskController

Plumbs through various implementations including mojo.
Adds typemaps for the new mojo types.

Necessary for making predections in the renderer. A follow up CL will
call the API to make smoothness predictions for MediaCapabilities.

Change-Id: Ib6e46b2ac7f375bd343e1b9cda389ecf82eb880f
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2009843
Commit-Queue: Chrome Cunningham <chcunningham@chromium.org>
Auto-Submit: Chrome Cunningham <chcunningham@chromium.org>
Reviewed-by: default avatarRobert Sesek <rsesek@chromium.org>
Reviewed-by: default avatarKen Rockot <rockot@google.com>
Reviewed-by: default avatarFrank Liberato <liberato@chromium.org>
Cr-Commit-Position: refs/heads/master@{#734184}
parent a7a4dd07
......@@ -37,8 +37,10 @@ class MockLearningTaskController : public LearningTaskController {
MOCK_METHOD2(UpdateDefaultTarget,
void(base::UnguessableToken id,
const base::Optional<TargetValue>& default_target));
MOCK_METHOD2(PredictDistribution,
void(const FeatureVector& features, PredictionCB callback));
const LearningTask& GetLearningTask() { return task_; }
const LearningTask& GetLearningTask() override { return task_; }
private:
LearningTask task_;
......
......@@ -61,6 +61,8 @@ class SmoothnessHelperTest : public testing::Test {
const base::Optional<TargetValue>& default_target));
MOCK_METHOD0(GetLearningTask, const LearningTask&());
MOCK_METHOD2(PredictDistribution,
void(const FeatureVector& features, PredictionCB callback));
};
class MockClient : public SmoothnessHelper::Client {
......
......@@ -37,6 +37,8 @@ component("common") {
"learning_task_controller.h",
"media_learning_tasks.cc",
"media_learning_tasks.h",
"target_histogram.cc",
"target_histogram.h",
"value.cc",
"value.h",
]
......@@ -53,6 +55,7 @@ source_set("unit_tests") {
"feature_dictionary_unittest.cc",
"labelled_example_unittest.cc",
"media_learning_tasks_unittest.cc",
"target_histogram_unittest.cc",
"value_unittest.cc",
]
......
include_rules = [
"+services/metrics",
# Needed for typemap to befriend TargetHistogram.
"+mojo/public",
]
......@@ -12,6 +12,7 @@
#include "base/unguessable_token.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/common/target_histogram.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
namespace media {
......@@ -49,6 +50,9 @@ struct ObservationCompletion {
// observed to do that.
class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
public:
using PredictionCB = base::OnceCallback<void(
const base::Optional<TargetHistogram>& predicted)>;
LearningTaskController() = default;
virtual ~LearningTaskController() = default;
......@@ -91,6 +95,12 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
// Returns the LearningTask associated with |this|.
virtual const LearningTask& GetLearningTask() = 0;
// Asynchronously predicts distribution for given |features|. |callback| will
// receive a base::nullopt prediction when model is not available. |callback|
// may be called immediately without posting.
virtual void PredictDistribution(const FeatureVector& features,
PredictionCB callback) = 0;
private:
DISALLOW_COPY_AND_ASSIGN(LearningTaskController);
};
......
......@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/target_histogram.h"
#include "media/learning/common/target_histogram.h"
#include <sstream>
......
......@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_
#define MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_
#ifndef MEDIA_LEARNING_COMMON_TARGET_HISTOGRAM_H_
#define MEDIA_LEARNING_COMMON_TARGET_HISTOGRAM_H_
#include <ostream>
#include <string>
......@@ -14,17 +14,22 @@
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/value.h"
#include "mojo/public/cpp/bindings/struct_traits.h" // nogncheck
namespace media {
namespace learning {
namespace mojom {
class TargetHistogramDataView;
}
// Histogram of target values that allows fractional counts.
class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram {
private:
class COMPONENT_EXPORT(LEARNING_COMMON) TargetHistogram {
public:
// We use a flat_map since this will often have only one or two TargetValues,
// such as "true" or "false".
using CountMap = base::flat_map<TargetValue, double>;
public:
TargetHistogram();
TargetHistogram(const TargetHistogram& rhs);
TargetHistogram(TargetHistogram&& rhs);
......@@ -81,6 +86,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram {
std::string ToString() const;
private:
friend struct mojo::StructTraits<
media::learning::mojom::TargetHistogramDataView,
media::learning::TargetHistogram>;
const CountMap& counts() const { return counts_; }
// [value] == counts
......@@ -89,10 +98,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram {
// Allow copy and assign.
};
COMPONENT_EXPORT(LEARNING_IMPL)
COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const TargetHistogram& dist);
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_
#endif // MEDIA_LEARNING_COMMON_TARGET_HISTOGRAM_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/target_histogram.h"
#include "media/learning/common/target_histogram.h"
#include "testing/gtest/include/gtest/gtest.h"
......
......@@ -38,8 +38,6 @@ component("impl") {
"random_number_generator.h",
"random_tree_trainer.cc",
"random_tree_trainer.h",
"target_histogram.cc",
"target_histogram.h",
"training_algorithm.h",
"voting_ensemble.cc",
"voting_ensemble.h",
......@@ -71,7 +69,6 @@ source_set("unit_tests") {
"one_hot_unittest.cc",
"random_number_generator_unittest.cc",
"random_tree_trainer_unittest.cc",
"target_histogram_unittest.cc",
"test_random_number_generator.cc",
"test_random_number_generator.h",
]
......
......@@ -13,8 +13,8 @@
#include "base/memory/weak_ptr.h"
#include "base/optional.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/common/target_histogram.h"
#include "media/learning/impl/model.h"
#include "media/learning/impl/target_histogram.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
namespace media {
......
......@@ -90,6 +90,12 @@ class WeakLearningTaskController : public LearningTaskController {
const LearningTask& GetLearningTask() override { return task_; }
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override {
controller_->Post(FROM_HERE, &LearningTaskController::PredictDistribution,
features, std::move(callback));
}
base::WeakPtr<LearningSessionImpl> weak_session_;
base::SequenceBound<LearningTaskController>* controller_;
LearningTask task_;
......
......@@ -45,14 +45,14 @@ class LearningSessionImplTest : public testing::Test {
const FeatureVector& features,
const base::Optional<TargetValue>& default_target) override {
id_ = id;
features_ = features;
observation_features_ = features;
default_target_ = default_target;
}
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override {
EXPECT_EQ(id_, id);
example_.features = std::move(features_);
example_.features = std::move(observation_features_);
example_.target_value = completion.target_value;
example_.weight = completion.weight;
}
......@@ -74,9 +74,17 @@ class LearningSessionImplTest : public testing::Test {
return LearningTask::Empty();
}
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override {
predict_features_ = features;
predict_cb_ = std::move(callback);
}
SequenceBoundFeatureProvider feature_provider_;
base::UnguessableToken id_;
FeatureVector features_;
FeatureVector observation_features_;
FeatureVector predict_features_;
PredictionCB predict_cb_;
base::Optional<TargetValue> default_target_;
LabelledExample example_;
......@@ -317,5 +325,34 @@ TEST_F(LearningSessionImplTest, ChangeDefaultTargetToNoValue) {
EXPECT_FALSE(task_controllers_[0]->updated_id_);
}
TEST_F(LearningSessionImplTest, PredictDistribution) {
session_->RegisterTask(task_0_);
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
task_environment_.RunUntilIdle();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetHistogram observed_prediction;
controller->PredictDistribution(
features, base::BindOnce(
[](TargetHistogram* test_storage,
const base::Optional<TargetHistogram>& predicted) {
*test_storage = *predicted;
},
&observed_prediction));
task_environment_.RunUntilIdle();
EXPECT_EQ(features, task_controllers_[0]->predict_features_);
EXPECT_FALSE(task_controllers_[0]->predict_cb_.is_null());
TargetHistogram expected_prediction;
expected_prediction[TargetValue(1)] = 1.0;
expected_prediction[TargetValue(2)] = 2.0;
expected_prediction[TargetValue(3)] = 3.0;
std::move(task_controllers_[0]->predict_cb_).Run(expected_prediction);
task_environment_.RunUntilIdle();
EXPECT_EQ(expected_prediction, observed_prediction);
}
} // namespace learning
} // namespace media
......@@ -92,6 +92,15 @@ const LearningTask& LearningTaskControllerImpl::GetLearningTask() {
return task_;
}
void LearningTaskControllerImpl::PredictDistribution(
const FeatureVector& features,
PredictionCB callback) {
if (model_)
std::move(callback).Run(model_->PredictDistribution(features));
else
std::move(callback).Run(base::nullopt);
}
void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example,
ukm::SourceId source_id) {
// Verify that we have a trainer and that we got the right number of features.
......
......@@ -62,6 +62,8 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override;
const LearningTask& GetLearningTask() override;
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override;
private:
// Add |example| to the training data, and process it.
......
......@@ -137,6 +137,20 @@ class LearningTaskControllerImplTest : public testing::Test {
id, ObservationCompletion(example.target_value, example.weight));
}
void VerifyPrediction(const FeatureVector& features,
base::Optional<TargetHistogram> expectation) {
base::Optional<TargetHistogram> observed_prediction;
controller_->PredictDistribution(
features, base::BindOnce(
[](base::Optional<TargetHistogram>* test_storage,
const base::Optional<TargetHistogram>& predicted) {
*test_storage = predicted;
},
&observed_prediction));
task_environment_.RunUntilIdle();
EXPECT_EQ(observed_prediction, expectation);
}
base::test::TaskEnvironment task_environment_;
// Number of models that we trained.
......@@ -258,5 +272,18 @@ TEST_F(LearningTaskControllerImplTest, FeatureSubsetsWork) {
EXPECT_EQ(trainer_raw_->training_data()[0].features, expected_features);
}
TEST_F(LearningTaskControllerImplTest, PredictDistribution) {
CreateController();
// Predictions should be base::nullopt until we have a model.
LabelledExample example;
VerifyPrediction(example.features, base::nullopt);
AddExample(example);
TargetHistogram expected_histogram;
expected_histogram += predicted_target_;
VerifyPrediction(example.features, expected_histogram);
}
} // namespace learning
} // namespace media
......@@ -7,8 +7,8 @@
#include "base/component_export.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/target_histogram.h"
#include "media/learning/impl/model.h"
#include "media/learning/impl/target_histogram.h"
namespace media {
namespace learning {
......
......@@ -6,6 +6,7 @@
#include <utility>
#include "base/bind.h"
#include "media/learning/common/learning_task_controller.h"
namespace media {
......@@ -72,5 +73,11 @@ void MojoLearningTaskControllerService::UpdateDefaultTarget(
impl_->UpdateDefaultTarget(id, default_target);
}
void MojoLearningTaskControllerService::PredictDistribution(
const FeatureVector& features,
PredictDistributionCallback callback) {
impl_->PredictDistribution(features, std::move(callback));
}
} // namespace learning
} // namespace media
......@@ -38,6 +38,8 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskControllerService
void UpdateDefaultTarget(
const base::UnguessableToken& id,
const base::Optional<TargetValue>& default_target) override;
void PredictDistribution(const FeatureVector& features,
PredictDistributionCallback callback) override;
protected:
const LearningTask task_;
......
......@@ -50,6 +50,12 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
return LearningTask::Empty();
}
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override {
predict_distribution_args_.features_ = features;
predict_distribution_args_.callback_ = std::move(callback);
}
struct {
base::UnguessableToken id_;
FeatureVector features_;
......@@ -69,6 +75,11 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
base::UnguessableToken id_;
base::Optional<TargetValue> default_target_;
} update_default_args_;
struct {
FeatureVector features_;
PredictionCB callback_;
} predict_distribution_args_;
};
public:
......@@ -193,5 +204,27 @@ TEST_F(MojoLearningTaskControllerServiceTest, UpdateDefaultTargetToNoValue) {
controller_raw_->update_default_args_.default_target_);
}
TEST_F(MojoLearningTaskControllerServiceTest, PredictDistribution) {
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetHistogram observed_prediction;
service_->PredictDistribution(
features, base::BindOnce(
[](TargetHistogram* test_storage,
const base::Optional<TargetHistogram>& predicted) {
*test_storage = *predicted;
},
&observed_prediction));
EXPECT_EQ(features, controller_raw_->predict_distribution_args_.features_);
EXPECT_FALSE(controller_raw_->predict_distribution_args_.callback_.is_null());
TargetHistogram expected_prediction;
expected_prediction[TargetValue(1)] = 1.0;
expected_prediction[TargetValue(2)] = 2.0;
expected_prediction[TargetValue(3)] = 3.0;
std::move(controller_raw_->predict_distribution_args_.callback_)
.Run(expected_prediction);
EXPECT_EQ(expected_prediction, observed_prediction);
}
} // namespace learning
} // namespace media
......@@ -49,4 +49,16 @@ bool StructTraits<media::learning::mojom::ObservationCompletionDataView,
out_observation_completion->weight = data.weight();
return true;
}
// static
bool StructTraits<media::learning::mojom::TargetHistogramDataView,
media::learning::TargetHistogram>::
Read(media::learning::mojom::TargetHistogramDataView data,
media::learning::TargetHistogram* out_target_histogram) {
if (!data.ReadCounts(&out_target_histogram->counts_))
return false;
return true;
}
} // namespace mojo
......@@ -15,9 +15,8 @@
namespace mojo {
template <>
class StructTraits<media::learning::mojom::LabelledExampleDataView,
media::learning::LabelledExample> {
public:
struct StructTraits<media::learning::mojom::LabelledExampleDataView,
media::learning::LabelledExample> {
static const std::vector<media::learning::FeatureValue>& features(
const media::learning::LabelledExample& e) {
return e.features;
......@@ -32,9 +31,8 @@ class StructTraits<media::learning::mojom::LabelledExampleDataView,
};
template <>
class StructTraits<media::learning::mojom::FeatureValueDataView,
media::learning::FeatureValue> {
public:
struct StructTraits<media::learning::mojom::FeatureValueDataView,
media::learning::FeatureValue> {
static int64_t value(const media::learning::FeatureValue& e) {
return e.value();
}
......@@ -43,9 +41,8 @@ class StructTraits<media::learning::mojom::FeatureValueDataView,
};
template <>
class StructTraits<media::learning::mojom::TargetValueDataView,
media::learning::TargetValue> {
public:
struct StructTraits<media::learning::mojom::TargetValueDataView,
media::learning::TargetValue> {
static int64_t value(const media::learning::TargetValue& e) {
return e.value();
}
......@@ -54,9 +51,8 @@ class StructTraits<media::learning::mojom::TargetValueDataView,
};
template <>
class StructTraits<media::learning::mojom::ObservationCompletionDataView,
media::learning::ObservationCompletion> {
public:
struct StructTraits<media::learning::mojom::ObservationCompletionDataView,
media::learning::ObservationCompletion> {
static media::learning::TargetValue target_value(
const media::learning::ObservationCompletion& e) {
return e.target_value;
......@@ -70,6 +66,18 @@ class StructTraits<media::learning::mojom::ObservationCompletionDataView,
media::learning::ObservationCompletion* out_observation_completion);
};
template <>
struct StructTraits<media::learning::mojom::TargetHistogramDataView,
media::learning::TargetHistogram> {
static media::learning::TargetHistogram::CountMap counts(
const media::learning::TargetHistogram& e) {
return e.counts();
}
static bool Read(media::learning::mojom::TargetHistogramDataView data,
media::learning::TargetHistogram* out_target_histogram);
};
} // namespace mojo
#endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_LEARNING_MOJOM_TRAITS_H_
......@@ -45,5 +45,11 @@ const LearningTask& MojoLearningTaskController::GetLearningTask() {
return task_;
}
void MojoLearningTaskController::PredictDistribution(
const FeatureVector& features,
PredictionCB callback) {
controller_->PredictDistribution(features, std::move(callback));
}
} // namespace learning
} // namespace media
......@@ -40,6 +40,8 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController
base::UnguessableToken id,
const base::Optional<TargetValue>& default_target) override;
const LearningTask& GetLearningTask() override;
void PredictDistribution(const FeatureVector& features,
PredictionCB callback) override;
private:
LearningTask task_;
......
......@@ -6,6 +6,7 @@
#include <utility>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "base/test/task_environment.h"
......@@ -48,6 +49,12 @@ class MojoLearningTaskControllerTest : public ::testing::Test {
update_default_args_.default_target_ = default_target;
}
void PredictDistribution(const FeatureVector& features,
PredictDistributionCallback callback) override {
predict_args_.features_ = features;
predict_args_.callback_ = std::move(callback);
}
struct {
base::UnguessableToken id_;
FeatureVector features_;
......@@ -67,6 +74,11 @@ class MojoLearningTaskControllerTest : public ::testing::Test {
base::UnguessableToken id_;
base::Optional<TargetValue> default_target_;
} update_default_args_;
struct {
FeatureVector features_;
PredictDistributionCallback callback_;
} predict_args_;
};
public:
......@@ -165,5 +177,30 @@ TEST_F(MojoLearningTaskControllerTest, Cancel) {
EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_);
}
TEST_F(MojoLearningTaskControllerTest, PredictDistribution) {
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
TargetHistogram observed_prediction;
learning_controller_->PredictDistribution(
features, base::BindOnce(
[](TargetHistogram* test_storage,
const base::Optional<TargetHistogram>& predicted) {
*test_storage = *predicted;
},
&observed_prediction));
task_environment_.RunUntilIdle();
EXPECT_EQ(features, fake_learning_controller_.predict_args_.features_);
EXPECT_FALSE(fake_learning_controller_.predict_args_.callback_.is_null());
TargetHistogram expected_prediction;
expected_prediction[TargetValue(1)] = 1.0;
expected_prediction[TargetValue(2)] = 2.0;
expected_prediction[TargetValue(3)] = 3.0;
std::move(fake_learning_controller_.predict_args_.callback_)
.Run(expected_prediction);
task_environment_.RunUntilIdle();
EXPECT_EQ(observed_prediction, expected_prediction);
}
} // namespace learning
} // namespace media
......@@ -41,4 +41,9 @@ interface LearningTaskController {
// so that the observation will be cancelled if the controller is destroyed.
UpdateDefaultTarget(mojo_base.mojom.UnguessableToken id,
TargetValue? default_target);
// Asynchronously predicts distribution for given |features|. |callback| will
// receive a base::nullopt prediction when model is not available.
PredictDistribution(array<FeatureValue> features)
=> (TargetHistogram? predicted);
};
......@@ -25,3 +25,8 @@ struct ObservationCompletion {
TargetValue target_value;
uint64 weight = 1;
};
// learning::TargetHistogram (common/target_histogram.h)
struct TargetHistogram {
map<TargetValue, double> counts;
};
......@@ -9,12 +9,11 @@ sources = [
"//media/learning/mojo/public/cpp/learning_mojom_traits.cc",
"//media/learning/mojo/public/cpp/learning_mojom_traits.h",
]
public_deps = [
"//media/learning/common",
]
public_deps = [ "//media/learning/common" ]
type_mappings = [
"media.learning.mojom.LabelledExample=::media::learning::LabelledExample",
"media.learning.mojom.FeatureValue=::media::learning::FeatureValue",
"media.learning.mojom.TargetValue=::media::learning::TargetValue",
"media.learning.mojom.ObservationCompletion=::media::learning::ObservationCompletion",
"media.learning.mojom.TargetHistogram=::media::learning::TargetHistogram",
]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment