Commit a927c079 authored by Yao Xiao's avatar Yao Xiao Committed by Commit Bot

Avoid recalculating floc on history-delete. Will either invalidate or no-op.

Why: we don't want floc to change more often than the intended
scheduled update cadence.

What:
- For each history-delete notification, if it's all-history or
the time-range overlaps with the time range of the history used to
compute the floc, we invalidate the floc. Otherwise, we keep using the
current floc.
- Remove the mechanism caching the swaa_nac_account_enabled status.
It's no longer needed as we don't need to query more often than the
scheduled update rate (% rare race condition).

How: Store history_begin_time_/history_end_time_ fields to FlocId.
Compare these fields with the history delete info.

Bug: 1143597
Change-Id: I685fd1a10cbc8044e799d7a644dfa2b7bb82ddd4
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2507454
Commit-Queue: Yao Xiao <yaoxia@chromium.org>
Reviewed-by: default avatarJosh Karlin <jkarlin@chromium.org>
Cr-Commit-Position: refs/heads/master@{#826680}
parent 1a5301a5
...@@ -214,6 +214,38 @@ class FlocIdProviderWithCustomizedServicesBrowserTest ...@@ -214,6 +214,38 @@ class FlocIdProviderWithCustomizedServicesBrowserTest
return enumerator.urls(); return enumerator.urls();
} }
std::pair<base::Time, base::Time> GetHistoryTimeRange() {
history::QueryOptions options;
options.duplicate_policy = history::QueryOptions::KEEP_ALL_DUPLICATES;
base::Time history_begin_time = base::Time::Max();
base::Time history_end_time = base::Time::Min();
base::RunLoop run_loop;
base::CancelableTaskTracker tracker;
HistoryServiceFactory::GetForProfile(browser()->profile(),
ServiceAccessType::EXPLICIT_ACCESS)
->QueryHistory(
base::string16(), options,
base::BindLambdaForTesting([&](history::QueryResults results) {
for (const history::URLResult& url_result : results) {
if (!url_result.publicly_routable())
continue;
if (url_result.visit_time() < history_begin_time)
history_begin_time = url_result.visit_time();
if (url_result.visit_time() > history_end_time)
history_end_time = url_result.visit_time();
}
run_loop.Quit();
}),
&tracker);
run_loop.Run();
return {history_begin_time, history_end_time};
}
void FinishOutstandingRemotePermissionQueries() { void FinishOutstandingRemotePermissionQueries() {
base::RunLoop run_loop; base::RunLoop run_loop;
FlocRemotePermissionServiceFactory::GetForProfile(browser()->profile()) FlocRemotePermissionServiceFactory::GetForProfile(browser()->profile())
...@@ -241,7 +273,8 @@ class FlocIdProviderWithCustomizedServicesBrowserTest ...@@ -241,7 +273,8 @@ class FlocIdProviderWithCustomizedServicesBrowserTest
const uint64_t dummy_sim_hash = 0u; const uint64_t dummy_sim_hash = 0u;
g_browser_process->floc_sorting_lsh_clusters_service()->ApplySortingLsh( g_browser_process->floc_sorting_lsh_clusters_service()->ApplySortingLsh(
dummy_sim_hash, dummy_sim_hash,
base::BindLambdaForTesting([&](FlocId floc) { run_loop.Quit(); })); base::BindLambdaForTesting(
[&](base::Optional<uint64_t>, base::Version) { run_loop.Quit(); }));
run_loop.Run(); run_loop.Run();
} }
...@@ -250,9 +283,10 @@ class FlocIdProviderWithCustomizedServicesBrowserTest ...@@ -250,9 +283,10 @@ class FlocIdProviderWithCustomizedServicesBrowserTest
base::CancelableTaskTracker tracker; base::CancelableTaskTracker tracker;
HistoryServiceFactory::GetForProfile(browser()->profile(), HistoryServiceFactory::GetForProfile(browser()->profile(),
ServiceAccessType::EXPLICIT_ACCESS) ServiceAccessType::EXPLICIT_ACCESS)
->ExpireHistoryBeforeForTesting( ->ExpireHistoryBetween(
end_time, base::BindLambdaForTesting([&]() { run_loop.Quit(); }), /*restrict_urls=*/{}, /*begin_time=*/base::Time(), end_time,
&tracker); /*user_initiated=*/true,
base::BindLambdaForTesting([&]() { run_loop.Quit(); }), &tracker);
run_loop.Run(); run_loop.Run();
} }
...@@ -528,8 +562,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest, ...@@ -528,8 +562,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest,
InitializeHistorySync(); InitializeHistorySync();
// Promise resolved with the expected floc value. // Promise resolved with the expected string.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), 0).ToString(), EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), base::Time(),
base::Time(), 0)
.ToStringForJsApi(),
InvokeInterestCohortJsApi(web_contents())); InvokeInterestCohortJsApi(web_contents()));
} }
...@@ -555,8 +591,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest, ...@@ -555,8 +591,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest,
content::RenderFrameHost* child = content::RenderFrameHost* child =
content::ChildFrameAt(web_contents()->GetMainFrame(), 0); content::ChildFrameAt(web_contents()->GetMainFrame(), 0);
// Promise resolved with the expected floc value. // Promise resolved with the expected string.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), 0).ToString(), EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), base::Time(),
base::Time(), 0)
.ToStringForJsApi(),
InvokeInterestCohortJsApi(child)); InvokeInterestCohortJsApi(child));
} }
...@@ -582,8 +620,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest, ...@@ -582,8 +620,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest,
content::RenderFrameHost* child = content::RenderFrameHost* child =
content::ChildFrameAt(web_contents()->GetMainFrame(), 0); content::ChildFrameAt(web_contents()->GetMainFrame(), 0);
// Promise resolved with the expected floc value. // Promise resolved with the expected string.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), 0).ToString(), EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), base::Time(),
base::Time(), 0)
.ToStringForJsApi(),
InvokeInterestCohortJsApi(child)); InvokeInterestCohortJsApi(child));
} }
...@@ -618,8 +658,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest, ...@@ -618,8 +658,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest,
// Promise rejected as the cookies permission disallows the child's host. // Promise rejected as the cookies permission disallows the child's host.
EXPECT_EQ("rejected", InvokeInterestCohortJsApi(child)); EXPECT_EQ("rejected", InvokeInterestCohortJsApi(child));
// Promise resolved with the expected floc value. // Promise resolved with the expected string.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), 0).ToString(), EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), base::Time(),
base::Time(), 0)
.ToStringForJsApi(),
InvokeInterestCohortJsApi(web_contents())); InvokeInterestCohortJsApi(web_contents()));
} }
...@@ -646,6 +688,9 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest, ...@@ -646,6 +688,9 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
browser(), https_server_.GetURL(test_host(), cookies_to_set)); browser(), https_server_.GetURL(test_host(), cookies_to_set));
EXPECT_EQ(1u, GetHistoryUrls().size()); EXPECT_EQ(1u, GetHistoryUrls().size());
auto p = GetHistoryTimeRange();
base::Time history_begin_time = p.first;
base::Time history_end_time = p.second;
EXPECT_FALSE(GetFlocId().IsValid()); EXPECT_FALSE(GetFlocId().IsValid());
...@@ -659,7 +704,7 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest, ...@@ -659,7 +704,7 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
EXPECT_NE(0u, FlocId::SimHashHistory({test_host()})); EXPECT_NE(0u, FlocId::SimHashHistory({test_host()}));
// Expect that the final id is 0 because the sorting-lsh was applied. // Expect that the final id is 0 because the sorting-lsh was applied.
EXPECT_EQ(FlocId(0, 9), GetFlocId()); EXPECT_EQ(FlocId(0, history_begin_time, history_end_time, 9), GetFlocId());
} }
IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest, IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
......
...@@ -29,8 +29,6 @@ constexpr size_t kMinHistoryDomainSizeToReportFlocId = 1; ...@@ -29,8 +29,6 @@ constexpr size_t kMinHistoryDomainSizeToReportFlocId = 1;
constexpr base::TimeDelta kFlocScheduledUpdateInterval = constexpr base::TimeDelta kFlocScheduledUpdateInterval =
base::TimeDelta::FromDays(1); base::TimeDelta::FromDays(1);
constexpr int kQueryHistoryWindowInDays = 7; constexpr int kQueryHistoryWindowInDays = 7;
constexpr base::TimeDelta kSwaaNacAccountEnabledCachePeriod =
base::TimeDelta::FromHours(12);
// The placeholder sorting-lsh version when the sorting-lsh feature is disabled. // The placeholder sorting-lsh version when the sorting-lsh feature is disabled.
constexpr uint32_t kSortingLshVersionPlaceholder = 0; constexpr uint32_t kSortingLshVersionPlaceholder = 0;
...@@ -81,7 +79,7 @@ std::string FlocIdProviderImpl::GetInterestCohortForJsApi( ...@@ -81,7 +79,7 @@ std::string FlocIdProviderImpl::GetInterestCohortForJsApi(
if (!floc_id_.IsValid()) if (!floc_id_.IsValid())
return std::string(); return std::string();
return floc_id_.ToString(); return floc_id_.ToStringForJsApi();
} }
void FlocIdProviderImpl::OnComputeFlocCompleted(ComputeFlocTrigger trigger, void FlocIdProviderImpl::OnComputeFlocCompleted(ComputeFlocTrigger trigger,
...@@ -89,12 +87,11 @@ void FlocIdProviderImpl::OnComputeFlocCompleted(ComputeFlocTrigger trigger, ...@@ -89,12 +87,11 @@ void FlocIdProviderImpl::OnComputeFlocCompleted(ComputeFlocTrigger trigger,
DCHECK(floc_computation_in_progress_); DCHECK(floc_computation_in_progress_);
floc_computation_in_progress_ = false; floc_computation_in_progress_ = false;
// Some recompute event came in when this computation was in progress. Ignore // History-delete event came in when this computation was in progress. Ignore
// this computation completely. Handle the pending one. // this computation completely and recompute.
if (pending_recompute_event_) { if (need_recompute_) {
ComputeFlocTrigger recompute_trigger = pending_recompute_event_.value(); need_recompute_ = false;
pending_recompute_event_.reset(); ComputeFloc(trigger);
ComputeFloc(recompute_trigger);
return; return;
} }
...@@ -167,18 +164,34 @@ void FlocIdProviderImpl::Shutdown() { ...@@ -167,18 +164,34 @@ void FlocIdProviderImpl::Shutdown() {
void FlocIdProviderImpl::OnURLsDeleted( void FlocIdProviderImpl::OnURLsDeleted(
history::HistoryService* history_service, history::HistoryService* history_service,
const history::DeletionInfo& deletion_info) { const history::DeletionInfo& deletion_info) {
// Set a pending event or override the existing one, that will get run when // Set the |need_recompute_| flag so that we will recompute the floc
// the in-progress computation finishes. // immediately after the in-progress one finishes, so as to avoid potential
// data races.
if (floc_computation_in_progress_) { if (floc_computation_in_progress_) {
DCHECK(first_floc_computation_triggered_); DCHECK(first_floc_computation_triggered_);
pending_recompute_event_ = ComputeFlocTrigger::kHistoryDelete; need_recompute_ = true;
return; return;
} }
if (!first_floc_computation_triggered_ || !floc_id_.IsValid()) if (!floc_id_.IsValid())
return;
// Only invalidate the floc if it's delete-all or if the time range overlaps
// with the time range of the history used to compute the current floc.
if (!deletion_info.IsAllHistory() && !deletion_info.time_range().IsValid()) {
return; return;
}
if (deletion_info.time_range().begin() > floc_id_.history_end_time() ||
deletion_info.time_range().end() < floc_id_.history_begin_time()) {
return;
}
ComputeFloc(ComputeFlocTrigger::kHistoryDelete); // We log the invalidation event although it's technically not a recompute.
// It'd give us a better idea how often the floc is invalidated due to
// history-delete.
LogFlocComputedEvent(ComputeFlocTrigger::kHistoryDelete, ComputeFlocResult());
floc_id_ = FlocId();
} }
void FlocIdProviderImpl::OnSortingLshClustersFileReady() { void FlocIdProviderImpl::OnSortingLshClustersFileReady() {
...@@ -218,21 +231,15 @@ void FlocIdProviderImpl::MaybeTriggerFirstFlocComputation() { ...@@ -218,21 +231,15 @@ void FlocIdProviderImpl::MaybeTriggerFirstFlocComputation() {
} }
void FlocIdProviderImpl::OnComputeFlocScheduledUpdate() { void FlocIdProviderImpl::OnComputeFlocScheduledUpdate() {
// It's fine to skip the scheduled update as long as there's one in progress. DCHECK(!floc_computation_in_progress_);
// We won't be losing the recomputing frequency, as the in-progress one only
// occurs sooner and when it finishes a new compute-floc task will be
// scheduled.
if (floc_computation_in_progress_)
return;
DCHECK(!pending_recompute_event_);
ComputeFloc(ComputeFlocTrigger::kScheduledUpdate); ComputeFloc(ComputeFlocTrigger::kScheduledUpdate);
} }
void FlocIdProviderImpl::ComputeFloc(ComputeFlocTrigger trigger) { void FlocIdProviderImpl::ComputeFloc(ComputeFlocTrigger trigger) {
DCHECK_NE(trigger == ComputeFlocTrigger::kBrowserStart, DCHECK(trigger == ComputeFlocTrigger::kBrowserStart ||
first_floc_computation_triggered_); (trigger == ComputeFlocTrigger::kScheduledUpdate &&
first_floc_computation_triggered_));
DCHECK(!floc_computation_in_progress_); DCHECK(!floc_computation_in_progress_);
floc_computation_in_progress_ = true; floc_computation_in_progress_ = true;
...@@ -285,14 +292,6 @@ bool FlocIdProviderImpl::AreThirdPartyCookiesAllowed() const { ...@@ -285,14 +292,6 @@ bool FlocIdProviderImpl::AreThirdPartyCookiesAllowed() const {
void FlocIdProviderImpl::IsSwaaNacAccountEnabled( void FlocIdProviderImpl::IsSwaaNacAccountEnabled(
CanComputeFlocCallback callback) { CanComputeFlocCallback callback) {
if (!last_swaa_nac_account_enabled_query_time_.is_null() &&
last_swaa_nac_account_enabled_query_time_ +
kSwaaNacAccountEnabledCachePeriod >
base::TimeTicks::Now()) {
std::move(callback).Run(cached_swaa_nac_account_enabled_);
return;
}
net::PartialNetworkTrafficAnnotationTag partial_traffic_annotation = net::PartialNetworkTrafficAnnotationTag partial_traffic_annotation =
net::DefinePartialNetworkTrafficAnnotation( net::DefinePartialNetworkTrafficAnnotation(
"floc_id_provider_impl", "floc_remote_permission_service", "floc_id_provider_impl", "floc_remote_permission_service",
...@@ -323,17 +322,7 @@ void FlocIdProviderImpl::IsSwaaNacAccountEnabled( ...@@ -323,17 +322,7 @@ void FlocIdProviderImpl::IsSwaaNacAccountEnabled(
})"); })");
floc_remote_permission_service_->QueryFlocPermission( floc_remote_permission_service_->QueryFlocPermission(
base::BindOnce(&FlocIdProviderImpl::OnCheckSwaaNacAccountEnabledCompleted, std::move(callback), partial_traffic_annotation);
weak_ptr_factory_.GetWeakPtr(), std::move(callback)),
partial_traffic_annotation);
}
void FlocIdProviderImpl::OnCheckSwaaNacAccountEnabledCompleted(
CanComputeFlocCallback callback,
bool enabled) {
cached_swaa_nac_account_enabled_ = enabled;
last_swaa_nac_account_enabled_query_time_ = base::TimeTicks::Now();
std::move(callback).Run(enabled);
} }
void FlocIdProviderImpl::GetRecentlyVisitedURLs( void FlocIdProviderImpl::GetRecentlyVisitedURLs(
...@@ -350,10 +339,20 @@ void FlocIdProviderImpl::OnGetRecentlyVisitedURLsCompleted( ...@@ -350,10 +339,20 @@ void FlocIdProviderImpl::OnGetRecentlyVisitedURLsCompleted(
ComputeFlocCompletedCallback callback, ComputeFlocCompletedCallback callback,
history::QueryResults results) { history::QueryResults results) {
std::unordered_set<std::string> domains; std::unordered_set<std::string> domains;
base::Time history_begin_time = base::Time::Max();
base::Time history_end_time = base::Time::Min();
for (const history::URLResult& url_result : results) { for (const history::URLResult& url_result : results) {
if (!url_result.publicly_routable()) if (!url_result.publicly_routable())
continue; continue;
if (url_result.visit_time() < history_begin_time)
history_begin_time = url_result.visit_time();
if (url_result.visit_time() > history_end_time)
history_end_time = url_result.visit_time();
domains.insert(net::registry_controlled_domains::GetDomainAndRegistry( domains.insert(net::registry_controlled_domains::GetDomainAndRegistry(
url_result.url(), url_result.url(),
net::registry_controlled_domains::INCLUDE_PRIVATE_REGISTRIES)); net::registry_controlled_domains::INCLUDE_PRIVATE_REGISTRIES));
...@@ -365,16 +364,20 @@ void FlocIdProviderImpl::OnGetRecentlyVisitedURLsCompleted( ...@@ -365,16 +364,20 @@ void FlocIdProviderImpl::OnGetRecentlyVisitedURLsCompleted(
} }
ApplySortingLshPostProcessing(std::move(callback), ApplySortingLshPostProcessing(std::move(callback),
FlocId::SimHashHistory(domains)); FlocId::SimHashHistory(domains),
history_begin_time, history_end_time);
} }
void FlocIdProviderImpl::ApplySortingLshPostProcessing( void FlocIdProviderImpl::ApplySortingLshPostProcessing(
ComputeFlocCompletedCallback callback, ComputeFlocCompletedCallback callback,
uint64_t sim_hash) { uint64_t sim_hash,
base::Time history_begin_time,
base::Time history_end_time) {
if (!base::FeatureList::IsEnabled( if (!base::FeatureList::IsEnabled(
features::kFlocIdSortingLshBasedComputation)) { features::kFlocIdSortingLshBasedComputation)) {
std::move(callback).Run(ComputeFlocResult( std::move(callback).Run(ComputeFlocResult(
sim_hash, FlocId(sim_hash, kSortingLshVersionPlaceholder))); sim_hash, FlocId(sim_hash, history_begin_time, history_end_time,
kSortingLshVersionPlaceholder)));
return; return;
} }
...@@ -382,14 +385,24 @@ void FlocIdProviderImpl::ApplySortingLshPostProcessing( ...@@ -382,14 +385,24 @@ void FlocIdProviderImpl::ApplySortingLshPostProcessing(
sim_hash, sim_hash,
base::BindOnce(&FlocIdProviderImpl::DidApplySortingLshPostProcessing, base::BindOnce(&FlocIdProviderImpl::DidApplySortingLshPostProcessing,
weak_ptr_factory_.GetWeakPtr(), std::move(callback), weak_ptr_factory_.GetWeakPtr(), std::move(callback),
sim_hash)); sim_hash, history_begin_time, history_end_time));
} }
void FlocIdProviderImpl::DidApplySortingLshPostProcessing( void FlocIdProviderImpl::DidApplySortingLshPostProcessing(
ComputeFlocCompletedCallback callback, ComputeFlocCompletedCallback callback,
uint64_t sim_hash, uint64_t sim_hash,
FlocId floc_id) { base::Time history_begin_time,
std::move(callback).Run(ComputeFlocResult(sim_hash, std::move(floc_id))); base::Time history_end_time,
base::Optional<uint64_t> final_hash,
base::Version version) {
if (!final_hash) {
std::move(callback).Run(ComputeFlocResult(sim_hash, FlocId()));
return;
}
std::move(callback).Run(ComputeFlocResult(
sim_hash, FlocId(final_hash.value(), history_begin_time, history_end_time,
version.components().front())));
} }
} // namespace federated_learning } // namespace federated_learning
...@@ -42,8 +42,9 @@ class FlocRemotePermissionService; ...@@ -42,8 +42,9 @@ class FlocRemotePermissionService;
// //
// The floc will be first computed after sync & sync-history are enabled. After // The floc will be first computed after sync & sync-history are enabled. After
// each computation, another computation will be scheduled 24 hours later. In // each computation, another computation will be scheduled 24 hours later. In
// the event of history deletion, the floc will be recomputed immediately and // the event of history deletion, the floc will be invalidated immediately
// reset the timer of any currently scheduled computation to be 24 hours later. // if the time range of the deletion overlaps with the time range used to
// compute the existing floc.
class FlocIdProviderImpl : public FlocIdProvider, class FlocIdProviderImpl : public FlocIdProvider,
public FlocSortingLshClustersService::Observer, public FlocSortingLshClustersService::Observer,
public history::HistoryServiceObserver, public history::HistoryServiceObserver,
...@@ -134,8 +135,6 @@ class FlocIdProviderImpl : public FlocIdProvider, ...@@ -134,8 +135,6 @@ class FlocIdProviderImpl : public FlocIdProvider,
bool AreThirdPartyCookiesAllowed() const; bool AreThirdPartyCookiesAllowed() const;
void IsSwaaNacAccountEnabled(CanComputeFlocCallback callback); void IsSwaaNacAccountEnabled(CanComputeFlocCallback callback);
void OnCheckSwaaNacAccountEnabledCompleted(CanComputeFlocCallback callback,
bool enabled);
void GetRecentlyVisitedURLs(GetRecentlyVisitedURLsCallback callback); void GetRecentlyVisitedURLs(GetRecentlyVisitedURLsCallback callback);
void OnGetRecentlyVisitedURLsCompleted(ComputeFlocCompletedCallback callback, void OnGetRecentlyVisitedURLsCompleted(ComputeFlocCompletedCallback callback,
...@@ -145,10 +144,15 @@ class FlocIdProviderImpl : public FlocIdProvider, ...@@ -145,10 +144,15 @@ class FlocIdProviderImpl : public FlocIdProvider,
// The final floc may be invalid if the file is corrupted or the floc end up // The final floc may be invalid if the file is corrupted or the floc end up
// being blocked. // being blocked.
void ApplySortingLshPostProcessing(ComputeFlocCompletedCallback callback, void ApplySortingLshPostProcessing(ComputeFlocCompletedCallback callback,
uint64_t sim_hash); uint64_t sim_hash,
base::Time history_begin_time,
base::Time history_end_time);
void DidApplySortingLshPostProcessing(ComputeFlocCompletedCallback callback, void DidApplySortingLshPostProcessing(ComputeFlocCompletedCallback callback,
uint64_t sim_hash, uint64_t sim_hash,
FlocId floc_id); base::Time history_begin_time,
base::Time history_end_time,
base::Optional<uint64_t> final_hash,
base::Version version);
// The id to be exposed to the JS API. // The id to be exposed to the JS API.
FlocId floc_id_; FlocId floc_id_;
...@@ -156,19 +160,17 @@ class FlocIdProviderImpl : public FlocIdProvider, ...@@ -156,19 +160,17 @@ class FlocIdProviderImpl : public FlocIdProvider,
bool floc_computation_in_progress_ = false; bool floc_computation_in_progress_ = false;
bool first_floc_computation_triggered_ = false; bool first_floc_computation_triggered_ = false;
// We store a pending event if it arrives during an in-progress computation. // True if history-delete occurs during an in-progress computation. When the
// When the in-progress one finishes, we would disregard the result (no // in-progress one finishes, we would disregard the result (i.e. no loggings
// loggings, updates, etc.), and compute again. // or floc update), and compute again. Potentially we could maintain extra
base::Optional<ComputeFlocTrigger> pending_recompute_event_; // states to tell if the history-delete would have impact on the in-progress
// result, but since this would only happen in rare race situations, we just
// always recompute to keep things simple.
bool need_recompute_ = false;
bool first_sorting_lsh_file_ready_seen_ = false; bool first_sorting_lsh_file_ready_seen_ = false;
bool first_sync_history_enabled_seen_ = false; bool first_sync_history_enabled_seen_ = false;
// For the swaa/nac/account_type permission, we will use a cached status to
// avoid querying too often.
bool cached_swaa_nac_account_enabled_ = false;
base::TimeTicks last_swaa_nac_account_enabled_query_time_;
syncer::SyncService* sync_service_; syncer::SyncService* sync_service_;
scoped_refptr<content_settings::CookieSettings> cookie_settings_; scoped_refptr<content_settings::CookieSettings> cookie_settings_;
FlocRemotePermissionService* floc_remote_permission_service_; FlocRemotePermissionService* floc_remote_permission_service_;
......
...@@ -26,8 +26,14 @@ uint64_t FlocId::SimHashHistory( ...@@ -26,8 +26,14 @@ uint64_t FlocId::SimHashHistory(
FlocId::FlocId() = default; FlocId::FlocId() = default;
FlocId::FlocId(uint64_t id, uint32_t sorting_lsh_version) FlocId::FlocId(uint64_t id,
: id_(id), sorting_lsh_version_(sorting_lsh_version) {} base::Time history_begin_time,
base::Time history_end_time,
uint32_t sorting_lsh_version)
: id_(id),
history_begin_time_(history_begin_time),
history_end_time_(history_end_time),
sorting_lsh_version_(sorting_lsh_version) {}
FlocId::FlocId(const FlocId& id) = default; FlocId::FlocId(const FlocId& id) = default;
...@@ -42,19 +48,16 @@ bool FlocId::IsValid() const { ...@@ -42,19 +48,16 @@ bool FlocId::IsValid() const {
} }
bool FlocId::operator==(const FlocId& other) const { bool FlocId::operator==(const FlocId& other) const {
return id_ == other.id_ && sorting_lsh_version_ == other.sorting_lsh_version_; return id_ == other.id_ && history_begin_time_ == other.history_begin_time_ &&
history_end_time_ == other.history_end_time_ &&
sorting_lsh_version_ == other.sorting_lsh_version_;
} }
bool FlocId::operator!=(const FlocId& other) const { bool FlocId::operator!=(const FlocId& other) const {
return !(*this == other); return !(*this == other);
} }
uint64_t FlocId::ToUint64() const { std::string FlocId::ToStringForJsApi() const {
DCHECK(id_.has_value());
return id_.value();
}
std::string FlocId::ToString() const {
DCHECK(id_.has_value()); DCHECK(id_.has_value());
return base::StrCat({base::NumberToString(id_.value()), ".", return base::StrCat({base::NumberToString(id_.value()), ".",
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#define COMPONENTS_FEDERATED_LEARNING_FLOC_ID_H_ #define COMPONENTS_FEDERATED_LEARNING_FLOC_ID_H_
#include "base/optional.h" #include "base/optional.h"
#include "base/time/time.h"
#include "base/version.h" #include "base/version.h"
#include <stdint.h> #include <stdint.h>
...@@ -24,9 +25,13 @@ class FlocId { ...@@ -24,9 +25,13 @@ class FlocId {
const std::unordered_set<std::string>& domains); const std::unordered_set<std::string>& domains);
FlocId(); FlocId();
explicit FlocId(uint64_t id, uint32_t sorting_lsh_version);
FlocId(const FlocId& id);
explicit FlocId(uint64_t id,
base::Time history_begin_time,
base::Time history_end_time,
uint32_t sorting_lsh_version);
FlocId(const FlocId& id);
~FlocId(); ~FlocId();
FlocId& operator=(const FlocId& id); FlocId& operator=(const FlocId& id);
FlocId& operator=(FlocId&& id); FlocId& operator=(FlocId&& id);
...@@ -35,16 +40,24 @@ class FlocId { ...@@ -35,16 +40,24 @@ class FlocId {
bool operator!=(const FlocId& other) const; bool operator!=(const FlocId& other) const;
bool IsValid() const; bool IsValid() const;
uint64_t ToUint64() const;
// The id, followed by the chrome floc version, followed by the async floc // The id, followed by the chrome floc version, followed by the async floc
// component versions (i.e. model and sorting-lsh). This is the format to be // component versions (i.e. model and sorting-lsh). This is the format to be
// exposed to the JS API. Precondition: |id_| must be valid. // exposed to the JS API. Precondition: |id_| must be valid.
std::string ToString() const; std::string ToStringForJsApi() const;
base::Time history_begin_time() const { return history_begin_time_; }
base::Time history_end_time() const { return history_end_time_; }
private: private:
base::Optional<uint64_t> id_; base::Optional<uint64_t> id_;
// The time range of the actual history used to compute the floc. This should
// always be within the time range of each history query.
base::Time history_begin_time_;
base::Time history_end_time_;
// The main version (i.e. 1st int) of the sorting lsh component version. // The main version (i.e. 1st int) of the sorting lsh component version.
uint32_t sorting_lsh_version_ = 0; uint32_t sorting_lsh_version_ = 0;
}; };
......
...@@ -8,32 +8,33 @@ ...@@ -8,32 +8,33 @@
namespace federated_learning { namespace federated_learning {
const base::Time kTime0 = base::Time();
const base::Time kTime1 = base::Time::FromTimeT(1);
const base::Time kTime2 = base::Time::FromTimeT(2);
TEST(FlocIdTest, IsValid) { TEST(FlocIdTest, IsValid) {
EXPECT_FALSE(FlocId().IsValid()); EXPECT_FALSE(FlocId().IsValid());
EXPECT_TRUE(FlocId(0, 0).IsValid()); EXPECT_TRUE(FlocId(0, kTime0, kTime0, 0).IsValid());
EXPECT_TRUE(FlocId(0, 1).IsValid()); EXPECT_TRUE(FlocId(0, kTime1, kTime2, 1).IsValid());
}
TEST(FlocIdTest, ToUint64) {
EXPECT_EQ(0u, FlocId(0, 0).ToUint64());
EXPECT_EQ(1u, FlocId(1, 0).ToUint64());
EXPECT_EQ(1u, FlocId(1, 1).ToUint64());
} }
TEST(FlocIdTest, Comparison) { TEST(FlocIdTest, Comparison) {
EXPECT_EQ(FlocId(), FlocId()); EXPECT_EQ(FlocId(), FlocId());
EXPECT_EQ(FlocId(0, 0), FlocId(0, 0));
EXPECT_EQ(FlocId(0, 1), FlocId(0, 1));
EXPECT_NE(FlocId(), FlocId(0, 0)); EXPECT_EQ(FlocId(0, kTime0, kTime0, 0), FlocId(0, kTime0, kTime0, 0));
EXPECT_NE(FlocId(0, 0), FlocId(1, 0)); EXPECT_EQ(FlocId(0, kTime1, kTime1, 1), FlocId(0, kTime1, kTime1, 1));
EXPECT_NE(FlocId(0, 0), FlocId(0, 1)); EXPECT_EQ(FlocId(0, kTime1, kTime2, 1), FlocId(0, kTime1, kTime2, 1));
EXPECT_NE(FlocId(), FlocId(0, kTime0, kTime0, 0));
EXPECT_NE(FlocId(0, kTime0, kTime0, 0), FlocId(1, kTime0, kTime0, 0));
EXPECT_NE(FlocId(0, kTime0, kTime1, 0), FlocId(0, kTime1, kTime1, 0));
EXPECT_NE(FlocId(0, kTime0, kTime0, 0), FlocId(0, kTime0, kTime0, 1));
} }
TEST(FlocIdTest, ToString) { TEST(FlocIdTest, ToStringForJsApi) {
EXPECT_EQ("0.1.0", FlocId(0, 0).ToString()); EXPECT_EQ("0.1.0", FlocId(0, kTime0, kTime0, 0).ToStringForJsApi());
EXPECT_EQ("12345.1.0", FlocId(12345, 0).ToString()); EXPECT_EQ("12345.1.0", FlocId(12345, kTime0, kTime0, 0).ToStringForJsApi());
EXPECT_EQ("12345.1.2", FlocId(12345, 2).ToString()); EXPECT_EQ("12345.1.2", FlocId(12345, kTime1, kTime1, 2).ToStringForJsApi());
} }
} // namespace federated_learning } // namespace federated_learning
...@@ -36,13 +36,13 @@ class CopyingFileInputStream : public google::protobuf::io::CopyingInputStream { ...@@ -36,13 +36,13 @@ class CopyingFileInputStream : public google::protobuf::io::CopyingInputStream {
base::File file_; base::File file_;
}; };
FlocId ApplySortingLshOnBackgroundThread(uint64_t sim_hash, base::Optional<uint64_t> ApplySortingLshOnBackgroundThread(
const base::FilePath& file_path, uint64_t sim_hash,
const base::Version& version) { const base::FilePath& file_path) {
base::File sorting_lsh_clusters_file( base::File sorting_lsh_clusters_file(
file_path, base::File::FLAG_OPEN | base::File::FLAG_READ); file_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
if (!sorting_lsh_clusters_file.IsValid()) if (!sorting_lsh_clusters_file.IsValid())
return FlocId(); return base::nullopt;
CopyingFileInputStream copying_stream(std::move(sorting_lsh_clusters_file)); CopyingFileInputStream copying_stream(std::move(sorting_lsh_clusters_file));
google::protobuf::io::CopyingInputStreamAdaptor zero_copy_stream_adaptor( google::protobuf::io::CopyingInputStreamAdaptor zero_copy_stream_adaptor(
...@@ -88,33 +88,33 @@ FlocId ApplySortingLshOnBackgroundThread(uint64_t sim_hash, ...@@ -88,33 +88,33 @@ FlocId ApplySortingLshOnBackgroundThread(uint64_t sim_hash,
for (uint64_t index = 0; input_stream.ReadVarint32(&next_combined); ++index) { for (uint64_t index = 0; input_stream.ReadVarint32(&next_combined); ++index) {
// Sanitizing error: the entry used more than |kSortingLshMaxBits| bits. // Sanitizing error: the entry used more than |kSortingLshMaxBits| bits.
if ((next_combined >> kSortingLshMaxBits) > 0) if ((next_combined >> kSortingLshMaxBits) > 0)
return FlocId(); return base::nullopt;
bool is_blocked = next_combined & kSortingLshBlockedMask; bool is_blocked = next_combined & kSortingLshBlockedMask;
uint32_t next = next_combined & kSortingLshSizeMask; uint32_t next = next_combined & kSortingLshSizeMask;
// Sanitizing error // Sanitizing error
if (next > kMaxNumberOfBitsInFloc) if (next > kMaxNumberOfBitsInFloc)
return FlocId(); return base::nullopt;
cumulative_sum += (1ULL << next); cumulative_sum += (1ULL << next);
// Sanitizing error // Sanitizing error
if (cumulative_sum > kExpectedFinalCumulativeSum) if (cumulative_sum > kExpectedFinalCumulativeSum)
return FlocId(); return base::nullopt;
// Found the sim-hash upper bound. Use the index as the new floc. // Found the sim-hash upper bound. Use the index as the new floc.
if (cumulative_sum > sim_hash) { if (cumulative_sum > sim_hash) {
if (is_blocked) if (is_blocked)
return FlocId(); return base::nullopt;
return FlocId(index, version.components().front()); return index;
} }
} }
// Sanitizing error: we didn't find a sim-hash upper bound, but we expect to // Sanitizing error: we didn't find a sim-hash upper bound, but we expect to
// always find it after finish iterating through the list. // always find it after finish iterating through the list.
return FlocId(); return base::nullopt;
} }
} // namespace } // namespace
...@@ -158,9 +158,10 @@ void FlocSortingLshClustersService::ApplySortingLsh( ...@@ -158,9 +158,10 @@ void FlocSortingLshClustersService::ApplySortingLsh(
base::PostTaskAndReplyWithResult( base::PostTaskAndReplyWithResult(
background_task_runner_.get(), FROM_HERE, background_task_runner_.get(), FROM_HERE,
base::BindOnce(&ApplySortingLshOnBackgroundThread, sim_hash, base::BindOnce(&ApplySortingLshOnBackgroundThread, sim_hash,
sorting_lsh_clusters_file_path_, sorting_lsh_clusters_file_path_),
sorting_lsh_clusters_version_), base::BindOnce(&FlocSortingLshClustersService::DidApplySortingLsh,
std::move(callback)); weak_ptr_factory_.GetWeakPtr(), std::move(callback),
sorting_lsh_clusters_version_));
} }
void FlocSortingLshClustersService::SetBackgroundTaskRunnerForTesting( void FlocSortingLshClustersService::SetBackgroundTaskRunnerForTesting(
...@@ -168,4 +169,11 @@ void FlocSortingLshClustersService::SetBackgroundTaskRunnerForTesting( ...@@ -168,4 +169,11 @@ void FlocSortingLshClustersService::SetBackgroundTaskRunnerForTesting(
background_task_runner_ = background_task_runner; background_task_runner_ = background_task_runner;
} }
void FlocSortingLshClustersService::DidApplySortingLsh(
ApplySortingLshCallback callback,
base::Version version,
base::Optional<uint64_t> final_hash) {
std::move(callback).Run(std::move(final_hash), std::move(version));
}
} // namespace federated_learning } // namespace federated_learning
...@@ -29,7 +29,8 @@ namespace federated_learning { ...@@ -29,7 +29,8 @@ namespace federated_learning {
// File reading and parsing is posted to |background_task_runner_|. // File reading and parsing is posted to |background_task_runner_|.
class FlocSortingLshClustersService { class FlocSortingLshClustersService {
public: public:
using ApplySortingLshCallback = base::OnceCallback<void(FlocId)>; using ApplySortingLshCallback =
base::OnceCallback<void(base::Optional<uint64_t>, base::Version)>;
class Observer { class Observer {
public: public:
...@@ -63,6 +64,10 @@ class FlocSortingLshClustersService { ...@@ -63,6 +64,10 @@ class FlocSortingLshClustersService {
private: private:
friend class FlocSortingLshClustersServiceTest; friend class FlocSortingLshClustersServiceTest;
void DidApplySortingLsh(ApplySortingLshCallback callback,
base::Version version,
base::Optional<uint64_t> final_hash);
// Runner for tasks that do not influence user experience. // Runner for tasks that do not influence user experience.
scoped_refptr<base::SequencedTaskRunner> background_task_runner_; scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
......
...@@ -45,6 +45,11 @@ class CopyingFileOutputStream ...@@ -45,6 +45,11 @@ class CopyingFileOutputStream
base::File file_; base::File file_;
}; };
struct ApplySortingLshResult {
base::Optional<uint64_t> final_hash;
base::Version version;
};
} // namespace } // namespace
class FlocSortingLshClustersServiceTest : public ::testing::Test { class FlocSortingLshClustersServiceTest : public ::testing::Test {
...@@ -117,14 +122,16 @@ class FlocSortingLshClustersServiceTest : public ::testing::Test { ...@@ -117,14 +122,16 @@ class FlocSortingLshClustersServiceTest : public ::testing::Test {
return service()->sorting_lsh_clusters_file_path_; return service()->sorting_lsh_clusters_file_path_;
} }
FlocId ApplySortingLsh(uint64_t sim_hash) { ApplySortingLshResult ApplySortingLsh(uint64_t sim_hash) {
FlocId result; ApplySortingLshResult result;
base::RunLoop run_loop; base::RunLoop run_loop;
auto cb = base::BindLambdaForTesting([&](FlocId floc_id) { auto cb = base::BindLambdaForTesting(
result = floc_id; [&](base::Optional<uint64_t> final_hash, base::Version version) {
run_loop.Quit(); result.final_hash = final_hash;
}); result.version = version;
run_loop.Quit();
});
service()->ApplySortingLsh(sim_hash, std::move(cb)); service()->ApplySortingLsh(sim_hash, std::move(cb));
background_task_runner_->RunPendingTasks(); background_task_runner_->RunPendingTasks();
...@@ -149,83 +156,86 @@ TEST_F(FlocSortingLshClustersServiceTest, NoFilePath) { ...@@ -149,83 +156,86 @@ TEST_F(FlocSortingLshClustersServiceTest, NoFilePath) {
TEST_F(FlocSortingLshClustersServiceTest, EmptyList) { TEST_F(FlocSortingLshClustersServiceTest, EmptyList) {
InitializeSortingLshClustersFile({}, base::Version("2.3.4")); InitializeSortingLshClustersFile({}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(), ApplySortingLsh(0));
EXPECT_EQ(FlocId(), ApplySortingLsh(1)); EXPECT_EQ(base::nullopt, ApplySortingLsh(0).final_hash);
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(base::nullopt, ApplySortingLsh(1).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
} }
TEST_F(FlocSortingLshClustersServiceTest, List_0) { TEST_F(FlocSortingLshClustersServiceTest, List_0) {
InitializeSortingLshClustersFile({{0, false}}, base::Version("2.3.4")); InitializeSortingLshClustersFile({{0, false}}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0)); EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(FlocId(), ApplySortingLsh(1)); EXPECT_EQ(base::Version("2.3.4"), ApplySortingLsh(0).version);
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(base::nullopt, ApplySortingLsh(1).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
} }
TEST_F(FlocSortingLshClustersServiceTest, List_0_Blocked) { TEST_F(FlocSortingLshClustersServiceTest, List_0_Blocked) {
InitializeSortingLshClustersFile({{0, true}}, base::Version("2.3.4")); InitializeSortingLshClustersFile({{0, true}}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(), ApplySortingLsh(0)); EXPECT_EQ(base::nullopt, ApplySortingLsh(0).final_hash);
EXPECT_EQ(FlocId(), ApplySortingLsh(1)); EXPECT_EQ(base::nullopt, ApplySortingLsh(1).final_hash);
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
} }
TEST_F(FlocSortingLshClustersServiceTest, List_UnexpectedNumber) { TEST_F(FlocSortingLshClustersServiceTest, List_UnexpectedNumber) {
InitializeSortingLshClustersFile({{1 << 8, false}}, base::Version("2.3.4")); InitializeSortingLshClustersFile({{1 << 8, false}}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(), ApplySortingLsh(0)); EXPECT_EQ(base::nullopt, ApplySortingLsh(0).final_hash);
EXPECT_EQ(FlocId(), ApplySortingLsh(1)); EXPECT_EQ(base::nullopt, ApplySortingLsh(1).final_hash);
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
} }
TEST_F(FlocSortingLshClustersServiceTest, List_1) { TEST_F(FlocSortingLshClustersServiceTest, List_1) {
InitializeSortingLshClustersFile({{1, false}}, base::Version("2.3.4")); InitializeSortingLshClustersFile({{1, false}}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0)); EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(1)); EXPECT_EQ(0u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(FlocId(), ApplySortingLsh(2)); EXPECT_EQ(base::nullopt, ApplySortingLsh(2).final_hash);
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
} }
TEST_F(FlocSortingLshClustersServiceTest, List_0_0) { TEST_F(FlocSortingLshClustersServiceTest, List_0_0) {
InitializeSortingLshClustersFile({{0, false}, {0, false}}, InitializeSortingLshClustersFile({{0, false}, {0, false}},
base::Version("2.3.4")); base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0)); EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(1)); EXPECT_EQ(1u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(FlocId(), ApplySortingLsh(2)); EXPECT_EQ(base::nullopt, ApplySortingLsh(2).final_hash);
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
} }
TEST_F(FlocSortingLshClustersServiceTest, List_0_1) { TEST_F(FlocSortingLshClustersServiceTest, List_0_1) {
InitializeSortingLshClustersFile({{0, false}, {1, false}}, InitializeSortingLshClustersFile({{0, false}, {1, false}},
base::Version("2.3.4")); base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0)); EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(1)); EXPECT_EQ(1u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(2)); EXPECT_EQ(1u, ApplySortingLsh(2).final_hash.value());
EXPECT_EQ(FlocId(), ApplySortingLsh(3)); EXPECT_EQ(base::nullopt, ApplySortingLsh(3).final_hash);
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
} }
TEST_F(FlocSortingLshClustersServiceTest, List_1_0) { TEST_F(FlocSortingLshClustersServiceTest, List_1_0) {
InitializeSortingLshClustersFile({{1, false}, {0, false}}, InitializeSortingLshClustersFile({{1, false}, {0, false}},
base::Version("2.3.4")); base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0)); EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(1)); EXPECT_EQ(0u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(2)); EXPECT_EQ(1u, ApplySortingLsh(2).final_hash.value());
EXPECT_EQ(FlocId(), ApplySortingLsh(3)); EXPECT_EQ(base::nullopt, ApplySortingLsh(3).final_hash);
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
} }
TEST_F(FlocSortingLshClustersServiceTest, List_SingleCluster) { TEST_F(FlocSortingLshClustersServiceTest, List_SingleCluster) {
InitializeSortingLshClustersFile({{kMaxNumberOfBitsInFloc, false}}, InitializeSortingLshClustersFile({{kMaxNumberOfBitsInFloc, false}},
base::Version("2.3.4")); base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(1)); EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(12345)); EXPECT_EQ(0u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(0u, ApplySortingLsh(12345).final_hash.value());
EXPECT_EQ(0u, ApplySortingLsh(kMaxSimHash).final_hash.value());
} }
TEST_F(FlocSortingLshClustersServiceTest, List_TwoClustersEqualSize) { TEST_F(FlocSortingLshClustersServiceTest, List_TwoClustersEqualSize) {
...@@ -234,12 +244,12 @@ TEST_F(FlocSortingLshClustersServiceTest, List_TwoClustersEqualSize) { ...@@ -234,12 +244,12 @@ TEST_F(FlocSortingLshClustersServiceTest, List_TwoClustersEqualSize) {
base::Version("2.3.4")); base::Version("2.3.4"));
uint64_t middle_value = (1ULL << (kMaxNumberOfBitsInFloc - 1)); uint64_t middle_value = (1ULL << (kMaxNumberOfBitsInFloc - 1));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0)); EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(1)); EXPECT_EQ(0u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(middle_value - 1)); EXPECT_EQ(0u, ApplySortingLsh(middle_value - 1).final_hash.value());
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(middle_value)); EXPECT_EQ(1u, ApplySortingLsh(middle_value).final_hash.value());
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(middle_value + 1)); EXPECT_EQ(1u, ApplySortingLsh(middle_value + 1).final_hash.value());
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(kMaxSimHash)); EXPECT_EQ(1u, ApplySortingLsh(kMaxSimHash).final_hash.value());
} }
TEST_F(FlocSortingLshClustersServiceTest, TEST_F(FlocSortingLshClustersServiceTest,
...@@ -248,11 +258,13 @@ TEST_F(FlocSortingLshClustersServiceTest, ...@@ -248,11 +258,13 @@ TEST_F(FlocSortingLshClustersServiceTest,
InitializeSortingLshClustersFile({{0, false}}, base::Version("2.3.4")); InitializeSortingLshClustersFile({{0, false}}, base::Version("2.3.4"));
base::RunLoop run_loop; base::RunLoop run_loop;
auto cb = base::BindLambdaForTesting([&](FlocId floc_id) { auto cb = base::BindLambdaForTesting(
// Since the file has been deleted, expect an invalid floc id. [&](base::Optional<uint64_t> final_hash, base::Version version) {
EXPECT_EQ(FlocId(), floc_id); // Since the file has been deleted, expect an invalid final_hash.
run_loop.Quit(); EXPECT_EQ(base::nullopt, final_hash);
}); EXPECT_EQ(base::Version("2.3.4"), version);
run_loop.Quit();
});
service()->ApplySortingLsh(/*sim_hash=*/0, std::move(cb)); service()->ApplySortingLsh(/*sim_hash=*/0, std::move(cb));
base::DeleteFile(file_path); base::DeleteFile(file_path);
...@@ -264,7 +276,9 @@ TEST_F(FlocSortingLshClustersServiceTest, ...@@ -264,7 +276,9 @@ TEST_F(FlocSortingLshClustersServiceTest,
TEST_F(FlocSortingLshClustersServiceTest, MultipleUpdate_LatestOneUsed) { TEST_F(FlocSortingLshClustersServiceTest, MultipleUpdate_LatestOneUsed) {
InitializeSortingLshClustersFile({}, base::Version("2.3.4")); InitializeSortingLshClustersFile({}, base::Version("2.3.4"));
InitializeSortingLshClustersFile({{0, false}}, base::Version("6.7.8.9")); InitializeSortingLshClustersFile({{0, false}}, base::Version("6.7.8.9"));
EXPECT_EQ(FlocId(0, 6), ApplySortingLsh(0));
EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(base::Version("6.7.8.9"), ApplySortingLsh(0).version);
} }
} // namespace federated_learning } // namespace federated_learning
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment