Commit c04ea2cd authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Not-very-random RandomTree implementation

Adds a RandomTree that splits on features by decreasing entropy
improvement.  It does not select a random subset of features when
deciding on each split.  It also does not stop growing the tree
when no improvement is found, including pure leaves.

Bug: 902857
Change-Id: I742f64c71fd535bd7acc20b02b6478b20d9dae24
Reviewed-on: https://chromium-review.googlesource.com/c/1324130
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarFredrik Hubinette <hubbe@chromium.org>
Cr-Commit-Position: refs/heads/master@{#607368}
parent a5f88b7e
......@@ -9,6 +9,7 @@ source_set("unit_tests") {
"//base/test:test_support",
"//media:test_support",
"//media/learning/common:unit_tests",
"//media/learning/impl:unit_tests",
"//testing/gtest",
]
}
......@@ -39,5 +39,10 @@ bool TrainingExample::operator!=(const TrainingExample& rhs) const {
return !((*this) == rhs);
}
TrainingExample& TrainingExample::operator=(const TrainingExample& rhs) =
default;
TrainingExample& TrainingExample::operator=(TrainingExample&& rhs) = default;
} // namespace learning
} // namespace media
......@@ -16,6 +16,12 @@
namespace media {
namespace learning {
// Vector of features, for training or prediction.
// To interpret the features, one probably needs to check a LearningTask. It
// provides a description for each index. For example, [0]=="height",
// [1]=="url", etc.
using FeatureVector = std::vector<FeatureValue>;
// One training example == group of feature values, plus the desired target.
struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
TrainingExample();
......@@ -28,8 +34,13 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
bool operator==(const TrainingExample& rhs) const;
bool operator!=(const TrainingExample& rhs) const;
TrainingExample& operator=(const TrainingExample& rhs);
TrainingExample& operator=(TrainingExample&& rhs);
// Observed feature values.
std::vector<FeatureValue> features;
// Note that to interpret these values, you probably need to have the
// LearningTask that they're supposed to be used with.
FeatureVector features;
// Observed output value, when given |features| as input.
TargetValue target_value;
......@@ -37,6 +48,13 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
// Copy / assignment is allowed.
};
// Collection of training examples. We use a vector since we allow duplicates.
using TrainingDataStorage = std::vector<TrainingExample>;
// Collection of pointers to training data. References would be more convenient
// but they're not allowed.
using TrainingData = std::vector<const TrainingExample*>;
COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const TrainingExample& example);
......
......@@ -4,10 +4,14 @@
component("impl") {
output_name = "learning_impl"
visibility = [ "//media/learning/impl:unit_tests" ]
sources = [
"learner.h",
"learning_session_impl.cc",
"learning_session_impl.h",
"random_tree.cc",
"random_tree.h",
]
defines = [ "IS_LEARNING_IMPL_IMPL" ]
......@@ -20,3 +24,18 @@ component("impl") {
"//media/learning/common",
]
}
source_set("unit_tests") {
testonly = true
sources = [
"random_tree_unittest.cc",
]
deps = [
"//base/test:test_support",
"//media:test_support",
"//media/learning/impl",
"//testing/gtest",
]
}
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_tree.h"
namespace media {
namespace learning {
RandomTree::TreeNode::~TreeNode() = default;
RandomTree::Split::Split() = default;
RandomTree::Split::Split(int index) : split_index(index) {}
RandomTree::Split::Split(Split&& rhs) = default;
RandomTree::Split::~Split() = default;
RandomTree::Split& RandomTree::Split::operator=(Split&& rhs) = default;
RandomTree::Split::BranchInfo::BranchInfo() = default;
RandomTree::Split::BranchInfo::BranchInfo(const BranchInfo& rhs) = default;
RandomTree::Split::BranchInfo::~BranchInfo() = default;
struct InteriorNode : public RandomTree::TreeNode {
InteriorNode(int split_index) : split_index_(split_index) {}
// TreeNode
TargetDistribution* ComputeDistribution(
const FeatureVector& features) override {
auto iter = children_.find(features[split_index_]);
// If we've never seen this feature value, then make no prediction.
if (iter == children_.end())
return nullptr;
return iter->second->ComputeDistribution(features);
}
// Add |child| has the node for feature value |v|.
void AddChild(FeatureValue v, std::unique_ptr<TreeNode> child) {
DCHECK_EQ(children_.count(v), 0u);
children_.emplace(v, std::move(child));
}
private:
// Feature value that we split on.
int split_index_ = -1;
std::map<FeatureValue, std::unique_ptr<TreeNode>> children_;
};
struct LeafNode : public RandomTree::TreeNode {
LeafNode(const TrainingData& training_data) {
for (const TrainingExample* example : training_data)
distribution_[example->target_value]++;
}
// TreeNode
TargetDistribution* ComputeDistribution(const FeatureVector&) override {
return &distribution_;
}
private:
TargetDistribution distribution_;
};
RandomTree::RandomTree() = default;
RandomTree::~RandomTree() = default;
void RandomTree::Train(const TrainingData& training_data) {
root_ = nullptr;
if (training_data.empty())
return;
root_ = Build(training_data, FeatureSet());
}
const RandomTree::TreeNode::TargetDistribution*
RandomTree::ComputeDistributionForTesting(const FeatureVector& instance) {
if (!root_)
return nullptr;
return root_->ComputeDistribution(instance);
}
std::unique_ptr<RandomTree::TreeNode> RandomTree::Build(
const TrainingData& training_data,
const FeatureSet& used_set) {
DCHECK(training_data.size());
// TODO(liberato): Does it help if we refuse to split without an info gain?
Split best_potential_split;
// Select the feature subset to consider at this leaf.
// TODO(liberato): This should select a subset, which is why it's not merged
// with the loop below.
FeatureSet feature_candidates;
for (size_t i = 0; i < training_data[0]->features.size(); i++) {
if (used_set.find(i) != used_set.end())
continue;
feature_candidates.insert(i);
}
// Find the best split among the candidates that we have.
for (int i : feature_candidates) {
Split potential_split = ConstructSplit(training_data, i);
if (potential_split.nats_remaining < best_potential_split.nats_remaining) {
best_potential_split = std::move(potential_split);
}
}
// Note that we can have a split with no index (i.e., no features left, or no
// feature was an improvement in nats), or with a single index (had features,
// but all had the same value). Either way, we should end up with a leaf.
if (best_potential_split.branch_infos.size() < 2) {
// Stop when there is no more tree.
return std::make_unique<LeafNode>(training_data);
}
// Build an interior node
std::unique_ptr<InteriorNode> node =
std::make_unique<InteriorNode>(best_potential_split.split_index);
// Don't let the subtree use this feature.
FeatureSet new_used_set(used_set);
new_used_set.insert(best_potential_split.split_index);
for (auto& branch_iter : best_potential_split.branch_infos) {
node->AddChild(branch_iter.first,
Build(branch_iter.second.training_data, new_used_set));
}
return node;
}
RandomTree::Split RandomTree::ConstructSplit(const TrainingData& training_data,
int index) {
// We should not be given a training set of size 0, since there's no need to
// check an empty split.
DCHECK_GT(training_data.size(), 0u);
Split split(index);
// Find the split's feature values and construct the training set for each.
// I think we want to iterate on the underlying vector, and look up the int in
// the training data directly.
for (const TrainingExample* example : training_data) {
// Get the value of the |index|-th feature for
FeatureValue v_i = example->features[split.split_index];
// Add |v_i| to the right training set.
Split::BranchInfo& branch_info = split.branch_infos[v_i];
branch_info.training_data.push_back(example);
branch_info.class_counts[example->target_value]++;
}
// Compute the nats given that we're at this node.
split.nats_remaining = 0;
for (auto& info_iter : split.branch_infos) {
Split::BranchInfo& branch_info = info_iter.second;
const int total_counts = branch_info.training_data.size();
for (auto& iter : branch_info.class_counts) {
double p = ((double)iter.second) / total_counts;
split.nats_remaining -= p * log(p);
}
}
return split;
}
} // namespace learning
} // namespace media
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_RANDOM_TREE_H_
#define MEDIA_LEARNING_IMPL_RANDOM_TREE_H_
#include <limits>
#include <memory>
#include <set>
#include <vector>
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/impl/learner.h"
namespace media {
namespace learning {
// RandomTree decision tree classifier (doesn't handle regression currently).
//
// Decision trees, including RandomTree, classify instances as follows. Each
// non-leaf node is marked with a feature number |i|. The value of the |i|-th
// feature of the instance is then used to select which outgoing edge is
// traversed. This repeats until arriving at a leaf, which has a distribution
// over target values that is the prediction. The tree structure, including the
// feature index at each node and distribution at each leaf, is chosen once when
// the tree is trained.
//
// Training involves starting with a set of training examples, each of which has
// features and a target value. The tree is constructed recursively, starting
// with the root. For the node being constructed, the training algorithm is
// given the portion of the training set that would reach the node, if it were
// sent down the tree in a similar fashion as described above. It then
// considers assigning each (unused) feature index as the index to split the
// training examples at this node. For each index |t|, it groups the training
// set into subsets, each of which consists of those examples with the same
// of the |i|-th feature. It then computes a score for the split using the
// target values that ended up in each group. The index with the best score is
// chosen for the split.
//
// The training algorithm then recurses to build child nodes. One child node is
// created for each observed value of the |i|-th feature in the training set.
// The child node is trained using the subset of the training set that shares
// that node's value for feature |i|.
//
// The above is a generic decision tree training algorithm. A RandomTree
// differs from that mostly in how it selects the feature to split at each node
// during training. Rather than computing a score for each feature, a
// RandomTree chooses a random subset of the features and only compares those.
//
// See https://en.wikipedia.org/wiki/Random_forest for information. Note that
// this is just a single tree, not the whole forest.
//
// TODO(liberato): Right now, it not-so-randomly selects from the entire set.
// TODO(liberato): consider PRF or other simplified approximations.
// TODO(liberato): separate Model and TrainingAlgorithm. This is the latter.
class COMPONENT_EXPORT(LEARNING_IMPL) RandomTree {
public:
struct TreeNode {
// [target value] == counts
using TargetDistribution = std::map<TargetValue, int>;
virtual ~TreeNode();
virtual TargetDistribution* ComputeDistribution(
const FeatureVector& features) = 0;
};
RandomTree();
virtual ~RandomTree();
// Train the tree.
void Train(const TrainingData& examples);
const TreeNode::TargetDistribution* ComputeDistributionForTesting(
const FeatureVector& instance);
private:
// Set of feature indices.
using FeatureSet = std::set<int>;
// Information about a proposed split, and the training sets that would result
// from that split.
struct Split {
Split();
explicit Split(int index);
Split(Split&& rhs);
~Split();
Split& operator=(Split&& rhs);
// Feature index to split on.
size_t split_index = 0;
// Expected nats needed to compute the class, given that we're at this
// node in the tree.
// "nat" == entropy measured with natural log rather than base-2.
double nats_remaining = std::numeric_limits<double>::infinity();
// Per-branch (i.e. per-child node) information about this split.
struct BranchInfo {
explicit BranchInfo();
BranchInfo(const BranchInfo& rhs);
~BranchInfo();
// Training set for this branch of the split.
TrainingData training_data;
// Number of occurances of each target value in |training_data| along this
// branch of the split.
std::map<TargetValue, int> class_counts;
};
// [feature value at this split] = info about which examples take this
// branch of the split.
// TODO(liberato): this complained about not having copy-assignment,
// which makes me worried that it does a lot of copy-assignment. that
// was when it was a map <FeatureValue, TrainingData>. consider just
// making it a unique_ptr<BranchInfo>.
std::map<FeatureValue, BranchInfo> branch_infos;
DISALLOW_COPY_AND_ASSIGN(Split);
};
// Build this node from |training_data|. |used_set| is the set of features
// that we already used higher in the tree.
std::unique_ptr<TreeNode> Build(const TrainingData& training_data,
const FeatureSet& used_set);
// Compute and return a split of |training_data| on the |index|-th feature.
Split ConstructSplit(const TrainingData& training_data, int index);
// Root of the random tree, or null.
std::unique_ptr<TreeNode> root_;
DISALLOW_COPY_AND_ASSIGN(RandomTree);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_RANDOM_TREE_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_tree.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class RandomTreeTest : public testing::Test {
public:
RandomTree tree_;
};
TEST_F(RandomTreeTest, EmptyTrainingDataWorks) {
TrainingData empty;
tree_.Train(empty);
EXPECT_EQ(tree_.ComputeDistributionForTesting(FeatureVector()), nullptr);
}
TEST_F(RandomTreeTest, UniformTrainingDataWorks) {
TrainingExample example({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
TrainingData training_data;
const int n_examples = 10;
for (int i = 0; i < n_examples; i++)
training_data.push_back(&example);
tree_.Train(training_data);
// The tree should produce a distribution for one value (our target), which
// has |n_examples| counts.
const RandomTree::TreeNode::TargetDistribution* distribution =
tree_.ComputeDistributionForTesting(example.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 1u);
EXPECT_EQ(distribution->find(example.target_value)->second, n_examples);
}
TEST_F(RandomTreeTest, SimpleSeparableTrainingData) {
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(456)}, TargetValue(2));
TrainingData training_data({&example_1, &example_2});
tree_.Train(training_data);
// Each value should have a distribution with one target value with one count.
const RandomTree::TreeNode::TargetDistribution* distribution =
tree_.ComputeDistributionForTesting(example_1.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 1u);
EXPECT_EQ(distribution->find(example_1.target_value)->second, 1);
distribution = tree_.ComputeDistributionForTesting(example_2.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 1u);
EXPECT_EQ(distribution->find(example_2.target_value)->second, 1);
}
TEST_F(RandomTreeTest, ComplexSeparableTrainingData) {
// Build a four-feature training set that's completely separable, but one
// needs all four features to do it.
TrainingDataStorage training_data_storage;
for (int f1 = 0; f1 < 2; f1++) {
for (int f2 = 0; f2 < 2; f2++) {
for (int f3 = 0; f3 < 2; f3++) {
for (int f4 = 0; f4 < 2; f4++) {
training_data_storage.push_back(
TrainingExample({FeatureValue(f1), FeatureValue(f2),
FeatureValue(f3), FeatureValue(f4)},
TargetValue(f1 * 1 + f2 * 2 + f3 * 4 + f4 * 8)));
}
}
}
}
// Add two copies of each example. Note that we do this after fully
// constructing |training_data_storage|, since it may realloc.
TrainingData training_data;
for (auto& example : training_data_storage) {
training_data.push_back(&example);
training_data.push_back(&example);
}
tree_.Train(training_data);
// Each example should have a distribution by itself, with two counts.
for (const TrainingExample* example : training_data) {
const RandomTree::TreeNode::TargetDistribution* distribution =
tree_.ComputeDistributionForTesting(example->features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 1u);
EXPECT_EQ(distribution->find(example->target_value)->second, 2);
}
}
TEST_F(RandomTreeTest, UnseparableTrainingData) {
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(123)}, TargetValue(2));
TrainingData training_data({&example_1, &example_2});
tree_.Train(training_data);
// Each value should have a distribution with two targets with one count each.
const RandomTree::TreeNode::TargetDistribution* distribution =
tree_.ComputeDistributionForTesting(example_1.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 2u);
EXPECT_EQ(distribution->find(example_1.target_value)->second, 1);
EXPECT_EQ(distribution->find(example_2.target_value)->second, 1);
distribution = tree_.ComputeDistributionForTesting(example_2.features);
EXPECT_NE(distribution, nullptr);
EXPECT_EQ(distribution->size(), 2u);
EXPECT_EQ(distribution->find(example_1.target_value)->second, 1);
EXPECT_EQ(distribution->find(example_2.target_value)->second, 1);
}
} // namespace learning
} // namespace media
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment