Commit 6549c4a6 authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Implement LearningTaskControllerImpl.

Replace LearningTaskControllerImpl's do-nothing initial
implementation with one that:
  - constructs the training callback based on the learning task
  - collects training examples as they're added
  - re-train the model periodically.
  - record stats about model performance

Change-Id: I6a949e6188c71d00be5449dc9539508254c7d74a
Reviewed-on: https://chromium-review.googlesource.com/c/1355434
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarFredrik Hubinette <hubbe@chromium.org>
Cr-Commit-Position: refs/heads/master@{#612860}
parent 1f6fd36b
......@@ -23,16 +23,11 @@ namespace learning {
// registering tasks.
struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// Not all models support all feature / target descriptions. For example,
// NaiveBayes requires kUnordered features. Similarly, kLogLinear doesn't
// NaiveBayes requires kUnordered features. Similarly, LogLinear woudln't
// support kUnordered features or targets. kRandomForest might support more
// combination of orderings and types.
//
// Also note that not all of these are implemented yet.
enum class Model {
kMostCommonTarget,
kNaiveBayes,
kRandomForest,
kLogLinear,
};
enum class Ordering {
......@@ -81,7 +76,7 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// Unique name for this learner.
std::string name;
Model model = Model::kMostCommonTarget;
Model model = Model::kRandomForest;
std::vector<ValueDescription> feature_descriptions;
......@@ -91,6 +86,13 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// TODO(liberato): add training parameters, like smoothing constants. It's
// okay if some of these are model-specific.
// TODO(liberato): switch to base::DictionaryValue?
// Number of examples before we'll train a model.
size_t min_data_set_size = 10u;
// Should the accuracy of this model be recorded to UMA?
bool record_accuracy_via_uma = true;
};
} // namespace learning
......
......@@ -5,6 +5,7 @@
#ifndef MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#define MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#include <deque>
#include <initializer_list>
#include <ostream>
#include <vector>
......@@ -49,11 +50,15 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
// Copy / assignment is allowed.
};
// Collection of training examples. We use a vector since we allow duplicates.
// Collection of training examples.
// TODO(liberato): This should probably move to impl/ .
class COMPONENT_EXPORT(LEARNING_COMMON) TrainingDataStorage
: public base::RefCountedThreadSafe<TrainingDataStorage> {
public:
using StorageVector = std::vector<TrainingExample>;
// We store examples in a deque, since we don't want to invalidate pointers in
// TrainingData collections (see below) as new examples are added. Deques
// promise not to do that when inserting at either end.
using StorageVector = std::deque<TrainingExample>;
using const_iterator = StorageVector::const_iterator;
TrainingDataStorage();
......@@ -61,31 +66,31 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingDataStorage
StorageVector::const_iterator begin() const { return examples_.begin(); }
StorageVector::const_iterator end() const { return examples_.end(); }
// Note that it's okay to add examples at any time.
void push_back(const TrainingExample& example) {
examples_.push_back(example);
}
// Returns true if and only if |example| is included in our data. Note that
// this checks that the pointer itself is included, so that one might tell if
// an example is backed by this storage or not. It does not care if there is
// an example in our storage that would TrainingExample::operator==(*example).
bool contains(const TrainingExample* example) const {
return (example >= examples_.data()) &&
(example < examples_.data() + examples_.size());
}
// Notice that there's no option to clear storage; that might invalidate
// outstanding pointers in TrainingData (see below). Instead, just create a
// new TrainingDataStorage.
// Return the number of examples that we store.
size_t size() const { return examples_.size(); }
private:
friend class base::RefCountedThreadSafe<TrainingDataStorage>;
~TrainingDataStorage();
std::vector<TrainingExample> examples_;
StorageVector examples_;
DISALLOW_COPY_AND_ASSIGN(TrainingDataStorage);
};
// Collection of pointers to training data. References would be more convenient
// but they're not allowed.
// TODO(liberato): This should probably move to impl/ .
class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
public:
using ExampleVector = std::vector<const TrainingExample*>;
......@@ -107,7 +112,6 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
void push_back(const TrainingExample* example) {
DCHECK(backing_storage_);
DCHECK(backing_storage_->contains(example));
examples_.push_back(example);
}
......
......@@ -90,17 +90,6 @@ TEST_F(LearnerTrainingExampleTest, StoragePushBack) {
EXPECT_EQ(*storage->begin(), example);
}
TEST_F(LearnerTrainingExampleTest, StorageCheckWorks) {
// Verify that TrainingDataStorage can tell if an example is in its storage.
TrainingExample example({FeatureValue(123)}, TargetValue(789));
scoped_refptr<TrainingDataStorage> storage =
base::MakeRefCounted<TrainingDataStorage>();
storage->push_back(example);
EXPECT_TRUE(storage->contains(&(*storage->begin())));
EXPECT_FALSE(storage->contains(&example));
}
TEST_F(LearnerTrainingExampleTest, TrainingDataPushBack) {
TrainingExample example({FeatureValue(123)}, TargetValue(789));
scoped_refptr<TrainingDataStorage> storage =
......
......@@ -36,6 +36,7 @@ source_set("unit_tests") {
sources = [
"learning_session_impl_unittest.cc",
"learning_task_controller_impl_unittest.cc",
"random_tree_trainer_unittest.cc",
"target_distribution_unittest.cc",
]
......
......@@ -7,16 +7,67 @@
#include <memory>
#include "base/bind.h"
#include "media/learning/impl/random_tree_trainer.h"
namespace media {
namespace learning {
LearningTaskControllerImpl::LearningTaskControllerImpl(
const LearningTask& task) {}
LearningTaskControllerImpl::LearningTaskControllerImpl(const LearningTask& task)
: task_(task), storage_(base::MakeRefCounted<TrainingDataStorage>()) {
switch (task_.model) {
case LearningTask::Model::kRandomForest:
// TODO(liberato): send in the task, so that it can get params.
// TODO(liberato): forest!
training_cb_ = RandomTreeTrainer::GetTrainingAlgorithmCB();
break;
}
// TODO(liberato): Record via UMA based on the task name.
accuracy_reporting_cb_ =
base::BindRepeating([](const LearningTask&, bool is_correct) {});
}
LearningTaskControllerImpl::~LearningTaskControllerImpl() = default;
void LearningTaskControllerImpl::AddExample(const TrainingExample& example) {
// TODO: do something.
// TODO(liberato): do we ever trim older examples?
storage_->push_back(example);
// Once we have a model, see if we'd get |example| correct.
if (model_) {
TargetDistribution distribution =
model_->PredictDistribution(example.features);
TargetValue predicted_value;
const bool is_correct = distribution.FindSingularMax(&predicted_value) &&
predicted_value == example.target_value;
accuracy_reporting_cb_.Run(task_, is_correct);
// TODO(liberato): record entropy / not representable?
}
// Train every time we get a multiple of |data_set_size|.
if ((storage_->size() % task_.min_data_set_size) != 0)
return;
TrainingData training_data(storage_, storage_->begin(), storage_->end());
TrainedModelCB model_cb =
base::BindOnce(&LearningTaskControllerImpl::OnModelTrained, AsWeakPtr());
training_cb_.Run(training_data, std::move(model_cb));
}
void LearningTaskControllerImpl::OnModelTrained(std::unique_ptr<Model> model) {
model_ = std::move(model);
// TODO(liberato): record oob results.
}
void LearningTaskControllerImpl::SetTrainingCBForTesting(
TrainingAlgorithmCB cb) {
training_cb_ = std::move(cb);
}
void LearningTaskControllerImpl::SetAccuracyReportingCBForTesting(
AccuracyReportingCB cb) {
accuracy_reporting_cb_ = std::move(cb);
}
} // namespace learning
......
......@@ -7,20 +7,55 @@
#include <memory>
#include "base/callback.h"
#include "base/component_export.h"
#include "base/memory/weak_ptr.h"
#include "media/learning/impl/learning_task_controller.h"
#include "media/learning/impl/training_algorithm.h"
namespace media {
namespace learning {
class LearningTaskControllerImplTest;
class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
: public LearningTaskController {
: public LearningTaskController,
public base::SupportsWeakPtr<LearningTaskControllerImpl> {
public:
explicit LearningTaskControllerImpl(const LearningTask& task);
~LearningTaskControllerImpl() override;
// LearningTaskController
void AddExample(const TrainingExample& example) override;
private:
// Called with accuracy results as new examples are added. Only tests should
// need to worry about this.
using AccuracyReportingCB =
base::RepeatingCallback<void(const LearningTask& task, bool is_correct)>;
// Override the training CB for testing.
void SetTrainingCBForTesting(TrainingAlgorithmCB cb);
// Override the reporting CB for testing.
void SetAccuracyReportingCBForTesting(AccuracyReportingCB cb);
// Called by |training_cb_| when the model is trained.
void OnModelTrained(std::unique_ptr<Model> model);
LearningTask task_;
// Current batch of examples.
scoped_refptr<TrainingDataStorage> storage_;
// Most recently trained model, or null.
std::unique_ptr<Model> model_;
TrainingAlgorithmCB training_cb_;
AccuracyReportingCB accuracy_reporting_cb_;
friend class LearningTaskControllerImplTest;
};
} // namespace learning
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/learning_task_controller_impl.h"
#include "base/bind.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class LearningTaskControllerImplTest : public testing::Test {
public:
LearningTaskControllerImplTest()
: predicted_target_(123), not_predicted_target_(456) {
// Don't require too many training examples per report.
task_.min_data_set_size = 4;
controller_ = std::make_unique<LearningTaskControllerImpl>(task_);
controller_->SetTrainingCBForTesting(base::BindRepeating(
&LearningTaskControllerImplTest::OnTrain, base::Unretained(this)));
controller_->SetAccuracyReportingCBForTesting(base::BindRepeating(
&LearningTaskControllerImplTest::OnAccuracy, base::Unretained(this)));
}
// Model that always predicts a constant.
class FakeModel : public Model {
public:
FakeModel(TargetValue target) : target_(target) {}
// Model
TargetDistribution PredictDistribution(
const FeatureVector& features) override {
TargetDistribution dist;
dist += target_;
return dist;
}
private:
// The value we predict.
TargetValue target_;
};
void OnTrain(TrainingData training_data, TrainedModelCB model_cb) {
num_models_++;
std::move(model_cb).Run(std::make_unique<FakeModel>(predicted_target_));
}
void OnAccuracy(const LearningTask& task, bool is_correct) {
num_reported_++;
if (is_correct)
num_correct_++;
}
// Number of models that we trained.
int num_models_ = 0;
// Results reported via OnAccuracy.
int num_reported_ = 0;
int num_correct_ = 0;
// Two distinct targets.
TargetValue predicted_target_;
TargetValue not_predicted_target_;
LearningTask task_;
std::unique_ptr<LearningTaskControllerImpl> controller_;
};
TEST_F(LearningTaskControllerImplTest, AddingExamplesTrainsModelAndReports) {
TrainingExample example;
// Adding the first n-1 examples shouldn't cause it to train a model.
for (size_t i = 0; i < task_.min_data_set_size - 1; i++)
controller_->AddExample(example);
EXPECT_EQ(num_models_, 0);
// Adding one more example should train a model.
controller_->AddExample(example);
EXPECT_EQ(num_models_, 1);
// No results should be reported yet.
EXPECT_EQ(num_reported_, 0);
EXPECT_EQ(num_correct_, 0);
// Adding one more example should report results.
example.target_value = predicted_target_;
controller_->AddExample(example);
EXPECT_EQ(num_models_, 1);
EXPECT_EQ(num_reported_, 1);
EXPECT_EQ(num_correct_, 1);
// Adding a value that doesn't match should report one more attempt.
example.target_value = not_predicted_target_;
controller_->AddExample(example);
EXPECT_EQ(num_models_, 1);
EXPECT_EQ(num_reported_, 2);
EXPECT_EQ(num_correct_, 1); // Still 1.
}
} // namespace learning
} // namespace media
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment