Commit d1d183b2 authored by liberato@chromium.org's avatar liberato@chromium.org Committed by Commit Bot

Allow registration of LearningTasks.

This CL adds the ability to register a LearningTask with a
LearningSessionImpl.  It, in turn, creates a LearningTaskController
for that task.

The LearningTaskController is responsible for taking training
examples, and forwarding them to learner(s), along with feature
selection and accuracy reporting.

Change-Id: I071dd08063b1ff2482ca4e9170f7f2d0cb6bd10a
Reviewed-on: https://chromium-review.googlesource.com/c/1315934
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarXiaohan Wang <xhwang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#609127}
parent eb84fe2e
......@@ -10,6 +10,9 @@ component("impl") {
"learner.h",
"learning_session_impl.cc",
"learning_session_impl.h",
"learning_task_controller.h",
"learning_task_controller_impl.cc",
"learning_task_controller_impl.h",
"random_tree.cc",
"random_tree.h",
]
......@@ -29,10 +32,12 @@ source_set("unit_tests") {
testonly = true
sources = [
"learning_session_impl_unittest.cc",
"random_tree_unittest.cc",
]
deps = [
":impl",
"//base/test:test_support",
"//media:test_support",
"//media/learning/impl",
......
......@@ -4,19 +4,37 @@
#include "media/learning/impl/learning_session_impl.h"
#include "base/bind.h"
#include "base/logging.h"
#include "media/learning/impl/learning_task_controller_impl.h"
namespace media {
namespace learning {
LearningSessionImpl::LearningSessionImpl() = default;
LearningSessionImpl::LearningSessionImpl()
: controller_factory_(
base::BindRepeating([](const LearningTask& task)
-> std::unique_ptr<LearningTaskController> {
return std::make_unique<LearningTaskControllerImpl>(task);
})) {}
LearningSessionImpl::~LearningSessionImpl() = default;
void LearningSessionImpl::SetTaskControllerFactoryCBForTesting(
CreateTaskControllerCB cb) {
controller_factory_ = std::move(cb);
}
void LearningSessionImpl::AddExample(const std::string& task_name,
const TrainingExample& example) {
// TODO: match |task_name| against a list of learning tasks, and find the
// learner(s) for it. Then add |instance|, |target| to it.
NOTIMPLEMENTED();
auto iter = task_map_.find(task_name);
if (iter != task_map_.end())
iter->second->AddExample(example);
}
void LearningSessionImpl::RegisterTask(const LearningTask& task) {
DCHECK(task_map_.count(task.name) == 0);
task_map_.emplace(task.name, controller_factory_.Run(task));
}
} // namespace learning
......
......@@ -5,23 +5,44 @@
#ifndef MEDIA_LEARNING_IMPL_LEARNING_SESSION_IMPL_H_
#define MEDIA_LEARNING_IMPL_LEARNING_SESSION_IMPL_H_
#include <map>
#include "base/component_export.h"
#include "media/learning/common/learning_session.h"
#include "media/learning/impl/learning_task_controller.h"
namespace media {
namespace learning {
// Concrete implementation of a LearningSession. This would have a list of
// learning tasks, and could provide local learners for each task.
// Concrete implementation of a LearningSession. This allows registration of
// learning tasks.
class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl
: public LearningSession {
public:
LearningSessionImpl();
explicit LearningSessionImpl();
~LearningSessionImpl() override;
using CreateTaskControllerCB =
base::RepeatingCallback<std::unique_ptr<LearningTaskController>(
const LearningTask&)>;
void SetTaskControllerFactoryCBForTesting(CreateTaskControllerCB cb);
// LearningSession
void AddExample(const std::string& task_name,
const TrainingExample& example) override;
// Registers |task|, so that calls to AddExample with |task.name| will work.
// This will create a new controller for the task.
void RegisterTask(const LearningTask& task);
private:
// [task_name] = task controller.
using LearningTaskMap =
std::map<std::string, std::unique_ptr<LearningTaskController>>;
LearningTaskMap task_map_;
CreateTaskControllerCB controller_factory_;
};
} // namespace learning
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <vector>
#include "base/bind.h"
#include "media/learning/impl/learning_session_impl.h"
#include "media/learning/impl/learning_task_controller.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class LearningSessionImplTest : public testing::Test {
public:
class FakeLearningTaskController : public LearningTaskController {
public:
FakeLearningTaskController(const LearningTask& task) {}
void AddExample(const TrainingExample& example) override {
example_ = example;
}
TrainingExample example_;
};
using ControllerVector = std::vector<FakeLearningTaskController*>;
LearningSessionImplTest() {
session_ = std::make_unique<LearningSessionImpl>();
session_->SetTaskControllerFactoryCBForTesting(base::BindRepeating(
[](ControllerVector* controllers, const LearningTask& task)
-> std::unique_ptr<LearningTaskController> {
auto controller = std::make_unique<FakeLearningTaskController>(task);
controllers->push_back(controller.get());
return controller;
},
&task_controllers_));
task_0_.name = "task_0";
task_1_.name = "task_1";
}
std::unique_ptr<LearningSessionImpl> session_;
LearningTask task_0_;
LearningTask task_1_;
ControllerVector task_controllers_;
};
TEST_F(LearningSessionImplTest, RegisteringTasksCreatesControllers) {
EXPECT_EQ(task_controllers_.size(), 0u);
session_->RegisterTask(task_0_);
EXPECT_EQ(task_controllers_.size(), 1u);
session_->RegisterTask(task_1_);
EXPECT_EQ(task_controllers_.size(), 2u);
}
TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) {
session_->RegisterTask(task_0_);
session_->RegisterTask(task_1_);
TrainingExample example_0({FeatureValue(123), FeatureValue(456)},
TargetValue(1234));
session_->AddExample(task_0_.name, example_0);
TrainingExample example_1({FeatureValue(321), FeatureValue(654)},
TargetValue(4321));
session_->AddExample(task_1_.name, example_1);
EXPECT_EQ(task_controllers_[0]->example_, example_0);
EXPECT_EQ(task_controllers_[1]->example_, example_1);
}
} // namespace learning
} // namespace media
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_LEARNING_TASK_CONTROLLER_H_
#define MEDIA_LEARNING_IMPL_LEARNING_TASK_CONTROLLER_H_
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/common/training_example.h"
namespace media {
namespace learning {
// Controller for a single learning task. Takes training examples, and forwards
// them to the learner(s). Responsible for things like:
// - Managing underlying learner(s) based on the learning task
// - Feature subset selection
// - UMA reporting on accuracy / feature importance
//
// The idea is that one can create a LearningTask, give it to an LTC, and the
// LTC will do the work of building / evaluating the model based on training
// examples that are provided to it.
class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskController {
public:
LearningTaskController() = default;
virtual ~LearningTaskController() = default;
// Receive an example for this task.
virtual void AddExample(const TrainingExample& example) = 0;
private:
DISALLOW_COPY_AND_ASSIGN(LearningTaskController);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_LEARNING_TASK_CONTROLLER_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/learning_task_controller_impl.h"
#include <memory>
#include "base/bind.h"
namespace media {
namespace learning {
LearningTaskControllerImpl::LearningTaskControllerImpl(
const LearningTask& task) {}
LearningTaskControllerImpl::~LearningTaskControllerImpl() = default;
void LearningTaskControllerImpl::AddExample(const TrainingExample& example) {
// TODO: do something.
}
} // namespace learning
} // namespace media
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_LEARNING_TASK_CONTROLLER_IMPL_H_
#define MEDIA_LEARNING_IMPL_LEARNING_TASK_CONTROLLER_IMPL_H_
#include <memory>
#include "base/component_export.h"
#include "media/learning/impl/learning_task_controller.h"
namespace media {
namespace learning {
class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
: public LearningTaskController {
public:
explicit LearningTaskControllerImpl(const LearningTask& task);
~LearningTaskControllerImpl() override;
// LearningTaskController
void AddExample(const TrainingExample& example) override;
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_LEARNING_TASK_CONTROLLER_IMPL_H_
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment