Commit c5f4d14d authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Instance => TrainingExample.

Rather than sending (Instance, TargetValue), just bundle it up
into TrainingExample.

Change-Id: I145dd136af2fec8e2d1177ad8c9fc42dd2e79cf9
Reviewed-on: https://chromium-review.googlesource.com/c/1318389Reviewed-by: default avatarFredrik Hubinette <hubbe@chromium.org>
Commit-Queue: Frank Liberato <liberato@chromium.org>
Cr-Commit-Position: refs/heads/master@{#607365}
parent b52b13f9
...@@ -14,12 +14,12 @@ component("common") { ...@@ -14,12 +14,12 @@ component("common") {
defines = [ "IS_LEARNING_COMMON_IMPL" ] defines = [ "IS_LEARNING_COMMON_IMPL" ]
sources = [ sources = [
"instance.cc",
"instance.h",
"learning_session.cc", "learning_session.cc",
"learning_session.h", "learning_session.h",
"learning_task.cc", "learning_task.cc",
"learning_task.h", "learning_task.h",
"training_example.cc",
"training_example.h",
"value.cc", "value.cc",
"value.h", "value.h",
] ]
...@@ -32,6 +32,7 @@ component("common") { ...@@ -32,6 +32,7 @@ component("common") {
source_set("unit_tests") { source_set("unit_tests") {
testonly = true testonly = true
sources = [ sources = [
"training_example_unittest.cc",
"value_unittest.cc", "value_unittest.cc",
] ]
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/common/instance.h"
namespace media {
namespace learning {
Instance::Instance() = default;
Instance::Instance(std::initializer_list<FeatureValue> init_list)
: features(init_list) {}
Instance::~Instance() = default;
std::ostream& operator<<(std::ostream& out, const Instance& instance) {
for (const auto& feature : instance.features)
out << " " << feature;
return out;
}
bool Instance::operator==(const Instance& rhs) const {
return features == rhs.features;
}
} // namespace learning
} // namespace media
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base/component_export.h" #include "base/component_export.h"
#include "base/macros.h" #include "base/macros.h"
#include "media/learning/common/learning_task.h" #include "media/learning/common/learning_task.h"
#include "media/learning/common/training_example.h"
namespace media { namespace media {
namespace learning { namespace learning {
...@@ -20,12 +21,10 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession { ...@@ -20,12 +21,10 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession {
LearningSession(); LearningSession();
virtual ~LearningSession(); virtual ~LearningSession();
// Add an observed example of |instance| with target value |target| to the // Add an observed example |example| to the learning task |task_name|.
// learning task |task_name|.
// TODO(liberato): Consider making this an enum to match mojo. // TODO(liberato): Consider making this an enum to match mojo.
virtual void AddExample(const std::string& task_name, virtual void AddExample(const std::string& task_name,
const Instance& instance, const TrainingExample& example) = 0;
const TargetValue& target) = 0;
// TODO(liberato): Add prediction API. // TODO(liberato): Add prediction API.
......
...@@ -7,9 +7,10 @@ ...@@ -7,9 +7,10 @@
#include <initializer_list> #include <initializer_list>
#include <string> #include <string>
#include <vector>
#include "base/component_export.h" #include "base/component_export.h"
#include "media/learning/common/instance.h" #include "media/learning/common/value.h"
namespace media { namespace media {
namespace learning { namespace learning {
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/common/training_example.h"
namespace media {
namespace learning {
TrainingExample::TrainingExample() = default;
TrainingExample::TrainingExample(std::initializer_list<FeatureValue> init_list,
TargetValue target)
: features(init_list), target_value(target) {}
TrainingExample::TrainingExample(const TrainingExample& rhs)
: features(rhs.features), target_value(rhs.target_value) {}
TrainingExample::TrainingExample(TrainingExample&& rhs) noexcept
: features(std::move(rhs.features)),
target_value(std::move(rhs.target_value)) {}
TrainingExample::~TrainingExample() = default;
std::ostream& operator<<(std::ostream& out, const TrainingExample& example) {
for (const auto& feature : example.features)
out << " " << feature;
out << " => " << example.target_value;
return out;
}
bool TrainingExample::operator==(const TrainingExample& rhs) const {
return target_value == rhs.target_value && features == rhs.features;
}
bool TrainingExample::operator!=(const TrainingExample& rhs) const {
return !((*this) == rhs);
}
} // namespace learning
} // namespace media
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#ifndef MEDIA_LEARNING_COMMON_INSTANCE_H_ #ifndef MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#define MEDIA_LEARNING_COMMON_INSTANCE_H_ #define MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#include <initializer_list> #include <initializer_list>
#include <ostream> #include <ostream>
...@@ -16,26 +16,31 @@ ...@@ -16,26 +16,31 @@
namespace media { namespace media {
namespace learning { namespace learning {
// One instance == group of feature values. // One training example == group of feature values, plus the desired target.
struct COMPONENT_EXPORT(LEARNING_COMMON) Instance { struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
Instance(); TrainingExample();
Instance(std::initializer_list<FeatureValue> init_list); TrainingExample(std::initializer_list<FeatureValue> init_list,
~Instance(); TargetValue target);
TrainingExample(const TrainingExample& rhs);
TrainingExample(TrainingExample&& rhs) noexcept;
~TrainingExample();
bool operator==(const Instance& rhs) const; bool operator==(const TrainingExample& rhs) const;
bool operator!=(const TrainingExample& rhs) const;
// It's up to you to add the right number of features to match the learner // Observed feature values.
// description. Otherwise, the learner will ignore (training) or lie to you
// (inference), silently.
std::vector<FeatureValue> features; std::vector<FeatureValue> features;
// Observed output value, when given |features| as input.
TargetValue target_value;
// Copy / assignment is allowed. // Copy / assignment is allowed.
}; };
COMPONENT_EXPORT(LEARNING_COMMON) COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const Instance& instance); std::ostream& operator<<(std::ostream& out, const TrainingExample& example);
} // namespace learning } // namespace learning
} // namespace media } // namespace media
#endif // MEDIA_LEARNING_COMMON_INSTANCE_H_ #endif // MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/common/training_example.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class LearnerTrainingExampleTest : public testing::Test {};
TEST_F(LearnerTrainingExampleTest, InitListWorks) {
const int kFeature1 = 123;
const int kFeature2 = 456;
std::vector<FeatureValue> features = {FeatureValue(kFeature1),
FeatureValue(kFeature2)};
TargetValue target(789);
TrainingExample example({FeatureValue(kFeature1), FeatureValue(kFeature2)},
target);
EXPECT_EQ(example.features, features);
EXPECT_EQ(example.target_value, target);
}
TEST_F(LearnerTrainingExampleTest, CopyConstructionWorks) {
TrainingExample example_1({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
TrainingExample example_2(example_1);
EXPECT_EQ(example_1, example_2);
}
TEST_F(LearnerTrainingExampleTest, MoveConstructionWorks) {
TrainingExample example_1({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
TrainingExample example_1_copy(example_1);
TrainingExample example_1_move(std::move(example_1));
EXPECT_EQ(example_1_copy, example_1_move);
EXPECT_NE(example_1_copy, example_1);
}
TEST_F(LearnerTrainingExampleTest, EqualExamplesCompareAsEqual) {
const int kFeature1 = 123;
const int kFeature2 = 456;
TargetValue target(789);
TrainingExample example_1({FeatureValue(kFeature1), FeatureValue(kFeature2)},
target);
TrainingExample example_2({FeatureValue(kFeature1), FeatureValue(kFeature2)},
target);
// Verify both that == and != work.
EXPECT_EQ(example_1, example_2);
EXPECT_FALSE(example_1 != example_2);
}
TEST_F(LearnerTrainingExampleTest, UnequalFeaturesCompareAsUnequal) {
const int kFeature1 = 123;
const int kFeature2 = 456;
TargetValue target(789);
TrainingExample example_1({FeatureValue(kFeature1), FeatureValue(kFeature1)},
target);
TrainingExample example_2({FeatureValue(kFeature2), FeatureValue(kFeature2)},
target);
EXPECT_NE(example_1, example_2);
EXPECT_FALSE(example_1 == example_2);
}
TEST_F(LearnerTrainingExampleTest, UnequalTargetsCompareAsUnequal) {
const int kFeature1 = 123;
const int kFeature2 = 456;
TrainingExample example_1({FeatureValue(kFeature1), FeatureValue(kFeature1)},
TargetValue(789));
TrainingExample example_2({FeatureValue(kFeature2), FeatureValue(kFeature2)},
TargetValue(987));
EXPECT_NE(example_1, example_2);
EXPECT_FALSE(example_1 == example_2);
}
} // namespace learning
} // namespace media
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "base/component_export.h" #include "base/component_export.h"
#include "base/values.h" #include "base/values.h"
#include "media/learning/common/instance.h" #include "media/learning/common/training_example.h"
namespace media { namespace media {
namespace learning { namespace learning {
...@@ -20,10 +20,8 @@ class COMPONENT_EXPORT(LEARNING_IMPL) Learner { ...@@ -20,10 +20,8 @@ class COMPONENT_EXPORT(LEARNING_IMPL) Learner {
public: public:
virtual ~Learner() = default; virtual ~Learner() = default;
// Tell the learner that |instance| has been observed with the target value // Tell the learner that |example| has been observed during training.
// |target| during training. virtual void AddExample(const TrainingExample& example) = 0;
virtual void AddExample(const Instance& instance,
const TargetValue& target) = 0;
}; };
} // namespace learning } // namespace learning
......
...@@ -13,8 +13,7 @@ LearningSessionImpl::LearningSessionImpl() = default; ...@@ -13,8 +13,7 @@ LearningSessionImpl::LearningSessionImpl() = default;
LearningSessionImpl::~LearningSessionImpl() = default; LearningSessionImpl::~LearningSessionImpl() = default;
void LearningSessionImpl::AddExample(const std::string& task_name, void LearningSessionImpl::AddExample(const std::string& task_name,
const Instance& instance, const TrainingExample& example) {
const TargetValue& target) {
// TODO: match |task_name| against a list of learning tasks, and find the // TODO: match |task_name| against a list of learning tasks, and find the
// learner(s) for it. Then add |instance|, |target| to it. // learner(s) for it. Then add |instance|, |target| to it.
NOTIMPLEMENTED(); NOTIMPLEMENTED();
......
...@@ -21,8 +21,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl ...@@ -21,8 +21,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl
// LearningSession // LearningSession
void AddExample(const std::string& task_name, void AddExample(const std::string& task_name,
const Instance& instance, const TrainingExample& example) override;
const TargetValue& target) override;
}; };
} // namespace learning } // namespace learning
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment