Commit a2089b42 authored by Ella Ge's avatar Ella Ge Committed by Commit Bot

Add lsq predictor to input prediction

This CL adds LSQ predictor to input prediction.
LSQ predictor uses a quadratic least square regression model:
y = b0 + b1 * x + b2 * x ^ 2.

The resampling is still behind flag: kResampleInputEvents.

design doc:
https://docs.google.com/document/d/1DLfibi2NkV85p7AfEtNkvy24d283VRBSd3hz1Qh24Jw/edit#

Bug: 836352
Change-Id: Ie1959c6b95152a80076debfb762c49d88a758af3
Reviewed-on: https://chromium-review.googlesource.com/1003220
Commit-Queue: Ella Ge <eirage@chromium.org>
Reviewed-by: default avatarDave Tapuska <dtapuska@chromium.org>
Reviewed-by: default avatardanakj <danakj@chromium.org>
Reviewed-by: default avatarNavid Zolghadr <nzolghadr@chromium.org>
Cr-Commit-Position: refs/heads/master@{#565035}
parent 3e82cf3a
......@@ -5,8 +5,11 @@
#include "content/renderer/input/input_event_prediction.h"
#include "base/feature_list.h"
#include "base/metrics/field_trial.h"
#include "base/metrics/field_trial_params.h"
#include "content/public/common/content_features.h"
#include "ui/events/blink/prediction/empty_predictor.h"
#include "ui/events/blink/prediction/least_squares_predictor.h"
using blink::WebInputEvent;
using blink::WebMouseEvent;
......@@ -18,14 +21,19 @@ namespace content {
namespace {
std::unique_ptr<ui::InputPredictor> SetUpPredictor() {
return std::make_unique<ui::EmptyPredictor>();
constexpr char kPredictor[] = "predictor";
constexpr char kInputEventPredictorTypeLsq[] = "lsq";
}
} // namespace
InputEventPrediction::InputEventPrediction() {
mouse_predictor_ = SetUpPredictor();
std::string predictor_type_ = GetFieldTrialParamValueByFeature(
features::kResamplingInputEvents, kPredictor);
if (predictor_type_ == kInputEventPredictorTypeLsq)
selected_predictor_type_ = PredictorType::kLsq;
else
selected_predictor_type_ = PredictorType::kEmpty;
mouse_predictor_ = CreatePredictor();
}
InputEventPrediction::~InputEventPrediction() {}
......@@ -54,6 +62,16 @@ void InputEventPrediction::HandleEvents(
}
}
std::unique_ptr<ui::InputPredictor> InputEventPrediction::CreatePredictor()
const {
switch (selected_predictor_type_) {
case PredictorType::kEmpty:
return std::make_unique<ui::EmptyPredictor>();
case PredictorType::kLsq:
return std::make_unique<ui::LeastSquaresPredictor>();
}
}
void InputEventPrediction::UpdatePrediction(const WebInputEvent& event) {
if (WebInputEvent::IsTouchEventType(event.GetType())) {
DCHECK(event.GetType() == WebInputEvent::kTouchMove);
......@@ -121,7 +139,7 @@ void InputEventPrediction::UpdateSinglePointer(
} else {
// Workaround for GLIBC C++ < 7.3 that fails to insert with braces
// See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=82522
auto pair = std::make_pair(event.id, SetUpPredictor());
auto pair = std::make_pair(event.id, CreatePredictor());
pointer_id_predictor_map_.insert(std::move(pair));
pointer_id_predictor_map_[event.id]->Update(data);
}
......
......@@ -28,7 +28,14 @@ class CONTENT_EXPORT InputEventPrediction {
base::TimeTicks frame_time,
blink::WebInputEvent* event);
// Initialize predictor for different pointer.
std::unique_ptr<ui::InputPredictor> CreatePredictor() const;
private:
friend class InputEventPredictionTest;
enum class PredictorType { kEmpty, kLsq };
// The following three function is for handling multiple TouchPoints in a
// WebTouchEvent. They should be more neat when WebTouchEvent is elimated.
// Cast events from WebInputEvent to WebPointerProperties. Call
......@@ -52,12 +59,14 @@ class CONTENT_EXPORT InputEventPrediction {
// predictor, for other pointer type, remove it from mapping.
void ResetSinglePredictor(const WebPointerProperties& event);
friend class InputEventPredictionTest;
std::unordered_map<ui::PointerId, std::unique_ptr<ui::InputPredictor>>
pointer_id_predictor_map_;
std::unique_ptr<ui::InputPredictor> mouse_predictor_;
// Store the field trial parameter used for choosing different types of
// predictor.
PredictorType selected_predictor_type_;
DISALLOW_COPY_AND_ASSIGN(InputEventPrediction);
};
......
......@@ -210,4 +210,4 @@ TEST_F(InputEventPredictionTest, TouchScrollStartedRemoveAllTouchPoints) {
EXPECT_EQ(GetPredictorMapSize(), 0);
}
} // namespace content
\ No newline at end of file
} // namespace content
......@@ -426,6 +426,7 @@ if (!is_ios) {
"blink/fling_booster_unittest.cc",
"blink/input_handler_proxy_unittest.cc",
"blink/input_scroll_elasticity_controller_unittest.cc",
"blink/prediction/least_squares_predictor_unittest.cc",
"blink/web_input_event_traits_unittest.cc",
"blink/web_input_event_unittest.cc",
"cocoa/events_mac_unittest.mm",
......
......@@ -27,6 +27,8 @@ jumbo_source_set("blink") {
"prediction/empty_predictor.cc",
"prediction/empty_predictor.h",
"prediction/input_predictor.h",
"prediction/least_squares_predictor.cc",
"prediction/least_squares_predictor.h",
"synchronous_input_handler_proxy.h",
"web_input_event.cc",
"web_input_event.h",
......
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ui/events/blink/prediction/least_squares_predictor.h"
#include <cmath>
namespace ui {
namespace {
constexpr double kEpsilon = std::numeric_limits<double>::epsilon();
// Solve XB = y.
static bool SolveLeastSquares(const gfx::Matrix3F& x,
const std::deque<double>& y,
gfx::Vector3dF& result) {
// return last point if y didn't change.
if (std::abs(y[0] - y[1]) < kEpsilon && std::abs(y[1] - y[2]) < kEpsilon) {
result = gfx::Vector3dF(y[2], 0, 0);
return true;
}
gfx::Matrix3F x_transpose = x.Transpose();
gfx::Matrix3F temp = gfx::MatrixProduct(x_transpose, x).Inverse();
// Return false if x is singular.
if (temp == gfx::Matrix3F::Zeros())
return false;
result = gfx::MatrixProduct(gfx::MatrixProduct(temp, x_transpose),
gfx::Vector3dF(y[0], y[1], y[2]));
return true;
}
} // namespace
LeastSquaresPredictor::LeastSquaresPredictor() {}
LeastSquaresPredictor::~LeastSquaresPredictor() {}
void LeastSquaresPredictor::Reset() {
x_queue_.clear();
y_queue_.clear();
time_.clear();
}
void LeastSquaresPredictor::Update(const InputData& cur_input) {
// Reset curve if last point is 50 milliseconds away.
constexpr double max_interval_millisecond = 50.0;
if (!time_.empty() &&
(cur_input.time_stamp - time_.back()).InMillisecondsF() >
max_interval_millisecond)
Reset();
x_queue_.push_back(cur_input.pos_x);
y_queue_.push_back(cur_input.pos_y);
time_.push_back(cur_input.time_stamp);
if (time_.size() > kSize) {
x_queue_.pop_front();
y_queue_.pop_front();
time_.pop_front();
}
}
bool LeastSquaresPredictor::HasPrediction() const {
return time_.size() >= kSize;
}
gfx::Matrix3F LeastSquaresPredictor::GetXMatrix() const {
gfx::Matrix3F x = gfx::Matrix3F::Zeros();
double t1 = (time_[1] - time_[0]).InMillisecondsF();
double t2 = (time_[2] - time_[0]).InMillisecondsF();
x.set(1, 0, 0, 1, t1, t1 * t1, 1, t2, t2 * t2);
return x;
}
bool LeastSquaresPredictor::GeneratePrediction(base::TimeTicks frame_time,
InputData* result) const {
if (!HasPrediction())
return false;
gfx::Matrix3F time_matrix = GetXMatrix();
double dt = (frame_time - time_[0]).InMillisecondsF();
if (dt > 0) {
gfx::Vector3dF b1, b2;
if (SolveLeastSquares(time_matrix, x_queue_, b1) &&
SolveLeastSquares(time_matrix, y_queue_, b2)) {
gfx::Vector3dF prediction_time(1, dt, dt * dt);
result->pos_x = gfx::DotProduct(prediction_time, b1);
result->pos_y = gfx::DotProduct(prediction_time, b2);
return true;
}
}
return false;
}
} // namespace ui
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef UI_EVENTS_BLINK_PREDICTION_LEAST_SQUARES_PREDICTOR_H_
#define UI_EVENTS_BLINK_PREDICTION_LEAST_SQUARES_PREDICTOR_H_
#include <deque>
#include "ui/events/blink/prediction/input_predictor.h"
#include "ui/gfx/geometry/matrix3_f.h"
namespace ui {
// This class use a quadratic least square regression model:
// y = b0 + b1 * x + b2 * x ^ 2.
// See https://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)
class LeastSquaresPredictor : public InputPredictor {
public:
static constexpr size_t kSize = 3;
explicit LeastSquaresPredictor();
~LeastSquaresPredictor() override;
// Reset the predictor to initial state.
void Reset() override;
// Store current input in queue.
void Update(const InputData& cur_input) override;
// Return if there is enough data in the queue to generate prediction.
bool HasPrediction() const override;
// Generate the prediction based on stored points and given time_stamp.
// Return an empty vector if no prediction available.
bool GeneratePrediction(base::TimeTicks frame_time,
InputData* result) const override;
private:
// Generate X matrix from time_ queue.
gfx::Matrix3F GetXMatrix() const;
std::deque<double> x_queue_;
std::deque<double> y_queue_;
std::deque<base::TimeTicks> time_;
DISALLOW_COPY_AND_ASSIGN(LeastSquaresPredictor);
};
} // namespace ui
#endif // UI_EVENTS_BLINK_PREDICTION_LEAST_SQUARES_PREDICTOR_H_
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ui/events/blink/prediction/least_squares_predictor.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace ui {
// The epsilon of predicted result.
const double kEpsilon = 0.01;
base::TimeTicks FromMilliseconds(int64_t ms) {
return base::TimeTicks() + base::TimeDelta::FromMilliseconds(ms);
}
TEST(LSQPredictorTest, ShouldHasPrediction) {
LeastSquaresPredictor predictor;
for (size_t i = 0; i < LeastSquaresPredictor::kSize; i++) {
EXPECT_FALSE(predictor.HasPrediction());
InputPredictor::InputData data = {1 /* x */, 1 /* y */,
FromMilliseconds(8 * i)};
predictor.Update(data);
}
EXPECT_TRUE(predictor.HasPrediction());
}
// Test the lest squares filter behavior.
// The data set is generated by a "known to work" quadratic fit.
TEST(LSQPredictorTest, PredictedValue) {
LeastSquaresPredictor predictor;
std::vector<double> x = {22, 58, 102};
std::vector<double> y = {100, 100, 100};
std::vector<base::TimeTicks> t = {FromMilliseconds(13), FromMilliseconds(21),
FromMilliseconds(37)};
for (int i = 0; i < 3; i++) {
InputPredictor::InputData data = {x[i], y[i], t[i]};
predictor.Update(data);
}
ui::InputPredictor::InputData result;
EXPECT_TRUE(predictor.GeneratePrediction(FromMilliseconds(42), &result));
EXPECT_NEAR(result.pos_x, 108.094, kEpsilon);
EXPECT_NEAR(result.pos_y, 100, kEpsilon);
x = {100, 100, 101};
y = {120, 280, 600};
t = {FromMilliseconds(101), FromMilliseconds(126), FromMilliseconds(148)};
for (int i = 0; i < 3; i++) {
InputPredictor::InputData data = {x[i], y[i], t[i]};
predictor.Update(data);
}
EXPECT_TRUE(predictor.GeneratePrediction(FromMilliseconds(180), &result));
EXPECT_NEAR(result.pos_x, 104.126, kEpsilon);
EXPECT_NEAR(result.pos_y, 1364.93, kEpsilon);
}
// Test that lsq predictor will not crash when given constant time stamp.
TEST(LSQPredictorTest, ConstantTimeStampNotCrash) {
LeastSquaresPredictor predictor;
InputPredictor::InputData data = {100 /* x */, 101 /* y */,
FromMilliseconds(0)};
predictor.Update(data);
data = {101 /* x */, 102 /* y */, FromMilliseconds(0)};
predictor.Update(data);
data = {102 /* x */, 103 /* y */, FromMilliseconds(0)};
predictor.Update(data);
EXPECT_FALSE(predictor.GeneratePrediction(FromMilliseconds(42), &data));
data = {100 /* x */, 100 /* y */, FromMilliseconds(100)};
predictor.Update(data);
data = {100 /* x */, 100 /* y */, FromMilliseconds(100)};
predictor.Update(data);
EXPECT_FALSE(predictor.GeneratePrediction(FromMilliseconds(42), &data));
}
} // namespace ui
......@@ -9,6 +9,7 @@
#include <limits>
#include "base/numerics/math_constants.h"
#include "base/strings/stringprintf.h"
namespace {
......@@ -124,6 +125,13 @@ Matrix3F Matrix3F::Inverse() const {
return inverse;
}
Matrix3F Matrix3F::Transpose() const {
Matrix3F transpose;
transpose.set(data_[M00], data_[M10], data_[M20], data_[M01], data_[M11],
data_[M21], data_[M02], data_[M12], data_[M22]);
return transpose;
}
float Matrix3F::Determinant() const {
return static_cast<float>(Determinant3x3(data_));
}
......@@ -241,4 +249,29 @@ Vector3dF Matrix3F::SolveEigenproblem(Matrix3F* eigenvectors) const {
return Vector3dF(eigenvalues[0], eigenvalues[1], eigenvalues[2]);
}
Matrix3F MatrixProduct(const Matrix3F& lhs, const Matrix3F& rhs) {
Matrix3F result = Matrix3F::Zeros();
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
result.set(i, j, DotProduct(lhs.get_row(i), rhs.get_column(j)));
}
}
return result;
}
Vector3dF MatrixProduct(const Matrix3F& lhs, const Vector3dF& rhs) {
return Vector3dF(DotProduct(lhs.get_row(0), rhs),
DotProduct(lhs.get_row(1), rhs),
DotProduct(lhs.get_row(2), rhs));
}
std::string Matrix3F::ToString() const {
return base::StringPrintf(
"[[%+0.4f, %+0.4f, %+0.4f],"
" [%+0.4f, %+0.4f, %+0.4f],"
" [%+0.4f, %+0.4f, %+0.4f]]",
data_[M00], data_[M01], data_[M02], data_[M10], data_[M11], data_[M12],
data_[M20], data_[M21], data_[M22]);
}
} // namespace gfx
......@@ -46,6 +46,12 @@ class GFX_EXPORT Matrix3F {
data_[8] = m22;
}
Vector3dF get_row(int i) const {
return Vector3dF(data_[MatrixToArrayCoords(i, 0)],
data_[MatrixToArrayCoords(i, 1)],
data_[MatrixToArrayCoords(i, 2)]);
}
Vector3dF get_column(int i) const {
return Vector3dF(
data_[MatrixToArrayCoords(0, i)],
......@@ -63,6 +69,9 @@ class GFX_EXPORT Matrix3F {
// otherwise.
Matrix3F Inverse() const;
// Returns a transpose of this matrix.
Matrix3F Transpose() const;
// Value of the determinant of the matrix.
float Determinant() const;
......@@ -87,6 +96,8 @@ class GFX_EXPORT Matrix3F {
// to eigenvalues.
Vector3dF SolveEigenproblem(Matrix3F* eigenvectors) const;
std::string ToString() const;
private:
Matrix3F(); // Uninitialized default.
......@@ -103,6 +114,9 @@ inline bool operator==(const Matrix3F& lhs, const Matrix3F& rhs) {
return lhs.IsEqual(rhs);
}
GFX_EXPORT Matrix3F MatrixProduct(const Matrix3F& lhs, const Matrix3F& rhs);
GFX_EXPORT Vector3dF MatrixProduct(const Matrix3F& lhs, const Vector3dF& rhs);
} // namespace gfx
#endif // UI_GFX_GEOMETRY_MATRIX3_F_H_
......@@ -34,10 +34,14 @@ TEST(Matrix3fTest, DataAccess) {
Matrix3F identity = Matrix3F::Identity();
EXPECT_EQ(Vector3dF(0.0f, 1.0f, 0.0f), identity.get_column(1));
EXPECT_EQ(Vector3dF(0.0f, 1.0f, 0.0f), identity.get_row(1));
matrix.set(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f);
EXPECT_EQ(Vector3dF(2.0f, 5.0f, 8.0f), matrix.get_column(2));
EXPECT_EQ(Vector3dF(6.0f, 7.0f, 8.0f), matrix.get_row(2));
matrix.set_column(0, Vector3dF(0.1f, 0.2f, 0.3f));
matrix.set_column(0, Vector3dF(0.1f, 0.2f, 0.3f));
EXPECT_EQ(Vector3dF(0.1f, 0.2f, 0.3f), matrix.get_column(0));
EXPECT_EQ(Vector3dF(0.1f, 1.0f, 2.0f), matrix.get_row(0));
EXPECT_EQ(0.1f, matrix.get(0, 0));
EXPECT_EQ(5.0f, matrix.get(1, 2));
......@@ -86,6 +90,19 @@ TEST(Matrix3fTest, Inverse) {
EXPECT_TRUE(regular.IsNear(inv_regular, 0.00001f));
}
TEST(Matrix3fTest, Transpose) {
Matrix3F matrix = Matrix3F::Zeros();
matrix.set(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f);
Matrix3F transpose = matrix.Transpose();
EXPECT_EQ(Vector3dF(0.0f, 1.0f, 2.0f), transpose.get_column(0));
EXPECT_EQ(Vector3dF(3.0f, 4.0f, 5.0f), transpose.get_column(1));
EXPECT_EQ(Vector3dF(6.0f, 7.0f, 8.0f), transpose.get_column(2));
EXPECT_TRUE(matrix.IsEqual(transpose.Transpose()));
}
TEST(Matrix3fTest, EigenvectorsIdentity) {
// This block tests the trivial case of eigenvalues of the identity matrix.
Matrix3F identity = Matrix3F::Identity();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment