Commit faadd6b9 authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Update RandomTreeTrainer to use TargetDistribution.

Updates Model and RandomTreeTrainer to use TargetDistribution
instead of Model::TargetDistribution.

Change-Id: I0e1a28dc78d5716aa54fd4d1061e86975e4815db
Reviewed-on: https://chromium-review.googlesource.com/c/1352056
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarFredrik Hubinette <hubbe@chromium.org>
Cr-Commit-Position: refs/heads/master@{#611775}
parent 78ae8cdb
...@@ -5,11 +5,10 @@ ...@@ -5,11 +5,10 @@
#ifndef MEDIA_LEARNING_IMPL_MODEL_H_ #ifndef MEDIA_LEARNING_IMPL_MODEL_H_
#define MEDIA_LEARNING_IMPL_MODEL_H_ #define MEDIA_LEARNING_IMPL_MODEL_H_
#include <map>
#include "base/component_export.h" #include "base/component_export.h"
#include "media/learning/common/training_example.h" #include "media/learning/common/training_example.h"
#include "media/learning/impl/model.h" #include "media/learning/impl/model.h"
#include "media/learning/impl/target_distribution.h"
namespace media { namespace media {
namespace learning { namespace learning {
...@@ -19,11 +18,6 @@ namespace learning { ...@@ -19,11 +18,6 @@ namespace learning {
// can support it. // can support it.
class COMPONENT_EXPORT(LEARNING_IMPL) Model { class COMPONENT_EXPORT(LEARNING_IMPL) Model {
public: public:
// [target value] == counts
// This is classification-centric. Not sure about the right interface for
// regressors. Mostly for testing.
using TargetDistribution = std::map<TargetValue, int>;
virtual ~Model() = default; virtual ~Model() = default;
virtual TargetDistribution PredictDistribution( virtual TargetDistribution PredictDistribution(
......
...@@ -43,7 +43,7 @@ struct InteriorNode : public Model { ...@@ -43,7 +43,7 @@ struct InteriorNode : public Model {
InteriorNode(int split_index) : split_index_(split_index) {} InteriorNode(int split_index) : split_index_(split_index) {}
// Model // Model
Model::TargetDistribution PredictDistribution( TargetDistribution PredictDistribution(
const FeatureVector& features) override { const FeatureVector& features) override {
auto iter = children_.find(features[split_index_]); auto iter = children_.find(features[split_index_]);
// If we've never seen this feature value, then make no prediction. // If we've never seen this feature value, then make no prediction.
...@@ -62,22 +62,22 @@ struct InteriorNode : public Model { ...@@ -62,22 +62,22 @@ struct InteriorNode : public Model {
private: private:
// Feature value that we split on. // Feature value that we split on.
int split_index_ = -1; int split_index_ = -1;
std::map<FeatureValue, std::unique_ptr<Model>> children_; base::flat_map<FeatureValue, std::unique_ptr<Model>> children_;
}; };
struct LeafNode : public Model { struct LeafNode : public Model {
LeafNode(const TrainingData& training_data) { LeafNode(const TrainingData& training_data) {
for (const TrainingExample* example : training_data) for (const TrainingExample* example : training_data)
distribution_[example->target_value]++; distribution_ += example->target_value;
} }
// TreeNode // TreeNode
Model::TargetDistribution PredictDistribution(const FeatureVector&) override { TargetDistribution PredictDistribution(const FeatureVector&) override {
return distribution_; return distribution_;
} }
private: private:
Model::TargetDistribution distribution_; TargetDistribution distribution_;
}; };
RandomTreeTrainer::RandomTreeTrainer() = default; RandomTreeTrainer::RandomTreeTrainer() = default;
......
...@@ -6,10 +6,12 @@ ...@@ -6,10 +6,12 @@
#define MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_ #define MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
#include <limits> #include <limits>
#include <map>
#include <memory> #include <memory>
#include <set> #include <set>
#include "base/component_export.h" #include "base/component_export.h"
#include "base/containers/flat_map.h"
#include "base/macros.h" #include "base/macros.h"
#include "media/learning/impl/training_algorithm.h" #include "media/learning/impl/training_algorithm.h"
...@@ -101,7 +103,9 @@ class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer { ...@@ -101,7 +103,9 @@ class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer {
// Number of occurances of each target value in |training_data| along this // Number of occurances of each target value in |training_data| along this
// branch of the split. // branch of the split.
std::map<TargetValue, int> class_counts; // This is a flat_map since we're likely to have a very small (e.g.,
// "true / "false") number of targets.
base::flat_map<TargetValue, int> class_counts;
}; };
// [feature value at this split] = info about which examples take this // [feature value at this split] = info about which examples take this
......
...@@ -26,8 +26,7 @@ TEST_F(RandomTreeTest, EmptyTrainingDataWorks) { ...@@ -26,8 +26,7 @@ TEST_F(RandomTreeTest, EmptyTrainingDataWorks) {
TrainingData empty(storage_); TrainingData empty(storage_);
std::unique_ptr<Model> model = trainer_.Train(empty); std::unique_ptr<Model> model = trainer_.Train(empty);
EXPECT_NE(model.get(), nullptr); EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(model->PredictDistribution(FeatureVector()), EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetDistribution());
Model::TargetDistribution());
} }
TEST_F(RandomTreeTest, UniformTrainingDataWorks) { TEST_F(RandomTreeTest, UniformTrainingDataWorks) {
...@@ -41,10 +40,10 @@ TEST_F(RandomTreeTest, UniformTrainingDataWorks) { ...@@ -41,10 +40,10 @@ TEST_F(RandomTreeTest, UniformTrainingDataWorks) {
// The tree should produce a distribution for one value (our target), which // The tree should produce a distribution for one value (our target), which
// has |n_examples| counts. // has |n_examples| counts.
Model::TargetDistribution distribution = TargetDistribution distribution =
model->PredictDistribution(example.features); model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u); EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example.target_value)->second, n_examples); EXPECT_EQ(distribution[example.target_value], n_examples);
} }
TEST_F(RandomTreeTest, UniformTrainingDataWorksWithCallback) { TEST_F(RandomTreeTest, UniformTrainingDataWorksWithCallback) {
...@@ -68,10 +67,10 @@ TEST_F(RandomTreeTest, UniformTrainingDataWorksWithCallback) { ...@@ -68,10 +67,10 @@ TEST_F(RandomTreeTest, UniformTrainingDataWorksWithCallback) {
std::move(model_cb)); std::move(model_cb));
base::RunLoop().RunUntilIdle(); base::RunLoop().RunUntilIdle();
Model::TargetDistribution distribution = TargetDistribution distribution =
model->PredictDistribution(example.features); model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u); EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example.target_value)->second, n_examples); EXPECT_EQ(distribution[example.target_value], n_examples);
} }
TEST_F(RandomTreeTest, SimpleSeparableTrainingData) { TEST_F(RandomTreeTest, SimpleSeparableTrainingData) {
...@@ -83,15 +82,15 @@ TEST_F(RandomTreeTest, SimpleSeparableTrainingData) { ...@@ -83,15 +82,15 @@ TEST_F(RandomTreeTest, SimpleSeparableTrainingData) {
std::unique_ptr<Model> model = trainer_.Train(training_data); std::unique_ptr<Model> model = trainer_.Train(training_data);
// Each value should have a distribution with one target value with one count. // Each value should have a distribution with one target value with one count.
Model::TargetDistribution distribution = TargetDistribution distribution =
model->PredictDistribution(example_1.features); model->PredictDistribution(example_1.features);
EXPECT_NE(model.get(), nullptr); EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(distribution.size(), 1u); EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example_1.target_value)->second, 1); EXPECT_EQ(distribution[example_1.target_value], 1);
distribution = model->PredictDistribution(example_2.features); distribution = model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 1u); EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example_2.target_value)->second, 1); EXPECT_EQ(distribution[example_2.target_value], 1);
} }
TEST_F(RandomTreeTest, ComplexSeparableTrainingData) { TEST_F(RandomTreeTest, ComplexSeparableTrainingData) {
...@@ -123,10 +122,10 @@ TEST_F(RandomTreeTest, ComplexSeparableTrainingData) { ...@@ -123,10 +122,10 @@ TEST_F(RandomTreeTest, ComplexSeparableTrainingData) {
// Each example should have a distribution by itself, with two counts. // Each example should have a distribution by itself, with two counts.
for (const TrainingExample* example : training_data) { for (const TrainingExample* example : training_data) {
Model::TargetDistribution distribution = TargetDistribution distribution =
model->PredictDistribution(example->features); model->PredictDistribution(example->features);
EXPECT_EQ(distribution.size(), 1u); EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example->target_value)->second, 2); EXPECT_EQ(distribution[example->target_value], 2);
} }
} }
...@@ -140,16 +139,16 @@ TEST_F(RandomTreeTest, UnseparableTrainingData) { ...@@ -140,16 +139,16 @@ TEST_F(RandomTreeTest, UnseparableTrainingData) {
EXPECT_NE(model.get(), nullptr); EXPECT_NE(model.get(), nullptr);
// Each value should have a distribution with two targets with one count each. // Each value should have a distribution with two targets with one count each.
Model::TargetDistribution distribution = TargetDistribution distribution =
model->PredictDistribution(example_1.features); model->PredictDistribution(example_1.features);
EXPECT_EQ(distribution.size(), 2u); EXPECT_EQ(distribution.size(), 2u);
EXPECT_EQ(distribution.find(example_1.target_value)->second, 1); EXPECT_EQ(distribution[example_1.target_value], 1);
EXPECT_EQ(distribution.find(example_2.target_value)->second, 1); EXPECT_EQ(distribution[example_2.target_value], 1);
distribution = model->PredictDistribution(example_2.features); distribution = model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 2u); EXPECT_EQ(distribution.size(), 2u);
EXPECT_EQ(distribution.find(example_1.target_value)->second, 1); EXPECT_EQ(distribution[example_1.target_value], 1);
EXPECT_EQ(distribution.find(example_2.target_value)->second, 1); EXPECT_EQ(distribution[example_2.target_value], 1);
} }
} // namespace learning } // namespace learning
......
...@@ -8,8 +8,23 @@ namespace media { ...@@ -8,8 +8,23 @@ namespace media {
namespace learning { namespace learning {
TargetDistribution::TargetDistribution() = default; TargetDistribution::TargetDistribution() = default;
TargetDistribution::TargetDistribution(const TargetDistribution& rhs) = default;
TargetDistribution::TargetDistribution(TargetDistribution&& rhs) = default;
TargetDistribution::~TargetDistribution() = default; TargetDistribution::~TargetDistribution() = default;
TargetDistribution& TargetDistribution::operator=(
const TargetDistribution& rhs) = default;
TargetDistribution& TargetDistribution::operator=(TargetDistribution&& rhs) =
default;
bool TargetDistribution::operator==(const TargetDistribution& rhs) const {
return rhs.total_counts() == total_counts() && rhs.counts_ == counts_;
}
TargetDistribution& TargetDistribution::operator+=( TargetDistribution& TargetDistribution::operator+=(
const TargetDistribution& rhs) { const TargetDistribution& rhs) {
for (auto& rhs_pair : rhs.counts()) for (auto& rhs_pair : rhs.counts())
......
...@@ -17,8 +17,15 @@ namespace learning { ...@@ -17,8 +17,15 @@ namespace learning {
class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution { class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution {
public: public:
TargetDistribution(); TargetDistribution();
TargetDistribution(const TargetDistribution& rhs);
TargetDistribution(TargetDistribution&& rhs);
~TargetDistribution(); ~TargetDistribution();
TargetDistribution& operator=(const TargetDistribution& rhs);
TargetDistribution& operator=(TargetDistribution&& rhs);
bool operator==(const TargetDistribution& rhs) const;
// Add |rhs| to our counts. // Add |rhs| to our counts.
TargetDistribution& operator+=(const TargetDistribution& rhs); TargetDistribution& operator+=(const TargetDistribution& rhs);
...@@ -37,6 +44,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution { ...@@ -37,6 +44,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution {
return total; return total;
} }
// Return the number of buckets in the distribution.
// TODO(liberato): Do we want this?
size_t size() const { return counts_.size(); }
// Find the singular value with the highest counts, and copy it into // Find the singular value with the highest counts, and copy it into
// |value_out| and (optionally) |counts_out|. Returns true if there is a // |value_out| and (optionally) |counts_out|. Returns true if there is a
// singular maximum, else returns false with the out params undefined. // singular maximum, else returns false with the out params undefined.
......
...@@ -132,5 +132,21 @@ TEST_F(TargetDistributionTest, FindSingularMaxDoesntRequireCounts) { ...@@ -132,5 +132,21 @@ TEST_F(TargetDistributionTest, FindSingularMaxDoesntRequireCounts) {
EXPECT_EQ(max_value, value_1); EXPECT_EQ(max_value, value_1);
} }
TEST_F(TargetDistributionTest, EqualDistributionsCompareAsEqual) {
distribution_[value_1] = counts_1;
TargetDistribution distribution_2;
distribution_2[value_1] = counts_1;
EXPECT_TRUE(distribution_ == distribution_2);
}
TEST_F(TargetDistributionTest, UnequalDistributionsCompareAsNotEqual) {
distribution_[value_1] = counts_1;
TargetDistribution distribution_2;
distribution_2[value_2] = counts_2;
EXPECT_FALSE(distribution_ == distribution_2);
}
} // namespace learning } // namespace learning
} // namespace media } // namespace media
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment