Commit dddfb603 authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Remove RandomForest in favor of ExtraTrees.

Since ExtraTrees seem to be more useful than RandomForest for
our application, remove RF in favor of them.

Change-Id: Ia7301725aa55032bc7fded74139cf143f76f82f2
Reviewed-on: https://chromium-review.googlesource.com/c/1407605Reviewed-by: default avatarDan Sanders <sandersd@chromium.org>
Commit-Queue: Frank Liberato <liberato@chromium.org>
Cr-Commit-Position: refs/heads/master@{#622159}
parent e777a3d0
......@@ -27,7 +27,6 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// support kUnordered features or targets. kRandomForest might support more
// combination of orderings and types.
enum class Model {
kRandomForest,
kExtraTrees,
};
......@@ -77,7 +76,7 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// Unique name for this learner.
std::string name;
Model model = Model::kRandomForest;
Model model = Model::kExtraTrees;
std::vector<ValueDescription> feature_descriptions;
......
......@@ -17,8 +17,6 @@ component("impl") {
"model.h",
"one_hot.cc",
"one_hot.h",
"random_forest_trainer.cc",
"random_forest_trainer.h",
"random_number_generator.cc",
"random_number_generator.h",
"random_tree_trainer.cc",
......@@ -51,7 +49,6 @@ source_set("unit_tests") {
"learning_session_impl_unittest.cc",
"learning_task_controller_impl_unittest.cc",
"one_hot_unittest.cc",
"random_forest_trainer_unittest.cc",
"random_number_generator_unittest.cc",
"random_tree_trainer_unittest.cc",
"target_distribution_unittest.cc",
......
......@@ -25,10 +25,6 @@ LearningTaskControllerImpl::LearningTaskControllerImpl(const LearningTask& task)
},
task_);
break;
case LearningTask::Model::kRandomForest:
// TODO(liberato): forest!
training_cb_ = RandomTreeTrainer::GetTrainingAlgorithmCB(task_);
break;
}
// TODO(liberato): Record via UMA based on the task name.
......
......@@ -15,7 +15,7 @@ class OneHotTest : public testing::Test {
};
TEST_F(OneHotTest, EmptyLearningTaskWorks) {
LearningTask empty_task("EmptyTask", LearningTask::Model::kRandomForest, {},
LearningTask empty_task("EmptyTask", LearningTask::Model::kExtraTrees, {},
LearningTask::ValueDescription({"target"}));
TrainingData empty_training_data;
OneHotConverter one_hot(empty_task, empty_training_data);
......@@ -23,7 +23,7 @@ TEST_F(OneHotTest, EmptyLearningTaskWorks) {
}
TEST_F(OneHotTest, SimpleConversionWorks) {
LearningTask task("SimpleTask", LearningTask::Model::kRandomForest,
LearningTask task("SimpleTask", LearningTask::Model::kExtraTrees,
{{"feature1", LearningTask::Ordering::kUnordered}},
LearningTask::ValueDescription({"target"}));
TrainingData training_data;
......@@ -78,7 +78,7 @@ TEST_F(OneHotTest, SimpleConversionWorks) {
}
TEST_F(OneHotTest, NumericsAreNotConverted) {
LearningTask task("SimpleTask", LearningTask::Model::kRandomForest,
LearningTask task("SimpleTask", LearningTask::Model::kExtraTrees,
{{"feature1", LearningTask::Ordering::kNumeric}},
LearningTask::ValueDescription({"target"}));
OneHotConverter one_hot(task, TrainingData());
......@@ -97,7 +97,7 @@ TEST_F(OneHotTest, NumericsAreNotConverted) {
}
TEST_F(OneHotTest, UnknownValuesAreZeroHot) {
LearningTask task("SimpleTask", LearningTask::Model::kRandomForest,
LearningTask task("SimpleTask", LearningTask::Model::kExtraTrees,
{{"feature1", LearningTask::Ordering::kUnordered}},
LearningTask::ValueDescription({"target"}));
TrainingData training_data;
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_forest_trainer.h"
#include <set>
#include "base/logging.h"
#include "media/learning/impl/random_tree_trainer.h"
#include "media/learning/impl/voting_ensemble.h"
namespace media {
namespace learning {
RandomForestTrainer::TrainingResult::TrainingResult() = default;
RandomForestTrainer::TrainingResult::~TrainingResult() = default;
RandomForestTrainer::RandomForestTrainer() = default;
RandomForestTrainer::~RandomForestTrainer() = default;
std::unique_ptr<RandomForestTrainer::TrainingResult> RandomForestTrainer::Train(
const LearningTask& task,
const TrainingData& training_data) {
int n_trees = task.rf_number_of_trees;
RandomTreeTrainer tree_trainer(rng());
std::vector<std::unique_ptr<Model>> trees;
trees.reserve(n_trees);
// [example index] = sum of all oob predictions
std::map<size_t, TargetDistribution> oob_distributions;
// We don't support weighted training data, since bagging with weights is
// very hard to do right without spending time that depends on the value of
// the weights. Since ExactTrees don't need bagging, we skip it.
if (!training_data.is_unweighted())
return std::make_unique<TrainingResult>();
const size_t n_examples = training_data.size();
for (int i = 0; i < n_trees; i++) {
// Collect a bagged training set and oob data for it.
std::vector<size_t> bagged_idx;
bagged_idx.reserve(training_data.size());
std::set<size_t> bagged_set;
for (size_t e = 0; e < n_examples; e++) {
size_t idx = rng()->Generate(n_examples);
bagged_idx.push_back(idx);
bagged_set.insert(idx);
}
// Train the tree.
std::unique_ptr<Model> tree =
tree_trainer.Train(task, training_data, bagged_idx);
// Compute OOB distribution.
for (size_t e = 0; e < n_examples; e++) {
if (bagged_set.find(e) != bagged_set.end())
continue;
const LabelledExample& example = training_data[e];
TargetDistribution predicted =
tree->PredictDistribution(example.features);
// Add the predicted distribution to this example's total distribution.
// Remember that the distribution is not normalized, so the counts will
// scale with the number of examples.
// TODO(liberato): Should it be normalized before being combined?
TargetDistribution& our_oob_dist = oob_distributions[e];
our_oob_dist += predicted;
}
trees.push_back(std::move(tree));
}
std::unique_ptr<VotingEnsemble> forest =
std::make_unique<VotingEnsemble>(std::move(trees));
// Compute OOB accuracy.
int num_correct = 0;
for (auto& oob_pair : oob_distributions) {
const LabelledExample& example = training_data[oob_pair.first];
const TargetDistribution& distribution = oob_pair.second;
// If there are no guesses, or if it's a tie, then count it as wrong.
TargetValue max_value;
if (distribution.FindSingularMax(&max_value) &&
max_value == example.target_value) {
num_correct++;
}
}
std::unique_ptr<TrainingResult> result = std::make_unique<TrainingResult>();
result->model = std::move(forest);
result->oob_correct = num_correct;
result->oob_total = oob_distributions.size();
return result;
}
} // namespace learning
} // namespace media
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_RANDOM_FOREST_TRAINER_H_
#define MEDIA_LEARNING_IMPL_RANDOM_FOREST_TRAINER_H_
#include <memory>
#include <vector>
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/random_number_generator.h"
#include "media/learning/impl/training_algorithm.h"
namespace media {
namespace learning {
// Bagged forest of randomized trees.
// TODO(liberato): consider a generic Bagging class.
class COMPONENT_EXPORT(LEARNING_IMPL) RandomForestTrainer
: public HasRandomNumberGenerator {
public:
RandomForestTrainer();
~RandomForestTrainer();
struct COMPONENT_EXPORT(LEARNING_IMPL) TrainingResult {
TrainingResult();
~TrainingResult();
std::unique_ptr<Model> model;
// Number of correctly classified oob samples.
size_t oob_correct = 0;
// TODO: include oob entropy and oob unrepresentable?
// Total number of oob samples.
size_t oob_total = 0;
DISALLOW_COPY_AND_ASSIGN(TrainingResult);
};
std::unique_ptr<TrainingResult> Train(const LearningTask& task,
const TrainingData& training_data);
private:
DISALLOW_COPY_AND_ASSIGN(RandomForestTrainer);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_RANDOM_FOREST_TRAINER_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_forest_trainer.h"
#include "base/memory/ref_counted.h"
#include "media/learning/impl/fisher_iris_dataset.h"
#include "media/learning/impl/test_random_number_generator.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class RandomForestTest : public testing::TestWithParam<LearningTask::Ordering> {
public:
RandomForestTest()
: rng_(0),
ordering_(GetParam()) {
trainer_.SetRandomNumberGeneratorForTesting(&rng_);
}
// Set up |task_| to have |n| features with the given ordering.
void SetupFeatures(size_t n) {
for (size_t i = 0; i < n; i++) {
LearningTask::ValueDescription desc;
desc.ordering = ordering_;
task_.feature_descriptions.push_back(desc);
}
}
TestRandomNumberGenerator rng_;
RandomForestTrainer trainer_;
LearningTask task_;
// Feature ordering.
LearningTask::Ordering ordering_;
};
TEST_P(RandomForestTest, EmptyTrainingDataWorks) {
TrainingData empty;
auto result = trainer_.Train(task_, empty);
EXPECT_NE(result->model.get(), nullptr);
EXPECT_EQ(result->model->PredictDistribution(FeatureVector()),
TargetDistribution());
}
TEST_P(RandomForestTest, UniformTrainingDataWorks) {
SetupFeatures(2);
LabelledExample example({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
const int n_examples = 10;
// We need distinct pointers.
TrainingData training_data;
for (int i = 0; i < n_examples; i++)
training_data.push_back(example);
auto result = trainer_.Train(task_, training_data);
// The tree should produce a distribution for one value (our target), which
// has |n_examples| counts.
TargetDistribution distribution =
result->model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u);
}
TEST_P(RandomForestTest, SimpleSeparableTrainingData) {
SetupFeatures(1);
// TODO: oob estimates aren't so good if a target only shows up once. any
// tree that trains on it won't be used to predict it during oob accuracy,
// and the remaining trees will get it wrong.
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(456)}, TargetValue(2));
TrainingData training_data;
training_data.push_back(example_1);
training_data.push_back(example_2);
auto result = trainer_.Train(task_, training_data);
// Each value should have a distribution with the correct target value.
TargetDistribution distribution =
result->model->PredictDistribution(example_1.features);
EXPECT_NE(result->model.get(), nullptr);
TargetValue max_1;
EXPECT_TRUE(distribution.FindSingularMax(&max_1));
EXPECT_EQ(max_1, example_1.target_value);
distribution = result->model->PredictDistribution(example_2.features);
TargetValue max_2;
EXPECT_TRUE(distribution.FindSingularMax(&max_2));
EXPECT_EQ(max_2, example_2.target_value);
}
TEST_P(RandomForestTest, ComplexSeparableTrainingData) {
SetupFeatures(4);
// Build a four-feature training set that's completely separable, but one
// needs all four features to do it.
TrainingData training_data;
for (int f1 = 0; f1 < 2; f1++) {
for (int f2 = 0; f2 < 2; f2++) {
for (int f3 = 0; f3 < 2; f3++) {
for (int f4 = 0; f4 < 2; f4++) {
LabelledExample example(
{FeatureValue(f1), FeatureValue(f2), FeatureValue(f3),
FeatureValue(f4)},
TargetValue(f1 * 1 + f2 * 2 + f3 * 4 + f4 * 8));
// Add two distinct copies of each example.
// i guess we don't need to, but oob estimation won't work.
training_data.push_back(example);
training_data.push_back(example);
}
}
}
}
auto result = trainer_.Train(task_, training_data);
EXPECT_NE(result->model.get(), nullptr);
// Each example should have a distribution in which it is the max.
for (const LabelledExample& example : training_data) {
TargetDistribution distribution =
result->model->PredictDistribution(example.features);
TargetValue max_value;
EXPECT_TRUE(distribution.FindSingularMax(&max_value));
EXPECT_EQ(max_value, example.target_value);
}
}
TEST_P(RandomForestTest, UnseparableTrainingData) {
SetupFeatures(1);
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
TrainingData training_data;
training_data.push_back(example_1);
training_data.push_back(example_2);
auto result = trainer_.Train(task_, training_data);
EXPECT_NE(result->model.get(), nullptr);
// Each value should have a distribution with two targets.
TargetDistribution distribution =
result->model->PredictDistribution(example_1.features);
EXPECT_EQ(distribution.size(), 2u);
distribution = result->model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 2u);
}
TEST_P(RandomForestTest, FisherIrisDataset) {
SetupFeatures(4);
FisherIrisDataset iris;
TrainingData training_data = iris.GetTrainingData();
auto result = trainer_.Train(task_, training_data);
// Require at least 75% oob data. Should probably be ~100%.
EXPECT_GT(result->oob_total, training_data.total_weight() * 0.75);
// Require at least 85% oob accuracy. We actually get about 88% (kUnordered)
// or 95% (kOrdered).
double oob_accuracy = ((double)result->oob_correct) / result->oob_total;
EXPECT_GT(oob_accuracy, 0.85);
// Verify predictions on the training set, just for sanity.
size_t num_correct = 0;
for (const LabelledExample& example : training_data) {
TargetDistribution distribution =
result->model->PredictDistribution(example.features);
TargetValue predicted_value;
if (distribution.FindSingularMax(&predicted_value) &&
predicted_value == example.target_value) {
num_correct += example.weight;
}
}
// Expect very high accuracy. We should get ~100%.
// Currently, we seem to get about ~95%. If we switch to kEmptyDistribution
// in the learning task, then it goes back to 1 for kUnordered features. It's
// 1 for kOrdered.
double train_accuracy = ((double)num_correct) / training_data.total_weight();
EXPECT_GT(train_accuracy, 0.95);
}
TEST_P(RandomForestTest, WeightedTrainingSetIsUnsupported) {
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
const size_t weight = 100;
TrainingData training_data;
example_1.weight = weight;
training_data.push_back(example_1);
example_2.weight = weight;
training_data.push_back(example_2);
// Create a weighed set with |weight| for each example's weight.
EXPECT_FALSE(training_data.is_unweighted());
auto weighted_result = trainer_.Train(task_, training_data);
EXPECT_EQ(weighted_result->model.get(), nullptr);
}
INSTANTIATE_TEST_CASE_P(RandomForestTest,
RandomForestTest,
testing::ValuesIn({LearningTask::Ordering::kUnordered,
LearningTask::Ordering::kNumeric}));
} // namespace learning
} // namespace media
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment