Commit f81b392f authored by Rob Schonberger's avatar Rob Schonberger Committed by Commit Bot

Add an inline implementation of NeuralStylusPalmDetectionFilterModel.

Adds an inline implementation of NeuralStylusPalmDetectionFilterModel,
with a few extra changes:

1. Updates the factory to add a Flag and appropriate instantiation of
NeuralStylusPalmDetectionFilter when that flag is turned on.

2. Update the unit test of the factory to appropriately test that the
NeuralStylusPalmDetectionFilter is correctly instantiated.

3. Update NeuralStylusPalmDetectionFilter with a few lint updates
suggested by git cl lint

4. Move 2 items from private to public in NeuralStylusPalmDetectionFilter .

Bug: 1009290
Change-Id: I35c5a5e2c61c436dd2eaaeba439245cc7cc9433a
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/1890376
Commit-Queue: Rob Schonberger <robsc@chromium.org>
Reviewed-by: default avatarMichael Spang <spang@chromium.org>
Cr-Commit-Position: refs/heads/master@{#712835}
parent 90498869
...@@ -134,6 +134,10 @@ if (use_ozone) { ...@@ -134,6 +134,10 @@ if (use_ozone) {
"evdev/touch_filter/palm_detection_filter.h", "evdev/touch_filter/palm_detection_filter.h",
"evdev/touch_filter/palm_detection_filter_factory.cc", "evdev/touch_filter/palm_detection_filter_factory.cc",
"evdev/touch_filter/palm_detection_filter_factory.h", "evdev/touch_filter/palm_detection_filter_factory.h",
"evdev/touch_filter/palm_model/onedevice_train_palm_detection_filter_inference.cc",
"evdev/touch_filter/palm_model/onedevice_train_palm_detection_filter_inference.h",
"evdev/touch_filter/palm_model/onedevice_train_palm_detection_filter_model.cc",
"evdev/touch_filter/palm_model/onedevice_train_palm_detection_filter_model.h",
"evdev/touch_filter/shared_palm_detection_filter_state.h", "evdev/touch_filter/shared_palm_detection_filter_state.h",
"evdev/touch_filter/single_position_touch_noise_filter.cc", "evdev/touch_filter/single_position_touch_noise_filter.cc",
"evdev/touch_filter/single_position_touch_noise_filter.h", "evdev/touch_filter/single_position_touch_noise_filter.h",
......
...@@ -8,6 +8,10 @@ ...@@ -8,6 +8,10 @@
#include <bitset> #include <bitset>
#include <cstdint> #include <cstdint>
#include <deque> #include <deque>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector> #include <vector>
#include "base/time/time.h" #include "base/time/time.h"
...@@ -41,12 +45,13 @@ class EVENTS_OZONE_EVDEV_EXPORT NeuralStylusPalmDetectionFilter ...@@ -41,12 +45,13 @@ class EVENTS_OZONE_EVDEV_EXPORT NeuralStylusPalmDetectionFilter
static bool CompatibleWithNeuralStylusPalmDetectionFilter( static bool CompatibleWithNeuralStylusPalmDetectionFilter(
const EventDeviceInfo& devinfo); const EventDeviceInfo& devinfo);
const static int kFeaturesPerSample; static const int kFeaturesPerSample;
const static int kExtraFeaturesForNeighbor; static const int kExtraFeaturesForNeighbor;
private: static const char kFilterName[];
std::string FilterNameForTesting() const override; std::string FilterNameForTesting() const override;
private:
void FindNearestNeighborsWithin( void FindNearestNeighborsWithin(
int neighbor_count, int neighbor_count,
float max_distance, float max_distance,
...@@ -80,7 +85,6 @@ class EVENTS_OZONE_EVDEV_EXPORT NeuralStylusPalmDetectionFilter ...@@ -80,7 +85,6 @@ class EVENTS_OZONE_EVDEV_EXPORT NeuralStylusPalmDetectionFilter
const PalmFilterDeviceInfo palm_filter_dev_info_; const PalmFilterDeviceInfo palm_filter_dev_info_;
std::unique_ptr<NeuralStylusPalmDetectionFilterModel> model_; std::unique_ptr<NeuralStylusPalmDetectionFilterModel> model_;
const static char kFilterName[];
static const std::vector<int> kRequiredAbsMtCodes; static const std::vector<int> kRequiredAbsMtCodes;
DISALLOW_COPY_AND_ASSIGN(NeuralStylusPalmDetectionFilter); DISALLOW_COPY_AND_ASSIGN(NeuralStylusPalmDetectionFilter);
......
...@@ -5,19 +5,26 @@ ...@@ -5,19 +5,26 @@
#include "ui/events/ozone/evdev/touch_filter/palm_detection_filter_factory.h" #include "ui/events/ozone/evdev/touch_filter/palm_detection_filter_factory.h"
#include <memory> #include <memory>
#include <utility>
#include "base/feature_list.h" #include "base/feature_list.h"
#include "base/time/time.h" #include "base/time/time.h"
#include "ui/events/ozone/evdev/event_device_info.h" #include "ui/events/ozone/evdev/event_device_info.h"
#include "ui/events/ozone/evdev/touch_filter/heuristic_stylus_palm_detection_filter.h" #include "ui/events/ozone/evdev/touch_filter/heuristic_stylus_palm_detection_filter.h"
#include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h"
#include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h"
#include "ui/events/ozone/evdev/touch_filter/open_palm_detection_filter.h" #include "ui/events/ozone/evdev/touch_filter/open_palm_detection_filter.h"
#include "ui/events/ozone/evdev/touch_filter/palm_detection_filter.h" #include "ui/events/ozone/evdev/touch_filter/palm_detection_filter.h"
#include "ui/events/ozone/evdev/touch_filter/palm_model/onedevice_train_palm_detection_filter_model.h"
namespace ui { namespace ui {
const base::Feature kEnableHeuristicPalmDetectionFilter{ const base::Feature kEnableHeuristicPalmDetectionFilter{
"EnableHeuristicPalmDetectionFilter", base::FEATURE_DISABLED_BY_DEFAULT}; "EnableHeuristicPalmDetectionFilter", base::FEATURE_DISABLED_BY_DEFAULT};
const base::Feature kEnableNeuralPalmDetectionFilter{
"EnableNeuralPalmDetectionFilter", base::FEATURE_DISABLED_BY_DEFAULT};
const base::FeatureParam<double> kHeuristicCancelThresholdSeconds{ const base::FeatureParam<double> kHeuristicCancelThresholdSeconds{
&kEnableHeuristicPalmDetectionFilter, &kEnableHeuristicPalmDetectionFilter,
"heuristic_palm_cancel_threshold_seconds", 0.4}; "heuristic_palm_cancel_threshold_seconds", 0.4};
...@@ -32,6 +39,16 @@ const base::FeatureParam<int> kHeuristicStrokeCount{ ...@@ -32,6 +39,16 @@ const base::FeatureParam<int> kHeuristicStrokeCount{
std::unique_ptr<PalmDetectionFilter> CreatePalmDetectionFilter( std::unique_ptr<PalmDetectionFilter> CreatePalmDetectionFilter(
const EventDeviceInfo& devinfo, const EventDeviceInfo& devinfo,
SharedPalmDetectionFilterState* shared_palm_state) { SharedPalmDetectionFilterState* shared_palm_state) {
if (base::FeatureList::IsEnabled(kEnableNeuralPalmDetectionFilter) &&
NeuralStylusPalmDetectionFilter::
CompatibleWithNeuralStylusPalmDetectionFilter(devinfo)) {
// Theres only one model right now.
std::unique_ptr<NeuralStylusPalmDetectionFilterModel> model =
std::make_unique<OneDeviceTrainNeuralStylusPalmDetectionFilterModel>();
return std::make_unique<NeuralStylusPalmDetectionFilter>(
devinfo, std::move(model), shared_palm_state);
}
if (base::FeatureList::IsEnabled(kEnableHeuristicPalmDetectionFilter)) { if (base::FeatureList::IsEnabled(kEnableHeuristicPalmDetectionFilter)) {
const base::TimeDelta hold_time = const base::TimeDelta hold_time =
base::TimeDelta::FromSecondsD(kHeuristicHoldThresholdSeconds.Get()); base::TimeDelta::FromSecondsD(kHeuristicHoldThresholdSeconds.Get());
......
...@@ -21,6 +21,9 @@ namespace ui { ...@@ -21,6 +21,9 @@ namespace ui {
EVENTS_OZONE_EVDEV_EXPORT EVENTS_OZONE_EVDEV_EXPORT
extern const base::Feature kEnableHeuristicPalmDetectionFilter; extern const base::Feature kEnableHeuristicPalmDetectionFilter;
EVENTS_OZONE_EVDEV_EXPORT
extern const base::Feature kEnableNeuralPalmDetectionFilter;
EVENTS_OZONE_EVDEV_EXPORT EVENTS_OZONE_EVDEV_EXPORT
extern const base::FeatureParam<double> kHeuristicCancelThresholdSeconds; extern const base::FeatureParam<double> kHeuristicCancelThresholdSeconds;
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ui/events/ozone/evdev/event_device_info.h" #include "ui/events/ozone/evdev/event_device_info.h"
#include "ui/events/ozone/evdev/event_device_test_util.h" #include "ui/events/ozone/evdev/event_device_test_util.h"
#include "ui/events/ozone/evdev/touch_filter/heuristic_stylus_palm_detection_filter.h" #include "ui/events/ozone/evdev/touch_filter/heuristic_stylus_palm_detection_filter.h"
#include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h"
#include "ui/events/ozone/evdev/touch_filter/open_palm_detection_filter.h" #include "ui/events/ozone/evdev/touch_filter/open_palm_detection_filter.h"
#include "ui/events/ozone/evdev/touch_filter/palm_detection_filter.h" #include "ui/events/ozone/evdev/touch_filter/palm_detection_filter.h"
#include "ui/events/ozone/evdev/touch_filter/shared_palm_detection_filter_state.h" #include "ui/events/ozone/evdev/touch_filter/shared_palm_detection_filter_state.h"
...@@ -42,8 +43,9 @@ class PalmDetectionFilterFactoryTest : public testing::Test { ...@@ -42,8 +43,9 @@ class PalmDetectionFilterFactoryTest : public testing::Test {
}; };
TEST_F(PalmDetectionFilterFactoryTest, AllDisabled) { TEST_F(PalmDetectionFilterFactoryTest, AllDisabled) {
scoped_feature_list_->InitAndDisableFeature( scoped_feature_list_->InitWithFeatures(
ui::kEnableHeuristicPalmDetectionFilter); {}, {ui::kEnableHeuristicPalmDetectionFilter,
ui::kEnableNeuralPalmDetectionFilter});
std::unique_ptr<PalmDetectionFilter> palm_filter = std::unique_ptr<PalmDetectionFilter> palm_filter =
CreatePalmDetectionFilter(eve_touchscreen_info_, &shared_palm_state_); CreatePalmDetectionFilter(eve_touchscreen_info_, &shared_palm_state_);
EXPECT_EQ(OpenPalmDetectionFilter::kFilterName, EXPECT_EQ(OpenPalmDetectionFilter::kFilterName,
...@@ -59,7 +61,7 @@ TEST_F(PalmDetectionFilterFactoryTest, HeuristicEnabledForEve) { ...@@ -59,7 +61,7 @@ TEST_F(PalmDetectionFilterFactoryTest, HeuristicEnabledForEve) {
scoped_feature_list_->InitWithFeaturesAndParameters( scoped_feature_list_->InitWithFeaturesAndParameters(
{base::test::ScopedFeatureList::FeatureAndParams( {base::test::ScopedFeatureList::FeatureAndParams(
ui::kEnableHeuristicPalmDetectionFilter, {})}, ui::kEnableHeuristicPalmDetectionFilter, {})},
{}); {ui::kEnableNeuralPalmDetectionFilter});
std::unique_ptr<PalmDetectionFilter> palm_filter = std::unique_ptr<PalmDetectionFilter> palm_filter =
CreatePalmDetectionFilter(eve_touchscreen_info_, &shared_palm_state_); CreatePalmDetectionFilter(eve_touchscreen_info_, &shared_palm_state_);
EXPECT_EQ(HeuristicStylusPalmDetectionFilter::kFilterName, EXPECT_EQ(HeuristicStylusPalmDetectionFilter::kFilterName,
...@@ -87,7 +89,7 @@ TEST_F(PalmDetectionFilterFactoryTest, HeuristicTimesSet) { ...@@ -87,7 +89,7 @@ TEST_F(PalmDetectionFilterFactoryTest, HeuristicTimesSet) {
ui::kEnableHeuristicPalmDetectionFilter, ui::kEnableHeuristicPalmDetectionFilter,
{{"heuristic_palm_cancel_threshold_seconds", "0.8"}, {{"heuristic_palm_cancel_threshold_seconds", "0.8"},
{"heuristic_palm_hold_threshold_seconds", "15.327"}})}, {"heuristic_palm_hold_threshold_seconds", "15.327"}})},
{}); {ui::kEnableNeuralPalmDetectionFilter});
std::unique_ptr<PalmDetectionFilter> palm_filter = CreatePalmDetectionFilter( std::unique_ptr<PalmDetectionFilter> palm_filter = CreatePalmDetectionFilter(
nocturne_touchscreen_info_, &shared_palm_state_); nocturne_touchscreen_info_, &shared_palm_state_);
...@@ -100,4 +102,16 @@ TEST_F(PalmDetectionFilterFactoryTest, HeuristicTimesSet) { ...@@ -100,4 +102,16 @@ TEST_F(PalmDetectionFilterFactoryTest, HeuristicTimesSet) {
heuristic_filter->HoldTime()); heuristic_filter->HoldTime());
} }
TEST_F(PalmDetectionFilterFactoryTest, NeuralBeatsHeuristic) {
scoped_feature_list_->InitWithFeaturesAndParameters(
{base::test::ScopedFeatureList::FeatureAndParams(
ui::kEnableHeuristicPalmDetectionFilter, {}),
base::test::ScopedFeatureList::FeatureAndParams(
ui::kEnableNeuralPalmDetectionFilter, {})},
{});
std::unique_ptr<PalmDetectionFilter> palm_filter = CreatePalmDetectionFilter(
nocturne_touchscreen_info_, &shared_palm_state_);
ASSERT_EQ(NeuralStylusPalmDetectionFilter::kFilterName,
palm_filter->FilterNameForTesting());
}
} // namespace ui } // namespace ui
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Code generated by tf.native.
#ifndef UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_PALM_MODEL_ONEDEVICE_TRAIN_PALM_DETECTION_FILTER_INFERENCE_H_
#define UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_PALM_MODEL_ONEDEVICE_TRAIN_PALM_DETECTION_FILTER_INFERENCE_H_
#include <cstdint>
namespace ui {
namespace internal_onedevice {
struct alignas(16) FixedAllocations {
float alloc0[117];
float alloc1[115];
int32_t shape0[2];
};
extern int32_t input_from_feature_columns_input_layer_concat_concat0Shape[2];
extern int32_t logits_MatMul_merged_with_dnn_logits_BiasAdd0Shape[2];
#define CHROME_KNOWLEDGE_INPUT_FROM_FEATURE_COLUMNS_INPUT_LAYER_CONCAT_CONCAT0_RANK \
2
#define CHROME_KNOWLEDGE_INPUT_FROM_FEATURE_COLUMNS_INPUT_LAYER_CONCAT_CONCAT0_DIM0_SIZE \
1
#define CHROME_KNOWLEDGE_INPUT_FROM_FEATURE_COLUMNS_INPUT_LAYER_CONCAT_CONCAT0_DIM1_SIZE \
193
#define CHROME_KNOWLEDGE_LOGITS_MATMUL_MERGED_WITH_DNN_LOGITS_BIASADD0_RANK 2
#define CHROME_KNOWLEDGE_LOGITS_MATMUL_MERGED_WITH_DNN_LOGITS_BIASADD0_DIM0_SIZE \
1
#define CHROME_KNOWLEDGE_LOGITS_MATMUL_MERGED_WITH_DNN_LOGITS_BIASADD0_DIM1_SIZE \
1
void Inference(
const float* __restrict input_from_feature_columns_input_layer_concat_concat0 /* shape: 1,193 */
,
float* __restrict logits_MatMul_merged_with_dnn_logits_BiasAdd0 /* shape:
1,1 */
,
FixedAllocations* __restrict fixed);
} // namespace internal_onedevice
} // namespace ui
#endif // UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_PALM_MODEL_ONEDEVICE_TRAIN_PALM_DETECTION_FILTER_INFERENCE_H_
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ui/events/ozone/evdev/touch_filter/palm_model/onedevice_train_palm_detection_filter_model.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ui/events/ozone/evdev/touch_filter/palm_model/onedevice_train_palm_detection_filter_inference.h"
#define USE_EIGEN 0
namespace ui {
float OneDeviceTrainNeuralStylusPalmDetectionFilterModel::Inference(
const std::vector<float>& features) const {
DVLOG(1) << "In Inference.";
std::unique_ptr<internal_onedevice::FixedAllocations> fixed_allocations(
new internal_onedevice::FixedAllocations());
if (features.size() != 193) {
LOG(DFATAL) << "Bad count. Is " << features.size() << " expected " << 193;
return nanf("");
}
// TODO(robsc): Update to DVLOG_IS_ON if relevant.
if (DCHECK_IS_ON() && VLOG_IS_ON(1)) {
for (unsigned i = 0; i < features.size(); ++i) {
DVLOG(1) << "Feature " << i << " is " << features[i];
}
}
float output = 0;
internal_onedevice::Inference(&features[0], &output, fixed_allocations.get());
return output;
}
const NeuralStylusPalmDetectionFilterModelConfig&
OneDeviceTrainNeuralStylusPalmDetectionFilterModel::config() const {
return config_;
}
OneDeviceTrainNeuralStylusPalmDetectionFilterModel::
OneDeviceTrainNeuralStylusPalmDetectionFilterModel() {
config_.nearest_neighbor_count = 1;
config_.biggest_near_neighbor_count = 1;
config_.include_sequence_count_in_strokes = true;
config_.max_neighbor_distance_in_mm = 100.0f;
config_.min_sample_count = 6;
config_.max_sample_count = 12;
config_.max_dead_neighbor_time = base::TimeDelta::FromMillisecondsD(100.0f);
config_.heuristic_palm_touch_limit = 20.0f;
config_.heuristic_palm_area_limit = 400.0f;
}
} // namespace ui
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_PALM_MODEL_ONEDEVICE_TRAIN_PALM_DETECTION_FILTER_MODEL_H_
#define UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_PALM_MODEL_ONEDEVICE_TRAIN_PALM_DETECTION_FILTER_MODEL_H_
#include <cstdint>
#include <vector>
#include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h"
namespace ui {
// A simplified Neural stylus Palm Detection Model, trained on the data based on
// a single device class but translatable to others. Neural inference
// implementation based on inline neural net inference.
class EVENTS_OZONE_EVDEV_EXPORT
OneDeviceTrainNeuralStylusPalmDetectionFilterModel
: public NeuralStylusPalmDetectionFilterModel {
public:
OneDeviceTrainNeuralStylusPalmDetectionFilterModel();
float Inference(const std::vector<float>& features) const override;
const NeuralStylusPalmDetectionFilterModelConfig& config() const override;
DISALLOW_COPY_AND_ASSIGN(OneDeviceTrainNeuralStylusPalmDetectionFilterModel);
private:
NeuralStylusPalmDetectionFilterModelConfig config_;
};
} // namespace ui
#endif // UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_PALM_MODEL_ONEDEVICE_TRAIN_PALM_DETECTION_FILTER_MODEL_H_
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment