Commit cb16b703 authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

ExtraTrees model.

This CL introduces the ExtraTrees model, as a replacement for
RandomForest.  This allows us to skip bagging during training, and
perform simpler feature splits when constructing the trees.

This also renames the RandomForest model to VotingEnsemble, since it
is used by ExtraTrees as well as RandomForest, and isn't really
specific to RF at all.

Change-Id: I91647ef374a7889fabcbf57d902651e8373438cc
Reviewed-on: https://chromium-review.googlesource.com/c/1389016Reviewed-by: default avatarDan Sanders <sandersd@chromium.org>
Commit-Queue: Frank Liberato <liberato@chromium.org>
Cr-Commit-Position: refs/heads/master@{#619489}
parent c394cd4e
......@@ -28,6 +28,7 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// combination of orderings and types.
enum class Model {
kRandomForest,
kExtraTrees,
};
enum class Ordering {
......
......@@ -22,10 +22,14 @@ TrainingExample::TrainingExample(TrainingExample&& rhs) noexcept = default;
TrainingExample::~TrainingExample() = default;
std::ostream& operator<<(std::ostream& out, const TrainingExample& example) {
for (const auto& feature : example.features)
out << " " << feature;
out << example.features << " => " << example.target_value;
out << " => " << example.target_value;
return out;
}
std::ostream& operator<<(std::ostream& out, const FeatureVector& features) {
for (const auto& feature : features)
out << " " << feature;
return out;
}
......
......@@ -107,6 +107,9 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const TrainingExample& example);
COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const FeatureVector& features);
} // namespace learning
} // namespace media
......
......@@ -7,6 +7,8 @@ component("impl") {
visibility = [ "//media/learning/impl:unit_tests" ]
sources = [
"extra_trees_trainer.cc",
"extra_trees_trainer.h",
"learning_session_impl.cc",
"learning_session_impl.h",
"learning_task_controller.h",
......@@ -15,8 +17,6 @@ component("impl") {
"model.h",
"one_hot.cc",
"one_hot.h",
"random_forest.cc",
"random_forest.h",
"random_forest_trainer.cc",
"random_forest_trainer.h",
"random_number_generator.cc",
......@@ -26,6 +26,8 @@ component("impl") {
"target_distribution.cc",
"target_distribution.h",
"training_algorithm.h",
"voting_ensemble.cc",
"voting_ensemble.h",
]
defines = [ "IS_LEARNING_IMPL_IMPL" ]
......@@ -43,6 +45,7 @@ source_set("unit_tests") {
testonly = true
sources = [
"extra_trees_trainer_unittest.cc",
"fisher_iris_dataset.cc",
"fisher_iris_dataset.h",
"learning_session_impl_unittest.cc",
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/extra_trees_trainer.h"
#include <set>
#include "base/logging.h"
#include "media/learning/impl/one_hot.h"
#include "media/learning/impl/random_tree_trainer.h"
#include "media/learning/impl/voting_ensemble.h"
namespace media {
namespace learning {
ExtraTreesTrainer::ExtraTreesTrainer() = default;
ExtraTreesTrainer::~ExtraTreesTrainer() = default;
std::unique_ptr<Model> ExtraTreesTrainer::Train(
const LearningTask& task,
const TrainingData& training_data) {
int n_trees = task.rf_number_of_trees;
RandomTreeTrainer tree_trainer(rng());
std::vector<std::unique_ptr<Model>> trees;
trees.reserve(n_trees);
// RandomTree requires one-hot vectors to properly choose split points the way
// that ExtraTrees require.
std::unique_ptr<OneHotConverter> converter =
std::make_unique<OneHotConverter>(task, training_data);
TrainingData converted_training_data = converter->Convert(training_data);
for (int i = 0; i < n_trees; i++) {
// Train the tree.
std::unique_ptr<Model> tree = tree_trainer.Train(
converter->converted_task(), converted_training_data);
trees.push_back(std::move(tree));
}
return std::make_unique<ConvertingModel>(
std::move(converter), std::make_unique<VotingEnsemble>(std::move(trees)));
}
} // namespace learning
} // namespace media
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_EXTRA_TREES_TRAINER_H_
#define MEDIA_LEARNING_IMPL_EXTRA_TREES_TRAINER_H_
#include <memory>
#include <vector>
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/random_number_generator.h"
#include "media/learning/impl/training_algorithm.h"
namespace media {
namespace learning {
// Bagged forest of extremely randomized trees.
//
// These are an ensemble of trees. Each tree is constructed from the full
// training set. The trees are constructed by selecting a random subset of
// features at each node. For each feature, a uniformly random split point is
// chosen. The feature with the best randomly chosen split point is used.
//
// These will automatically convert nominal values to one-hot vectors.
class COMPONENT_EXPORT(LEARNING_IMPL) ExtraTreesTrainer
: public HasRandomNumberGenerator {
public:
ExtraTreesTrainer();
~ExtraTreesTrainer();
std::unique_ptr<Model> Train(const LearningTask& task,
const TrainingData& training_data);
private:
DISALLOW_COPY_AND_ASSIGN(ExtraTreesTrainer);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_EXTRA_TREES_TRAINER_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/extra_trees_trainer.h"
#include "base/memory/ref_counted.h"
#include "media/learning/impl/fisher_iris_dataset.h"
#include "media/learning/impl/test_random_number_generator.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class ExtraTreesTest : public testing::TestWithParam<LearningTask::Ordering> {
public:
ExtraTreesTest() : rng_(0), ordering_(GetParam()) {
trainer_.SetRandomNumberGeneratorForTesting(&rng_);
}
// Set up |task_| to have |n| features with the given ordering.
void SetupFeatures(size_t n) {
for (size_t i = 0; i < n; i++) {
LearningTask::ValueDescription desc;
desc.ordering = ordering_;
task_.feature_descriptions.push_back(desc);
}
}
TestRandomNumberGenerator rng_;
ExtraTreesTrainer trainer_;
LearningTask task_;
// Feature ordering.
LearningTask::Ordering ordering_;
};
TEST_P(ExtraTreesTest, EmptyTrainingDataWorks) {
TrainingData empty;
auto model = trainer_.Train(task_, empty);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetDistribution());
}
TEST_P(ExtraTreesTest, FisherIrisDataset) {
SetupFeatures(4);
FisherIrisDataset iris;
TrainingData training_data = iris.GetTrainingData();
auto model = trainer_.Train(task_, training_data);
// Verify predictions on the training set, just for sanity.
size_t num_correct = 0;
for (const TrainingExample& example : training_data) {
TargetDistribution distribution =
model->PredictDistribution(example.features);
TargetValue predicted_value;
if (distribution.FindSingularMax(&predicted_value) &&
predicted_value == example.target_value) {
num_correct += example.weight;
}
}
// Expect very high accuracy. We should get ~100%.
// We get about 96% for one-hot features, and 100% for numeric. Since the
// data really is numeric, that seems reasonable.
double train_accuracy = ((double)num_correct) / training_data.total_weight();
EXPECT_GT(train_accuracy, 0.95);
}
TEST_P(ExtraTreesTest, WeightedTrainingSetIsSupported) {
// Create a training set with unseparable data, but give one of them a large
// weight. See if that one wins.
SetupFeatures(1);
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(123)}, TargetValue(2));
const size_t weight = 100;
TrainingData training_data;
example_1.weight = weight;
training_data.push_back(example_1);
// Push many |example_2|'s, which will win without the weights.
training_data.push_back(example_2);
training_data.push_back(example_2);
training_data.push_back(example_2);
training_data.push_back(example_2);
// Create a weighed set with |weight| for each example's weight.
EXPECT_FALSE(training_data.is_unweighted());
auto model = trainer_.Train(task_, training_data);
// The singular max should be example_1.
TargetDistribution distribution =
model->PredictDistribution(example_1.features);
TargetValue predicted_value;
EXPECT_TRUE(distribution.FindSingularMax(&predicted_value));
EXPECT_EQ(predicted_value, example_1.target_value);
}
INSTANTIATE_TEST_CASE_P(ExtraTreesTest,
ExtraTreesTest,
testing::ValuesIn({LearningTask::Ordering::kUnordered,
LearningTask::Ordering::kNumeric}));
} // namespace learning
} // namespace media
......@@ -7,6 +7,7 @@
#include <memory>
#include "base/bind.h"
#include "media/learning/impl/extra_trees_trainer.h"
#include "media/learning/impl/random_tree_trainer.h"
namespace media {
......@@ -15,6 +16,15 @@ namespace learning {
LearningTaskControllerImpl::LearningTaskControllerImpl(const LearningTask& task)
: task_(task), training_data_(std::make_unique<TrainingData>()) {
switch (task_.model) {
case LearningTask::Model::kExtraTrees:
training_cb_ = base::BindRepeating(
[](const LearningTask& task, TrainingData training_data,
TrainedModelCB model_cb) {
ExtraTreesTrainer trainer;
std::move(model_cb).Run(trainer.Train(task, training_data));
},
task_);
break;
case LearningTask::Model::kRandomForest:
// TODO(liberato): forest!
training_cb_ = RandomTreeTrainer::GetTrainingAlgorithmCB(task_);
......
......@@ -7,8 +7,8 @@
#include <set>
#include "base/logging.h"
#include "media/learning/impl/random_forest.h"
#include "media/learning/impl/random_tree_trainer.h"
#include "media/learning/impl/voting_ensemble.h"
namespace media {
namespace learning {
......@@ -77,8 +77,8 @@ std::unique_ptr<RandomForestTrainer::TrainingResult> RandomForestTrainer::Train(
trees.push_back(std::move(tree));
}
std::unique_ptr<RandomForest> forest =
std::make_unique<RandomForest>(std::move(trees));
std::unique_ptr<VotingEnsemble> forest =
std::make_unique<VotingEnsemble>(std::move(trees));
// Compute OOB accuracy.
int num_correct = 0;
......
......@@ -42,6 +42,22 @@ namespace learning {
// target values that ended up in each group. The index with the best score is
// chosen for the split.
//
// For nominal features, we split the feature into all of its nominal values.
// This is somewhat nonstandard; one would normally convert to one-hot numeric
// features first. See OneHotConverter if you'd like to do this.
//
// For numeric features, we choose a split point uniformly at random between its
// min and max values in the training data. We do this because it's suitable
// for extra trees. RandomForest trees want to select the best split point for
// each feature, rather than uniformly. Either way, of course, we choose the
// best split among the (feature, split point) pairs we're considering.
//
// Also note that for one-hot features, these are the same thing. So, this
// implementation is suitable for extra trees with numeric (possibly one hot)
// features, or RF with one-hot nominal features. Note that non-one-hot nominal
// features probably work fine with RF too. Numeric, non-binary features don't
// work with RF, unless one changes the split point selection.
//
// The training algorithm then recurses to build child nodes. One child node is
// created for each observed value of the |i|-th feature in the training set.
// The child node is trained using the subset of the training set that shares
......
......@@ -2,24 +2,24 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_forest.h"
#include "media/learning/impl/voting_ensemble.h"
namespace media {
namespace learning {
RandomForest::RandomForest(std::vector<std::unique_ptr<Model>> trees)
: trees_(std::move(trees)) {}
VotingEnsemble::VotingEnsemble(std::vector<std::unique_ptr<Model>> models)
: models_(std::move(models)) {}
RandomForest::~RandomForest() = default;
VotingEnsemble::~VotingEnsemble() = default;
TargetDistribution RandomForest::PredictDistribution(
TargetDistribution VotingEnsemble::PredictDistribution(
const FeatureVector& instance) {
TargetDistribution forest_distribution;
TargetDistribution distribution;
for (auto iter = trees_.begin(); iter != trees_.end(); iter++)
forest_distribution += (*iter)->PredictDistribution(instance);
for (auto iter = models_.begin(); iter != models_.end(); iter++)
distribution += (*iter)->PredictDistribution(instance);
return forest_distribution;
return distribution;
}
} // namespace learning
......
......@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_RANDOM_FOREST_H_
#define MEDIA_LEARNING_IMPL_RANDOM_FOREST_H_
#ifndef MEDIA_LEARNING_IMPL_VOTING_ENSEMBLE_H_
#define MEDIA_LEARNING_IMPL_VOTING_ENSEMBLE_H_
#include <memory>
#include <vector>
......@@ -15,25 +15,24 @@
namespace media {
namespace learning {
// Bagged forest of randomized trees.
// TODO(liberato): consider a generic Bagging class, since this doesn't really
// depend on RandomTree at all.
class COMPONENT_EXPORT(LEARNING_IMPL) RandomForest : public Model {
// Ensemble classifier. Takes multiple models and returns an aggregate of the
// individual predictions.
class COMPONENT_EXPORT(LEARNING_IMPL) VotingEnsemble : public Model {
public:
RandomForest(std::vector<std::unique_ptr<Model>> trees);
~RandomForest() override;
VotingEnsemble(std::vector<std::unique_ptr<Model>> models);
~VotingEnsemble() override;
// Model
TargetDistribution PredictDistribution(
const FeatureVector& instance) override;
private:
std::vector<std::unique_ptr<Model>> trees_;
std::vector<std::unique_ptr<Model>> models_;
DISALLOW_COPY_AND_ASSIGN(RandomForest);
DISALLOW_COPY_AND_ASSIGN(VotingEnsemble);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_RANDOM_FOREST_H_
#endif // MEDIA_LEARNING_IMPL_VOTING_ENSEMBLE_H_
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment