Commit 39bb38e8 authored by tby's avatar tby Committed by Commit Bot

[Cros SR] Provide the app shortcut SP with access to Roselle.

See here for context:

  https://docs.google.com/document/d/1GTQ3vvrszK4I1pDzmTFb6krmUzAXox0V9ZIqIi_GH1E/edit?usp=sharing

We are expanding Roselle to re-rank the results returned when the user enters
a query. However, these results include Arc app shortcuts as well as apps
themselves, so they need to be ranked. This CL gives the shortcut search
provider access to the Roselle model, for future use in training/ranking.

Bug: 931149
Change-Id: I826be986a98d870de21ecd9c88f5a266c5d0308e
Reviewed-on: https://chromium-review.googlesource.com/c/1478570
Commit-Queue: Tony Yeoman <tby@chromium.org>
Reviewed-by: default avatarXiyuan Xia <xiyuan@chromium.org>
Reviewed-by: default avatarJia Meng <jiameng@chromium.org>
Cr-Commit-Position: refs/heads/master@{#634428}
parent d444574c
...@@ -34,7 +34,6 @@ ...@@ -34,7 +34,6 @@
#include "chrome/browser/chromeos/crostini/crostini_registry_service_factory.h" #include "chrome/browser/chromeos/crostini/crostini_registry_service_factory.h"
#include "chrome/browser/chromeos/crostini/crostini_util.h" #include "chrome/browser/chromeos/crostini/crostini_util.h"
#include "chrome/browser/chromeos/extensions/gfx_utils.h" #include "chrome/browser/chromeos/extensions/gfx_utils.h"
#include "chrome/browser/chromeos/profiles/profile_helper.h"
#include "chrome/browser/extensions/extension_service.h" #include "chrome/browser/extensions/extension_service.h"
#include "chrome/browser/extensions/extension_ui_util.h" #include "chrome/browser/extensions/extension_ui_util.h"
#include "chrome/browser/extensions/extension_util.h" #include "chrome/browser/extensions/extension_util.h"
...@@ -620,14 +619,13 @@ class CrostiniDataSource : public AppSearchProvider::DataSource, ...@@ -620,14 +619,13 @@ class CrostiniDataSource : public AppSearchProvider::DataSource,
AppSearchProvider::AppSearchProvider(Profile* profile, AppSearchProvider::AppSearchProvider(Profile* profile,
AppListControllerDelegate* list_controller, AppListControllerDelegate* list_controller,
base::Clock* clock, base::Clock* clock,
AppListModelUpdater* model_updater) AppListModelUpdater* model_updater,
AppSearchResultRanker* ranker)
: profile_(profile), : profile_(profile),
list_controller_(list_controller), list_controller_(list_controller),
model_updater_(model_updater), model_updater_(model_updater),
clock_(clock), clock_(clock),
ranker_(std::make_unique<AppSearchResultRanker>( ranker_(ranker),
profile->GetPath(),
chromeos::ProfileHelper::IsEphemeralUserProfile(profile))),
refresh_apps_factory_(this), refresh_apps_factory_(this),
update_results_factory_(this) { update_results_factory_(this) {
bool app_service_enabled = bool app_service_enabled =
......
...@@ -44,7 +44,8 @@ class AppSearchProvider : public SearchProvider { ...@@ -44,7 +44,8 @@ class AppSearchProvider : public SearchProvider {
AppSearchProvider(Profile* profile, AppSearchProvider(Profile* profile,
AppListControllerDelegate* list_controller, AppListControllerDelegate* list_controller,
base::Clock* clock, base::Clock* clock,
AppListModelUpdater* model_updater); AppListModelUpdater* model_updater,
AppSearchResultRanker* ranker);
~AppSearchProvider() override; ~AppSearchProvider() override;
// SearchProvider overrides: // SearchProvider overrides:
...@@ -88,7 +89,7 @@ class AppSearchProvider : public SearchProvider { ...@@ -88,7 +89,7 @@ class AppSearchProvider : public SearchProvider {
AppListModelUpdater* const model_updater_; AppListModelUpdater* const model_updater_;
base::Clock* clock_; base::Clock* clock_;
std::vector<std::unique_ptr<DataSource>> data_sources_; std::vector<std::unique_ptr<DataSource>> data_sources_;
std::unique_ptr<AppSearchResultRanker> ranker_; AppSearchResultRanker* ranker_;
sync_sessions::OpenTabsUIDelegate* open_tabs_ui_delegate_for_testing_ = sync_sessions::OpenTabsUIDelegate* open_tabs_ui_delegate_for_testing_ =
nullptr; nullptr;
base::WeakPtrFactory<AppSearchProvider> refresh_apps_factory_; base::WeakPtrFactory<AppSearchProvider> refresh_apps_factory_;
......
...@@ -7,11 +7,13 @@ ...@@ -7,11 +7,13 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "ash/public/cpp/app_list/app_list_features.h"
#include "base/bind.h" #include "base/bind.h"
#include "base/strings/utf_string_conversions.h" #include "base/strings/utf_string_conversions.h"
#include "chrome/browser/profiles/profile.h" #include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/app_list/arc/arc_app_utils.h" #include "chrome/browser/ui/app_list/arc/arc_app_utils.h"
#include "chrome/browser/ui/app_list/search/arc/arc_app_shortcut_search_result.h" #include "chrome/browser/ui/app_list/search/arc/arc_app_shortcut_search_result.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/app_search_result_ranker.h"
#include "components/arc/arc_bridge_service.h" #include "components/arc/arc_bridge_service.h"
#include "components/arc/arc_service_manager.h" #include "components/arc/arc_service_manager.h"
...@@ -20,10 +22,12 @@ namespace app_list { ...@@ -20,10 +22,12 @@ namespace app_list {
ArcAppShortcutsSearchProvider::ArcAppShortcutsSearchProvider( ArcAppShortcutsSearchProvider::ArcAppShortcutsSearchProvider(
int max_results, int max_results,
Profile* profile, Profile* profile,
AppListControllerDelegate* list_controller) AppListControllerDelegate* list_controller,
AppSearchResultRanker* ranker)
: max_results_(max_results), : max_results_(max_results),
profile_(profile), profile_(profile),
list_controller_(list_controller), list_controller_(list_controller),
ranker_(ranker),
weak_ptr_factory_(this) {} weak_ptr_factory_(this) {}
ArcAppShortcutsSearchProvider::~ArcAppShortcutsSearchProvider() = default; ArcAppShortcutsSearchProvider::~ArcAppShortcutsSearchProvider() = default;
...@@ -66,7 +70,14 @@ void ArcAppShortcutsSearchProvider::OnGetAppShortcutGlobalQueryItems( ...@@ -66,7 +70,14 @@ void ArcAppShortcutsSearchProvider::OnGetAppShortcutGlobalQueryItems(
continue; continue;
search_results.emplace_back(std::make_unique<ArcAppShortcutSearchResult>( search_results.emplace_back(std::make_unique<ArcAppShortcutSearchResult>(
std::move(item), profile_, list_controller_)); std::move(item), profile_, list_controller_));
if (app_list_features::IsAppSearchResultRankerEnabled() &&
ranker_ != nullptr) {
// TODO(crbug.com/931149): tweak the scores of each search result item
// using the ranker.
}
} }
SwapResults(&search_results); SwapResults(&search_results);
} }
......
...@@ -18,11 +18,14 @@ class Profile; ...@@ -18,11 +18,14 @@ class Profile;
namespace app_list { namespace app_list {
class AppSearchResultRanker;
class ArcAppShortcutsSearchProvider : public SearchProvider { class ArcAppShortcutsSearchProvider : public SearchProvider {
public: public:
ArcAppShortcutsSearchProvider(int max_results, ArcAppShortcutsSearchProvider(int max_results,
Profile* profile, Profile* profile,
AppListControllerDelegate* list_controller); AppListControllerDelegate* list_controller,
AppSearchResultRanker* ranker);
~ArcAppShortcutsSearchProvider() override; ~ArcAppShortcutsSearchProvider() override;
// SearchProvider: // SearchProvider:
...@@ -35,6 +38,9 @@ class ArcAppShortcutsSearchProvider : public SearchProvider { ...@@ -35,6 +38,9 @@ class ArcAppShortcutsSearchProvider : public SearchProvider {
const int max_results_; const int max_results_;
Profile* const profile_; // Owned by ProfileInfo. Profile* const profile_; // Owned by ProfileInfo.
AppListControllerDelegate* const list_controller_; // Owned by AppListClient. AppListControllerDelegate* const list_controller_; // Owned by AppListClient.
// TODO(crbug.com/931149): train this ranker on app shortcut clicks, and use
// it to tweak their relevance scores.
AppSearchResultRanker* ranker_;
base::WeakPtrFactory<ArcAppShortcutsSearchProvider> weak_ptr_factory_; base::WeakPtrFactory<ArcAppShortcutsSearchProvider> weak_ptr_factory_;
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "base/files/scoped_temp_dir.h"
#include "base/macros.h" #include "base/macros.h"
#include "base/strings/stringprintf.h" #include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h" #include "base/strings/utf_string_conversions.h"
...@@ -15,6 +16,7 @@ ...@@ -15,6 +16,7 @@
#include "chrome/browser/ui/app_list/arc/arc_app_list_prefs.h" #include "chrome/browser/ui/app_list/arc/arc_app_list_prefs.h"
#include "chrome/browser/ui/app_list/arc/arc_app_test.h" #include "chrome/browser/ui/app_list/arc/arc_app_test.h"
#include "chrome/browser/ui/app_list/search/chrome_search_result.h" #include "chrome/browser/ui/app_list/search/chrome_search_result.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/app_search_result_ranker.h"
#include "chrome/browser/ui/app_list/test/test_app_list_controller_delegate.h" #include "chrome/browser/ui/app_list/test/test_app_list_controller_delegate.h"
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
...@@ -36,6 +38,10 @@ class ArcAppShortcutsSearchProviderTest ...@@ -36,6 +38,10 @@ class ArcAppShortcutsSearchProviderTest
AppListTestBase::SetUp(); AppListTestBase::SetUp();
arc_test_.SetUp(profile()); arc_test_.SetUp(profile());
controller_ = std::make_unique<test::TestAppListControllerDelegate>(); controller_ = std::make_unique<test::TestAppListControllerDelegate>();
ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
ranker_ =
std::make_unique<AppSearchResultRanker>(temp_dir_.GetPath(),
/*is_ephemeral_user=*/true);
} }
void TearDown() override { void TearDown() override {
...@@ -70,6 +76,8 @@ class ArcAppShortcutsSearchProviderTest ...@@ -70,6 +76,8 @@ class ArcAppShortcutsSearchProviderTest
return app_id; return app_id;
} }
base::ScopedTempDir temp_dir_;
std::unique_ptr<AppSearchResultRanker> ranker_;
std::unique_ptr<test::TestAppListControllerDelegate> controller_; std::unique_ptr<test::TestAppListControllerDelegate> controller_;
ArcAppTest arc_test_; ArcAppTest arc_test_;
...@@ -88,7 +96,7 @@ TEST_P(ArcAppShortcutsSearchProviderTest, Basic) { ...@@ -88,7 +96,7 @@ TEST_P(ArcAppShortcutsSearchProviderTest, Basic) {
constexpr char kQuery[] = "shortlabel"; constexpr char kQuery[] = "shortlabel";
auto provider = std::make_unique<ArcAppShortcutsSearchProvider>( auto provider = std::make_unique<ArcAppShortcutsSearchProvider>(
kMaxResults, profile(), controller_.get()); kMaxResults, profile(), controller_.get(), ranker_.get());
EXPECT_TRUE(provider->results().empty()); EXPECT_TRUE(provider->results().empty());
arc::IconDecodeRequest::DisableSafeDecodingForTesting(); arc::IconDecodeRequest::DisableSafeDecodingForTesting();
......
...@@ -13,18 +13,25 @@ ...@@ -13,18 +13,25 @@
#include "base/bind.h" #include "base/bind.h"
#include "base/strings/string_util.h" #include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h" #include "base/strings/utf_string_conversions.h"
#include "chrome/browser/chromeos/profiles/profile_helper.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/app_list/app_list_controller_delegate.h" #include "chrome/browser/ui/app_list/app_list_controller_delegate.h"
#include "chrome/browser/ui/app_list/app_list_model_updater.h" #include "chrome/browser/ui/app_list/app_list_model_updater.h"
#include "chrome/browser/ui/app_list/search/chrome_search_result.h" #include "chrome/browser/ui/app_list/search/chrome_search_result.h"
#include "chrome/browser/ui/app_list/search/search_provider.h" #include "chrome/browser/ui/app_list/search/search_provider.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/app_search_result_ranker.h"
#include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker.h" #include "chrome/browser/ui/app_list/search/search_result_ranker/recurrence_ranker.h"
#include "chrome/browser/ui/ash/tablet_mode_client.h" #include "chrome/browser/ui/ash/tablet_mode_client.h"
namespace app_list { namespace app_list {
SearchController::SearchController(AppListModelUpdater* model_updater, SearchController::SearchController(AppListModelUpdater* model_updater,
AppListControllerDelegate* list_controller) AppListControllerDelegate* list_controller,
Profile* profile)
: mixer_(std::make_unique<Mixer>(model_updater)), : mixer_(std::make_unique<Mixer>(model_updater)),
ranker_(std::make_unique<AppSearchResultRanker>(
profile->GetPath(),
chromeos::ProfileHelper::IsEphemeralUserProfile(profile))),
list_controller_(list_controller) {} list_controller_(list_controller) {}
SearchController::~SearchController() {} SearchController::~SearchController() {}
...@@ -126,4 +133,8 @@ void SearchController::Train(const std::string& id, RankingItemType type) { ...@@ -126,4 +133,8 @@ void SearchController::Train(const std::string& id, RankingItemType type) {
mixer_->Train(id, type); mixer_->Train(id, type);
} }
AppSearchResultRanker* SearchController::GetSearchResultRanker() {
return ranker_.get();
}
} // namespace app_list } // namespace app_list
...@@ -18,9 +18,11 @@ ...@@ -18,9 +18,11 @@
class AppListControllerDelegate; class AppListControllerDelegate;
class AppListModelUpdater; class AppListModelUpdater;
class ChromeSearchResult; class ChromeSearchResult;
class Profile;
namespace app_list { namespace app_list {
class AppSearchResultRanker;
class RecurrenceRanker; class RecurrenceRanker;
class SearchProvider; class SearchProvider;
enum class RankingItemType; enum class RankingItemType;
...@@ -31,7 +33,8 @@ enum class RankingItemType; ...@@ -31,7 +33,8 @@ enum class RankingItemType;
class SearchController { class SearchController {
public: public:
SearchController(AppListModelUpdater* model_updater, SearchController(AppListModelUpdater* model_updater,
AppListControllerDelegate* list_controller); AppListControllerDelegate* list_controller,
Profile* profile);
virtual ~SearchController(); virtual ~SearchController();
void Start(const base::string16& query); void Start(const base::string16& query);
...@@ -56,6 +59,9 @@ class SearchController { ...@@ -56,6 +59,9 @@ class SearchController {
// Sends training signal to each |providers_| // Sends training signal to each |providers_|
void Train(const std::string& id, RankingItemType type); void Train(const std::string& id, RankingItemType type);
// Get the app search result ranker owned by this object.
AppSearchResultRanker* GetSearchResultRanker();
private: private:
// Invoked when the search results are changed. // Invoked when the search results are changed.
void OnResultsChanged(); void OnResultsChanged();
...@@ -68,6 +74,7 @@ class SearchController { ...@@ -68,6 +74,7 @@ class SearchController {
using Providers = std::vector<std::unique_ptr<SearchProvider>>; using Providers = std::vector<std::unique_ptr<SearchProvider>>;
Providers providers_; Providers providers_;
std::unique_ptr<Mixer> mixer_; std::unique_ptr<Mixer> mixer_;
std::unique_ptr<AppSearchResultRanker> ranker_;
AppListControllerDelegate* list_controller_; AppListControllerDelegate* list_controller_;
DISALLOW_COPY_AND_ASSIGN(SearchController); DISALLOW_COPY_AND_ASSIGN(SearchController);
......
...@@ -73,7 +73,9 @@ std::unique_ptr<SearchController> CreateSearchController( ...@@ -73,7 +73,9 @@ std::unique_ptr<SearchController> CreateSearchController(
AppListModelUpdater* model_updater, AppListModelUpdater* model_updater,
AppListControllerDelegate* list_controller) { AppListControllerDelegate* list_controller) {
std::unique_ptr<SearchController> controller = std::unique_ptr<SearchController> controller =
std::make_unique<SearchController>(model_updater, list_controller); std::make_unique<SearchController>(model_updater, list_controller,
profile);
AppSearchResultRanker* ranker = controller->GetSearchResultRanker();
// Add mixer groups. There are four main groups: answer card, apps // Add mixer groups. There are four main groups: answer card, apps
// and omnibox. Each group has a "soft" maximum number of results. However, if // and omnibox. Each group has a "soft" maximum number of results. However, if
...@@ -89,10 +91,10 @@ std::unique_ptr<SearchController> CreateSearchController( ...@@ -89,10 +91,10 @@ std::unique_ptr<SearchController> CreateSearchController(
size_t omnibox_group_id = controller->AddGroup(kMaxOmniboxResults, 1.0, 0.0); size_t omnibox_group_id = controller->AddGroup(kMaxOmniboxResults, 1.0, 0.0);
// Add search providers. // Add search providers.
controller->AddProvider( controller->AddProvider(apps_group_id, std::make_unique<AppSearchProvider>(
apps_group_id, std::make_unique<AppSearchProvider>( profile, list_controller,
profile, list_controller, base::DefaultClock::GetInstance(),
base::DefaultClock::GetInstance(), model_updater)); model_updater, ranker));
controller->AddProvider(omnibox_group_id, std::make_unique<OmniboxProvider>( controller->AddProvider(omnibox_group_id, std::make_unique<OmniboxProvider>(
profile, list_controller)); profile, list_controller));
if (app_list_features::IsAnswerCardEnabled()) { if (app_list_features::IsAnswerCardEnabled()) {
...@@ -155,7 +157,7 @@ std::unique_ptr<SearchController> CreateSearchController( ...@@ -155,7 +157,7 @@ std::unique_ptr<SearchController> CreateSearchController(
controller->AddProvider( controller->AddProvider(
app_shortcut_group_id, app_shortcut_group_id,
std::make_unique<ArcAppShortcutsSearchProvider>( std::make_unique<ArcAppShortcutsSearchProvider>(
kMaxAppShortcutResults, profile, list_controller)); kMaxAppShortcutResults, profile, list_controller, ranker));
} }
// TODO(https://crbug.com/921429): Put feature switch in ash/public/app_list/ // TODO(https://crbug.com/921429): Put feature switch in ash/public/app_list/
......
...@@ -80,6 +80,8 @@ constexpr char kRankingNormalAppPackageName[] = "test.ranking.app.normal"; ...@@ -80,6 +80,8 @@ constexpr char kRankingNormalAppPackageName[] = "test.ranking.app.normal";
constexpr char kSettingsInternalName[] = "Settings"; constexpr char kSettingsInternalName[] = "Settings";
constexpr bool kEphemeralUser = true;
// Waits for base::Time::Now() is updated. // Waits for base::Time::Now() is updated.
void WaitTimeUpdated() { void WaitTimeUpdated() {
base::RunLoop run_loop; base::RunLoop run_loop;
...@@ -108,18 +110,25 @@ class AppSearchProviderTest : public AppListTestBase { ...@@ -108,18 +110,25 @@ class AppSearchProviderTest : public AppListTestBase {
model_updater_ = std::make_unique<FakeAppListModelUpdater>(); model_updater_ = std::make_unique<FakeAppListModelUpdater>();
controller_ = std::make_unique<::test::TestAppListControllerDelegate>(); controller_ = std::make_unique<::test::TestAppListControllerDelegate>();
ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
} }
void CreateSearch() { void CreateSearch() {
clock_.SetNow(kTestCurrentTime); clock_.SetNow(kTestCurrentTime);
// Create ranker here so that tests can modify feature flags.
ranker_ = std::make_unique<AppSearchResultRanker>(temp_dir_.GetPath(),
kEphemeralUser);
app_search_ = std::make_unique<AppSearchProvider>( app_search_ = std::make_unique<AppSearchProvider>(
profile_.get(), nullptr, &clock_, model_updater_.get()); profile_.get(), nullptr, &clock_, model_updater_.get(), ranker_.get());
} }
void CreateSearchWithContinueReading() { void CreateSearchWithContinueReading() {
clock_.SetNow(kTestCurrentTime); clock_.SetNow(kTestCurrentTime);
// Create ranker here so that tests can modify feature flags.
ranker_ = std::make_unique<AppSearchResultRanker>(temp_dir_.GetPath(),
kEphemeralUser);
app_search_ = std::make_unique<AppSearchProvider>( app_search_ = std::make_unique<AppSearchProvider>(
profile_.get(), nullptr, &clock_, model_updater_.get()); profile_.get(), nullptr, &clock_, model_updater_.get(), ranker_.get());
session_tracker_ = std::make_unique<sync_sessions::SyncedSessionTracker>( session_tracker_ = std::make_unique<sync_sessions::SyncedSessionTracker>(
&mock_sync_sessions_client_); &mock_sync_sessions_client_);
...@@ -216,9 +225,11 @@ class AppSearchProviderTest : public AppListTestBase { ...@@ -216,9 +225,11 @@ class AppSearchProviderTest : public AppListTestBase {
private: private:
base::SimpleTestClock clock_; base::SimpleTestClock clock_;
base::ScopedTempDir temp_dir_;
std::unique_ptr<FakeAppListModelUpdater> model_updater_; std::unique_ptr<FakeAppListModelUpdater> model_updater_;
std::unique_ptr<AppSearchProvider> app_search_; std::unique_ptr<AppSearchProvider> app_search_;
std::unique_ptr<::test::TestAppListControllerDelegate> controller_; std::unique_ptr<::test::TestAppListControllerDelegate> controller_;
std::unique_ptr<AppSearchResultRanker> ranker_;
ArcAppTest arc_test_; ArcAppTest arc_test_;
// For continue reading. // For continue reading.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment