Commit 8b0bd21f authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Separated model from training algorithm.

Previously, 'Learner' was a roll-up trainer + model, intended to
make it easy for clients to use the learning system.  However, since
LearningSession can do this too, it makes sense to split the model
apart from the particular training algorithm used to build it.

Change-Id: If207bf64c6a5b34fb84591b9b149bdd9a3ff6af4
Reviewed-on: https://chromium-review.googlesource.com/c/1327544Reviewed-by: default avatarFredrik Hubinette <hubbe@chromium.org>
Commit-Queue: Frank Liberato <liberato@chromium.org>
Cr-Commit-Position: refs/heads/master@{#609802}
parent 6acbabc0
......@@ -7,14 +7,15 @@ component("impl") {
visibility = [ "//media/learning/impl:unit_tests" ]
sources = [
"learner.h",
"learning_session_impl.cc",
"learning_session_impl.h",
"learning_task_controller.h",
"learning_task_controller_impl.cc",
"learning_task_controller_impl.h",
"random_tree.cc",
"random_tree.h",
"model.h",
"random_tree_trainer.cc",
"random_tree_trainer.h",
"training_algorithm.h",
]
defines = [ "IS_LEARNING_IMPL_IMPL" ]
......@@ -33,7 +34,7 @@ source_set("unit_tests") {
sources = [
"learning_session_impl_unittest.cc",
"random_tree_unittest.cc",
"random_tree_trainer_unittest.cc",
]
deps = [
......
......@@ -2,29 +2,35 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_LEARNER_H_
#define MEDIA_LEARNING_IMPL_LEARNER_H_
#ifndef MEDIA_LEARNING_IMPL_MODEL_H_
#define MEDIA_LEARNING_IMPL_MODEL_H_
#include <map>
#include "base/component_export.h"
#include "base/values.h"
#include "media/learning/common/training_example.h"
#include "media/learning/impl/model.h"
namespace media {
namespace learning {
// Base class for a thing that takes examples of the form {features, target},
// and trains a model to predict the target given the features. The target may
// be either nominal (classification) or numeric (regression), though this must
// be chosen in advance when creating the learner via LearnerFactory.
class COMPONENT_EXPORT(LEARNING_IMPL) Learner {
// One trained model, useful for making predictions.
// TODO(liberato): Provide an API for incremental update, for those models that
// can support it.
class COMPONENT_EXPORT(LEARNING_IMPL) Model {
public:
virtual ~Learner() = default;
// [target value] == counts
// This is classification-centric. Not sure about the right interface for
// regressors. Mostly for testing.
using TargetDistribution = std::map<TargetValue, int>;
virtual ~Model() = default;
// Tell the learner that |example| has been observed during training.
virtual void AddExample(const TrainingExample& example) = 0;
virtual TargetDistribution PredictDistribution(
const FeatureVector& instance) = 0;
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_LEARNER_H_
#endif // MEDIA_LEARNING_IMPL_MODEL_H_
......@@ -2,37 +2,51 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_tree.h"
#include "media/learning/impl/random_tree_trainer.h"
#include <math.h>
#include "base/bind.h"
#include "base/logging.h"
namespace media {
namespace learning {
RandomTree::TreeNode::~TreeNode() = default;
RandomTree::Split::Split() = default;
RandomTree::Split::Split(int index) : split_index(index) {}
RandomTree::Split::Split(Split&& rhs) = default;
RandomTree::Split::~Split() = default;
RandomTree::Split& RandomTree::Split::operator=(Split&& rhs) = default;
RandomTree::Split::BranchInfo::BranchInfo() = default;
RandomTree::Split::BranchInfo::BranchInfo(const BranchInfo& rhs) = default;
RandomTree::Split::BranchInfo::~BranchInfo() = default;
struct InteriorNode : public RandomTree::TreeNode {
// static
TrainingAlgorithmCB RandomTreeTrainer::GetTrainingAlgorithmCB() {
return base::BindRepeating(
[](const TrainingData& training_data) -> std::unique_ptr<Model> {
return RandomTreeTrainer().Train(training_data);
});
}
RandomTreeTrainer::Split::Split() = default;
RandomTreeTrainer::Split::Split(int index) : split_index(index) {}
RandomTreeTrainer::Split::Split(Split&& rhs) = default;
RandomTreeTrainer::Split::~Split() = default;
RandomTreeTrainer::Split& RandomTreeTrainer::Split::operator=(Split&& rhs) =
default;
RandomTreeTrainer::Split::BranchInfo::BranchInfo() = default;
RandomTreeTrainer::Split::BranchInfo::BranchInfo(const BranchInfo& rhs) =
default;
RandomTreeTrainer::Split::BranchInfo::~BranchInfo() = default;
struct InteriorNode : public Model {
InteriorNode(int split_index) : split_index_(split_index) {}
// TreeNode
TargetDistribution* ComputeDistribution(
// Model
Model::TargetDistribution PredictDistribution(
const FeatureVector& features) override {
auto iter = children_.find(features[split_index_]);
// If we've never seen this feature value, then make no prediction.
if (iter == children_.end())
return nullptr;
return TargetDistribution();
return iter->second->ComputeDistribution(features);
return iter->second->PredictDistribution(features);
}
// Add |child| has the node for feature value |v|.
void AddChild(FeatureValue v, std::unique_ptr<TreeNode> child) {
void AddChild(FeatureValue v, std::unique_ptr<Model> child) {
DCHECK_EQ(children_.count(v), 0u);
children_.emplace(v, std::move(child));
}
......@@ -40,44 +54,37 @@ struct InteriorNode : public RandomTree::TreeNode {
private:
// Feature value that we split on.
int split_index_ = -1;
std::map<FeatureValue, std::unique_ptr<TreeNode>> children_;
std::map<FeatureValue, std::unique_ptr<Model>> children_;
};
struct LeafNode : public RandomTree::TreeNode {
struct LeafNode : public Model {
LeafNode(const TrainingData& training_data) {
for (const TrainingExample* example : training_data)
distribution_[example->target_value]++;
}
// TreeNode
TargetDistribution* ComputeDistribution(const FeatureVector&) override {
return &distribution_;
Model::TargetDistribution PredictDistribution(const FeatureVector&) override {
return distribution_;
}
private:
TargetDistribution distribution_;
Model::TargetDistribution distribution_;
};
RandomTree::RandomTree() = default;
RandomTreeTrainer::RandomTreeTrainer() = default;
RandomTree::~RandomTree() = default;
RandomTreeTrainer::~RandomTreeTrainer() = default;
void RandomTree::Train(const TrainingData& training_data) {
root_ = nullptr;
std::unique_ptr<Model> RandomTreeTrainer::Train(
const TrainingData& training_data) {
if (training_data.empty())
return;
root_ = Build(training_data, FeatureSet());
}
return std::make_unique<InteriorNode>(-1);
const RandomTree::TreeNode::TargetDistribution*
RandomTree::ComputeDistributionForTesting(const FeatureVector& instance) {
if (!root_)
return nullptr;
return root_->ComputeDistribution(instance);
return Build(training_data, FeatureSet());
}
std::unique_ptr<RandomTree::TreeNode> RandomTree::Build(
std::unique_ptr<Model> RandomTreeTrainer::Build(
const TrainingData& training_data,
const FeatureSet& used_set) {
DCHECK(training_data.size());
......@@ -86,8 +93,7 @@ std::unique_ptr<RandomTree::TreeNode> RandomTree::Build(
Split best_potential_split;
// Select the feature subset to consider at this leaf.
// TODO(liberato): This should select a subset, which is why it's not merged
// with the loop below.
// TODO(liberato): subset.
FeatureSet feature_candidates;
for (size_t i = 0; i < training_data[0]->features.size(); i++) {
if (used_set.find(i) != used_set.end())
......@@ -127,8 +133,9 @@ std::unique_ptr<RandomTree::TreeNode> RandomTree::Build(
return node;
}
RandomTree::Split RandomTree::ConstructSplit(const TrainingData& training_data,
int index) {
RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit(
const TrainingData& training_data,
int index) {
// We should not be given a training set of size 0, since there's no need to
// check an empty split.
DCHECK_GT(training_data.size(), 0u);
......
......@@ -2,23 +2,21 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_RANDOM_TREE_H_
#define MEDIA_LEARNING_IMPL_RANDOM_TREE_H_
#ifndef MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
#define MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
#include <limits>
#include <memory>
#include <set>
#include <vector>
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/impl/learner.h"
#include "media/learning/impl/training_algorithm.h"
namespace media {
namespace learning {
// RandomTree decision tree classifier (doesn't handle regression currently).
// Trains RandomTree decision tree classifier (doesn't handle regression).
//
// Decision trees, including RandomTree, classify instances as follows. Each
// non-leaf node is marked with a feature number |i|. The value of the |i|-th
......@@ -56,24 +54,15 @@ namespace learning {
// TODO(liberato): Right now, it not-so-randomly selects from the entire set.
// TODO(liberato): consider PRF or other simplified approximations.
// TODO(liberato): separate Model and TrainingAlgorithm. This is the latter.
class COMPONENT_EXPORT(LEARNING_IMPL) RandomTree {
class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer {
public:
struct TreeNode {
// [target value] == counts
using TargetDistribution = std::map<TargetValue, int>;
virtual ~TreeNode();
virtual TargetDistribution* ComputeDistribution(
const FeatureVector& features) = 0;
};
RandomTree();
virtual ~RandomTree();
RandomTreeTrainer();
~RandomTreeTrainer();
// Train the tree.
void Train(const TrainingData& examples);
// Return a callback that can be used to train a random tree.
static TrainingAlgorithmCB GetTrainingAlgorithmCB();
const TreeNode::TargetDistribution* ComputeDistributionForTesting(
const FeatureVector& instance);
std::unique_ptr<Model> Train(const TrainingData& examples);
private:
// Set of feature indices.
......@@ -124,19 +113,16 @@ class COMPONENT_EXPORT(LEARNING_IMPL) RandomTree {
// Build this node from |training_data|. |used_set| is the set of features
// that we already used higher in the tree.
std::unique_ptr<TreeNode> Build(const TrainingData& training_data,
const FeatureSet& used_set);
std::unique_ptr<Model> Build(const TrainingData& training_data,
const FeatureSet& used_set);
// Compute and return a split of |training_data| on the |index|-th feature.
Split ConstructSplit(const TrainingData& training_data, int index);
// Root of the random tree, or null.
std::unique_ptr<TreeNode> root_;
DISALLOW_COPY_AND_ASSIGN(RandomTree);
DISALLOW_COPY_AND_ASSIGN(RandomTreeTrainer);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_RANDOM_TREE_H_
#endif // MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
......@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_tree.h"
#include "media/learning/impl/random_tree_trainer.h"
#include "testing/gtest/include/gtest/gtest.h"
......@@ -11,13 +11,15 @@ namespace learning {
class RandomTreeTest : public testing::Test {
public:
RandomTree tree_;
RandomTreeTrainer trainer_;
};
TEST_F(RandomTreeTest, EmptyTrainingDataWorks) {
TrainingData empty;
tree_.Train(empty);
EXPECT_EQ(tree_.ComputeDistributionForTesting(FeatureVector()), nullptr);
std::unique_ptr<Model> model = trainer_.Train(empty);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(model->PredictDistribution(FeatureVector()),
Model::TargetDistribution());
}
TEST_F(RandomTreeTest, UniformTrainingDataWorks) {
......@@ -27,34 +29,49 @@ TEST_F(RandomTreeTest, UniformTrainingDataWorks) {
const int n_examples = 10;
for (int i = 0; i < n_examples; i++)
training_data.push_back(&example);
tree_.Train(training_data);
std::unique_ptr<Model> model = trainer_.Train(training_data);
// The tree should produce a distribution for one value (our target), which
// has |n_examples| counts.
const RandomTree::TreeNode::TargetDistribution* distribution =
tree_.ComputeDistributionForTesting(example.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 1u);
EXPECT_EQ(distribution->find(example.target_value)->second, n_examples);
Model::TargetDistribution distribution =
model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example.target_value)->second, n_examples);
}
TEST_F(RandomTreeTest, UniformTrainingDataWorksWithCallback) {
TrainingExample example({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
TrainingData training_data;
const int n_examples = 10;
for (int i = 0; i < n_examples; i++)
training_data.push_back(&example);
std::unique_ptr<Model> model =
RandomTreeTrainer::GetTrainingAlgorithmCB().Run(training_data);
Model::TargetDistribution distribution =
model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example.target_value)->second, n_examples);
}
TEST_F(RandomTreeTest, SimpleSeparableTrainingData) {
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(456)}, TargetValue(2));
TrainingData training_data({&example_1, &example_2});
tree_.Train(training_data);
std::unique_ptr<Model> model = trainer_.Train(training_data);
// Each value should have a distribution with one target value with one count.
const RandomTree::TreeNode::TargetDistribution* distribution =
tree_.ComputeDistributionForTesting(example_1.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 1u);
EXPECT_EQ(distribution->find(example_1.target_value)->second, 1);
distribution = tree_.ComputeDistributionForTesting(example_2.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 1u);
EXPECT_EQ(distribution->find(example_2.target_value)->second, 1);
Model::TargetDistribution distribution =
model->PredictDistribution(example_1.features);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example_1.target_value)->second, 1);
distribution = model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example_2.target_value)->second, 1);
}
TEST_F(RandomTreeTest, ComplexSeparableTrainingData) {
......@@ -82,15 +99,15 @@ TEST_F(RandomTreeTest, ComplexSeparableTrainingData) {
training_data.push_back(&example);
}
tree_.Train(training_data);
std::unique_ptr<Model> model = trainer_.Train(training_data);
EXPECT_NE(model.get(), nullptr);
// Each example should have a distribution by itself, with two counts.
for (const TrainingExample* example : training_data) {
const RandomTree::TreeNode::TargetDistribution* distribution =
tree_.ComputeDistributionForTesting(example->features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 1u);
EXPECT_EQ(distribution->find(example->target_value)->second, 2);
Model::TargetDistribution distribution =
model->PredictDistribution(example->features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example->target_value)->second, 2);
}
}
......@@ -98,21 +115,20 @@ TEST_F(RandomTreeTest, UnseparableTrainingData) {
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(123)}, TargetValue(2));
TrainingData training_data({&example_1, &example_2});
tree_.Train(training_data);
std::unique_ptr<Model> model = trainer_.Train(training_data);
EXPECT_NE(model.get(), nullptr);
// Each value should have a distribution with two targets with one count each.
const RandomTree::TreeNode::TargetDistribution* distribution =
tree_.ComputeDistributionForTesting(example_1.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 2u);
EXPECT_EQ(distribution->find(example_1.target_value)->second, 1);
EXPECT_EQ(distribution->find(example_2.target_value)->second, 1);
distribution = tree_.ComputeDistributionForTesting(example_2.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 2u);
EXPECT_EQ(distribution->find(example_1.target_value)->second, 1);
EXPECT_EQ(distribution->find(example_2.target_value)->second, 1);
Model::TargetDistribution distribution =
model->PredictDistribution(example_1.features);
EXPECT_EQ(distribution.size(), 2u);
EXPECT_EQ(distribution.find(example_1.target_value)->second, 1);
EXPECT_EQ(distribution.find(example_2.target_value)->second, 1);
distribution = model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 2u);
EXPECT_EQ(distribution.find(example_1.target_value)->second, 1);
EXPECT_EQ(distribution.find(example_2.target_value)->second, 1);
}
} // namespace learning
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_TRAINING_ALGORITHM_H_
#define MEDIA_LEARNING_IMPL_TRAINING_ALGORITHM_H_
#include <memory>
#include "base/callback.h"
#include "media/learning/common/training_example.h"
#include "media/learning/impl/model.h"
namespace media {
namespace learning {
// A TrainingAlgorithm takes as input training examples, and produces as output
// a trained model that can be used for prediction.
// Train a model with on |examples| and return it.
// TODO(liberato): Switch to a callback to return the model.
using TrainingAlgorithmCB = base::RepeatingCallback<std::unique_ptr<Model>(
const TrainingData& examples)>;
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_TRAINING_ALGORITHM_H_
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment