Commit faadd6b9 authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Update RandomTreeTrainer to use TargetDistribution.

Updates Model and RandomTreeTrainer to use TargetDistribution
instead of Model::TargetDistribution.

Change-Id: I0e1a28dc78d5716aa54fd4d1061e86975e4815db
Reviewed-on: https://chromium-review.googlesource.com/c/1352056
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarFredrik Hubinette <hubbe@chromium.org>
Cr-Commit-Position: refs/heads/master@{#611775}
parent 78ae8cdb
......@@ -5,11 +5,10 @@
#ifndef MEDIA_LEARNING_IMPL_MODEL_H_
#define MEDIA_LEARNING_IMPL_MODEL_H_
#include <map>
#include "base/component_export.h"
#include "media/learning/common/training_example.h"
#include "media/learning/impl/model.h"
#include "media/learning/impl/target_distribution.h"
namespace media {
namespace learning {
......@@ -19,11 +18,6 @@ namespace learning {
// can support it.
class COMPONENT_EXPORT(LEARNING_IMPL) Model {
public:
// [target value] == counts
// This is classification-centric. Not sure about the right interface for
// regressors. Mostly for testing.
using TargetDistribution = std::map<TargetValue, int>;
virtual ~Model() = default;
virtual TargetDistribution PredictDistribution(
......
......@@ -43,7 +43,7 @@ struct InteriorNode : public Model {
InteriorNode(int split_index) : split_index_(split_index) {}
// Model
Model::TargetDistribution PredictDistribution(
TargetDistribution PredictDistribution(
const FeatureVector& features) override {
auto iter = children_.find(features[split_index_]);
// If we've never seen this feature value, then make no prediction.
......@@ -62,22 +62,22 @@ struct InteriorNode : public Model {
private:
// Feature value that we split on.
int split_index_ = -1;
std::map<FeatureValue, std::unique_ptr<Model>> children_;
base::flat_map<FeatureValue, std::unique_ptr<Model>> children_;
};
struct LeafNode : public Model {
LeafNode(const TrainingData& training_data) {
for (const TrainingExample* example : training_data)
distribution_[example->target_value]++;
distribution_ += example->target_value;
}
// TreeNode
Model::TargetDistribution PredictDistribution(const FeatureVector&) override {
TargetDistribution PredictDistribution(const FeatureVector&) override {
return distribution_;
}
private:
Model::TargetDistribution distribution_;
TargetDistribution distribution_;
};
RandomTreeTrainer::RandomTreeTrainer() = default;
......
......@@ -6,10 +6,12 @@
#define MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
#include <limits>
#include <map>
#include <memory>
#include <set>
#include "base/component_export.h"
#include "base/containers/flat_map.h"
#include "base/macros.h"
#include "media/learning/impl/training_algorithm.h"
......@@ -101,7 +103,9 @@ class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer {
// Number of occurances of each target value in |training_data| along this
// branch of the split.
std::map<TargetValue, int> class_counts;
// This is a flat_map since we're likely to have a very small (e.g.,
// "true / "false") number of targets.
base::flat_map<TargetValue, int> class_counts;
};
// [feature value at this split] = info about which examples take this
......
......@@ -26,8 +26,7 @@ TEST_F(RandomTreeTest, EmptyTrainingDataWorks) {
TrainingData empty(storage_);
std::unique_ptr<Model> model = trainer_.Train(empty);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(model->PredictDistribution(FeatureVector()),
Model::TargetDistribution());
EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetDistribution());
}
TEST_F(RandomTreeTest, UniformTrainingDataWorks) {
......@@ -41,10 +40,10 @@ TEST_F(RandomTreeTest, UniformTrainingDataWorks) {
// The tree should produce a distribution for one value (our target), which
// has |n_examples| counts.
Model::TargetDistribution distribution =
TargetDistribution distribution =
model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example.target_value)->second, n_examples);
EXPECT_EQ(distribution[example.target_value], n_examples);
}
TEST_F(RandomTreeTest, UniformTrainingDataWorksWithCallback) {
......@@ -68,10 +67,10 @@ TEST_F(RandomTreeTest, UniformTrainingDataWorksWithCallback) {
std::move(model_cb));
base::RunLoop().RunUntilIdle();
Model::TargetDistribution distribution =
TargetDistribution distribution =
model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example.target_value)->second, n_examples);
EXPECT_EQ(distribution[example.target_value], n_examples);
}
TEST_F(RandomTreeTest, SimpleSeparableTrainingData) {
......@@ -83,15 +82,15 @@ TEST_F(RandomTreeTest, SimpleSeparableTrainingData) {
std::unique_ptr<Model> model = trainer_.Train(training_data);
// Each value should have a distribution with one target value with one count.
Model::TargetDistribution distribution =
TargetDistribution distribution =
model->PredictDistribution(example_1.features);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example_1.target_value)->second, 1);
EXPECT_EQ(distribution[example_1.target_value], 1);
distribution = model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example_2.target_value)->second, 1);
EXPECT_EQ(distribution[example_2.target_value], 1);
}
TEST_F(RandomTreeTest, ComplexSeparableTrainingData) {
......@@ -123,10 +122,10 @@ TEST_F(RandomTreeTest, ComplexSeparableTrainingData) {
// Each example should have a distribution by itself, with two counts.
for (const TrainingExample* example : training_data) {
Model::TargetDistribution distribution =
TargetDistribution distribution =
model->PredictDistribution(example->features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution.find(example->target_value)->second, 2);
EXPECT_EQ(distribution[example->target_value], 2);
}
}
......@@ -140,16 +139,16 @@ TEST_F(RandomTreeTest, UnseparableTrainingData) {
EXPECT_NE(model.get(), nullptr);
// Each value should have a distribution with two targets with one count each.
Model::TargetDistribution distribution =
TargetDistribution distribution =
model->PredictDistribution(example_1.features);
EXPECT_EQ(distribution.size(), 2u);
EXPECT_EQ(distribution.find(example_1.target_value)->second, 1);
EXPECT_EQ(distribution.find(example_2.target_value)->second, 1);
EXPECT_EQ(distribution[example_1.target_value], 1);
EXPECT_EQ(distribution[example_2.target_value], 1);
distribution = model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 2u);
EXPECT_EQ(distribution.find(example_1.target_value)->second, 1);
EXPECT_EQ(distribution.find(example_2.target_value)->second, 1);
EXPECT_EQ(distribution[example_1.target_value], 1);
EXPECT_EQ(distribution[example_2.target_value], 1);
}
} // namespace learning
......
......@@ -8,8 +8,23 @@ namespace media {
namespace learning {
TargetDistribution::TargetDistribution() = default;
TargetDistribution::TargetDistribution(const TargetDistribution& rhs) = default;
TargetDistribution::TargetDistribution(TargetDistribution&& rhs) = default;
TargetDistribution::~TargetDistribution() = default;
TargetDistribution& TargetDistribution::operator=(
const TargetDistribution& rhs) = default;
TargetDistribution& TargetDistribution::operator=(TargetDistribution&& rhs) =
default;
bool TargetDistribution::operator==(const TargetDistribution& rhs) const {
return rhs.total_counts() == total_counts() && rhs.counts_ == counts_;
}
TargetDistribution& TargetDistribution::operator+=(
const TargetDistribution& rhs) {
for (auto& rhs_pair : rhs.counts())
......
......@@ -17,8 +17,15 @@ namespace learning {
class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution {
public:
TargetDistribution();
TargetDistribution(const TargetDistribution& rhs);
TargetDistribution(TargetDistribution&& rhs);
~TargetDistribution();
TargetDistribution& operator=(const TargetDistribution& rhs);
TargetDistribution& operator=(TargetDistribution&& rhs);
bool operator==(const TargetDistribution& rhs) const;
// Add |rhs| to our counts.
TargetDistribution& operator+=(const TargetDistribution& rhs);
......@@ -37,6 +44,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution {
return total;
}
// Return the number of buckets in the distribution.
// TODO(liberato): Do we want this?
size_t size() const { return counts_.size(); }
// Find the singular value with the highest counts, and copy it into
// |value_out| and (optionally) |counts_out|. Returns true if there is a
// singular maximum, else returns false with the out params undefined.
......
......@@ -132,5 +132,21 @@ TEST_F(TargetDistributionTest, FindSingularMaxDoesntRequireCounts) {
EXPECT_EQ(max_value, value_1);
}
TEST_F(TargetDistributionTest, EqualDistributionsCompareAsEqual) {
distribution_[value_1] = counts_1;
TargetDistribution distribution_2;
distribution_2[value_1] = counts_1;
EXPECT_TRUE(distribution_ == distribution_2);
}
TEST_F(TargetDistributionTest, UnequalDistributionsCompareAsNotEqual) {
distribution_[value_1] = counts_1;
TargetDistribution distribution_2;
distribution_2[value_2] = counts_2;
EXPECT_FALSE(distribution_ == distribution_2);
}
} // namespace learning
} // namespace media
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment