Commit cd25ee21 authored by Luis Sanchez Padilla's avatar Luis Sanchez Padilla Committed by Commit Bot

Add heuristic approach option to KalmanPredictor.

Heuristic approaches, which change the weight given to velocity and
acceleration depending on the magnitude of direction change, are
mentioned in https://dl.acm.org/citation.cfm?id=2984590.

Although the Kalman Predictor produces less Jitter than non filtered
methods, it is still fairly noticeable. This is addressed by adding a
heuristic step in the prediction generation.

NIT: Changed InputPredictor::kMinimumTimeInterval to kMinTimeInterval
to resemble variables using Max instead of Maximum.

Change-Id: I5a3e3366b474bb453930bdddb70c85d9ea0e8a94
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1822657Reviewed-by: default avatarAvi Drissman <avi@chromium.org>
Reviewed-by: default avatarElla Ge <eirage@chromium.org>
Reviewed-by: default avatarNavid Zolghadr <nzolghadr@chromium.org>
Commit-Queue: Luis Sanchez Padilla <lusanpad@microsoft.com>
Cr-Commit-Position: refs/heads/master@{#699861}
parent bd599611
...@@ -1136,6 +1136,10 @@ const FeatureEntry::FeatureParam kResamplingInputEventsLSQEnabled[] = { ...@@ -1136,6 +1136,10 @@ const FeatureEntry::FeatureParam kResamplingInputEventsLSQEnabled[] = {
{"predictor", ui::input_prediction::kScrollPredictorNameLsq}}; {"predictor", ui::input_prediction::kScrollPredictorNameLsq}};
const FeatureEntry::FeatureParam kResamplingInputEventsKalmanEnabled[] = { const FeatureEntry::FeatureParam kResamplingInputEventsKalmanEnabled[] = {
{"predictor", ui::input_prediction::kScrollPredictorNameKalman}}; {"predictor", ui::input_prediction::kScrollPredictorNameKalman}};
const FeatureEntry::FeatureParam
kResamplingInputEventsKalmanHeuristicEnabled[] = {
{"predictor",
ui::input_prediction::kScrollPredictorNameKalmanHeuristic}};
const FeatureEntry::FeatureParam kResamplingInputEventsLinearFirstEnabled[] = { const FeatureEntry::FeatureParam kResamplingInputEventsLinearFirstEnabled[] = {
{"predictor", ui::input_prediction::kScrollPredictorNameLinearFirst}}; {"predictor", ui::input_prediction::kScrollPredictorNameLinearFirst}};
const FeatureEntry::FeatureParam kResamplingInputEventsLinearSecondEnabled[] = { const FeatureEntry::FeatureParam kResamplingInputEventsLinearSecondEnabled[] = {
...@@ -1152,6 +1156,9 @@ const FeatureEntry::FeatureVariation kResamplingInputEventsFeatureVariations[] = ...@@ -1152,6 +1156,9 @@ const FeatureEntry::FeatureVariation kResamplingInputEventsFeatureVariations[] =
{ui::input_prediction::kScrollPredictorNameKalman, {ui::input_prediction::kScrollPredictorNameKalman,
kResamplingInputEventsKalmanEnabled, kResamplingInputEventsKalmanEnabled,
base::size(kResamplingInputEventsKalmanEnabled), nullptr}, base::size(kResamplingInputEventsKalmanEnabled), nullptr},
{ui::input_prediction::kScrollPredictorNameKalmanHeuristic,
kResamplingInputEventsKalmanHeuristicEnabled,
base::size(kResamplingInputEventsKalmanHeuristicEnabled), nullptr},
{ui::input_prediction::kScrollPredictorNameLinearFirst, {ui::input_prediction::kScrollPredictorNameLinearFirst,
kResamplingInputEventsLinearFirstEnabled, kResamplingInputEventsLinearFirstEnabled,
base::size(kResamplingInputEventsLinearFirstEnabled), nullptr}, base::size(kResamplingInputEventsLinearFirstEnabled), nullptr},
......
...@@ -60,13 +60,13 @@ class InputPredictor { ...@@ -60,13 +60,13 @@ class InputPredictor {
static constexpr base::TimeDelta kTimeInterval = static constexpr base::TimeDelta kTimeInterval =
base::TimeDelta::FromMilliseconds(8); base::TimeDelta::FromMilliseconds(8);
// Minimum time interval between events. // Minimum time interval between events.
static constexpr base::TimeDelta kMinimumTimeInterval = static constexpr base::TimeDelta kMinTimeInterval =
base::TimeDelta::FromMillisecondsD(2.5); base::TimeDelta::FromMillisecondsD(2.5);
// Maximum amount of prediction when resampling // Maximum amount of prediction when resampling.
static constexpr base::TimeDelta kMaxResampleTime = static constexpr base::TimeDelta kMaxResampleTime =
base::TimeDelta::FromMilliseconds(20); base::TimeDelta::FromMilliseconds(20);
// Maximum time delta for prediction // Maximum time delta for prediction.
static constexpr base::TimeDelta kMaxPredictionTime = static constexpr base::TimeDelta kMaxPredictionTime =
base::TimeDelta::FromMilliseconds(25); base::TimeDelta::FromMilliseconds(25);
}; };
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
#define _USE_MATH_DEFINES // For VC++ to get M_PI. This has to be first.
#include "ui/events/blink/prediction/kalman_predictor.h" #include "ui/events/blink/prediction/kalman_predictor.h"
#include "ui/events/blink/prediction/predictor_factory.h" #include "ui/events/blink/prediction/predictor_factory.h"
...@@ -20,9 +22,11 @@ constexpr base::TimeDelta InputPredictor::kMaxTimeDelta; ...@@ -20,9 +22,11 @@ constexpr base::TimeDelta InputPredictor::kMaxTimeDelta;
constexpr base::TimeDelta InputPredictor::kMaxResampleTime; constexpr base::TimeDelta InputPredictor::kMaxResampleTime;
constexpr base::TimeDelta InputPredictor::kMaxPredictionTime; constexpr base::TimeDelta InputPredictor::kMaxPredictionTime;
constexpr base::TimeDelta InputPredictor::kTimeInterval; constexpr base::TimeDelta InputPredictor::kTimeInterval;
constexpr base::TimeDelta InputPredictor::kMinimumTimeInterval; constexpr base::TimeDelta InputPredictor::kMinTimeInterval;
constexpr base::TimeDelta KalmanPredictor::kMaxTimeInQueue;
KalmanPredictor::KalmanPredictor() = default; KalmanPredictor::KalmanPredictor(HeuristicsMode heuristics_mode)
: heuristics_mode_(heuristics_mode) {}
KalmanPredictor::~KalmanPredictor() = default; KalmanPredictor::~KalmanPredictor() = default;
...@@ -33,15 +37,15 @@ const char* KalmanPredictor::GetName() const { ...@@ -33,15 +37,15 @@ const char* KalmanPredictor::GetName() const {
void KalmanPredictor::Reset() { void KalmanPredictor::Reset() {
x_predictor_.Reset(); x_predictor_.Reset();
y_predictor_.Reset(); y_predictor_.Reset();
last_point_.time_stamp = base::TimeTicks(); last_points_.clear();
time_filter_.Reset(); time_filter_.Reset();
} }
void KalmanPredictor::Update(const InputData& cur_input) { void KalmanPredictor::Update(const InputData& cur_input) {
base::TimeDelta dt; base::TimeDelta dt;
if (!last_point_.time_stamp.is_null()) { if (last_points_.size()) {
// When last point is kMaxTimeDelta away, consider it is incontinuous. // When last point is kMaxTimeDelta away, consider it is incontinuous.
dt = cur_input.time_stamp - last_point_.time_stamp; dt = cur_input.time_stamp - last_points_.back().time_stamp;
if (dt > kMaxTimeDelta) if (dt > kMaxTimeDelta)
Reset(); Reset();
else else
...@@ -49,9 +53,14 @@ void KalmanPredictor::Update(const InputData& cur_input) { ...@@ -49,9 +53,14 @@ void KalmanPredictor::Update(const InputData& cur_input) {
} }
double dt_ms = time_filter_.GetPosition(); double dt_ms = time_filter_.GetPosition();
last_point_ = cur_input; last_points_.push_back(cur_input);
x_predictor_.Update(cur_input.pos.x(), dt_ms); x_predictor_.Update(cur_input.pos.x(), dt_ms);
y_predictor_.Update(cur_input.pos.y(), dt_ms); y_predictor_.Update(cur_input.pos.y(), dt_ms);
while (last_points_.back().time_stamp - last_points_.front().time_stamp >
kMaxTimeInQueue) {
last_points_.pop_front();
}
} }
bool KalmanPredictor::HasPrediction() const { bool KalmanPredictor::HasPrediction() const {
...@@ -63,17 +72,37 @@ bool KalmanPredictor::GeneratePrediction(base::TimeTicks predict_time, ...@@ -63,17 +72,37 @@ bool KalmanPredictor::GeneratePrediction(base::TimeTicks predict_time,
if (!HasPrediction()) if (!HasPrediction())
return false; return false;
float pred_dt = (predict_time - last_point_.time_stamp).InMillisecondsF(); DCHECK(last_points_.size());
float pred_dt =
(predict_time - last_points_.back().time_stamp).InMillisecondsF();
std::vector<InputData> pred_points; std::vector<InputData> pred_points;
gfx::Vector2dF position(last_point_.pos.x(), last_point_.pos.y()); gfx::Vector2dF position(last_points_.back().pos.x(),
// gfx::Vector2dF position = PredictPosition(); last_points_.back().pos.y());
gfx::Vector2dF velocity = PredictVelocity(); gfx::Vector2dF velocity = PredictVelocity();
gfx::Vector2dF acceleration = PredictAcceleration(); gfx::Vector2dF acceleration = PredictAcceleration();
position += position += ScaleVector2d(velocity, kVelocityInfluence * pred_dt);
ScaleVector2d(velocity, kVelocityInfluence * pred_dt) +
ScaleVector2d(acceleration, kAccelerationInfluence * pred_dt * pred_dt); if (heuristics_mode_ == HeuristicsMode::kHeuristicsEnabled) {
float points_angle = 0.0f;
for (size_t i = 2; i < last_points_.size(); i++) {
gfx::Vector2dF first_dir =
last_points_[i - 1].pos - last_points_[i - 2].pos;
gfx::Vector2dF second_dir = last_points_[i].pos - last_points_[i - 1].pos;
if (first_dir.Length() && second_dir.Length()) {
points_angle += atan2(first_dir.x(), first_dir.y()) -
atan2(second_dir.x(), second_dir.y());
}
}
if (abs(points_angle) * 180 / M_PI > 15) {
position += ScaleVector2d(acceleration,
kAccelerationInfluence * pred_dt * pred_dt);
}
} else {
position +=
ScaleVector2d(acceleration, kAccelerationInfluence * pred_dt * pred_dt);
}
result->pos.set_x(position.x()); result->pos.set_x(position.x());
result->pos.set_y(position.y()); result->pos.set_y(position.y());
...@@ -82,8 +111,8 @@ bool KalmanPredictor::GeneratePrediction(base::TimeTicks predict_time, ...@@ -82,8 +111,8 @@ bool KalmanPredictor::GeneratePrediction(base::TimeTicks predict_time,
base::TimeDelta KalmanPredictor::TimeInterval() const { base::TimeDelta KalmanPredictor::TimeInterval() const {
return time_filter_.GetPosition() return time_filter_.GetPosition()
? std::max(kMinimumTimeInterval, base::TimeDelta::FromMilliseconds( ? std::max(kMinTimeInterval, base::TimeDelta::FromMilliseconds(
time_filter_.GetPosition())) time_filter_.GetPosition()))
: kTimeInterval; : kTimeInterval;
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifndef UI_EVENTS_BLINK_PREDICTION_KALMAN_PREDICTOR_H_ #ifndef UI_EVENTS_BLINK_PREDICTION_KALMAN_PREDICTOR_H_
#define UI_EVENTS_BLINK_PREDICTION_KALMAN_PREDICTOR_H_ #define UI_EVENTS_BLINK_PREDICTION_KALMAN_PREDICTOR_H_
#include <deque>
#include <vector> #include <vector>
#include "ui/events/blink/prediction/input_predictor.h" #include "ui/events/blink/prediction/input_predictor.h"
...@@ -19,7 +20,9 @@ namespace ui { ...@@ -19,7 +20,9 @@ namespace ui {
// be used to predict one dimension (x, y). // be used to predict one dimension (x, y).
class KalmanPredictor : public InputPredictor { class KalmanPredictor : public InputPredictor {
public: public:
explicit KalmanPredictor(); enum class HeuristicsMode { kHeuristicsDisabled, kHeuristicsEnabled };
explicit KalmanPredictor(HeuristicsMode heuristics_mode);
~KalmanPredictor() override; ~KalmanPredictor() override;
const char* GetName() const override; const char* GetName() const override;
...@@ -54,8 +57,16 @@ class KalmanPredictor : public InputPredictor { ...@@ -54,8 +57,16 @@ class KalmanPredictor : public InputPredictor {
// Filter to smooth time intervals. // Filter to smooth time intervals.
KalmanFilter time_filter_; KalmanFilter time_filter_;
// The last input point. // Most recent input data.
InputData last_point_; std::deque<InputData> last_points_;
// Maximum time interval between first and last events in last points queue.
static constexpr base::TimeDelta kMaxTimeInQueue =
base::TimeDelta::FromMilliseconds(40);
// Flag to determine heuristic behavior based on the accumulated angle between
// the last set of points.
const HeuristicsMode heuristics_mode_;
DISALLOW_COPY_AND_ASSIGN(KalmanPredictor); DISALLOW_COPY_AND_ASSIGN(KalmanPredictor);
}; };
......
...@@ -44,7 +44,8 @@ class KalmanPredictorTest : public InputPredictorTest { ...@@ -44,7 +44,8 @@ class KalmanPredictorTest : public InputPredictorTest {
explicit KalmanPredictorTest() {} explicit KalmanPredictorTest() {}
void SetUp() override { void SetUp() override {
predictor_ = std::make_unique<ui::KalmanPredictor>(); predictor_ = std::make_unique<ui::KalmanPredictor>(
ui::KalmanPredictor::HeuristicsMode::kHeuristicsDisabled);
} }
DISALLOW_COPY_AND_ASSIGN(KalmanPredictorTest); DISALLOW_COPY_AND_ASSIGN(KalmanPredictorTest);
...@@ -132,7 +133,6 @@ TEST_F(KalmanPredictorTest, PredictQuadraticValue) { ...@@ -132,7 +133,6 @@ TEST_F(KalmanPredictorTest, PredictQuadraticValue) {
// Tests the kalman predictor time interval filter. // Tests the kalman predictor time interval filter.
TEST_F(KalmanPredictorTest, TimeInterval) { TEST_F(KalmanPredictorTest, TimeInterval) {
predictor_ = std::make_unique<ui::KalmanPredictor>();
EXPECT_EQ(predictor_->TimeInterval(), kExpectedDefaultTimeInterval); EXPECT_EQ(predictor_->TimeInterval(), kExpectedDefaultTimeInterval);
std::vector<double> x = {0, 2, 8, 18}; std::vector<double> x = {0, 2, 8, 18};
std::vector<double> y = {10, 11, 14, 19}; std::vector<double> y = {10, 11, 14, 19};
...@@ -146,5 +146,41 @@ TEST_F(KalmanPredictorTest, TimeInterval) { ...@@ -146,5 +146,41 @@ TEST_F(KalmanPredictorTest, TimeInterval) {
base::TimeDelta::FromMilliseconds(7).InMillisecondsF()); base::TimeDelta::FromMilliseconds(7).InMillisecondsF());
} }
// Test the benefit from the heuristic approach on noisy data.
TEST_F(KalmanPredictorTest, HeuristicApproach) {
std::unique_ptr<InputPredictor> heuristic_predictor =
std::make_unique<ui::KalmanPredictor>(
ui::KalmanPredictor::HeuristicsMode::kHeuristicsEnabled);
std::vector<double> x_stabilizer = {-40, -32, -24, -16, -8, 0};
std::vector<double> y_stabilizer = {-40, -32, -24, -16, -8, 0};
std::vector<double> t_stabilizer = {-40, -32, -24, -16, -8, 0};
for (size_t i = 0; i < t_stabilizer.size(); i++) {
InputPredictor::InputData data = {
gfx::PointF(x_stabilizer[i], y_stabilizer[i]),
FromMilliseconds(t_stabilizer[i])};
predictor_->Update(data);
heuristic_predictor->Update(data);
}
std::vector<double> x = {7, 17, 23, 33, 39, 49, 60};
std::vector<double> y = {9, 15, 25, 31, 41, 47, 60};
std::vector<double> t = {8, 16, 24, 32, 40, 48, 60};
for (size_t i = 0; i < t.size(); i++) {
gfx::PointF point(x[i], y[i]);
if (heuristic_predictor->HasPrediction() && predictor_->HasPrediction()) {
ui::InputPredictor::InputData result, heuristic_result;
EXPECT_TRUE(heuristic_predictor->GeneratePrediction(
FromMilliseconds(t[i]), &heuristic_result));
EXPECT_TRUE(
predictor_->GeneratePrediction(FromMilliseconds(t[i]), &result));
EXPECT_LE((heuristic_result.pos - point).Length(),
(result.pos - point).Length());
}
InputPredictor::InputData data = {point, FromMilliseconds(t[i])};
heuristic_predictor->Update(data);
predictor_->Update(data);
}
}
} // namespace test } // namespace test
} // namespace ui } // namespace ui
...@@ -100,7 +100,7 @@ bool LeastSquaresPredictor::GeneratePrediction(base::TimeTicks predict_time, ...@@ -100,7 +100,7 @@ bool LeastSquaresPredictor::GeneratePrediction(base::TimeTicks predict_time,
base::TimeDelta LeastSquaresPredictor::TimeInterval() const { base::TimeDelta LeastSquaresPredictor::TimeInterval() const {
if (time_.size() > 1) { if (time_.size() > 1) {
return std::max(kMinimumTimeInterval, return std::max(kMinTimeInterval,
(time_.back() - time_.front()) / (time_.size() - 1)); (time_.back() - time_.front()) / (time_.size() - 1));
} }
return kTimeInterval; return kTimeInterval;
......
...@@ -116,9 +116,9 @@ void LinearPredictor::GeneratePredictionSecondOrder(float pred_dt, ...@@ -116,9 +116,9 @@ void LinearPredictor::GeneratePredictionSecondOrder(float pred_dt,
base::TimeDelta LinearPredictor::TimeInterval() const { base::TimeDelta LinearPredictor::TimeInterval() const {
if (events_queue_.size() > 1) { if (events_queue_.size() > 1) {
return std::max(kMinimumTimeInterval, (events_queue_.back().time_stamp - return std::max(kMinTimeInterval, (events_queue_.back().time_stamp -
events_queue_.front().time_stamp) / events_queue_.front().time_stamp) /
(events_queue_.size() - 1)); (events_queue_.size() - 1));
} }
return kTimeInterval; return kTimeInterval;
} }
......
...@@ -16,6 +16,7 @@ namespace input_prediction { ...@@ -16,6 +16,7 @@ namespace input_prediction {
const char kScrollPredictorNameLsq[] = "lsq"; const char kScrollPredictorNameLsq[] = "lsq";
const char kScrollPredictorNameKalman[] = "kalman"; const char kScrollPredictorNameKalman[] = "kalman";
const char kScrollPredictorNameKalmanHeuristic[] = "kalman_heuristic";
const char kScrollPredictorNameLinearFirst[] = "linear_first"; const char kScrollPredictorNameLinearFirst[] = "linear_first";
const char kScrollPredictorNameLinearSecond[] = "linear_second"; const char kScrollPredictorNameLinearSecond[] = "linear_second";
const char kScrollPredictorNameLinearResampling[] = "linear_resampling"; const char kScrollPredictorNameLinearResampling[] = "linear_resampling";
...@@ -35,6 +36,9 @@ PredictorType PredictorFactory::GetPredictorTypeFromName( ...@@ -35,6 +36,9 @@ PredictorType PredictorFactory::GetPredictorTypeFromName(
return PredictorType::kScrollPredictorTypeLsq; return PredictorType::kScrollPredictorTypeLsq;
else if (predictor_name == input_prediction::kScrollPredictorNameKalman) else if (predictor_name == input_prediction::kScrollPredictorNameKalman)
return PredictorType::kScrollPredictorTypeKalman; return PredictorType::kScrollPredictorTypeKalman;
else if (predictor_name ==
input_prediction::kScrollPredictorNameKalmanHeuristic)
return PredictorType::kScrollPredictorTypeKalmanHeuristic;
else if (predictor_name == input_prediction::kScrollPredictorNameLinearFirst) else if (predictor_name == input_prediction::kScrollPredictorNameLinearFirst)
return PredictorType::kScrollPredictorTypeLinearFirst; return PredictorType::kScrollPredictorTypeLinearFirst;
else if (predictor_name == input_prediction::kScrollPredictorNameLinearSecond) else if (predictor_name == input_prediction::kScrollPredictorNameLinearSecond)
...@@ -50,7 +54,11 @@ std::unique_ptr<InputPredictor> PredictorFactory::GetPredictor( ...@@ -50,7 +54,11 @@ std::unique_ptr<InputPredictor> PredictorFactory::GetPredictor(
else if (predictor_type == PredictorType::kScrollPredictorTypeLsq) else if (predictor_type == PredictorType::kScrollPredictorTypeLsq)
return std::make_unique<LeastSquaresPredictor>(); return std::make_unique<LeastSquaresPredictor>();
else if (predictor_type == PredictorType::kScrollPredictorTypeKalman) else if (predictor_type == PredictorType::kScrollPredictorTypeKalman)
return std::make_unique<KalmanPredictor>(); return std::make_unique<KalmanPredictor>(
KalmanPredictor::HeuristicsMode::kHeuristicsDisabled);
else if (predictor_type == PredictorType::kScrollPredictorTypeKalmanHeuristic)
return std::make_unique<KalmanPredictor>(
KalmanPredictor::HeuristicsMode::kHeuristicsEnabled);
else if (predictor_type == PredictorType::kScrollPredictorTypeLinearFirst) else if (predictor_type == PredictorType::kScrollPredictorTypeLinearFirst)
return std::make_unique<LinearPredictor>( return std::make_unique<LinearPredictor>(
LinearPredictor::EquationOrder::kFirstOrder); LinearPredictor::EquationOrder::kFirstOrder);
......
...@@ -13,6 +13,7 @@ namespace input_prediction { ...@@ -13,6 +13,7 @@ namespace input_prediction {
extern const char kScrollPredictorNameLsq[]; extern const char kScrollPredictorNameLsq[];
extern const char kScrollPredictorNameKalman[]; extern const char kScrollPredictorNameKalman[];
extern const char kScrollPredictorNameKalmanHeuristic[];
extern const char kScrollPredictorNameLinearFirst[]; extern const char kScrollPredictorNameLinearFirst[];
extern const char kScrollPredictorNameLinearSecond[]; extern const char kScrollPredictorNameLinearSecond[];
extern const char kScrollPredictorNameLinearResampling[]; extern const char kScrollPredictorNameLinearResampling[];
...@@ -21,6 +22,7 @@ extern const char kScrollPredictorNameEmpty[]; ...@@ -21,6 +22,7 @@ extern const char kScrollPredictorNameEmpty[];
enum class PredictorType { enum class PredictorType {
kScrollPredictorTypeLsq, kScrollPredictorTypeLsq,
kScrollPredictorTypeKalman, kScrollPredictorTypeKalman,
kScrollPredictorTypeKalmanHeuristic,
kScrollPredictorTypeLinearFirst, kScrollPredictorTypeLinearFirst,
kScrollPredictorTypeLinearSecond, kScrollPredictorTypeLinearSecond,
kScrollPredictorTypeLinearResampling, kScrollPredictorTypeLinearResampling,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment