Commit 05fcc7df authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Add Distribution to replace Model::TargetDistribution.

This adds the Distribution class, which will replace the
TargetDistribution typedef in Model.  It provides some
convenient functionality, such as += and max-finding.

None of the uses of TargetDistribution are updated here.

Change-Id: I0e86d9dbaddcd5421cbc65e73213329603346733
Reviewed-on: https://chromium-review.googlesource.com/c/1329881
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarFredrik Hubinette <hubbe@chromium.org>
Cr-Commit-Position: refs/heads/master@{#611569}
parent 72b6d2ff
......@@ -15,6 +15,8 @@ component("impl") {
"model.h",
"random_tree_trainer.cc",
"random_tree_trainer.h",
"target_distribution.cc",
"target_distribution.h",
"training_algorithm.h",
]
......@@ -35,6 +37,7 @@ source_set("unit_tests") {
sources = [
"learning_session_impl_unittest.cc",
"random_tree_trainer_unittest.cc",
"target_distribution_unittest.cc",
]
deps = [
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/target_distribution.h"
namespace media {
namespace learning {
TargetDistribution::TargetDistribution() = default;
TargetDistribution::~TargetDistribution() = default;
TargetDistribution& TargetDistribution::operator+=(
const TargetDistribution& rhs) {
for (auto& rhs_pair : rhs.counts()) {
counts_[rhs_pair.first] += rhs_pair.second;
total_counts_ += rhs_pair.second;
}
return *this;
}
TargetDistribution& TargetDistribution::operator+=(const TargetValue& rhs) {
counts_[rhs]++;
total_counts_++;
return *this;
}
int TargetDistribution::operator[](const TargetValue& value) const {
auto iter = counts_.find(value);
if (iter == counts_.end())
return 0;
return iter->second;
}
void TargetDistribution::Add(const TargetValue& value, int counts) {
counts_[value] += counts;
total_counts_ += counts;
}
bool TargetDistribution::FindSingularMax(TargetValue* value_out,
int* counts_out) const {
if (!counts_.size())
return false;
int unused_counts;
if (!counts_out)
counts_out = &unused_counts;
auto iter = counts_.begin();
*value_out = iter->first;
*counts_out = iter->second;
bool singular_max = true;
for (iter++; iter != counts_.end(); iter++) {
if (iter->second > *counts_out) {
*value_out = iter->first;
*counts_out = iter->second;
singular_max = true;
} else if (iter->second == *counts_out) {
// If this turns out to be the max, then it's not singular.
singular_max = false;
}
}
return singular_max;
}
} // namespace learning
} // namespace media
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_TARGET_DISTRIBUTION_H_
#define MEDIA_LEARNING_IMPL_TARGET_DISTRIBUTION_H_
#include "base/component_export.h"
#include "base/containers/flat_map.h"
#include "base/macros.h"
#include "media/learning/common/value.h"
namespace media {
namespace learning {
// TargetDistribution of target values.
class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution {
public:
TargetDistribution();
~TargetDistribution();
// Add |rhs| to our counts.
TargetDistribution& operator+=(const TargetDistribution& rhs);
// Increment |rhs| by one.
TargetDistribution& operator+=(const TargetValue& rhs);
// Return the number of counts for |value|.
int operator[](const TargetValue& value) const;
// It would be nice to have an int& variant of operator[], but then we can't
// keep |total_counts_| up to date.
// Include |counts| counts for |value|.
// TODO(liberato): operator+=(std::pair<value, int>)?
void Add(const TargetValue& value, int counts);
// Return the total counts in the map.
int total_counts() const { return total_counts_; }
// Find the singular value with the highest counts, and copy it into
// |value_out| and (optionally) |counts_out|. Returns true if there is a
// singular maximum, else returns false with the out params undefined.
bool FindSingularMax(TargetValue* value_out, int* counts_out = nullptr) const;
private:
// We use a flat_map since this will often have only one or two TargetValues,
// such as "true" or "false".
using distribution_map_t = base::flat_map<TargetValue, int>;
const distribution_map_t& counts() const { return counts_; }
// [value] == counts
distribution_map_t counts_;
// Sum of all entries in |counts_|.
int total_counts_ = 0;
// Allow copy and assign.
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_TARGET_DISTRIBUTION_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/target_distribution.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class TargetDistributionTest : public testing::Test {
public:
TargetDistributionTest() : value_1(123), value_2(456), value_3(789) {}
TargetDistribution distribution_;
TargetValue value_1;
const int counts_1 = 100;
TargetValue value_2;
const int counts_2 = 10;
TargetValue value_3;
};
TEST_F(TargetDistributionTest, EmptyTargetDistributionHasZeroCounts) {
EXPECT_EQ(distribution_.total_counts(), 0);
}
TEST_F(TargetDistributionTest, AddingCountsWorks) {
distribution_.Add(value_1, counts_1);
EXPECT_EQ(distribution_.total_counts(), counts_1);
EXPECT_EQ(distribution_[value_1], counts_1);
distribution_.Add(value_1, counts_1);
EXPECT_EQ(distribution_.total_counts(), counts_1 * 2);
EXPECT_EQ(distribution_[value_1], counts_1 * 2);
}
TEST_F(TargetDistributionTest, MultipleValuesAreSeparate) {
distribution_.Add(value_1, counts_1);
distribution_.Add(value_2, counts_2);
EXPECT_EQ(distribution_.total_counts(), counts_1 + counts_2);
EXPECT_EQ(distribution_[value_1], counts_1);
EXPECT_EQ(distribution_[value_2], counts_2);
}
TEST_F(TargetDistributionTest, AddingTargetValues) {
distribution_ += value_1;
EXPECT_EQ(distribution_.total_counts(), 1);
EXPECT_EQ(distribution_[value_1], 1);
EXPECT_EQ(distribution_[value_2], 0);
distribution_ += value_1;
EXPECT_EQ(distribution_.total_counts(), 2);
EXPECT_EQ(distribution_[value_1], 2);
EXPECT_EQ(distribution_[value_2], 0);
distribution_ += value_2;
EXPECT_EQ(distribution_.total_counts(), 3);
EXPECT_EQ(distribution_[value_1], 2);
EXPECT_EQ(distribution_[value_2], 1);
}
TEST_F(TargetDistributionTest, AddingTargetDistributions) {
distribution_.Add(value_1, counts_1);
TargetDistribution rhs;
rhs.Add(value_2, counts_2);
distribution_ += rhs;
EXPECT_EQ(distribution_.total_counts(), counts_1 + counts_2);
EXPECT_EQ(distribution_[value_1], counts_1);
EXPECT_EQ(distribution_[value_2], counts_2);
}
TEST_F(TargetDistributionTest, FindSingularMaxFindsTheSingularMax) {
distribution_.Add(value_1, counts_1);
distribution_.Add(value_2, counts_2);
ASSERT_TRUE(counts_1 > counts_2);
TargetValue max_value(0);
int max_counts = 0;
EXPECT_TRUE(distribution_.FindSingularMax(&max_value, &max_counts));
EXPECT_EQ(max_value, value_1);
EXPECT_EQ(max_counts, counts_1);
}
TEST_F(TargetDistributionTest,
FindSingularMaxFindsTheSingularMaxAlternateOrder) {
// Switch the order, to handle sorting in different directions.
distribution_.Add(value_1, counts_2);
distribution_.Add(value_2, counts_1);
ASSERT_TRUE(counts_1 > counts_2);
TargetValue max_value(0);
int max_counts = 0;
EXPECT_TRUE(distribution_.FindSingularMax(&max_value, &max_counts));
EXPECT_EQ(max_value, value_2);
EXPECT_EQ(max_counts, counts_1);
}
TEST_F(TargetDistributionTest, FindSingularMaxReturnsFalsForNonSingularMax) {
distribution_.Add(value_1, counts_1);
distribution_.Add(value_2, counts_1);
TargetValue max_value(0);
int max_counts = 0;
EXPECT_FALSE(distribution_.FindSingularMax(&max_value, &max_counts));
}
TEST_F(TargetDistributionTest, FindSingularMaxIgnoresNonSingularNonMax) {
distribution_.Add(value_1, counts_1);
// |value_2| and |value_3| are tied, but not the max.
distribution_.Add(value_2, counts_2);
distribution_.Add(value_3, counts_2);
ASSERT_TRUE(counts_1 > counts_2);
TargetValue max_value(0);
int max_counts = 0;
EXPECT_TRUE(distribution_.FindSingularMax(&max_value, &max_counts));
EXPECT_EQ(max_value, value_1);
EXPECT_EQ(max_counts, counts_1);
}
TEST_F(TargetDistributionTest, FindSingularMaxDoesntRequireCounts) {
distribution_.Add(value_1, counts_1);
TargetValue max_value(0);
EXPECT_TRUE(distribution_.FindSingularMax(&max_value));
EXPECT_EQ(max_value, value_1);
}
} // namespace learning
} // namespace media
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment