Commit c5f4d14d authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Instance => TrainingExample.

Rather than sending (Instance, TargetValue), just bundle it up
into TrainingExample.

Change-Id: I145dd136af2fec8e2d1177ad8c9fc42dd2e79cf9
Reviewed-on: https://chromium-review.googlesource.com/c/1318389Reviewed-by: default avatarFredrik Hubinette <hubbe@chromium.org>
Commit-Queue: Frank Liberato <liberato@chromium.org>
Cr-Commit-Position: refs/heads/master@{#607365}
parent b52b13f9
......@@ -14,12 +14,12 @@ component("common") {
defines = [ "IS_LEARNING_COMMON_IMPL" ]
sources = [
"instance.cc",
"instance.h",
"learning_session.cc",
"learning_session.h",
"learning_task.cc",
"learning_task.h",
"training_example.cc",
"training_example.h",
"value.cc",
"value.h",
]
......@@ -32,6 +32,7 @@ component("common") {
source_set("unit_tests") {
testonly = true
sources = [
"training_example_unittest.cc",
"value_unittest.cc",
]
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/common/instance.h"
namespace media {
namespace learning {
Instance::Instance() = default;
Instance::Instance(std::initializer_list<FeatureValue> init_list)
: features(init_list) {}
Instance::~Instance() = default;
std::ostream& operator<<(std::ostream& out, const Instance& instance) {
for (const auto& feature : instance.features)
out << " " << feature;
return out;
}
bool Instance::operator==(const Instance& rhs) const {
return features == rhs.features;
}
} // namespace learning
} // namespace media
......@@ -10,6 +10,7 @@
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/common/training_example.h"
namespace media {
namespace learning {
......@@ -20,12 +21,10 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession {
LearningSession();
virtual ~LearningSession();
// Add an observed example of |instance| with target value |target| to the
// learning task |task_name|.
// Add an observed example |example| to the learning task |task_name|.
// TODO(liberato): Consider making this an enum to match mojo.
virtual void AddExample(const std::string& task_name,
const Instance& instance,
const TargetValue& target) = 0;
const TrainingExample& example) = 0;
// TODO(liberato): Add prediction API.
......
......@@ -7,9 +7,10 @@
#include <initializer_list>
#include <string>
#include <vector>
#include "base/component_export.h"
#include "media/learning/common/instance.h"
#include "media/learning/common/value.h"
namespace media {
namespace learning {
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/common/training_example.h"
namespace media {
namespace learning {
TrainingExample::TrainingExample() = default;
TrainingExample::TrainingExample(std::initializer_list<FeatureValue> init_list,
TargetValue target)
: features(init_list), target_value(target) {}
TrainingExample::TrainingExample(const TrainingExample& rhs)
: features(rhs.features), target_value(rhs.target_value) {}
TrainingExample::TrainingExample(TrainingExample&& rhs) noexcept
: features(std::move(rhs.features)),
target_value(std::move(rhs.target_value)) {}
TrainingExample::~TrainingExample() = default;
std::ostream& operator<<(std::ostream& out, const TrainingExample& example) {
for (const auto& feature : example.features)
out << " " << feature;
out << " => " << example.target_value;
return out;
}
bool TrainingExample::operator==(const TrainingExample& rhs) const {
return target_value == rhs.target_value && features == rhs.features;
}
bool TrainingExample::operator!=(const TrainingExample& rhs) const {
return !((*this) == rhs);
}
} // namespace learning
} // namespace media
......@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_COMMON_INSTANCE_H_
#define MEDIA_LEARNING_COMMON_INSTANCE_H_
#ifndef MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#define MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
#include <initializer_list>
#include <ostream>
......@@ -16,26 +16,31 @@
namespace media {
namespace learning {
// One instance == group of feature values.
struct COMPONENT_EXPORT(LEARNING_COMMON) Instance {
Instance();
Instance(std::initializer_list<FeatureValue> init_list);
~Instance();
// One training example == group of feature values, plus the desired target.
struct COMPONENT_EXPORT(LEARNING_COMMON) TrainingExample {
TrainingExample();
TrainingExample(std::initializer_list<FeatureValue> init_list,
TargetValue target);
TrainingExample(const TrainingExample& rhs);
TrainingExample(TrainingExample&& rhs) noexcept;
~TrainingExample();
bool operator==(const Instance& rhs) const;
bool operator==(const TrainingExample& rhs) const;
bool operator!=(const TrainingExample& rhs) const;
// It's up to you to add the right number of features to match the learner
// description. Otherwise, the learner will ignore (training) or lie to you
// (inference), silently.
// Observed feature values.
std::vector<FeatureValue> features;
// Observed output value, when given |features| as input.
TargetValue target_value;
// Copy / assignment is allowed.
};
COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const Instance& instance);
std::ostream& operator<<(std::ostream& out, const TrainingExample& example);
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_COMMON_INSTANCE_H_
#endif // MEDIA_LEARNING_COMMON_TRAINING_EXAMPLE_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/common/training_example.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class LearnerTrainingExampleTest : public testing::Test {};
TEST_F(LearnerTrainingExampleTest, InitListWorks) {
const int kFeature1 = 123;
const int kFeature2 = 456;
std::vector<FeatureValue> features = {FeatureValue(kFeature1),
FeatureValue(kFeature2)};
TargetValue target(789);
TrainingExample example({FeatureValue(kFeature1), FeatureValue(kFeature2)},
target);
EXPECT_EQ(example.features, features);
EXPECT_EQ(example.target_value, target);
}
TEST_F(LearnerTrainingExampleTest, CopyConstructionWorks) {
TrainingExample example_1({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
TrainingExample example_2(example_1);
EXPECT_EQ(example_1, example_2);
}
TEST_F(LearnerTrainingExampleTest, MoveConstructionWorks) {
TrainingExample example_1({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
TrainingExample example_1_copy(example_1);
TrainingExample example_1_move(std::move(example_1));
EXPECT_EQ(example_1_copy, example_1_move);
EXPECT_NE(example_1_copy, example_1);
}
TEST_F(LearnerTrainingExampleTest, EqualExamplesCompareAsEqual) {
const int kFeature1 = 123;
const int kFeature2 = 456;
TargetValue target(789);
TrainingExample example_1({FeatureValue(kFeature1), FeatureValue(kFeature2)},
target);
TrainingExample example_2({FeatureValue(kFeature1), FeatureValue(kFeature2)},
target);
// Verify both that == and != work.
EXPECT_EQ(example_1, example_2);
EXPECT_FALSE(example_1 != example_2);
}
TEST_F(LearnerTrainingExampleTest, UnequalFeaturesCompareAsUnequal) {
const int kFeature1 = 123;
const int kFeature2 = 456;
TargetValue target(789);
TrainingExample example_1({FeatureValue(kFeature1), FeatureValue(kFeature1)},
target);
TrainingExample example_2({FeatureValue(kFeature2), FeatureValue(kFeature2)},
target);
EXPECT_NE(example_1, example_2);
EXPECT_FALSE(example_1 == example_2);
}
TEST_F(LearnerTrainingExampleTest, UnequalTargetsCompareAsUnequal) {
const int kFeature1 = 123;
const int kFeature2 = 456;
TrainingExample example_1({FeatureValue(kFeature1), FeatureValue(kFeature1)},
TargetValue(789));
TrainingExample example_2({FeatureValue(kFeature2), FeatureValue(kFeature2)},
TargetValue(987));
EXPECT_NE(example_1, example_2);
EXPECT_FALSE(example_1 == example_2);
}
} // namespace learning
} // namespace media
......@@ -7,7 +7,7 @@
#include "base/component_export.h"
#include "base/values.h"
#include "media/learning/common/instance.h"
#include "media/learning/common/training_example.h"
namespace media {
namespace learning {
......@@ -20,10 +20,8 @@ class COMPONENT_EXPORT(LEARNING_IMPL) Learner {
public:
virtual ~Learner() = default;
// Tell the learner that |instance| has been observed with the target value
// |target| during training.
virtual void AddExample(const Instance& instance,
const TargetValue& target) = 0;
// Tell the learner that |example| has been observed during training.
virtual void AddExample(const TrainingExample& example) = 0;
};
} // namespace learning
......
......@@ -13,8 +13,7 @@ LearningSessionImpl::LearningSessionImpl() = default;
LearningSessionImpl::~LearningSessionImpl() = default;
void LearningSessionImpl::AddExample(const std::string& task_name,
const Instance& instance,
const TargetValue& target) {
const TrainingExample& example) {
// TODO: match |task_name| against a list of learning tasks, and find the
// learner(s) for it. Then add |instance|, |target| to it.
NOTIMPLEMENTED();
......
......@@ -21,8 +21,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl
// LearningSession
void AddExample(const std::string& task_name,
const Instance& instance,
const TargetValue& target) override;
const TrainingExample& example) override;
};
} // namespace learning
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment