Commit 6549c4a6 authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Implement LearningTaskControllerImpl.

Replace LearningTaskControllerImpl's do-nothing initial
implementation with one that:
  - constructs the training callback based on the learning task
  - collects training examples as they're added
  - re-train the model periodically.
  - record stats about model performance

Change-Id: I6a949e6188c71d00be5449dc9539508254c7d74a
Reviewed-on: https://chromium-review.googlesource.com/c/1355434
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarFredrik Hubinette <hubbe@chromium.org>
Cr-Commit-Position: refs/heads/master@{#612860}
parent 1f6fd36b
...@@ -23,16 +23,11 @@ namespace learning { ...@@ -23,16 +23,11 @@ namespace learning {
// registering tasks. // registering tasks.
struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask { struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// Not all models support all feature / target descriptions. For example, // Not all models support all feature / target descriptions. For example,
// NaiveBayes requires kUnordered features. Similarly, kLogLinear doesn't // NaiveBayes requires kUnordered features. Similarly, LogLinear woudln't
// support kUnordered features or targets. kRandomForest might support more // support kUnordered features or targets. kRandomForest might support more
// combination of orderings and types. // combination of orderings and types.
//
// Also note that not all of these are implemented yet.
enum class Model { enum class Model {
kMostCommonTarget,
kNaiveBayes,
kRandomForest, kRandomForest,
kLogLinear,
}; };
enum class Ordering { enum class Ordering {
...@@ -81,7 +76,7 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask { ...@@ -81,7 +76,7 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// Unique name for this learner. // Unique name for this learner.
std::string name; std::string name;
Model model = Model::kMostCommonTarget; Model model = Model::kRandomForest;
std::vector<ValueDescription> feature_descriptions; std::vector<ValueDescription> feature_descriptions;
...@@ -91,6 +86,13 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask { ...@@ -91,6 +86,13 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// TODO(liberato): add training parameters, like smoothing constants. It's // TODO(liberato): add training parameters, like smoothing constants. It's
// okay if some of these are model-specific. // okay if some of these are model-specific.
// TODO(liberato): switch to base::DictionaryValue?
// Number of examples before we'll train a model.
size_t min_data_set_size = 10u;
// Should the accuracy of this model be recorded to UMA?
bool record_accuracy_via_uma = true;
}; };
} // namespace learning } // namespace learning
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifndef MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_ #ifndef MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#define MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_ #define MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#include <deque>
#include <initializer_list> #include <initializer_list>
#include <ostream> #include <ostream>
#include <vector> #include <vector>
...@@ -49,11 +50,15 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample { ...@@ -49,11 +50,15 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
// Copy / assignment is allowed. // Copy / assignment is allowed.
}; };
// Collection of training examples. We use a vector since we allow duplicates. // Collection of training examples.
// TODO(liberato): This should probably move to impl/ .
class COMPONENT_EXPORT(LEARNING_COMMON) TrainingDataStorage class COMPONENT_EXPORT(LEARNING_COMMON) TrainingDataStorage
: public base::RefCountedThreadSafe<TrainingDataStorage> { : public base::RefCountedThreadSafe<TrainingDataStorage> {
public: public:
using StorageVector = std::vector<TrainingExample>; // We store examples in a deque, since we don't want to invalidate pointers in
// TrainingData collections (see below) as new examples are added. Deques
// promise not to do that when inserting at either end.
using StorageVector = std::deque<TrainingExample>;
using const_iterator = StorageVector::const_iterator; using const_iterator = StorageVector::const_iterator;
TrainingDataStorage(); TrainingDataStorage();
...@@ -61,31 +66,31 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingDataStorage ...@@ -61,31 +66,31 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingDataStorage
StorageVector::const_iterator begin() const { return examples_.begin(); } StorageVector::const_iterator begin() const { return examples_.begin(); }
StorageVector::const_iterator end() const { return examples_.end(); } StorageVector::const_iterator end() const { return examples_.end(); }
// Note that it's okay to add examples at any time.
void push_back(const TrainingExample& example) { void push_back(const TrainingExample& example) {
examples_.push_back(example); examples_.push_back(example);
} }
// Returns true if and only if |example| is included in our data. Note that // Notice that there's no option to clear storage; that might invalidate
// this checks that the pointer itself is included, so that one might tell if // outstanding pointers in TrainingData (see below). Instead, just create a
// an example is backed by this storage or not. It does not care if there is // new TrainingDataStorage.
// an example in our storage that would TrainingExample::operator==(*example).
bool contains(const TrainingExample* example) const { // Return the number of examples that we store.
return (example >= examples_.data()) && size_t size() const { return examples_.size(); }
(example < examples_.data() + examples_.size());
}
private: private:
friend class base::RefCountedThreadSafe<TrainingDataStorage>; friend class base::RefCountedThreadSafe<TrainingDataStorage>;
~TrainingDataStorage(); ~TrainingDataStorage();
std::vector<TrainingExample> examples_; StorageVector examples_;
DISALLOW_COPY_AND_ASSIGN(TrainingDataStorage); DISALLOW_COPY_AND_ASSIGN(TrainingDataStorage);
}; };
// Collection of pointers to training data. References would be more convenient // Collection of pointers to training data. References would be more convenient
// but they're not allowed. // but they're not allowed.
// TODO(liberato): This should probably move to impl/ .
class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData { class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
public: public:
using ExampleVector = std::vector<const TrainingExample*>; using ExampleVector = std::vector<const TrainingExample*>;
...@@ -107,7 +112,6 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData { ...@@ -107,7 +112,6 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
void push_back(const TrainingExample* example) { void push_back(const TrainingExample* example) {
DCHECK(backing_storage_); DCHECK(backing_storage_);
DCHECK(backing_storage_->contains(example));
examples_.push_back(example); examples_.push_back(example);
} }
......
...@@ -90,17 +90,6 @@ TEST_F(LearnerTrainingExampleTest, StoragePushBack) { ...@@ -90,17 +90,6 @@ TEST_F(LearnerTrainingExampleTest, StoragePushBack) {
EXPECT_EQ(*storage->begin(), example); EXPECT_EQ(*storage->begin(), example);
} }
TEST_F(LearnerTrainingExampleTest, StorageCheckWorks) {
// Verify that TrainingDataStorage can tell if an example is in its storage.
TrainingExample example({FeatureValue(123)}, TargetValue(789));
scoped_refptr<TrainingDataStorage> storage =
base::MakeRefCounted<TrainingDataStorage>();
storage->push_back(example);
EXPECT_TRUE(storage->contains(&(*storage->begin())));
EXPECT_FALSE(storage->contains(&example));
}
TEST_F(LearnerTrainingExampleTest, TrainingDataPushBack) { TEST_F(LearnerTrainingExampleTest, TrainingDataPushBack) {
TrainingExample example({FeatureValue(123)}, TargetValue(789)); TrainingExample example({FeatureValue(123)}, TargetValue(789));
scoped_refptr<TrainingDataStorage> storage = scoped_refptr<TrainingDataStorage> storage =
......
...@@ -36,6 +36,7 @@ source_set("unit_tests") { ...@@ -36,6 +36,7 @@ source_set("unit_tests") {
sources = [ sources = [
"learning_session_impl_unittest.cc", "learning_session_impl_unittest.cc",
"learning_task_controller_impl_unittest.cc",
"random_tree_trainer_unittest.cc", "random_tree_trainer_unittest.cc",
"target_distribution_unittest.cc", "target_distribution_unittest.cc",
] ]
......
...@@ -7,16 +7,67 @@ ...@@ -7,16 +7,67 @@
#include <memory> #include <memory>
#include "base/bind.h" #include "base/bind.h"
#include "media/learning/impl/random_tree_trainer.h"
namespace media { namespace media {
namespace learning { namespace learning {
LearningTaskControllerImpl::LearningTaskControllerImpl( LearningTaskControllerImpl::LearningTaskControllerImpl(const LearningTask& task)
const LearningTask& task) {} : task_(task), storage_(base::MakeRefCounted<TrainingDataStorage>()) {
switch (task_.model) {
case LearningTask::Model::kRandomForest:
// TODO(liberato): send in the task, so that it can get params.
// TODO(liberato): forest!
training_cb_ = RandomTreeTrainer::GetTrainingAlgorithmCB();
break;
}
// TODO(liberato): Record via UMA based on the task name.
accuracy_reporting_cb_ =
base::BindRepeating([](const LearningTask&, bool is_correct) {});
}
LearningTaskControllerImpl::~LearningTaskControllerImpl() = default; LearningTaskControllerImpl::~LearningTaskControllerImpl() = default;
void LearningTaskControllerImpl::AddExample(const TrainingExample& example) { void LearningTaskControllerImpl::AddExample(const TrainingExample& example) {
// TODO: do something. // TODO(liberato): do we ever trim older examples?
storage_->push_back(example);
// Once we have a model, see if we'd get |example| correct.
if (model_) {
TargetDistribution distribution =
model_->PredictDistribution(example.features);
TargetValue predicted_value;
const bool is_correct = distribution.FindSingularMax(&predicted_value) &&
predicted_value == example.target_value;
accuracy_reporting_cb_.Run(task_, is_correct);
// TODO(liberato): record entropy / not representable?
}
// Train every time we get a multiple of |data_set_size|.
if ((storage_->size() % task_.min_data_set_size) != 0)
return;
TrainingData training_data(storage_, storage_->begin(), storage_->end());
TrainedModelCB model_cb =
base::BindOnce(&LearningTaskControllerImpl::OnModelTrained, AsWeakPtr());
training_cb_.Run(training_data, std::move(model_cb));
}
void LearningTaskControllerImpl::OnModelTrained(std::unique_ptr<Model> model) {
model_ = std::move(model);
// TODO(liberato): record oob results.
}
void LearningTaskControllerImpl::SetTrainingCBForTesting(
TrainingAlgorithmCB cb) {
training_cb_ = std::move(cb);
}
void LearningTaskControllerImpl::SetAccuracyReportingCBForTesting(
AccuracyReportingCB cb) {
accuracy_reporting_cb_ = std::move(cb);
} }
} // namespace learning } // namespace learning
......
...@@ -7,20 +7,55 @@ ...@@ -7,20 +7,55 @@
#include <memory> #include <memory>
#include "base/callback.h"
#include "base/component_export.h" #include "base/component_export.h"
#include "base/memory/weak_ptr.h"
#include "media/learning/impl/learning_task_controller.h" #include "media/learning/impl/learning_task_controller.h"
#include "media/learning/impl/training_algorithm.h"
namespace media { namespace media {
namespace learning { namespace learning {
class LearningTaskControllerImplTest;
class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
: public LearningTaskController { : public LearningTaskController,
public base::SupportsWeakPtr<LearningTaskControllerImpl> {
public: public:
explicit LearningTaskControllerImpl(const LearningTask& task); explicit LearningTaskControllerImpl(const LearningTask& task);
~LearningTaskControllerImpl() override; ~LearningTaskControllerImpl() override;
// LearningTaskController // LearningTaskController
void AddExample(const TrainingExample& example) override; void AddExample(const TrainingExample& example) override;
private:
// Called with accuracy results as new examples are added. Only tests should
// need to worry about this.
using AccuracyReportingCB =
base::RepeatingCallback<void(const LearningTask& task, bool is_correct)>;
// Override the training CB for testing.
void SetTrainingCBForTesting(TrainingAlgorithmCB cb);
// Override the reporting CB for testing.
void SetAccuracyReportingCBForTesting(AccuracyReportingCB cb);
// Called by |training_cb_| when the model is trained.
void OnModelTrained(std::unique_ptr<Model> model);
LearningTask task_;
// Current batch of examples.
scoped_refptr<TrainingDataStorage> storage_;
// Most recently trained model, or null.
std::unique_ptr<Model> model_;
TrainingAlgorithmCB training_cb_;
AccuracyReportingCB accuracy_reporting_cb_;
friend class LearningTaskControllerImplTest;
}; };
} // namespace learning } // namespace learning
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/learning_task_controller_impl.h"
#include "base/bind.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class LearningTaskControllerImplTest : public testing::Test {
public:
LearningTaskControllerImplTest()
: predicted_target_(123), not_predicted_target_(456) {
// Don't require too many training examples per report.
task_.min_data_set_size = 4;
controller_ = std::make_unique<LearningTaskControllerImpl>(task_);
controller_->SetTrainingCBForTesting(base::BindRepeating(
&LearningTaskControllerImplTest::OnTrain, base::Unretained(this)));
controller_->SetAccuracyReportingCBForTesting(base::BindRepeating(
&LearningTaskControllerImplTest::OnAccuracy, base::Unretained(this)));
}
// Model that always predicts a constant.
class FakeModel : public Model {
public:
FakeModel(TargetValue target) : target_(target) {}
// Model
TargetDistribution PredictDistribution(
const FeatureVector& features) override {
TargetDistribution dist;
dist += target_;
return dist;
}
private:
// The value we predict.
TargetValue target_;
};
void OnTrain(TrainingData training_data, TrainedModelCB model_cb) {
num_models_++;
std::move(model_cb).Run(std::make_unique<FakeModel>(predicted_target_));
}
void OnAccuracy(const LearningTask& task, bool is_correct) {
num_reported_++;
if (is_correct)
num_correct_++;
}
// Number of models that we trained.
int num_models_ = 0;
// Results reported via OnAccuracy.
int num_reported_ = 0;
int num_correct_ = 0;
// Two distinct targets.
TargetValue predicted_target_;
TargetValue not_predicted_target_;
LearningTask task_;
std::unique_ptr<LearningTaskControllerImpl> controller_;
};
TEST_F(LearningTaskControllerImplTest, AddingExamplesTrainsModelAndReports) {
TrainingExample example;
// Adding the first n-1 examples shouldn't cause it to train a model.
for (size_t i = 0; i < task_.min_data_set_size - 1; i++)
controller_->AddExample(example);
EXPECT_EQ(num_models_, 0);
// Adding one more example should train a model.
controller_->AddExample(example);
EXPECT_EQ(num_models_, 1);
// No results should be reported yet.
EXPECT_EQ(num_reported_, 0);
EXPECT_EQ(num_correct_, 0);
// Adding one more example should report results.
example.target_value = predicted_target_;
controller_->AddExample(example);
EXPECT_EQ(num_models_, 1);
EXPECT_EQ(num_reported_, 1);
EXPECT_EQ(num_correct_, 1);
// Adding a value that doesn't match should report one more attempt.
example.target_value = not_predicted_target_;
controller_->AddExample(example);
EXPECT_EQ(num_models_, 1);
EXPECT_EQ(num_reported_, 2);
EXPECT_EQ(num_correct_, 1); // Still 1.
}
} // namespace learning
} // namespace media
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment