Commit 64546ed4 authored by Chris Cunningham's avatar Chris Cunningham Committed by Commit Bot

Get media::learning::LearningTasks by name rather than ID.

The primary purpose of a LearningTask is to describe/configure a
LearningTaskController. LearningTaskController's are registered
by name, whereas LearningTasks have been registered by enum ID. This CL
remedies the asymmetry, making retrieval of Tasks match that of
Controllers by using a "name" in both places.

Change-Id: Ibc58b42e9c92876134f3da524b0804f3e1d71809
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2053619
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarFrank Liberato <liberato@chromium.org>
Auto-Submit: Chrome Cunningham <chcunningham@chromium.org>
Cr-Commit-Position: refs/heads/master@{#741100}
parent 81a3365f
...@@ -3739,18 +3739,16 @@ void WebMediaPlayerImpl::UpdateSmoothnessHelper() { ...@@ -3739,18 +3739,16 @@ void WebMediaPlayerImpl::UpdateSmoothnessHelper() {
// Create or restart the smoothness helper with |features|. // Create or restart the smoothness helper with |features|.
smoothness_helper_ = SmoothnessHelper::Create( smoothness_helper_ = SmoothnessHelper::Create(
GetLearningTaskController( GetLearningTaskController(learning::tasknames::kConsecutiveBadWindows),
learning::MediaLearningTasks::Id::kConsecutiveBadWindows), GetLearningTaskController(learning::tasknames::kConsecutiveNNRs),
GetLearningTaskController(
learning::MediaLearningTasks::Id::kConsecutiveNNRs),
features, this); features, this);
} }
std::unique_ptr<learning::LearningTaskController> std::unique_ptr<learning::LearningTaskController>
WebMediaPlayerImpl::GetLearningTaskController( WebMediaPlayerImpl::GetLearningTaskController(const char* task_name) {
learning::MediaLearningTasks::Id task_id) {
// Get the LearningTaskController for |task_id|. // Get the LearningTaskController for |task_id|.
learning::LearningTask task = learning::MediaLearningTasks::Get(task_id); learning::LearningTask task = learning::MediaLearningTasks::Get(task_name);
DCHECK_EQ(task.name, task_name);
mojo::Remote<media::learning::mojom::LearningTaskController> remote_ltc; mojo::Remote<media::learning::mojom::LearningTaskController> remote_ltc;
media_metrics_provider_->AcquireLearningTaskController( media_metrics_provider_->AcquireLearningTaskController(
......
...@@ -644,9 +644,9 @@ class MEDIA_BLINK_EXPORT WebMediaPlayerImpl ...@@ -644,9 +644,9 @@ class MEDIA_BLINK_EXPORT WebMediaPlayerImpl
// smoothness right now. // smoothness right now.
void UpdateSmoothnessHelper(); void UpdateSmoothnessHelper();
// Get the LearningTaskController for |task_id|. // Get the LearningTaskController for |task_name|.
std::unique_ptr<learning::LearningTaskController> GetLearningTaskController( std::unique_ptr<learning::LearningTaskController> GetLearningTaskController(
learning::MediaLearningTasks::Id task_id); const char* task_name);
blink::WebLocalFrame* const frame_; blink::WebLocalFrame* const frame_;
......
...@@ -10,7 +10,7 @@ namespace learning { ...@@ -10,7 +10,7 @@ namespace learning {
static const LearningTask& GetWillPlayTask() { static const LearningTask& GetWillPlayTask() {
static LearningTask task_; static LearningTask task_;
if (!task_.feature_descriptions.size()) { if (!task_.feature_descriptions.size()) {
task_.name = "MediaLearningWillPlay"; task_.name = tasknames::kWillPlay;
// TODO(liberato): fill in the rest here, once we have the features picked. // TODO(liberato): fill in the rest here, once we have the features picked.
} }
...@@ -34,7 +34,7 @@ static void PushWMPIFeatures(LearningTask& task) { ...@@ -34,7 +34,7 @@ static void PushWMPIFeatures(LearningTask& task) {
static const LearningTask& GetConsecutiveBadWindowsTask() { static const LearningTask& GetConsecutiveBadWindowsTask() {
static LearningTask task_; static LearningTask task_;
if (!task_.feature_descriptions.size()) { if (!task_.feature_descriptions.size()) {
task_.name = "MediaLearningConsecutiveBadWindows"; task_.name = tasknames::kConsecutiveBadWindows;
task_.model = LearningTask::Model::kExtraTrees; task_.model = LearningTask::Model::kExtraTrees;
// Target is max number of consecutive bad windows. // Target is max number of consecutive bad windows.
...@@ -57,7 +57,7 @@ static const LearningTask& GetConsecutiveBadWindowsTask() { ...@@ -57,7 +57,7 @@ static const LearningTask& GetConsecutiveBadWindowsTask() {
static const LearningTask& GetConsecutiveNNRsTask() { static const LearningTask& GetConsecutiveNNRsTask() {
static LearningTask task_; static LearningTask task_;
if (!task_.feature_descriptions.size()) { if (!task_.feature_descriptions.size()) {
task_.name = "MediaLearningConsecutiveNNRs"; task_.name = tasknames::kConsecutiveNNRs;
task_.model = LearningTask::Model::kExtraTrees; task_.model = LearningTask::Model::kExtraTrees;
// Target is max number of consecutive bad windows. // Target is max number of consecutive bad windows.
...@@ -75,23 +75,25 @@ static const LearningTask& GetConsecutiveNNRsTask() { ...@@ -75,23 +75,25 @@ static const LearningTask& GetConsecutiveNNRsTask() {
} }
// static // static
const LearningTask& MediaLearningTasks::Get(Id id) { const LearningTask& MediaLearningTasks::Get(const char* task_name) {
switch (id) { if (strcmp(task_name, tasknames::kWillPlay) == 0)
case Id::kWillPlay: return GetWillPlayTask();
return GetWillPlayTask(); if (strcmp(task_name, tasknames::kConsecutiveBadWindows) == 0)
case Id::kConsecutiveBadWindows: return GetConsecutiveBadWindowsTask();
return GetConsecutiveBadWindowsTask(); if (strcmp(task_name, tasknames::kConsecutiveNNRs) == 0)
case Id::kConsecutiveNNRs: return GetConsecutiveNNRsTask();
return GetConsecutiveNNRsTask();
} NOTREACHED() << " Unknown learning task:" << task_name;
static LearningTask empty_task;
return empty_task;
} }
// static // static
void MediaLearningTasks::Register( void MediaLearningTasks::Register(
base::RepeatingCallback<void(const LearningTask&)> cb) { base::RepeatingCallback<void(const LearningTask&)> cb) {
cb.Run(Get(Id::kWillPlay)); cb.Run(Get(tasknames::kWillPlay));
cb.Run(Get(Id::kConsecutiveBadWindows)); cb.Run(Get(tasknames::kConsecutiveBadWindows));
cb.Run(Get(Id::kConsecutiveNNRs)); cb.Run(Get(tasknames::kConsecutiveNNRs));
} }
} // namespace learning } // namespace learning
......
...@@ -12,21 +12,20 @@ ...@@ -12,21 +12,20 @@
namespace media { namespace media {
namespace learning { namespace learning {
namespace tasknames {
constexpr char kWillPlay[] = "MediaLearningWillPlay";
constexpr char kConsecutiveBadWindows[] = "MediaLearningConsecutiveBadWindows";
constexpr char kConsecutiveNNRs[] = "MediaLearningConsecutiveNNRs";
} // namespace tasknames
// All learning experiments for media/ . // All learning experiments for media/ .
// TODO(liberato): This should be in media/ somewhere, since the learning // TODO(liberato): This should be in media/ somewhere, since the learning
// framework doesn't care about it. For now, this is simpler to make deps // framework doesn't care about it. For now, this is simpler to make deps
// easier to handle. // easier to handle.
class COMPONENT_EXPORT(LEARNING_COMMON) MediaLearningTasks { class COMPONENT_EXPORT(LEARNING_COMMON) MediaLearningTasks {
public: public:
// Ids for each LearningTask. // Return the LearningTask for |name|.
enum class Id { static const learning::LearningTask& Get(const char* task_name);
kWillPlay,
kConsecutiveBadWindows,
kConsecutiveNNRs,
};
// Return the LearningTask for |id|.
static const learning::LearningTask& Get(Id id);
// Register all tasks by calling |registration_cb| repeatedly. // Register all tasks by calling |registration_cb| repeatedly.
static void Register( static void Register(
......
...@@ -18,22 +18,20 @@ namespace learning { ...@@ -18,22 +18,20 @@ namespace learning {
class MediaLearningTasksTest : public testing::Test {}; class MediaLearningTasksTest : public testing::Test {};
TEST_F(MediaLearningTasksTest, WillPlayTask) { TEST_F(MediaLearningTasksTest, WillPlayTask) {
LearningTask task = LearningTask task = MediaLearningTasks::Get(tasknames::kWillPlay);
MediaLearningTasks::Get(MediaLearningTasks::Id::kWillPlay);
// Make sure the name is correct, mostly to reduce cut-and-paste errors. // Make sure the name is correct, mostly to reduce cut-and-paste errors.
EXPECT_EQ(task.name, "MediaLearningWillPlay"); EXPECT_EQ(task.name, "MediaLearningWillPlay");
} }
TEST_F(MediaLearningTasksTest, ConsecutiveBadWindowsTask) { TEST_F(MediaLearningTasksTest, ConsecutiveBadWindowsTask) {
LearningTask task = LearningTask task =
MediaLearningTasks::Get(MediaLearningTasks::Id::kConsecutiveBadWindows); MediaLearningTasks::Get(tasknames::kConsecutiveBadWindows);
// Make sure the name is correct, mostly to reduce cut-and-paste errors. // Make sure the name is correct, mostly to reduce cut-and-paste errors.
EXPECT_EQ(task.name, "MediaLearningConsecutiveBadWindows"); EXPECT_EQ(task.name, "MediaLearningConsecutiveBadWindows");
} }
TEST_F(MediaLearningTasksTest, ConsecutiveNNRsTask) { TEST_F(MediaLearningTasksTest, ConsecutiveNNRsTask) {
LearningTask task = LearningTask task = MediaLearningTasks::Get(tasknames::kConsecutiveNNRs);
MediaLearningTasks::Get(MediaLearningTasks::Id::kConsecutiveNNRs);
// Make sure the name is correct, mostly to reduce cut-and-paste errors. // Make sure the name is correct, mostly to reduce cut-and-paste errors.
EXPECT_EQ(task.name, "MediaLearningConsecutiveNNRs"); EXPECT_EQ(task.name, "MediaLearningConsecutiveNNRs");
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment