Commit 64546ed4 authored by Chris Cunningham's avatar Chris Cunningham Committed by Commit Bot

Get media::learning::LearningTasks by name rather than ID.

The primary purpose of a LearningTask is to describe/configure a
LearningTaskController. LearningTaskController's are registered
by name, whereas LearningTasks have been registered by enum ID. This CL
remedies the asymmetry, making retrieval of Tasks match that of
Controllers by using a "name" in both places.

Change-Id: Ibc58b42e9c92876134f3da524b0804f3e1d71809
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2053619
Commit-Queue: Frank Liberato <liberato@chromium.org>
Reviewed-by: default avatarFrank Liberato <liberato@chromium.org>
Auto-Submit: Chrome Cunningham <chcunningham@chromium.org>
Cr-Commit-Position: refs/heads/master@{#741100}
parent 81a3365f
......@@ -3739,18 +3739,16 @@ void WebMediaPlayerImpl::UpdateSmoothnessHelper() {
// Create or restart the smoothness helper with |features|.
smoothness_helper_ = SmoothnessHelper::Create(
GetLearningTaskController(
learning::MediaLearningTasks::Id::kConsecutiveBadWindows),
GetLearningTaskController(
learning::MediaLearningTasks::Id::kConsecutiveNNRs),
GetLearningTaskController(learning::tasknames::kConsecutiveBadWindows),
GetLearningTaskController(learning::tasknames::kConsecutiveNNRs),
features, this);
}
std::unique_ptr<learning::LearningTaskController>
WebMediaPlayerImpl::GetLearningTaskController(
learning::MediaLearningTasks::Id task_id) {
WebMediaPlayerImpl::GetLearningTaskController(const char* task_name) {
// Get the LearningTaskController for |task_id|.
learning::LearningTask task = learning::MediaLearningTasks::Get(task_id);
learning::LearningTask task = learning::MediaLearningTasks::Get(task_name);
DCHECK_EQ(task.name, task_name);
mojo::Remote<media::learning::mojom::LearningTaskController> remote_ltc;
media_metrics_provider_->AcquireLearningTaskController(
......
......@@ -644,9 +644,9 @@ class MEDIA_BLINK_EXPORT WebMediaPlayerImpl
// smoothness right now.
void UpdateSmoothnessHelper();
// Get the LearningTaskController for |task_id|.
// Get the LearningTaskController for |task_name|.
std::unique_ptr<learning::LearningTaskController> GetLearningTaskController(
learning::MediaLearningTasks::Id task_id);
const char* task_name);
blink::WebLocalFrame* const frame_;
......
......@@ -10,7 +10,7 @@ namespace learning {
static const LearningTask& GetWillPlayTask() {
static LearningTask task_;
if (!task_.feature_descriptions.size()) {
task_.name = "MediaLearningWillPlay";
task_.name = tasknames::kWillPlay;
// TODO(liberato): fill in the rest here, once we have the features picked.
}
......@@ -34,7 +34,7 @@ static void PushWMPIFeatures(LearningTask& task) {
static const LearningTask& GetConsecutiveBadWindowsTask() {
static LearningTask task_;
if (!task_.feature_descriptions.size()) {
task_.name = "MediaLearningConsecutiveBadWindows";
task_.name = tasknames::kConsecutiveBadWindows;
task_.model = LearningTask::Model::kExtraTrees;
// Target is max number of consecutive bad windows.
......@@ -57,7 +57,7 @@ static const LearningTask& GetConsecutiveBadWindowsTask() {
static const LearningTask& GetConsecutiveNNRsTask() {
static LearningTask task_;
if (!task_.feature_descriptions.size()) {
task_.name = "MediaLearningConsecutiveNNRs";
task_.name = tasknames::kConsecutiveNNRs;
task_.model = LearningTask::Model::kExtraTrees;
// Target is max number of consecutive bad windows.
......@@ -75,23 +75,25 @@ static const LearningTask& GetConsecutiveNNRsTask() {
}
// static
const LearningTask& MediaLearningTasks::Get(Id id) {
switch (id) {
case Id::kWillPlay:
return GetWillPlayTask();
case Id::kConsecutiveBadWindows:
return GetConsecutiveBadWindowsTask();
case Id::kConsecutiveNNRs:
return GetConsecutiveNNRsTask();
}
const LearningTask& MediaLearningTasks::Get(const char* task_name) {
if (strcmp(task_name, tasknames::kWillPlay) == 0)
return GetWillPlayTask();
if (strcmp(task_name, tasknames::kConsecutiveBadWindows) == 0)
return GetConsecutiveBadWindowsTask();
if (strcmp(task_name, tasknames::kConsecutiveNNRs) == 0)
return GetConsecutiveNNRsTask();
NOTREACHED() << " Unknown learning task:" << task_name;
static LearningTask empty_task;
return empty_task;
}
// static
void MediaLearningTasks::Register(
base::RepeatingCallback<void(const LearningTask&)> cb) {
cb.Run(Get(Id::kWillPlay));
cb.Run(Get(Id::kConsecutiveBadWindows));
cb.Run(Get(Id::kConsecutiveNNRs));
cb.Run(Get(tasknames::kWillPlay));
cb.Run(Get(tasknames::kConsecutiveBadWindows));
cb.Run(Get(tasknames::kConsecutiveNNRs));
}
} // namespace learning
......
......@@ -12,21 +12,20 @@
namespace media {
namespace learning {
namespace tasknames {
constexpr char kWillPlay[] = "MediaLearningWillPlay";
constexpr char kConsecutiveBadWindows[] = "MediaLearningConsecutiveBadWindows";
constexpr char kConsecutiveNNRs[] = "MediaLearningConsecutiveNNRs";
} // namespace tasknames
// All learning experiments for media/ .
// TODO(liberato): This should be in media/ somewhere, since the learning
// framework doesn't care about it. For now, this is simpler to make deps
// easier to handle.
class COMPONENT_EXPORT(LEARNING_COMMON) MediaLearningTasks {
public:
// Ids for each LearningTask.
enum class Id {
kWillPlay,
kConsecutiveBadWindows,
kConsecutiveNNRs,
};
// Return the LearningTask for |id|.
static const learning::LearningTask& Get(Id id);
// Return the LearningTask for |name|.
static const learning::LearningTask& Get(const char* task_name);
// Register all tasks by calling |registration_cb| repeatedly.
static void Register(
......
......@@ -18,22 +18,20 @@ namespace learning {
class MediaLearningTasksTest : public testing::Test {};
TEST_F(MediaLearningTasksTest, WillPlayTask) {
LearningTask task =
MediaLearningTasks::Get(MediaLearningTasks::Id::kWillPlay);
LearningTask task = MediaLearningTasks::Get(tasknames::kWillPlay);
// Make sure the name is correct, mostly to reduce cut-and-paste errors.
EXPECT_EQ(task.name, "MediaLearningWillPlay");
}
TEST_F(MediaLearningTasksTest, ConsecutiveBadWindowsTask) {
LearningTask task =
MediaLearningTasks::Get(MediaLearningTasks::Id::kConsecutiveBadWindows);
MediaLearningTasks::Get(tasknames::kConsecutiveBadWindows);
// Make sure the name is correct, mostly to reduce cut-and-paste errors.
EXPECT_EQ(task.name, "MediaLearningConsecutiveBadWindows");
}
TEST_F(MediaLearningTasksTest, ConsecutiveNNRsTask) {
LearningTask task =
MediaLearningTasks::Get(MediaLearningTasks::Id::kConsecutiveNNRs);
LearningTask task = MediaLearningTasks::Get(tasknames::kConsecutiveNNRs);
// Make sure the name is correct, mostly to reduce cut-and-paste errors.
EXPECT_EQ(task.name, "MediaLearningConsecutiveNNRs");
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment