Commit 16a2a8e7 authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Renamed TrainingExample and weight_t

Per previous CL comments, TrainingExample is better as
LabelledExample, and weight_t should follow chromoium style as
WeightType.  This CL renames both.

Change-Id: Icd3cd0a52df86bf00d12da2147159a464f22b7ed
Reviewed-on: https://chromium-review.googlesource.com/c/1394400Reviewed-by: default avatarDan Sanders <sandersd@chromium.org>
Reviewed-by: default avatarDaniel Cheng <dcheng@chromium.org>
Commit-Queue: Frank Liberato <liberato@chromium.org>
Cr-Commit-Position: refs/heads/master@{#620804}
parent e2592da6
......@@ -17,12 +17,12 @@ component("common") {
defines = [ "IS_LEARNING_COMMON_IMPL" ]
sources = [
"labelled_example.cc",
"labelled_example.h",
"learning_session.cc",
"learning_session.h",
"learning_task.cc",
"learning_task.h",
"training_example.cc",
"training_example.h",
"value.cc",
"value.h",
]
......@@ -35,7 +35,7 @@ component("common") {
source_set("unit_tests") {
testonly = true
sources = [
"training_example_unittest.cc",
"labelled_example_unittest.cc",
"value_unittest.cc",
]
......
......@@ -2,26 +2,26 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/common/training_example.h"
#include "media/learning/common/labelled_example.h"
#include "base/containers/flat_set.h"
namespace media {
namespace learning {
TrainingExample::TrainingExample() = default;
LabelledExample::LabelledExample() = default;
TrainingExample::TrainingExample(std::initializer_list<FeatureValue> init_list,
LabelledExample::LabelledExample(std::initializer_list<FeatureValue> init_list,
TargetValue target)
: features(init_list), target_value(target) {}
TrainingExample::TrainingExample(const TrainingExample& rhs) = default;
LabelledExample::LabelledExample(const LabelledExample& rhs) = default;
TrainingExample::TrainingExample(TrainingExample&& rhs) noexcept = default;
LabelledExample::LabelledExample(LabelledExample&& rhs) noexcept = default;
TrainingExample::~TrainingExample() = default;
LabelledExample::~LabelledExample() = default;
std::ostream& operator<<(std::ostream& out, const TrainingExample& example) {
std::ostream& operator<<(std::ostream& out, const LabelledExample& example) {
out << example.features << " => " << example.target_value;
return out;
......@@ -34,17 +34,17 @@ std::ostream& operator<<(std::ostream& out, const FeatureVector& features) {
return out;
}
bool TrainingExample::operator==(const TrainingExample& rhs) const {
bool LabelledExample::operator==(const LabelledExample& rhs) const {
// Do not check weight.
return target_value == rhs.target_value && features == rhs.features;
}
bool TrainingExample::operator!=(const TrainingExample& rhs) const {
bool LabelledExample::operator!=(const LabelledExample& rhs) const {
// Do not check weight.
return !((*this) == rhs);
}
bool TrainingExample::operator<(const TrainingExample& rhs) const {
bool LabelledExample::operator<(const LabelledExample& rhs) const {
// Impose a somewhat arbitrary ordering.
// Do not check weight.
if (target_value != rhs.target_value)
......@@ -56,10 +56,10 @@ bool TrainingExample::operator<(const TrainingExample& rhs) const {
return features < rhs.features;
}
TrainingExample& TrainingExample::operator=(const TrainingExample& rhs) =
LabelledExample& LabelledExample::operator=(const LabelledExample& rhs) =
default;
TrainingExample& TrainingExample::operator=(TrainingExample&& rhs) = default;
LabelledExample& LabelledExample::operator=(LabelledExample&& rhs) = default;
TrainingData::TrainingData() = default;
......@@ -72,7 +72,7 @@ TrainingData::~TrainingData() = default;
TrainingData TrainingData::DeDuplicate() const {
// flat_set has non-const iterators, while std::set does not. const_cast is
// not allowed by chromium style outside of getters, so flat_set it is.
base::flat_set<TrainingExample> example_set;
base::flat_set<LabelledExample> example_set;
for (auto& example : examples_) {
auto iter = example_set.find(example);
if (iter != example_set.end())
......
......@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#define MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#ifndef MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_
#define MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_
#include <initializer_list>
#include <ostream>
......@@ -23,25 +23,24 @@ namespace learning {
// [1]=="url", etc.
using FeatureVector = std::vector<FeatureValue>;
// TODO(liberato): Rename.
using weight_t = size_t;
using WeightType = size_t;
// One training example == group of feature values, plus the desired target.
struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
TrainingExample();
TrainingExample(std::initializer_list<FeatureValue> init_list,
struct COMPONENT_EXPORT(LEARNING_COMMON) LabelledExample {
LabelledExample();
LabelledExample(std::initializer_list<FeatureValue> init_list,
TargetValue target);
TrainingExample(const TrainingExample& rhs);
TrainingExample(TrainingExample&& rhs) noexcept;
~TrainingExample();
LabelledExample(const LabelledExample& rhs);
LabelledExample(LabelledExample&& rhs) noexcept;
~LabelledExample();
// Comparisons ignore weight, because it's convenient.
bool operator==(const TrainingExample& rhs) const;
bool operator!=(const TrainingExample& rhs) const;
bool operator<(const TrainingExample& rhs) const;
bool operator==(const LabelledExample& rhs) const;
bool operator!=(const LabelledExample& rhs) const;
bool operator<(const LabelledExample& rhs) const;
TrainingExample& operator=(const TrainingExample& rhs);
TrainingExample& operator=(TrainingExample&& rhs);
LabelledExample& operator=(const LabelledExample& rhs);
LabelledExample& operator=(LabelledExample&& rhs);
// Observed feature values.
// Note that to interpret these values, you probably need to have the
......@@ -51,7 +50,7 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
// Observed output value, when given |features| as input.
TargetValue target_value;
weight_t weight = 1u;
WeightType weight = 1u;
// Copy / assignment is allowed.
};
......@@ -59,7 +58,7 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
// TODO(liberato): This should probably move to impl/ .
class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
public:
using ExampleVector = std::vector<TrainingExample>;
using ExampleVector = std::vector<LabelledExample>;
using const_iterator = ExampleVector::const_iterator;
TrainingData();
......@@ -69,7 +68,7 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
~TrainingData();
// Add |example| with weight |weight|.
void push_back(const TrainingExample& example) {
void push_back(const LabelledExample& example) {
DCHECK_GT(example.weight, 0u);
examples_.push_back(example);
total_weight_ += example.weight;
......@@ -82,7 +81,7 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
// Returns the number of instances, taking into account their weight. For
// example, if one adds an example with weight 2, then this will return two
// more than it did before.
weight_t total_weight() const { return total_weight_; }
WeightType total_weight() const { return total_weight_; }
const_iterator begin() const { return examples_.begin(); }
const_iterator end() const { return examples_.end(); }
......@@ -90,7 +89,7 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
bool is_unweighted() const { return examples_.size() == total_weight_; }
// Provide the |i|-th example, over [0, size()).
const TrainingExample& operator[](size_t i) const { return examples_[i]; }
const LabelledExample& operator[](size_t i) const { return examples_[i]; }
// Return a copy of this data with duplicate entries merged. Example weights
// will be summed.
......@@ -99,13 +98,13 @@ class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
private:
ExampleVector examples_;
weight_t total_weight_ = 0u;
WeightType total_weight_ = 0u;
// Copy / assignment is allowed.
};
COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const TrainingExample& example);
std::ostream& operator<<(std::ostream& out, const LabelledExample& example);
COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const FeatureVector& features);
......@@ -113,4 +112,4 @@ std::ostream& operator<<(std::ostream& out, const FeatureVector& features);
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#endif // MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_
......@@ -9,8 +9,8 @@
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/common/training_example.h"
namespace media {
namespace learning {
......@@ -24,7 +24,7 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession {
// Add an observed example |example| to the learning task |task_name|.
// TODO(liberato): Consider making this an enum to match mojo.
virtual void AddExample(const std::string& task_name,
const TrainingExample& example) = 0;
const LabelledExample& example) = 0;
// TODO(liberato): Add prediction API.
......
......@@ -49,7 +49,7 @@ TEST_P(ExtraTreesTest, FisherIrisDataset) {
// Verify predictions on the training set, just for sanity.
size_t num_correct = 0;
for (const TrainingExample& example : training_data) {
for (const LabelledExample& example : training_data) {
TargetDistribution distribution =
model->PredictDistribution(example.features);
TargetValue predicted_value;
......@@ -68,8 +68,8 @@ TEST_P(ExtraTreesTest, WeightedTrainingSetIsSupported) {
// Create a training set with unseparable data, but give one of them a large
// weight. See if that one wins.
SetupFeatures(1);
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(123)}, TargetValue(2));
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
const size_t weight = 100;
TrainingData training_data;
example_1.weight = weight;
......@@ -96,13 +96,13 @@ TEST_P(ExtraTreesTest, RegressionWorks) {
// Create a training set with unseparable data, but give one of them a large
// weight. See if that one wins.
SetupFeatures(2);
TrainingExample example_1({FeatureValue(1), FeatureValue(123)},
LabelledExample example_1({FeatureValue(1), FeatureValue(123)},
TargetValue(1));
TrainingExample example_1_a({FeatureValue(1), FeatureValue(123)},
LabelledExample example_1_a({FeatureValue(1), FeatureValue(123)},
TargetValue(5));
TrainingExample example_2({FeatureValue(1), FeatureValue(456)},
LabelledExample example_2({FeatureValue(1), FeatureValue(456)},
TargetValue(20));
TrainingExample example_2_a({FeatureValue(1), FeatureValue(456)},
LabelledExample example_2_a({FeatureValue(1), FeatureValue(456)},
TargetValue(25));
TrainingData training_data;
example_1.weight = 100;
......@@ -137,13 +137,13 @@ TEST_P(ExtraTreesTest, RegressionVsBinaryClassification) {
SetupFeatures(3);
TrainingData c_data, r_data;
std::set<TrainingExample> r_examples;
std::set<LabelledExample> r_examples;
for (size_t i = 0; i < 4 * 4 * 4; i++) {
FeatureValue f1(i & 3);
FeatureValue f2((i >> 2) & 3);
FeatureValue f3((i >> 4) & 3);
int pct = (100 * (f1.value() + f2.value() + f3.value())) / 9;
TrainingExample e({f1, f2, f3}, TargetValue(0));
LabelledExample e({f1, f2, f3}, TargetValue(0));
// TODO(liberato): Consider adding noise, and verifying that the model
// predictions are roughly the same as each other, rather than the same as
......@@ -162,7 +162,7 @@ TEST_P(ExtraTreesTest, RegressionVsBinaryClassification) {
// For the regression data, add an example with |pct| directly. Also save
// it so that we can look up the right answer below.
TrainingExample r_example(TrainingExample({f1, f2, f3}, TargetValue(pct)));
LabelledExample r_example(LabelledExample({f1, f2, f3}, TargetValue(pct)));
r_examples.insert(r_example);
r_data.push_back(r_example);
}
......
......@@ -7,7 +7,7 @@
#include <vector>
namespace {
struct IrisExample : public media::learning::TrainingExample {
struct IrisExample : public media::learning::LabelledExample {
IrisExample(float sepal_length,
float sepal_width,
float petal_length,
......
......@@ -8,7 +8,7 @@
#include <vector>
#include "base/memory/ref_counted.h"
#include "media/learning/common/training_example.h"
#include "media/learning/common/labelled_example.h"
namespace media {
namespace learning {
......
......@@ -26,7 +26,7 @@ void LearningSessionImpl::SetTaskControllerFactoryCBForTesting(
}
void LearningSessionImpl::AddExample(const std::string& task_name,
const TrainingExample& example) {
const LabelledExample& example) {
auto iter = task_map_.find(task_name);
if (iter != task_map_.end())
iter->second->AddExample(example);
......
......@@ -30,7 +30,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl
// LearningSession
void AddExample(const std::string& task_name,
const TrainingExample& example) override;
const LabelledExample& example) override;
// Registers |task|, so that calls to AddExample with |task.name| will work.
// This will create a new controller for the task.
......
......@@ -19,11 +19,11 @@ class LearningSessionImplTest : public testing::Test {
public:
FakeLearningTaskController(const LearningTask& task) {}
void AddExample(const TrainingExample& example) override {
void AddExample(const LabelledExample& example) override {
example_ = example;
}
TrainingExample example_;
LabelledExample example_;
};
using ControllerVector = std::vector<FakeLearningTaskController*>;
......@@ -63,11 +63,11 @@ TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) {
session_->RegisterTask(task_0_);
session_->RegisterTask(task_1_);
TrainingExample example_0({FeatureValue(123), FeatureValue(456)},
LabelledExample example_0({FeatureValue(123), FeatureValue(456)},
TargetValue(1234));
session_->AddExample(task_0_.name, example_0);
TrainingExample example_1({FeatureValue(321), FeatureValue(654)},
LabelledExample example_1({FeatureValue(321), FeatureValue(654)},
TargetValue(4321));
session_->AddExample(task_1_.name, example_1);
EXPECT_EQ(task_controllers_[0]->example_, example_0);
......
......@@ -8,8 +8,8 @@
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/common/training_example.h"
namespace media {
namespace learning {
......@@ -29,7 +29,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskController {
virtual ~LearningTaskController() = default;
// Receive an example for this task.
virtual void AddExample(const TrainingExample& example) = 0;
virtual void AddExample(const LabelledExample& example) = 0;
private:
DISALLOW_COPY_AND_ASSIGN(LearningTaskController);
......
......@@ -38,7 +38,7 @@ LearningTaskControllerImpl::LearningTaskControllerImpl(const LearningTask& task)
LearningTaskControllerImpl::~LearningTaskControllerImpl() = default;
void LearningTaskControllerImpl::AddExample(const TrainingExample& example) {
void LearningTaskControllerImpl::AddExample(const LabelledExample& example) {
// TODO(liberato): do we ever trim older examples?
training_data_->push_back(example);
......
......@@ -26,7 +26,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
~LearningTaskControllerImpl() override;
// LearningTaskController
void AddExample(const TrainingExample& example) override;
void AddExample(const LabelledExample& example) override;
private:
// Called with accuracy results as new examples are added. Only tests should
......
......@@ -69,7 +69,7 @@ class LearningTaskControllerImplTest : public testing::Test {
};
TEST_F(LearningTaskControllerImplTest, AddingExamplesTrainsModelAndReports) {
TrainingExample example;
LabelledExample example;
// Adding the first n-1 examples shouldn't cause it to train a model.
for (size_t i = 0; i < task_.min_data_set_size - 1; i++)
......
......@@ -6,7 +6,7 @@
#define MEDIA_LEARNING_IMPL_MODEL_H_
#include "base/component_export.h"
#include "media/learning/common/training_example.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/impl/model.h"
#include "media/learning/impl/target_distribution.h"
......
......@@ -37,7 +37,7 @@ OneHotConverter::~OneHotConverter() = default;
TrainingData OneHotConverter::Convert(const TrainingData& training_data) const {
TrainingData converted_training_data;
for (auto& example : training_data) {
TrainingExample converted_example(example);
LabelledExample converted_example(example);
converted_example.features = Convert(example.features);
converted_training_data.push_back(converted_example);
}
......
......@@ -12,8 +12,8 @@
#include "base/component_export.h"
#include "base/macros.h"
#include "base/optional.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/common/training_example.h"
#include "media/learning/common/value.h"
#include "media/learning/impl/model.h"
......
......@@ -61,7 +61,7 @@ std::unique_ptr<RandomForestTrainer::TrainingResult> RandomForestTrainer::Train(
if (bagged_set.find(e) != bagged_set.end())
continue;
const TrainingExample& example = training_data[e];
const LabelledExample& example = training_data[e];
TargetDistribution predicted =
tree->PredictDistribution(example.features);
......@@ -83,7 +83,7 @@ std::unique_ptr<RandomForestTrainer::TrainingResult> RandomForestTrainer::Train(
// Compute OOB accuracy.
int num_correct = 0;
for (auto& oob_pair : oob_distributions) {
const TrainingExample& example = training_data[oob_pair.first];
const LabelledExample& example = training_data[oob_pair.first];
const TargetDistribution& distribution = oob_pair.second;
// If there are no guesses, or if it's a tie, then count it as wrong.
......
......@@ -46,7 +46,7 @@ TEST_P(RandomForestTest, EmptyTrainingDataWorks) {
TEST_P(RandomForestTest, UniformTrainingDataWorks) {
SetupFeatures(2);
TrainingExample example({FeatureValue(123), FeatureValue(456)},
LabelledExample example({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
const int n_examples = 10;
......@@ -69,8 +69,8 @@ TEST_P(RandomForestTest, SimpleSeparableTrainingData) {
// TODO: oob estimates aren't so good if a target only shows up once. any
// tree that trains on it won't be used to predict it during oob accuracy,
// and the remaining trees will get it wrong.
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(456)}, TargetValue(2));
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(456)}, TargetValue(2));
TrainingData training_data;
training_data.push_back(example_1);
training_data.push_back(example_2);
......@@ -99,7 +99,7 @@ TEST_P(RandomForestTest, ComplexSeparableTrainingData) {
for (int f2 = 0; f2 < 2; f2++) {
for (int f3 = 0; f3 < 2; f3++) {
for (int f4 = 0; f4 < 2; f4++) {
TrainingExample example(
LabelledExample example(
{FeatureValue(f1), FeatureValue(f2), FeatureValue(f3),
FeatureValue(f4)},
TargetValue(f1 * 1 + f2 * 2 + f3 * 4 + f4 * 8));
......@@ -116,7 +116,7 @@ TEST_P(RandomForestTest, ComplexSeparableTrainingData) {
EXPECT_NE(result->model.get(), nullptr);
// Each example should have a distribution in which it is the max.
for (const TrainingExample& example : training_data) {
for (const LabelledExample& example : training_data) {
TargetDistribution distribution =
result->model->PredictDistribution(example.features);
TargetValue max_value;
......@@ -127,8 +127,8 @@ TEST_P(RandomForestTest, ComplexSeparableTrainingData) {
TEST_P(RandomForestTest, UnseparableTrainingData) {
SetupFeatures(1);
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(123)}, TargetValue(2));
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
TrainingData training_data;
training_data.push_back(example_1);
training_data.push_back(example_2);
......@@ -160,7 +160,7 @@ TEST_P(RandomForestTest, FisherIrisDataset) {
// Verify predictions on the training set, just for sanity.
size_t num_correct = 0;
for (const TrainingExample& example : training_data) {
for (const LabelledExample& example : training_data) {
TargetDistribution distribution =
result->model->PredictDistribution(example.features);
TargetValue predicted_value;
......@@ -179,8 +179,8 @@ TEST_P(RandomForestTest, FisherIrisDataset) {
}
TEST_P(RandomForestTest, WeightedTrainingSetIsUnsupported) {
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(123)}, TargetValue(2));
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
const size_t weight = 100;
TrainingData training_data;
example_1.weight = weight;
......
......@@ -192,7 +192,7 @@ std::unique_ptr<Model> RandomTreeTrainer::Build(
std::vector<std::set<FeatureValue>> feature_values;
feature_values.resize(training_data[0].features.size());
for (size_t idx : training_idx) {
const TrainingExample& example = training_data[idx];
const LabelledExample& example = training_data[idx];
// Record this target value to see if there is more than one. We skip the
// insertion if we've already determined that it's not constant.
if (target_values.size() < 2)
......@@ -315,7 +315,7 @@ RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit(
// the training data directly.
double total_weight = 0.;
for (size_t idx : training_idx) {
const TrainingExample& example = training_data[idx];
const LabelledExample& example = training_data[idx];
total_weight += example.weight;
// Get the value of the |index|-th feature for |example|.
......@@ -417,7 +417,7 @@ FeatureValue RandomTreeTrainer::FindNumericSplitPoint(
FeatureValue v_min = training_data[training_idx[0]].features[split_index];
FeatureValue v_max = training_data[training_idx[0]].features[split_index];
for (size_t idx : training_idx) {
const TrainingExample& example = training_data[idx];
const LabelledExample& example = training_data[idx];
// Get the value of the |split_index|-th feature for
FeatureValue v_i = example.features[split_index];
if (v_i < v_min)
......
......@@ -47,7 +47,7 @@ TEST_P(RandomTreeTest, EmptyTrainingDataWorks) {
TEST_P(RandomTreeTest, UniformTrainingDataWorks) {
SetupFeatures(2);
TrainingExample example({FeatureValue(123), FeatureValue(456)},
LabelledExample example({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
TrainingData training_data;
const size_t n_examples = 10;
......@@ -65,7 +65,7 @@ TEST_P(RandomTreeTest, UniformTrainingDataWorks) {
TEST_P(RandomTreeTest, UniformTrainingDataWorksWithCallback) {
SetupFeatures(2);
TrainingExample example({FeatureValue(123), FeatureValue(456)},
LabelledExample example({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
TrainingData training_data;
const size_t n_examples = 10;
......@@ -94,8 +94,8 @@ TEST_P(RandomTreeTest, UniformTrainingDataWorksWithCallback) {
TEST_P(RandomTreeTest, SimpleSeparableTrainingData) {
SetupFeatures(1);
TrainingData training_data;
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(456)}, TargetValue(2));
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(456)}, TargetValue(2));
training_data.push_back(example_1);
training_data.push_back(example_2);
std::unique_ptr<Model> model = trainer_.Train(task_, training_data);
......@@ -127,7 +127,7 @@ TEST_P(RandomTreeTest, ComplexSeparableTrainingData) {
for (int f2 = 0; f2 < 2; f2++) {
for (int f3 = 0; f3 < 2; f3++) {
for (int f4 = 0; f4 < 2; f4++) {
TrainingExample example(
LabelledExample example(
{FeatureValue(f1), FeatureValue(f2), FeatureValue(f3),
FeatureValue(f4)},
TargetValue(f1 * 1 + f2 * 2 + f3 * 4 + f4 * 8));
......@@ -143,7 +143,7 @@ TEST_P(RandomTreeTest, ComplexSeparableTrainingData) {
EXPECT_NE(model.get(), nullptr);
// Each example should have a distribution that selects the right value.
for (const TrainingExample& example : training_data) {
for (const LabelledExample& example : training_data) {
TargetDistribution distribution =
model->PredictDistribution(example.features);
TargetValue singular_max;
......@@ -155,8 +155,8 @@ TEST_P(RandomTreeTest, ComplexSeparableTrainingData) {
TEST_P(RandomTreeTest, UnseparableTrainingData) {
SetupFeatures(1);
TrainingData training_data;
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(123)}, TargetValue(2));
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
training_data.push_back(example_1);
training_data.push_back(example_2);
std::unique_ptr<Model> model = trainer_.Train(task_, training_data);
......@@ -179,8 +179,8 @@ TEST_P(RandomTreeTest, UnknownFeatureValueHandling) {
// Verify how a previously unseen feature value is handled.
SetupFeatures(1);
TrainingData training_data;
TrainingExample example_1({FeatureValue(123)}, TargetValue(1));
TrainingExample example_2({FeatureValue(456)}, TargetValue(2));
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(456)}, TargetValue(2));
training_data.push_back(example_1);
training_data.push_back(example_2);
......@@ -223,7 +223,7 @@ TEST_P(RandomTreeTest, NumericFeaturesSplitMultipleTimes) {
TrainingData training_data;
const int feature_mult = 10;
for (size_t i = 0; i < 4; i++) {
TrainingExample example({FeatureValue(i * feature_mult)}, TargetValue(i));
LabelledExample example({FeatureValue(i * feature_mult)}, TargetValue(i));
training_data.push_back(example);
}
......
......@@ -41,7 +41,7 @@ TargetDistribution& TargetDistribution::operator+=(const TargetValue& rhs) {
}
TargetDistribution& TargetDistribution::operator+=(
const TrainingExample& example) {
const LabelledExample& example) {
counts_[example.target_value] += example.weight;
return *this;
}
......
......@@ -11,7 +11,7 @@
#include "base/component_export.h"
#include "base/containers/flat_map.h"
#include "base/macros.h"
#include "media/learning/common/training_example.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/value.h"
namespace media {
......@@ -42,7 +42,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution {
TargetDistribution& operator+=(const TargetValue& rhs);
// Increment the distribution by |example|'s target value and weight.
TargetDistribution& operator+=(const TrainingExample& example);
TargetDistribution& operator+=(const LabelledExample& example);
// Return the number of counts for |value|.
size_t operator[](const TargetValue& value) const;
......
......@@ -148,8 +148,8 @@ TEST_F(TargetDistributionTest, UnequalDistributionsCompareAsNotEqual) {
EXPECT_FALSE(distribution_ == distribution_2);
}
TEST_F(TargetDistributionTest, WeightedTrainingExamplesCountCorrectly) {
TrainingExample example = {{}, value_1};
TEST_F(TargetDistributionTest, WeightedLabelledExamplesCountCorrectly) {
LabelledExample example = {{}, value_1};
example.weight = counts_1;
distribution_ += example;
......
......@@ -8,7 +8,7 @@
#include <memory>
#include "base/callback.h"
#include "media/learning/common/training_example.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/impl/model.h"
namespace media {
......
......@@ -20,7 +20,7 @@ void MojoLearningSessionImpl::Bind(mojom::LearningSessionRequest request) {
}
void MojoLearningSessionImpl::AddExample(mojom::LearningTaskType task_type,
const TrainingExample& example) {
const LabelledExample& example) {
// TODO(liberato): Convert |task_type| into a task name.
std::string task_name("no_task");
......
......@@ -29,7 +29,7 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningSessionImpl
// mojom::LearningSession
void AddExample(mojom::LearningTaskType task_type,
const TrainingExample& example) override;
const LabelledExample& example) override;
protected:
explicit MojoLearningSessionImpl(
......
......@@ -18,13 +18,13 @@ class MojoLearningSessionImplTest : public ::testing::Test {
class FakeLearningSession : public ::media::learning::LearningSession {
public:
void AddExample(const std::string& task_name,
const TrainingExample& example) override {
const LabelledExample& example) override {
most_recent_task_name_ = task_name;
most_recent_example_ = example;
}
std::string most_recent_task_name_;
TrainingExample most_recent_example_;
LabelledExample most_recent_example_;
};
public:
......@@ -50,8 +50,8 @@ class MojoLearningSessionImplTest : public ::testing::Test {
};
TEST_F(MojoLearningSessionImplTest, FeaturesAndTargetValueAreCopied) {
mojom::TrainingExamplePtr example_ptr = mojom::TrainingExample::New();
const TrainingExample example = {{Value(123), Value(456), Value(890)},
mojom::LabelledExamplePtr example_ptr = mojom::LabelledExample::New();
const LabelledExample example = {{Value(123), Value(456), Value(890)},
TargetValue(1234)};
learning_session_impl_->AddExample(task_type_, example);
......
......@@ -7,10 +7,10 @@
namespace mojo {
// static
bool StructTraits<media::learning::mojom::TrainingExampleDataView,
media::learning::TrainingExample>::
Read(media::learning::mojom::TrainingExampleDataView data,
media::learning::TrainingExample* out_example) {
bool StructTraits<media::learning::mojom::LabelledExampleDataView,
media::learning::LabelledExample>::
Read(media::learning::mojom::LabelledExampleDataView data,
media::learning::LabelledExample* out_example) {
out_example->features.clear();
if (!data.ReadFeatures(&out_example->features))
return false;
......
......@@ -14,20 +14,20 @@
namespace mojo {
template <>
class StructTraits<media::learning::mojom::TrainingExampleDataView,
media::learning::TrainingExample> {
class StructTraits<media::learning::mojom::LabelledExampleDataView,
media::learning::LabelledExample> {
public:
static const std::vector<media::learning::FeatureValue>& features(
const media::learning::TrainingExample& e) {
const media::learning::LabelledExample& e) {
return e.features;
}
static media::learning::TargetValue target_value(
const media::learning::TrainingExample& e) {
const media::learning::LabelledExample& e) {
return e.target_value;
}
static bool Read(media::learning::mojom::TrainingExampleDataView data,
media::learning::TrainingExample* out_example);
static bool Read(media::learning::mojom::LabelledExampleDataView data,
media::learning::LabelledExample* out_example);
};
template <>
......
......@@ -15,7 +15,7 @@ MojoLearningSession::MojoLearningSession(mojom::LearningSessionPtr session_ptr)
MojoLearningSession::~MojoLearningSession() = default;
void MojoLearningSession::AddExample(const std::string& task_name,
const TrainingExample& example) {
const LabelledExample& example) {
// TODO(liberato): Convert from |task_name| to a task type.
session_ptr_->AddExample(mojom::LearningTaskType::kPlaceHolderTask, example);
}
......
......@@ -22,7 +22,7 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningSession
// LearningSession
void AddExample(const std::string& task_name,
const TrainingExample& example) override;
const LabelledExample& example) override;
private:
mojom::LearningSessionPtr session_ptr_;
......
......@@ -22,13 +22,13 @@ class MojoLearningSessionTest : public ::testing::Test {
class FakeMojoLearningSessionImpl : public mojom::LearningSession {
public:
void AddExample(mojom::LearningTaskType task_type,
const TrainingExample& example) override {
const LabelledExample& example) override {
task_type_ = std::move(task_type);
example_ = example;
}
mojom::LearningTaskType task_type_;
TrainingExample example_;
LabelledExample example_;
};
public:
......@@ -57,7 +57,7 @@ class MojoLearningSessionTest : public ::testing::Test {
};
TEST_F(MojoLearningSessionTest, ExampleIsCopied) {
TrainingExample example({FeatureValue(123), FeatureValue(456)},
LabelledExample example({FeatureValue(123), FeatureValue(456)},
TargetValue(1234));
learning_session_->AddExample("unused task id", example);
learning_session_binding_.FlushForTesting();
......
......@@ -15,5 +15,5 @@ enum LearningTaskType {
// media/learning/public/learning_session.h
interface LearningSession {
// Add |example| to |task_type|.
AddExample(LearningTaskType task_type, TrainingExample example);
AddExample(LearningTaskType task_type, LabelledExample example);
};
......@@ -14,8 +14,8 @@ struct TargetValue {
int64 value;
};
// learning::TrainingExample (common/training_example.h)
struct TrainingExample {
// learning::LabelledExample (common/training_example.h)
struct LabelledExample {
array<FeatureValue> features;
TargetValue target_value;
};
mojom = "//media/learning/mojo/public/mojom/learning_types.mojom"
public_headers = [
"//media/learning/common/training_example.h",
"//media/learning/common/labelled_example.h",
"//media/learning/common/value.h",
]
traits_headers = [ "//media/learning/mojo/public/cpp/learning_mojom_traits.h" ]
......@@ -12,7 +12,7 @@ public_deps = [
"//media/learning/common",
]
type_mappings = [
"media.learning.mojom.TrainingExample=media::learning::TrainingExample",
"media.learning.mojom.LabelledExample=media::learning::LabelledExample",
"media.learning.mojom.FeatureValue=media::learning::FeatureValue",
"media.learning.mojom.TargetValue=media::learning::TargetValue",
]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment