Commit f2d543fb authored by Yao Xiao's avatar Yao Xiao Committed by Commit Bot

Integrate sorting-lsh in floc computation

The version needs to be passed around, and for a single floc
computation cycle, the blocklist service is supposed validate the
version from sorting-lsh output to make sure they are in sync. This is
to address race conditions when a new component comes between the 2
file reads.

Bug: 1062736
Change-Id: I597565902fb8e893c7481cee496411880c6839b5
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2443495
Commit-Queue: Yao Xiao <yaoxia@chromium.org>
Reviewed-by: default avatarDavid Trainor <dtrainor@chromium.org>
Reviewed-by: default avatarJosh Karlin <jkarlin@chromium.org>
Cr-Commit-Position: refs/heads/master@{#813010}
parent 7dda812a
...@@ -59,15 +59,12 @@ void FlocComponentInstallerPolicy::ComponentReady( ...@@ -59,15 +59,12 @@ void FlocComponentInstallerPolicy::ComponentReady(
std::unique_ptr<base::DictionaryValue> manifest) { std::unique_ptr<base::DictionaryValue> manifest) {
DCHECK(!install_dir.empty()); DCHECK(!install_dir.empty());
// TODO(yaoxia): Pass along the |version| to each service. At the end of each
// floc computation cycle, it should verify the two versions match. This is
// not needed currently as the floc_sorting_lsh_clusters_service is not set up
// yet.
floc_blocklist_service_->OnBlocklistFileReady( floc_blocklist_service_->OnBlocklistFileReady(
install_dir.Append(federated_learning::kBlocklistFileName)); install_dir.Append(federated_learning::kBlocklistFileName), version);
floc_sorting_lsh_clusters_service_->OnSortingLshClustersFileReady( floc_sorting_lsh_clusters_service_->OnSortingLshClustersFileReady(
install_dir.Append(federated_learning::kSortingLshClustersFileName)); install_dir.Append(federated_learning::kSortingLshClustersFileName),
version);
} }
// Called during startup and installation before ComponentReady(). // Called during startup and installation before ComponentReady().
......
...@@ -49,14 +49,18 @@ class MockFlocBlocklistService ...@@ -49,14 +49,18 @@ class MockFlocBlocklistService
~MockFlocBlocklistService() override = default; ~MockFlocBlocklistService() override = default;
void OnBlocklistFileReady(const base::FilePath& file_path) override { void OnBlocklistFileReady(const base::FilePath& file_path,
const base::Version& version) override {
file_paths_.push_back(file_path); file_paths_.push_back(file_path);
versions_.push_back(version);
} }
const std::vector<base::FilePath>& file_paths() const { return file_paths_; } const std::vector<base::FilePath>& file_paths() const { return file_paths_; }
const std::vector<base::Version>& versions() const { return versions_; }
private: private:
std::vector<base::FilePath> file_paths_; std::vector<base::FilePath> file_paths_;
std::vector<base::Version> versions_;
}; };
// This class monitors the OnSortingLshClustersFileReady method calls. // This class monitors the OnSortingLshClustersFileReady method calls.
...@@ -72,14 +76,18 @@ class MockFlocSortingLshClustersService ...@@ -72,14 +76,18 @@ class MockFlocSortingLshClustersService
~MockFlocSortingLshClustersService() override = default; ~MockFlocSortingLshClustersService() override = default;
void OnSortingLshClustersFileReady(const base::FilePath& file_path) override { void OnSortingLshClustersFileReady(const base::FilePath& file_path,
const base::Version& version) override {
file_paths_.push_back(file_path); file_paths_.push_back(file_path);
versions_.push_back(version);
} }
const std::vector<base::FilePath>& file_paths() const { return file_paths_; } const std::vector<base::FilePath>& file_paths() const { return file_paths_; }
const std::vector<base::Version>& versions() const { return versions_; }
private: private:
std::vector<base::FilePath> file_paths_; std::vector<base::FilePath> file_paths_;
std::vector<base::Version> versions_;
}; };
} // namespace } // namespace
...@@ -203,10 +211,12 @@ TEST_F(FlocComponentInstallerTest, LoadFlocComponent) { ...@@ -203,10 +211,12 @@ TEST_F(FlocComponentInstallerTest, LoadFlocComponent) {
std::string contents = "abcd"; std::string contents = "abcd";
ASSERT_NO_FATAL_FAILURE(CreateTestFlocComponentFiles(contents, contents)); ASSERT_NO_FATAL_FAILURE(CreateTestFlocComponentFiles(contents, contents));
ASSERT_NO_FATAL_FAILURE(LoadFlocComponent( ASSERT_NO_FATAL_FAILURE(LoadFlocComponent(
"1.0.0", federated_learning::kCurrentFlocComponentFormatVersion)); "1.0.1", federated_learning::kCurrentFlocComponentFormatVersion));
ASSERT_EQ(blocklist_service()->file_paths().size(), 1u); ASSERT_EQ(blocklist_service()->file_paths().size(), 1u);
ASSERT_EQ(blocklist_service()->versions().size(), 1u);
ASSERT_EQ(sorting_lsh_clusters_service()->file_paths().size(), 1u); ASSERT_EQ(sorting_lsh_clusters_service()->file_paths().size(), 1u);
ASSERT_EQ(sorting_lsh_clusters_service()->versions().size(), 1u);
// Assert that the file path is the concatenation of |component_install_dir_| // Assert that the file path is the concatenation of |component_install_dir_|
// and the corresponding file name, which implies that the |version| argument // and the corresponding file name, which implies that the |version| argument
...@@ -222,6 +232,9 @@ TEST_F(FlocComponentInstallerTest, LoadFlocComponent) { ...@@ -222,6 +232,9 @@ TEST_F(FlocComponentInstallerTest, LoadFlocComponent) {
.Append(federated_learning::kSortingLshClustersFileName) .Append(federated_learning::kSortingLshClustersFileName)
.AsUTF8Unsafe()); .AsUTF8Unsafe());
EXPECT_EQ(blocklist_service()->versions()[0].GetString(), "1.0.1");
EXPECT_EQ(sorting_lsh_clusters_service()->versions()[0].GetString(), "1.0.1");
std::string actual_contents; std::string actual_contents;
ASSERT_TRUE(base::ReadFileToString(blocklist_service()->file_paths()[0], ASSERT_TRUE(base::ReadFileToString(blocklist_service()->file_paths()[0],
&actual_contents)); &actual_contents));
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "chrome/test/base/in_process_browser_test.h" #include "chrome/test/base/in_process_browser_test.h"
#include "chrome/test/base/ui_test_utils.h" #include "chrome/test/base/ui_test_utils.h"
#include "components/content_settings/core/browser/cookie_settings.h" #include "components/content_settings/core/browser/cookie_settings.h"
#include "components/federated_learning/floc_constants.h"
#include "components/history/core/test/fake_web_history_service.h" #include "components/history/core/test/fake_web_history_service.h"
#include "components/keyed_service/content/browser_context_dependency_manager.h" #include "components/keyed_service/content/browser_context_dependency_manager.h"
#include "components/sync/driver/test_sync_service.h" #include "components/sync/driver/test_sync_service.h"
...@@ -237,11 +238,25 @@ class FlocIdProviderWithCustomizedServicesBrowserTest ...@@ -237,11 +238,25 @@ class FlocIdProviderWithCustomizedServicesBrowserTest
run_loop.Run(); run_loop.Run();
} }
void FinishOutstandingSortingLshQueries() {
base::RunLoop run_loop;
FlocId dummy_floc = FlocId(0u);
g_browser_process->floc_sorting_lsh_clusters_service()->ApplySortingLsh(
dummy_floc,
base::BindLambdaForTesting(
[&](FlocId floc, base::Optional<base::Version> version) {
run_loop.Quit();
}));
run_loop.Run();
}
void FinishOutstandingBlocklistQueries() { void FinishOutstandingBlocklistQueries() {
base::RunLoop run_loop; base::RunLoop run_loop;
FlocId dummy_unfiltered_floc = FlocId(0u); FlocId dummy_unfiltered_floc = FlocId(0u);
base::Optional<base::Version> no_need_to_verify_version = base::nullopt;
g_browser_process->floc_blocklist_service()->FilterByBlocklist( g_browser_process->floc_blocklist_service()->FilterByBlocklist(
dummy_unfiltered_floc, dummy_unfiltered_floc, no_need_to_verify_version,
base::BindLambdaForTesting( base::BindLambdaForTesting(
[&](FlocId filtered_floc) { run_loop.Quit(); })); [&](FlocId filtered_floc) { run_loop.Quit(); }));
run_loop.Run(); run_loop.Run();
...@@ -264,6 +279,30 @@ class FlocIdProviderWithCustomizedServicesBrowserTest ...@@ -264,6 +279,30 @@ class FlocIdProviderWithCustomizedServicesBrowserTest
base::NumberToString(next_unique_file_suffix_++)); base::NumberToString(next_unique_file_suffix_++));
} }
base::FilePath CreateSortingLshFile(
const std::vector<uint32_t>& sorting_lsh_entries) {
base::ScopedAllowBlockingForTesting allow_blocking;
base::FilePath file_path = GetUniqueTemporaryPath();
base::File file(file_path, base::File::FLAG_CREATE | base::File::FLAG_READ |
base::File::FLAG_WRITE);
CHECK(file.IsValid());
CopyingFileOutputStream copying_stream(std::move(file));
google::protobuf::io::CopyingOutputStreamAdaptor zero_copy_stream_adaptor(
&copying_stream);
google::protobuf::io::CodedOutputStream output_stream(
&zero_copy_stream_adaptor);
for (uint32_t next : sorting_lsh_entries)
output_stream.WriteVarint32(next);
CHECK(!output_stream.HadError());
return file_path;
}
base::FilePath CreateBlocklistFile( base::FilePath CreateBlocklistFile(
const std::vector<uint64_t>& blocklist_entries) { const std::vector<uint64_t>& blocklist_entries) {
base::ScopedAllowBlockingForTesting allow_blocking; base::ScopedAllowBlockingForTesting allow_blocking;
...@@ -292,21 +331,36 @@ class FlocIdProviderWithCustomizedServicesBrowserTest ...@@ -292,21 +331,36 @@ class FlocIdProviderWithCustomizedServicesBrowserTest
void FinishOutstandingAsyncQueries() { void FinishOutstandingAsyncQueries() {
FinishOutstandingRemotePermissionQueries(); FinishOutstandingRemotePermissionQueries();
FinishOutstandingHistoryQueries(); FinishOutstandingHistoryQueries();
FinishOutstandingSortingLshQueries();
FinishOutstandingBlocklistQueries(); FinishOutstandingBlocklistQueries();
} }
// Turn on sync-history, set up the blocklist file, and trigger the blocklist // Turn on sync-history, set up the blocklist and sorting-lsh file, and
// file-ready event. // trigger the blocklist file-ready event.
void InitializeBlocklist(const std::vector<uint64_t>& blocklist_entries) { void InitializeBlocklistAndSortingLsh(
const std::vector<uint64_t>& blocklist_entries,
base::Version blocklist_version,
const std::vector<uint32_t>& sorting_lsh_entries,
base::Version sorting_lsh_version) {
sync_service()->SetActiveDataTypes(syncer::ModelTypeSet::All()); sync_service()->SetActiveDataTypes(syncer::ModelTypeSet::All());
sync_service()->FireStateChanged(); sync_service()->FireStateChanged();
g_browser_process->floc_blocklist_service()->OnBlocklistFileReady( g_browser_process->floc_blocklist_service()->OnBlocklistFileReady(
CreateBlocklistFile(blocklist_entries)); CreateBlocklistFile(blocklist_entries), blocklist_version);
g_browser_process->floc_sorting_lsh_clusters_service()
->OnSortingLshClustersFileReady(
CreateSortingLshFile(sorting_lsh_entries), sorting_lsh_version);
FinishOutstandingAsyncQueries(); FinishOutstandingAsyncQueries();
} }
void InitializeBlocklist(const std::vector<uint64_t>& blocklist_entries) {
base::Version kDummyVersion("1.0.0");
InitializeBlocklistAndSortingLsh(blocklist_entries, kDummyVersion, {},
kDummyVersion);
}
history::HistoryService* history_service() { history::HistoryService* history_service() {
return HistoryServiceFactory::GetForProfile( return HistoryServiceFactory::GetForProfile(
browser()->profile(), ServiceAccessType::IMPLICIT_ACCESS); browser()->profile(), ServiceAccessType::IMPLICIT_ACCESS);
...@@ -658,4 +712,143 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest, ...@@ -658,4 +712,143 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest,
InvokeInterestCohortJsApi(web_contents())); InvokeInterestCohortJsApi(web_contents()));
} }
class FlocIdProviderSortingLshEnabledBrowserTest
: public FlocIdProviderWithCustomizedServicesBrowserTest {
public:
FlocIdProviderSortingLshEnabledBrowserTest() {
scoped_feature_list_.Reset();
scoped_feature_list_.InitWithFeatures(
{features::kFlocIdComputedEventLogging,
features::kFlocIdSortingLshBasedComputation,
features::kFlocIdBlocklistFiltering},
{});
}
};
IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
SingleSortingLshCluster) {
net::IPAddress::ConsiderLoopbackIPToBePubliclyRoutableForTesting();
ConfigureReplacementHostAndPortForRemotePermissionService();
std::string cookies_to_set = "/set-cookie?user_id=123";
ui_test_utils::NavigateToURL(
browser(), https_server_.GetURL(test_host(), cookies_to_set));
EXPECT_EQ(1u, GetHistoryUrls().size());
EXPECT_FALSE(GetFlocId().IsValid());
// All sim_hash will be encoded as 0 during sorting-lsh
std::vector<uint32_t> single_cluster_representation = {
kMaxNumberOfBitsInFloc};
InitializeBlocklistAndSortingLsh({}, base::Version("1.0.0"),
single_cluster_representation,
base::Version("1.0.0"));
// Expect that the FlocIdComputed user event is recorded.
ASSERT_EQ(1u, user_event_service()->GetRecordedUserEvents().size());
// Check that the original sim_hash is not 0.
EXPECT_NE(FlocId(0), FlocId::CreateFromHistory({test_host()}));
// Expect that the final id is 0 because the sorting-lsh was applied.
EXPECT_EQ(FlocId(0), GetFlocId());
}
IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
MismatchedBlocklistAndSortingLshVersion) {
net::IPAddress::ConsiderLoopbackIPToBePubliclyRoutableForTesting();
ConfigureReplacementHostAndPortForRemotePermissionService();
std::string cookies_to_set = "/set-cookie?user_id=123";
ui_test_utils::NavigateToURL(
browser(), https_server_.GetURL(test_host(), cookies_to_set));
EXPECT_EQ(1u, GetHistoryUrls().size());
EXPECT_FALSE(GetFlocId().IsValid());
// All sim_hash will be encoded as 0 during sorting-lsh
std::vector<uint32_t> single_cluster_representation = {
kMaxNumberOfBitsInFloc};
InitializeBlocklistAndSortingLsh({}, base::Version("1.0.1"),
single_cluster_representation,
base::Version("1.0.0"));
// Expect that the FlocIdComputed user event is recorded.
ASSERT_EQ(1u, user_event_service()->GetRecordedUserEvents().size());
// Check that the original sim_hash is not 0.
EXPECT_NE(FlocId(0), FlocId::CreateFromHistory({test_host()}));
// Expect that the final id is invalid because of version mismatch.
EXPECT_FALSE(GetFlocId().IsValid());
}
IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
SortingLshAndThenBlocked) {
net::IPAddress::ConsiderLoopbackIPToBePubliclyRoutableForTesting();
ConfigureReplacementHostAndPortForRemotePermissionService();
std::string cookies_to_set = "/set-cookie?user_id=123";
ui_test_utils::NavigateToURL(
browser(), https_server_.GetURL(test_host(), cookies_to_set));
EXPECT_EQ(1u, GetHistoryUrls().size());
EXPECT_FALSE(GetFlocId().IsValid());
// All sim_hash will be encoded as 0 during sorting-lsh
std::vector<uint32_t> single_cluster_representation = {
kMaxNumberOfBitsInFloc};
// Configure a blocklist that would block 0.
InitializeBlocklistAndSortingLsh({0}, base::Version("1.0.0"),
single_cluster_representation,
base::Version("1.0.0"));
// Expect that the FlocIdComputed user event is recorded.
ASSERT_EQ(1u, user_event_service()->GetRecordedUserEvents().size());
// Check that the original sim_hash is not 0.
EXPECT_NE(FlocId(0), FlocId::CreateFromHistory({test_host()}));
// Expect that the final id is invalid because it was blocked.
EXPECT_FALSE(GetFlocId().IsValid());
}
IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
CorruptedSortingLSH) {
net::IPAddress::ConsiderLoopbackIPToBePubliclyRoutableForTesting();
ConfigureReplacementHostAndPortForRemotePermissionService();
std::string cookies_to_set = "/set-cookie?user_id=123";
ui_test_utils::NavigateToURL(
browser(), https_server_.GetURL(test_host(), cookies_to_set));
EXPECT_EQ(1u, GetHistoryUrls().size());
EXPECT_FALSE(GetFlocId().IsValid());
// All sim_hash will be encoded as an invalid id.
std::vector<uint32_t> corrupted_sorting_lsh = {};
InitializeBlocklistAndSortingLsh({}, base::Version("1.0.0"),
corrupted_sorting_lsh,
base::Version("1.0.0"));
// Expect that the FlocIdComputed user event is recorded.
ASSERT_EQ(1u, user_event_service()->GetRecordedUserEvents().size());
// Expect that the final id is invalid due to unexpected sorting-lsh file
// format.
EXPECT_FALSE(GetFlocId().IsValid());
}
} // namespace federated_learning } // namespace federated_learning
...@@ -48,10 +48,16 @@ FlocIdProviderImpl::FlocIdProviderImpl( ...@@ -48,10 +48,16 @@ FlocIdProviderImpl::FlocIdProviderImpl(
user_event_service_(user_event_service) { user_event_service_(user_event_service) {
history_service->AddObserver(this); history_service->AddObserver(this);
sync_service_->AddObserver(this); sync_service_->AddObserver(this);
g_browser_process->floc_sorting_lsh_clusters_service()->AddObserver(this);
g_browser_process->floc_blocklist_service()->AddObserver(this); g_browser_process->floc_blocklist_service()->AddObserver(this);
OnStateChanged(sync_service); OnStateChanged(sync_service);
if (g_browser_process->floc_sorting_lsh_clusters_service()
->IsSortingLshClustersFileReady()) {
OnSortingLshClustersFileReady();
}
if (g_browser_process->floc_blocklist_service()->IsBlocklistFileReady()) if (g_browser_process->floc_blocklist_service()->IsBlocklistFileReady())
OnBlocklistFileReady(); OnBlocklistFileReady();
} }
...@@ -179,6 +185,15 @@ void FlocIdProviderImpl::OnURLsDeleted( ...@@ -179,6 +185,15 @@ void FlocIdProviderImpl::OnURLsDeleted(
ComputeFloc(ComputeFlocTrigger::kHistoryDelete); ComputeFloc(ComputeFlocTrigger::kHistoryDelete);
} }
void FlocIdProviderImpl::OnSortingLshClustersFileReady() {
if (first_sorting_lsh_file_ready_seen_)
return;
first_sorting_lsh_file_ready_seen_ = true;
MaybeTriggerFirstFlocComputation();
}
void FlocIdProviderImpl::OnBlocklistFileReady() { void FlocIdProviderImpl::OnBlocklistFileReady() {
if (first_blocklist_file_ready_seen_) if (first_blocklist_file_ready_seen_)
return; return;
...@@ -204,9 +219,17 @@ void FlocIdProviderImpl::MaybeTriggerFirstFlocComputation() { ...@@ -204,9 +219,17 @@ void FlocIdProviderImpl::MaybeTriggerFirstFlocComputation() {
if (first_floc_computation_triggered_) if (first_floc_computation_triggered_)
return; return;
if (!first_sync_history_enabled_seen_ || bool sorting_lsh_ready_or_not_required =
(base::FeatureList::IsEnabled(features::kFlocIdBlocklistFiltering) && !base::FeatureList::IsEnabled(
!first_blocklist_file_ready_seen_)) { features::kFlocIdSortingLshBasedComputation) ||
first_sorting_lsh_file_ready_seen_;
bool blocklist_ready_or_not_required =
!base::FeatureList::IsEnabled(features::kFlocIdBlocklistFiltering) ||
first_blocklist_file_ready_seen_;
if (!first_sync_history_enabled_seen_ || !sorting_lsh_ready_or_not_required ||
!blocklist_ready_or_not_required) {
return; return;
} }
...@@ -369,17 +392,41 @@ void FlocIdProviderImpl::ApplyAdditionalFiltering( ...@@ -369,17 +392,41 @@ void FlocIdProviderImpl::ApplyAdditionalFiltering(
const FlocId& sim_hash) { const FlocId& sim_hash) {
DCHECK(sim_hash.IsValid()); DCHECK(sim_hash.IsValid());
if (!base::FeatureList::IsEnabled(features::kFlocIdBlocklistFiltering)) { if (!base::FeatureList::IsEnabled(
std::move(callback).Run(ComputeFlocResult(sim_hash, sim_hash)); features::kFlocIdSortingLshBasedComputation)) {
SkippedOrAppliedSortingLsh(std::move(callback), sim_hash,
/*sim_hash_or_sorting_lsh=*/sim_hash,
/*version_to_validate=*/base::nullopt);
return; return;
} }
g_browser_process->floc_blocklist_service()->FilterByBlocklist( g_browser_process->floc_sorting_lsh_clusters_service()->ApplySortingLsh(
sim_hash, base::BindOnce(&FlocIdProviderImpl::DidApplyAdditionalFiltering, sim_hash, base::BindOnce(&FlocIdProviderImpl::SkippedOrAppliedSortingLsh,
weak_ptr_factory_.GetWeakPtr(), weak_ptr_factory_.GetWeakPtr(),
std::move(callback), sim_hash)); std::move(callback), sim_hash));
} }
void FlocIdProviderImpl::SkippedOrAppliedSortingLsh(
ComputeFlocCompletedCallback callback,
const FlocId& sim_hash,
FlocId sim_hash_or_sorting_lsh,
base::Optional<base::Version> version_to_validate) {
// |!sim_hash_or_sorting_lsh.IsValid()| indicates a missing or corrupted
// sorting-lsh file.
if (!base::FeatureList::IsEnabled(features::kFlocIdBlocklistFiltering) ||
!sim_hash_or_sorting_lsh.IsValid()) {
std::move(callback).Run(
ComputeFlocResult(sim_hash, sim_hash_or_sorting_lsh));
return;
}
g_browser_process->floc_blocklist_service()->FilterByBlocklist(
sim_hash_or_sorting_lsh, version_to_validate,
base::BindOnce(&FlocIdProviderImpl::DidApplyAdditionalFiltering,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
sim_hash));
}
void FlocIdProviderImpl::DidApplyAdditionalFiltering( void FlocIdProviderImpl::DidApplyAdditionalFiltering(
ComputeFlocCompletedCallback callback, ComputeFlocCompletedCallback callback,
FlocId sim_hash, FlocId sim_hash,
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "base/timer/timer.h" #include "base/timer/timer.h"
#include "chrome/browser/federated_learning/floc_id_provider.h" #include "chrome/browser/federated_learning/floc_id_provider.h"
#include "components/federated_learning/floc_blocklist_service.h" #include "components/federated_learning/floc_blocklist_service.h"
#include "components/federated_learning/floc_sorting_lsh_clusters_service.h"
#include "components/history/core/browser/history_service.h" #include "components/history/core/browser/history_service.h"
#include "components/history/core/browser/history_service_observer.h" #include "components/history/core/browser/history_service_observer.h"
#include "components/sync/driver/sync_service_observer.h" #include "components/sync/driver/sync_service_observer.h"
...@@ -44,6 +45,7 @@ class FlocRemotePermissionService; ...@@ -44,6 +45,7 @@ class FlocRemotePermissionService;
// the event of history deletion, the floc will be recomputed immediately and // the event of history deletion, the floc will be recomputed immediately and
// reset the timer of any currently scheduled computation to be 24 hours later. // reset the timer of any currently scheduled computation to be 24 hours later.
class FlocIdProviderImpl : public FlocIdProvider, class FlocIdProviderImpl : public FlocIdProvider,
public FlocSortingLshClustersService::Observer,
public FlocBlocklistService::Observer, public FlocBlocklistService::Observer,
public history::HistoryServiceObserver, public history::HistoryServiceObserver,
public syncer::SyncServiceObserver { public syncer::SyncServiceObserver {
...@@ -112,6 +114,9 @@ class FlocIdProviderImpl : public FlocIdProvider, ...@@ -112,6 +114,9 @@ class FlocIdProviderImpl : public FlocIdProvider,
void OnURLsDeleted(history::HistoryService* history_service, void OnURLsDeleted(history::HistoryService* history_service,
const history::DeletionInfo& deletion_info) override; const history::DeletionInfo& deletion_info) override;
// FlocSortingLshClustersService::Observer
void OnSortingLshClustersFileReady() override;
// FlocBlocklistService::Observer // FlocBlocklistService::Observer
void OnBlocklistFileReady() override; void OnBlocklistFileReady() override;
...@@ -143,6 +148,11 @@ class FlocIdProviderImpl : public FlocIdProvider, ...@@ -143,6 +148,11 @@ class FlocIdProviderImpl : public FlocIdProvider,
// history. For example, invalidate it if it's in the blocklist. // history. For example, invalidate it if it's in the blocklist.
void ApplyAdditionalFiltering(ComputeFlocCompletedCallback callback, void ApplyAdditionalFiltering(ComputeFlocCompletedCallback callback,
const FlocId& sim_hash); const FlocId& sim_hash);
void SkippedOrAppliedSortingLsh(
ComputeFlocCompletedCallback callback,
const FlocId& sim_hash,
FlocId sim_hash_or_sorting_lsh,
base::Optional<base::Version> version_to_validate);
void DidApplyAdditionalFiltering(ComputeFlocCompletedCallback callback, void DidApplyAdditionalFiltering(ComputeFlocCompletedCallback callback,
FlocId sim_hash, FlocId sim_hash,
FlocId final_hash); FlocId final_hash);
...@@ -158,6 +168,7 @@ class FlocIdProviderImpl : public FlocIdProvider, ...@@ -158,6 +168,7 @@ class FlocIdProviderImpl : public FlocIdProvider,
// loggings, updates, etc.), and compute again. // loggings, updates, etc.), and compute again.
base::Optional<ComputeFlocTrigger> pending_recompute_event_; base::Optional<ComputeFlocTrigger> pending_recompute_event_;
bool first_sorting_lsh_file_ready_seen_ = false;
bool first_blocklist_file_ready_seen_ = false; bool first_blocklist_file_ready_seen_ = false;
bool first_sync_history_enabled_seen_ = false; bool first_sync_history_enabled_seen_ = false;
......
...@@ -32,17 +32,22 @@ namespace { ...@@ -32,17 +32,22 @@ namespace {
using ComputeFlocTrigger = FlocIdProviderImpl::ComputeFlocTrigger; using ComputeFlocTrigger = FlocIdProviderImpl::ComputeFlocTrigger;
using ComputeFlocResult = FlocIdProviderImpl::ComputeFlocResult; using ComputeFlocResult = FlocIdProviderImpl::ComputeFlocResult;
using ComputeFlocCompletedCallback =
FlocIdProviderImpl::ComputeFlocCompletedCallback;
using CanComputeFlocCallback = FlocIdProviderImpl::CanComputeFlocCallback;
class MockFlocBlocklistService : public FlocBlocklistService { class MockFlocBlocklistService : public FlocBlocklistService {
public: public:
using FlocBlocklistService::FlocBlocklistService; using FlocBlocklistService::FlocBlocklistService;
void ConfigureFlocToBlock(const FlocId& floc_to_block) { void ConfigureBlocklist(const FlocId& floc_to_block) {
floc_to_block_ = floc_to_block; floc_to_block_ = floc_to_block;
} }
void FilterByBlocklist(const FlocId& unfiltered_floc, void FilterByBlocklist(
FilterByBlocklistCallback callback) override { const FlocId& unfiltered_floc,
const base::Optional<base::Version>& version_to_validate,
FilterByBlocklistCallback callback) override {
if (floc_to_block_ == unfiltered_floc) { if (floc_to_block_ == unfiltered_floc) {
std::move(callback).Run(FlocId()); std::move(callback).Run(FlocId());
return; return;
...@@ -54,6 +59,33 @@ class MockFlocBlocklistService : public FlocBlocklistService { ...@@ -54,6 +59,33 @@ class MockFlocBlocklistService : public FlocBlocklistService {
FlocId floc_to_block_; FlocId floc_to_block_;
}; };
class MockFlocSortingLshService : public FlocSortingLshClustersService {
public:
using FlocSortingLshClustersService::FlocSortingLshClustersService;
void ConfigureSortingLsh(
const std::unordered_map<uint64_t, FlocId>& sorting_lsh_map,
const base::Version& version) {
sorting_lsh_map_ = sorting_lsh_map;
version_ = version;
}
void ApplySortingLsh(const FlocId& raw_floc_id,
ApplySortingLshCallback callback) override {
if (sorting_lsh_map_.count(raw_floc_id.ToUint64())) {
std::move(callback).Run(sorting_lsh_map_.at(raw_floc_id.ToUint64()),
version_);
return;
}
std::move(callback).Run(FlocId(), version_);
}
private:
std::unordered_map<uint64_t, FlocId> sorting_lsh_map_;
base::Version version_;
};
class FakeFlocRemotePermissionService : public FlocRemotePermissionService { class FakeFlocRemotePermissionService : public FlocRemotePermissionService {
public: public:
using FlocRemotePermissionService::FlocRemotePermissionService; using FlocRemotePermissionService::FlocRemotePermissionService;
...@@ -203,6 +235,11 @@ class FlocIdProviderUnitTest : public testing::Test { ...@@ -203,6 +235,11 @@ class FlocIdProviderUnitTest : public testing::Test {
&prefs_, /*is_off_the_record=*/false, /*store_last_modified=*/false, &prefs_, /*is_off_the_record=*/false, /*store_last_modified=*/false,
/*restore_session=*/false); /*restore_session=*/false);
auto sorting_lsh_service = std::make_unique<MockFlocSortingLshService>();
sorting_lsh_service_ = sorting_lsh_service.get();
TestingBrowserProcess::GetGlobal()->SetFlocSortingLshClustersService(
std::move(sorting_lsh_service));
auto blocklist_service = std::make_unique<MockFlocBlocklistService>(); auto blocklist_service = std::make_unique<MockFlocBlocklistService>();
blocklist_service_ = blocklist_service.get(); blocklist_service_ = blocklist_service.get();
TestingBrowserProcess::GetGlobal()->SetFlocBlocklistService( TestingBrowserProcess::GetGlobal()->SetFlocBlocklistService(
...@@ -238,13 +275,16 @@ class FlocIdProviderUnitTest : public testing::Test { ...@@ -238,13 +275,16 @@ class FlocIdProviderUnitTest : public testing::Test {
history_service_->RemoveObserver(floc_id_provider_.get()); history_service_->RemoveObserver(floc_id_provider_.get());
} }
void CheckCanComputeFloc( void ApplyAdditionalFiltering(ComputeFlocCompletedCallback callback,
FlocIdProviderImpl::CanComputeFlocCallback callback) { const FlocId& sim_hash) {
floc_id_provider_->ApplyAdditionalFiltering(std::move(callback), sim_hash);
}
void CheckCanComputeFloc(CanComputeFlocCallback callback) {
floc_id_provider_->CheckCanComputeFloc(std::move(callback)); floc_id_provider_->CheckCanComputeFloc(std::move(callback));
} }
void IsSwaaNacAccountEnabled( void IsSwaaNacAccountEnabled(CanComputeFlocCallback callback) {
FlocIdProviderImpl::CanComputeFlocCallback callback) {
floc_id_provider_->IsSwaaNacAccountEnabled(std::move(callback)); floc_id_provider_->IsSwaaNacAccountEnabled(std::move(callback));
} }
...@@ -324,6 +364,7 @@ class FlocIdProviderUnitTest : public testing::Test { ...@@ -324,6 +364,7 @@ class FlocIdProviderUnitTest : public testing::Test {
scoped_refptr<FakeCookieSettings> fake_cookie_settings_; scoped_refptr<FakeCookieSettings> fake_cookie_settings_;
std::unique_ptr<MockFlocIdProvider> floc_id_provider_; std::unique_ptr<MockFlocIdProvider> floc_id_provider_;
MockFlocSortingLshService* sorting_lsh_service_;
MockFlocBlocklistService* blocklist_service_; MockFlocBlocklistService* blocklist_service_;
base::ScopedTempDir temp_dir_; base::ScopedTempDir temp_dir_;
...@@ -833,7 +874,7 @@ TEST_F(FlocIdProviderUnitTest, ...@@ -833,7 +874,7 @@ TEST_F(FlocIdProviderUnitTest,
// Trigger the blocklist ready event. The 1st floc computation should be // Trigger the blocklist ready event. The 1st floc computation should be
// triggered now as sync & sync-history are enabled the blocklist is ready. // triggered now as sync & sync-history are enabled the blocklist is ready.
blocklist_service_->OnBlocklistFileReady(base::FilePath()); blocklist_service_->OnBlocklistFileReady(base::FilePath(), base::Version());
EXPECT_TRUE(first_floc_computation_triggered()); EXPECT_TRUE(first_floc_computation_triggered());
} }
...@@ -845,7 +886,7 @@ TEST_F(FlocIdProviderUnitTest, ...@@ -845,7 +886,7 @@ TEST_F(FlocIdProviderUnitTest,
// Trigger the blocklist ready event. The 1st floc computation should not be // Trigger the blocklist ready event. The 1st floc computation should not be
// triggered as sync & sync-history are not enabled yet. // triggered as sync & sync-history are not enabled yet.
blocklist_service_->OnBlocklistFileReady(base::FilePath()); blocklist_service_->OnBlocklistFileReady(base::FilePath(), base::Version());
EXPECT_FALSE(first_floc_computation_triggered()); EXPECT_FALSE(first_floc_computation_triggered());
...@@ -876,7 +917,7 @@ TEST_F(FlocIdProviderUnitTest, BlocklistFilteringEnabled_BlockedFloc) { ...@@ -876,7 +917,7 @@ TEST_F(FlocIdProviderUnitTest, BlocklistFilteringEnabled_BlockedFloc) {
// Trigger the blocklist ready event, and turn on sync & sync-history to // Trigger the blocklist ready event, and turn on sync & sync-history to
// trigger the 1st floc computation. // trigger the 1st floc computation.
blocklist_service_->OnBlocklistFileReady(base::FilePath()); blocklist_service_->OnBlocklistFileReady(base::FilePath(), base::Version());
test_sync_service_->SetTransportState( test_sync_service_->SetTransportState(
syncer::SyncService::TransportState::ACTIVE); syncer::SyncService::TransportState::ACTIVE);
...@@ -895,7 +936,7 @@ TEST_F(FlocIdProviderUnitTest, BlocklistFilteringEnabled_BlockedFloc) { ...@@ -895,7 +936,7 @@ TEST_F(FlocIdProviderUnitTest, BlocklistFilteringEnabled_BlockedFloc) {
EXPECT_EQ(floc_from_history, floc_id()); EXPECT_EQ(floc_from_history, floc_id());
// Set the blocklist to block |floc_from_history|. // Set the blocklist to block |floc_from_history|.
blocklist_service_->ConfigureFlocToBlock(floc_from_history); blocklist_service_->ConfigureBlocklist(floc_from_history);
task_environment_.FastForwardBy(base::TimeDelta::FromDays(1)); task_environment_.FastForwardBy(base::TimeDelta::FromDays(1));
...@@ -927,7 +968,7 @@ TEST_F(FlocIdProviderUnitTest, BlocklistFilteringEnabled_BlockedFloc) { ...@@ -927,7 +968,7 @@ TEST_F(FlocIdProviderUnitTest, BlocklistFilteringEnabled_BlockedFloc) {
EXPECT_EQ(floc_from_history.ToUint64(), event.floc_id()); EXPECT_EQ(floc_from_history.ToUint64(), event.floc_id());
// Reset the blocklist to block nothing. // Reset the blocklist to block nothing.
blocklist_service_->ConfigureFlocToBlock(FlocId()); blocklist_service_->ConfigureBlocklist(FlocId());
task_environment_.FastForwardBy(base::TimeDelta::FromDays(1)); task_environment_.FastForwardBy(base::TimeDelta::FromDays(1));
...@@ -1148,4 +1189,52 @@ TEST_F(FlocIdProviderUnitTest, ScheduledUpdateDuringInProgressComputation) { ...@@ -1148,4 +1189,52 @@ TEST_F(FlocIdProviderUnitTest, ScheduledUpdateDuringInProgressComputation) {
EXPECT_EQ(FlocId::CreateFromHistory({domain1}), floc_id()); EXPECT_EQ(FlocId::CreateFromHistory({domain1}), floc_id());
} }
TEST_F(FlocIdProviderUnitTest, ApplyAdditionalFiltering_SortingLsh) {
base::test::ScopedFeatureList feature_list;
feature_list.InitWithFeatures({features::kFlocIdSortingLshBasedComputation},
{});
bool callback_called = false;
auto callback = base::BindLambdaForTesting([&](ComputeFlocResult result) {
EXPECT_FALSE(callback_called);
EXPECT_EQ(result.sim_hash, FlocId(3));
EXPECT_EQ(result.final_hash, FlocId(2));
callback_called = true;
});
// Map 3 to 2
sorting_lsh_service_->OnSortingLshClustersFileReady(base::FilePath(),
base::Version());
sorting_lsh_service_->ConfigureSortingLsh({{3, FlocId(2)}},
base::Version("3.4.5"));
FlocId sim_hash(3);
ApplyAdditionalFiltering(std::move(callback), sim_hash);
task_environment_.RunUntilIdle();
EXPECT_TRUE(callback_called);
}
TEST_F(FlocIdProviderUnitTest, ApplyAdditionalFiltering_SortingLshCorrupted) {
base::test::ScopedFeatureList feature_list;
feature_list.InitWithFeatures({features::kFlocIdSortingLshBasedComputation},
{});
bool callback_called = false;
auto callback = base::BindLambdaForTesting([&](ComputeFlocResult result) {
EXPECT_FALSE(callback_called);
EXPECT_EQ(result.sim_hash, FlocId(3));
EXPECT_EQ(result.final_hash, FlocId());
callback_called = true;
});
sorting_lsh_service_->OnSortingLshClustersFileReady(base::FilePath(),
base::Version());
sorting_lsh_service_->ConfigureSortingLsh({}, base::Version("3.4.5"));
FlocId sim_hash(3);
ApplyAdditionalFiltering(std::move(callback), sim_hash);
task_environment_.RunUntilIdle();
EXPECT_TRUE(callback_called);
}
} // namespace federated_learning } // namespace federated_learning
...@@ -385,6 +385,11 @@ const base::Feature kFlocIdComputedEventLogging{ ...@@ -385,6 +385,11 @@ const base::Feature kFlocIdComputedEventLogging{
const base::Feature kFlocIdBlocklistFiltering{ const base::Feature kFlocIdBlocklistFiltering{
"FlocIdBlocklistFiltering", base::FEATURE_DISABLED_BY_DEFAULT}; "FlocIdBlocklistFiltering", base::FEATURE_DISABLED_BY_DEFAULT};
// If enabled, the sim-hash floc computed from history will be further encoded
// based on the sorting-lsh.
const base::Feature kFlocIdSortingLshBasedComputation{
"FlocIdSortingLshBasedComputation", base::FEATURE_DISABLED_BY_DEFAULT};
// Enables Focus Mode which brings up a PWA-like window look. // Enables Focus Mode which brings up a PWA-like window look.
const base::Feature kFocusMode{"FocusMode", base::FEATURE_DISABLED_BY_DEFAULT}; const base::Feature kFocusMode{"FocusMode", base::FEATURE_DISABLED_BY_DEFAULT};
......
...@@ -248,6 +248,9 @@ extern const base::Feature kFlocIdComputedEventLogging; ...@@ -248,6 +248,9 @@ extern const base::Feature kFlocIdComputedEventLogging;
COMPONENT_EXPORT(CHROME_FEATURES) COMPONENT_EXPORT(CHROME_FEATURES)
extern const base::Feature kFlocIdBlocklistFiltering; extern const base::Feature kFlocIdBlocklistFiltering;
COMPONENT_EXPORT(CHROME_FEATURES)
extern const base::Feature kFlocIdSortingLshBasedComputation;
COMPONENT_EXPORT(CHROME_FEATURES) COMPONENT_EXPORT(CHROME_FEATURES)
extern const base::Feature kFocusMode; extern const base::Feature kFocusMode;
......
...@@ -90,12 +90,14 @@ void FlocBlocklistService::RemoveObserver(Observer* observer) { ...@@ -90,12 +90,14 @@ void FlocBlocklistService::RemoveObserver(Observer* observer) {
} }
bool FlocBlocklistService::IsBlocklistFileReady() const { bool FlocBlocklistService::IsBlocklistFileReady() const {
return blocklist_file_path_.has_value(); return first_file_ready_seen_;
} }
void FlocBlocklistService::OnBlocklistFileReady( void FlocBlocklistService::OnBlocklistFileReady(const base::FilePath& file_path,
const base::FilePath& file_path) { const base::Version& version) {
blocklist_file_path_ = file_path; blocklist_file_path_ = file_path;
blocklist_version_ = version;
first_file_ready_seen_ = true;
for (auto& observer : observers_) for (auto& observer : observers_)
observer.OnBlocklistFileReady(); observer.OnBlocklistFileReady();
...@@ -103,13 +105,20 @@ void FlocBlocklistService::OnBlocklistFileReady( ...@@ -103,13 +105,20 @@ void FlocBlocklistService::OnBlocklistFileReady(
void FlocBlocklistService::FilterByBlocklist( void FlocBlocklistService::FilterByBlocklist(
const FlocId& unfiltered_floc, const FlocId& unfiltered_floc,
const base::Optional<base::Version>& version_to_validate,
FilterByBlocklistCallback callback) { FilterByBlocklistCallback callback) {
DCHECK(unfiltered_floc.IsValid()); DCHECK(unfiltered_floc.IsValid());
DCHECK(blocklist_file_path_.has_value()); DCHECK(first_file_ready_seen_);
if (version_to_validate &&
version_to_validate.value().CompareTo(blocklist_version_) != 0) {
std::move(callback).Run(FlocId());
return;
}
base::PostTaskAndReplyWithResult( base::PostTaskAndReplyWithResult(
background_task_runner_.get(), FROM_HERE, background_task_runner_.get(), FROM_HERE,
base::BindOnce(&FilterByBlocklistOnBackgroundThread, unfiltered_floc, base::BindOnce(&FilterByBlocklistOnBackgroundThread, unfiltered_floc,
blocklist_file_path_.value()), blocklist_file_path_),
std::move(callback)); std::move(callback));
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/observer_list.h" #include "base/observer_list.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/version.h"
#include "components/federated_learning/floc_id.h" #include "components/federated_learning/floc_id.h"
namespace base { namespace base {
...@@ -47,11 +48,14 @@ class FlocBlocklistService { ...@@ -47,11 +48,14 @@ class FlocBlocklistService {
bool IsBlocklistFileReady() const; bool IsBlocklistFileReady() const;
// Virtual for testing. // Virtual for testing.
virtual void OnBlocklistFileReady(const base::FilePath& file_path); virtual void OnBlocklistFileReady(const base::FilePath& file_path,
const base::Version& version);
// Virtual for testing. // Virtual for testing.
virtual void FilterByBlocklist(const FlocId& unfiltered_floc, virtual void FilterByBlocklist(
FilterByBlocklistCallback callback); const FlocId& unfiltered_floc,
const base::Optional<base::Version>& version_to_validate,
FilterByBlocklistCallback callback);
void SetBackgroundTaskRunnerForTesting( void SetBackgroundTaskRunnerForTesting(
scoped_refptr<base::SequencedTaskRunner> background_task_runner); scoped_refptr<base::SequencedTaskRunner> background_task_runner);
...@@ -66,7 +70,9 @@ class FlocBlocklistService { ...@@ -66,7 +70,9 @@ class FlocBlocklistService {
base::ObserverList<Observer>::Unchecked observers_; base::ObserverList<Observer>::Unchecked observers_;
base::Optional<base::FilePath> blocklist_file_path_; bool first_file_ready_seen_ = false;
base::FilePath blocklist_file_path_;
base::Version blocklist_version_;
base::WeakPtrFactory<FlocBlocklistService> weak_ptr_factory_; base::WeakPtrFactory<FlocBlocklistService> weak_ptr_factory_;
}; };
......
...@@ -21,6 +21,10 @@ ...@@ -21,6 +21,10 @@
namespace federated_learning { namespace federated_learning {
namespace {
base::Version kDummyVersion = base::Version("1.0.0");
class CopyingFileOutputStream class CopyingFileOutputStream
: public google::protobuf::io::CopyingOutputStream { : public google::protobuf::io::CopyingOutputStream {
public: public:
...@@ -41,6 +45,8 @@ class CopyingFileOutputStream ...@@ -41,6 +45,8 @@ class CopyingFileOutputStream
base::File file_; base::File file_;
}; };
} // namespace
class FlocBlocklistServiceTest : public ::testing::Test { class FlocBlocklistServiceTest : public ::testing::Test {
public: public:
FlocBlocklistServiceTest() FlocBlocklistServiceTest()
...@@ -86,7 +92,7 @@ class FlocBlocklistServiceTest : public ::testing::Test { ...@@ -86,7 +92,7 @@ class FlocBlocklistServiceTest : public ::testing::Test {
base::FilePath InitializeBlocklistFile( base::FilePath InitializeBlocklistFile(
const std::vector<uint64_t>& blocklist) { const std::vector<uint64_t>& blocklist) {
base::FilePath file_path = CreateBlocklistFile(blocklist); base::FilePath file_path = CreateBlocklistFile(blocklist);
service()->OnBlocklistFileReady(file_path); service()->OnBlocklistFileReady(file_path, kDummyVersion);
EXPECT_TRUE(blocklist_file_path().has_value()); EXPECT_TRUE(blocklist_file_path().has_value());
return file_path; return file_path;
} }
...@@ -95,7 +101,10 @@ class FlocBlocklistServiceTest : public ::testing::Test { ...@@ -95,7 +101,10 @@ class FlocBlocklistServiceTest : public ::testing::Test {
FlocBlocklistService* service() { return service_.get(); } FlocBlocklistService* service() { return service_.get(); }
const base::Optional<base::FilePath>& blocklist_file_path() { base::Optional<base::FilePath> blocklist_file_path() {
if (!service()->first_file_ready_seen_)
return base::nullopt;
return service()->blocklist_file_path_; return service()->blocklist_file_path_;
} }
...@@ -108,7 +117,7 @@ class FlocBlocklistServiceTest : public ::testing::Test { ...@@ -108,7 +117,7 @@ class FlocBlocklistServiceTest : public ::testing::Test {
run_loop.Quit(); run_loop.Quit();
}); });
service()->FilterByBlocklist(unfiltered_floc, std::move(cb)); service()->FilterByBlocklist(unfiltered_floc, kDummyVersion, std::move(cb));
background_task_runner_->RunPendingTasks(); background_task_runner_->RunPendingTasks();
run_loop.Run(); run_loop.Run();
...@@ -167,7 +176,7 @@ TEST_F(FlocBlocklistServiceTest, List_MaxFlocPlus1) { ...@@ -167,7 +176,7 @@ TEST_F(FlocBlocklistServiceTest, List_MaxFlocPlus1) {
TEST_F(FlocBlocklistServiceTest, NonExistentBlocklist_Blocked) { TEST_F(FlocBlocklistServiceTest, NonExistentBlocklist_Blocked) {
base::FilePath file_path = GetUniqueTemporaryPath(); base::FilePath file_path = GetUniqueTemporaryPath();
service()->OnBlocklistFileReady(file_path); service()->OnBlocklistFileReady(file_path, kDummyVersion);
EXPECT_EQ(FlocId(), FilterByBlocklist(FlocId(3))); EXPECT_EQ(FlocId(), FilterByBlocklist(FlocId(3)));
} }
......
...@@ -116,29 +116,46 @@ void FlocSortingLshClustersService::RemoveObserver(Observer* observer) { ...@@ -116,29 +116,46 @@ void FlocSortingLshClustersService::RemoveObserver(Observer* observer) {
observers_.RemoveObserver(observer); observers_.RemoveObserver(observer);
} }
void FlocSortingLshClustersService::SetBackgroundTaskRunnerForTesting( bool FlocSortingLshClustersService::IsSortingLshClustersFileReady() const {
scoped_refptr<base::SequencedTaskRunner> background_task_runner) { return first_file_ready_seen_;
background_task_runner_ = background_task_runner; }
void FlocSortingLshClustersService::OnSortingLshClustersFileReady(
const base::FilePath& file_path,
const base::Version& version) {
sorting_lsh_clusters_file_path_ = file_path;
sorting_lsh_clusters_version_ = version;
first_file_ready_seen_ = true;
for (auto& observer : observers_)
observer.OnSortingLshClustersFileReady();
} }
void FlocSortingLshClustersService::ApplySortingLsh( void FlocSortingLshClustersService::ApplySortingLsh(
const FlocId& raw_floc_id, const FlocId& raw_floc_id,
ApplySortingLshCallback callback) { ApplySortingLshCallback callback) {
DCHECK(raw_floc_id.IsValid()); DCHECK(raw_floc_id.IsValid());
DCHECK(sorting_lsh_clusters_file_path_.has_value()); DCHECK(first_file_ready_seen_);
base::PostTaskAndReplyWithResult( base::PostTaskAndReplyWithResult(
background_task_runner_.get(), FROM_HERE, background_task_runner_.get(), FROM_HERE,
base::BindOnce(&ApplySortingLshOnBackgroundThread, raw_floc_id, base::BindOnce(&ApplySortingLshOnBackgroundThread, raw_floc_id,
sorting_lsh_clusters_file_path_.value()), sorting_lsh_clusters_file_path_),
std::move(callback)); base::BindOnce(&FlocSortingLshClustersService::DidApplySortingLsh,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
sorting_lsh_clusters_version_));
} }
void FlocSortingLshClustersService::OnSortingLshClustersFileReady( void FlocSortingLshClustersService::SetBackgroundTaskRunnerForTesting(
const base::FilePath& file_path) { scoped_refptr<base::SequencedTaskRunner> background_task_runner) {
sorting_lsh_clusters_file_path_ = file_path; background_task_runner_ = background_task_runner;
}
for (auto& observer : observers_) void FlocSortingLshClustersService::DidApplySortingLsh(
observer.OnSortingLshClustersFileReady(); ApplySortingLshCallback callback,
base::Version version,
FlocId floc_id) {
std::move(callback).Run(std::move(floc_id), std::move(version));
} }
} // namespace federated_learning } // namespace federated_learning
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/observer_list.h" #include "base/observer_list.h"
#include "base/optional.h" #include "base/optional.h"
#include "base/version.h"
#include "components/federated_learning/floc_id.h" #include "components/federated_learning/floc_id.h"
namespace base { namespace base {
...@@ -28,7 +29,8 @@ namespace federated_learning { ...@@ -28,7 +29,8 @@ namespace federated_learning {
// File reading and parsing is posted to |background_task_runner_|. // File reading and parsing is posted to |background_task_runner_|.
class FlocSortingLshClustersService { class FlocSortingLshClustersService {
public: public:
using ApplySortingLshCallback = base::OnceCallback<void(FlocId)>; using ApplySortingLshCallback =
base::OnceCallback<void(FlocId, base::Optional<base::Version>)>;
class Observer { class Observer {
public: public:
...@@ -46,24 +48,34 @@ class FlocSortingLshClustersService { ...@@ -46,24 +48,34 @@ class FlocSortingLshClustersService {
void AddObserver(Observer* observer); void AddObserver(Observer* observer);
void RemoveObserver(Observer* observer); void RemoveObserver(Observer* observer);
void SetBackgroundTaskRunnerForTesting( bool IsSortingLshClustersFileReady() const;
scoped_refptr<base::SequencedTaskRunner> background_task_runner);
void ApplySortingLsh(const FlocId& raw_floc_id, // Virtual for testing.
ApplySortingLshCallback callback); virtual void OnSortingLshClustersFileReady(const base::FilePath& file_path,
const base::Version& version);
// Virtual for testing. // Virtual for testing.
virtual void OnSortingLshClustersFileReady(const base::FilePath& file_path); virtual void ApplySortingLsh(const FlocId& raw_floc_id,
ApplySortingLshCallback callback);
void SetBackgroundTaskRunnerForTesting(
scoped_refptr<base::SequencedTaskRunner> background_task_runner);
private: private:
friend class FlocSortingLshClustersServiceTest; friend class FlocSortingLshClustersServiceTest;
void DidApplySortingLsh(ApplySortingLshCallback callback,
base::Version version,
FlocId floc_id);
// Runner for tasks that do not influence user experience. // Runner for tasks that do not influence user experience.
scoped_refptr<base::SequencedTaskRunner> background_task_runner_; scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
base::ObserverList<Observer>::Unchecked observers_; base::ObserverList<Observer>::Unchecked observers_;
base::Optional<base::FilePath> sorting_lsh_clusters_file_path_; bool first_file_ready_seen_ = false;
base::FilePath sorting_lsh_clusters_file_path_;
base::Version sorting_lsh_clusters_version_;
base::WeakPtrFactory<FlocSortingLshClustersService> weak_ptr_factory_; base::WeakPtrFactory<FlocSortingLshClustersService> weak_ptr_factory_;
}; };
......
...@@ -23,6 +23,8 @@ namespace federated_learning { ...@@ -23,6 +23,8 @@ namespace federated_learning {
namespace { namespace {
base::Version kDummyVersion = base::Version("1.2.3");
class CopyingFileOutputStream class CopyingFileOutputStream
: public google::protobuf::io::CopyingOutputStream { : public google::protobuf::io::CopyingOutputStream {
public: public:
...@@ -94,7 +96,7 @@ class FlocSortingLshClustersServiceTest : public ::testing::Test { ...@@ -94,7 +96,7 @@ class FlocSortingLshClustersServiceTest : public ::testing::Test {
const std::vector<uint32_t>& sorting_lsh_clusters) { const std::vector<uint32_t>& sorting_lsh_clusters) {
base::FilePath file_path = base::FilePath file_path =
CreateTestSortingLshClustersFile(sorting_lsh_clusters); CreateTestSortingLshClustersFile(sorting_lsh_clusters);
service()->OnSortingLshClustersFileReady(file_path); service()->OnSortingLshClustersFileReady(file_path, kDummyVersion);
EXPECT_TRUE(sorting_lsh_clusters_file_path().has_value()); EXPECT_TRUE(sorting_lsh_clusters_file_path().has_value());
return file_path; return file_path;
} }
...@@ -103,7 +105,10 @@ class FlocSortingLshClustersServiceTest : public ::testing::Test { ...@@ -103,7 +105,10 @@ class FlocSortingLshClustersServiceTest : public ::testing::Test {
FlocSortingLshClustersService* service() { return service_.get(); } FlocSortingLshClustersService* service() { return service_.get(); }
const base::Optional<base::FilePath>& sorting_lsh_clusters_file_path() { base::Optional<base::FilePath> sorting_lsh_clusters_file_path() {
if (!service()->first_file_ready_seen_)
return base::nullopt;
return service()->sorting_lsh_clusters_file_path_; return service()->sorting_lsh_clusters_file_path_;
} }
...@@ -111,10 +116,11 @@ class FlocSortingLshClustersServiceTest : public ::testing::Test { ...@@ -111,10 +116,11 @@ class FlocSortingLshClustersServiceTest : public ::testing::Test {
FlocId result; FlocId result;
base::RunLoop run_loop; base::RunLoop run_loop;
auto cb = base::BindLambdaForTesting([&](FlocId floc_id) { auto cb = base::BindLambdaForTesting(
result = floc_id; [&](FlocId floc_id, base::Optional<base::Version> version) {
run_loop.Quit(); result = floc_id;
}); run_loop.Quit();
});
service()->ApplySortingLsh(floc_id, std::move(cb)); service()->ApplySortingLsh(floc_id, std::move(cb));
background_task_runner_->RunPendingTasks(); background_task_runner_->RunPendingTasks();
...@@ -216,11 +222,12 @@ TEST_F(FlocSortingLshClustersServiceTest, ...@@ -216,11 +222,12 @@ TEST_F(FlocSortingLshClustersServiceTest,
base::FilePath file_path = InitializeSortingLshClustersFile({0}); base::FilePath file_path = InitializeSortingLshClustersFile({0});
base::RunLoop run_loop; base::RunLoop run_loop;
auto cb = base::BindLambdaForTesting([&](FlocId floc_id) { auto cb = base::BindLambdaForTesting(
// Since the file has been deleted, expect an invalid floc id. [&](FlocId floc_id, base::Optional<base::Version> version) {
EXPECT_EQ(FlocId(), floc_id); // Since the file has been deleted, expect an invalid floc id.
run_loop.Quit(); EXPECT_EQ(FlocId(), floc_id);
}); run_loop.Quit();
});
service()->ApplySortingLsh(FlocId(0), std::move(cb)); service()->ApplySortingLsh(FlocId(0), std::move(cb));
base::DeleteFile(file_path); base::DeleteFile(file_path);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment