Commit a927c079 authored by Yao Xiao's avatar Yao Xiao Committed by Commit Bot

Avoid recalculating floc on history-delete. Will either invalidate or no-op.

Why: we don't want floc to change more often than the intended
scheduled update cadence.

What:
- For each history-delete notification, if it's all-history or
the time-range overlaps with the time range of the history used to
compute the floc, we invalidate the floc. Otherwise, we keep using the
current floc.
- Remove the mechanism caching the swaa_nac_account_enabled status.
It's no longer needed as we don't need to query more often than the
scheduled update rate (% rare race condition).

How: Store history_begin_time_/history_end_time_ fields to FlocId.
Compare these fields with the history delete info.

Bug: 1143597
Change-Id: I685fd1a10cbc8044e799d7a644dfa2b7bb82ddd4
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2507454
Commit-Queue: Yao Xiao <yaoxia@chromium.org>
Reviewed-by: default avatarJosh Karlin <jkarlin@chromium.org>
Cr-Commit-Position: refs/heads/master@{#826680}
parent 1a5301a5
......@@ -214,6 +214,38 @@ class FlocIdProviderWithCustomizedServicesBrowserTest
return enumerator.urls();
}
std::pair<base::Time, base::Time> GetHistoryTimeRange() {
history::QueryOptions options;
options.duplicate_policy = history::QueryOptions::KEEP_ALL_DUPLICATES;
base::Time history_begin_time = base::Time::Max();
base::Time history_end_time = base::Time::Min();
base::RunLoop run_loop;
base::CancelableTaskTracker tracker;
HistoryServiceFactory::GetForProfile(browser()->profile(),
ServiceAccessType::EXPLICIT_ACCESS)
->QueryHistory(
base::string16(), options,
base::BindLambdaForTesting([&](history::QueryResults results) {
for (const history::URLResult& url_result : results) {
if (!url_result.publicly_routable())
continue;
if (url_result.visit_time() < history_begin_time)
history_begin_time = url_result.visit_time();
if (url_result.visit_time() > history_end_time)
history_end_time = url_result.visit_time();
}
run_loop.Quit();
}),
&tracker);
run_loop.Run();
return {history_begin_time, history_end_time};
}
void FinishOutstandingRemotePermissionQueries() {
base::RunLoop run_loop;
FlocRemotePermissionServiceFactory::GetForProfile(browser()->profile())
......@@ -241,7 +273,8 @@ class FlocIdProviderWithCustomizedServicesBrowserTest
const uint64_t dummy_sim_hash = 0u;
g_browser_process->floc_sorting_lsh_clusters_service()->ApplySortingLsh(
dummy_sim_hash,
base::BindLambdaForTesting([&](FlocId floc) { run_loop.Quit(); }));
base::BindLambdaForTesting(
[&](base::Optional<uint64_t>, base::Version) { run_loop.Quit(); }));
run_loop.Run();
}
......@@ -250,9 +283,10 @@ class FlocIdProviderWithCustomizedServicesBrowserTest
base::CancelableTaskTracker tracker;
HistoryServiceFactory::GetForProfile(browser()->profile(),
ServiceAccessType::EXPLICIT_ACCESS)
->ExpireHistoryBeforeForTesting(
end_time, base::BindLambdaForTesting([&]() { run_loop.Quit(); }),
&tracker);
->ExpireHistoryBetween(
/*restrict_urls=*/{}, /*begin_time=*/base::Time(), end_time,
/*user_initiated=*/true,
base::BindLambdaForTesting([&]() { run_loop.Quit(); }), &tracker);
run_loop.Run();
}
......@@ -528,8 +562,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest,
InitializeHistorySync();
// Promise resolved with the expected floc value.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), 0).ToString(),
// Promise resolved with the expected string.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), base::Time(),
base::Time(), 0)
.ToStringForJsApi(),
InvokeInterestCohortJsApi(web_contents()));
}
......@@ -555,8 +591,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest,
content::RenderFrameHost* child =
content::ChildFrameAt(web_contents()->GetMainFrame(), 0);
// Promise resolved with the expected floc value.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), 0).ToString(),
// Promise resolved with the expected string.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), base::Time(),
base::Time(), 0)
.ToStringForJsApi(),
InvokeInterestCohortJsApi(child));
}
......@@ -582,8 +620,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest,
content::RenderFrameHost* child =
content::ChildFrameAt(web_contents()->GetMainFrame(), 0);
// Promise resolved with the expected floc value.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), 0).ToString(),
// Promise resolved with the expected string.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), base::Time(),
base::Time(), 0)
.ToStringForJsApi(),
InvokeInterestCohortJsApi(child));
}
......@@ -618,8 +658,10 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderWithCustomizedServicesBrowserTest,
// Promise rejected as the cookies permission disallows the child's host.
EXPECT_EQ("rejected", InvokeInterestCohortJsApi(child));
// Promise resolved with the expected floc value.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), 0).ToString(),
// Promise resolved with the expected string.
EXPECT_EQ(FlocId(FlocId::SimHashHistory({test_host()}), base::Time(),
base::Time(), 0)
.ToStringForJsApi(),
InvokeInterestCohortJsApi(web_contents()));
}
......@@ -646,6 +688,9 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
browser(), https_server_.GetURL(test_host(), cookies_to_set));
EXPECT_EQ(1u, GetHistoryUrls().size());
auto p = GetHistoryTimeRange();
base::Time history_begin_time = p.first;
base::Time history_end_time = p.second;
EXPECT_FALSE(GetFlocId().IsValid());
......@@ -659,7 +704,7 @@ IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
EXPECT_NE(0u, FlocId::SimHashHistory({test_host()}));
// Expect that the final id is 0 because the sorting-lsh was applied.
EXPECT_EQ(FlocId(0, 9), GetFlocId());
EXPECT_EQ(FlocId(0, history_begin_time, history_end_time, 9), GetFlocId());
}
IN_PROC_BROWSER_TEST_F(FlocIdProviderSortingLshEnabledBrowserTest,
......
......@@ -29,8 +29,6 @@ constexpr size_t kMinHistoryDomainSizeToReportFlocId = 1;
constexpr base::TimeDelta kFlocScheduledUpdateInterval =
base::TimeDelta::FromDays(1);
constexpr int kQueryHistoryWindowInDays = 7;
constexpr base::TimeDelta kSwaaNacAccountEnabledCachePeriod =
base::TimeDelta::FromHours(12);
// The placeholder sorting-lsh version when the sorting-lsh feature is disabled.
constexpr uint32_t kSortingLshVersionPlaceholder = 0;
......@@ -81,7 +79,7 @@ std::string FlocIdProviderImpl::GetInterestCohortForJsApi(
if (!floc_id_.IsValid())
return std::string();
return floc_id_.ToString();
return floc_id_.ToStringForJsApi();
}
void FlocIdProviderImpl::OnComputeFlocCompleted(ComputeFlocTrigger trigger,
......@@ -89,12 +87,11 @@ void FlocIdProviderImpl::OnComputeFlocCompleted(ComputeFlocTrigger trigger,
DCHECK(floc_computation_in_progress_);
floc_computation_in_progress_ = false;
// Some recompute event came in when this computation was in progress. Ignore
// this computation completely. Handle the pending one.
if (pending_recompute_event_) {
ComputeFlocTrigger recompute_trigger = pending_recompute_event_.value();
pending_recompute_event_.reset();
ComputeFloc(recompute_trigger);
// History-delete event came in when this computation was in progress. Ignore
// this computation completely and recompute.
if (need_recompute_) {
need_recompute_ = false;
ComputeFloc(trigger);
return;
}
......@@ -167,18 +164,34 @@ void FlocIdProviderImpl::Shutdown() {
void FlocIdProviderImpl::OnURLsDeleted(
history::HistoryService* history_service,
const history::DeletionInfo& deletion_info) {
// Set a pending event or override the existing one, that will get run when
// the in-progress computation finishes.
// Set the |need_recompute_| flag so that we will recompute the floc
// immediately after the in-progress one finishes, so as to avoid potential
// data races.
if (floc_computation_in_progress_) {
DCHECK(first_floc_computation_triggered_);
pending_recompute_event_ = ComputeFlocTrigger::kHistoryDelete;
need_recompute_ = true;
return;
}
if (!first_floc_computation_triggered_ || !floc_id_.IsValid())
if (!floc_id_.IsValid())
return;
// Only invalidate the floc if it's delete-all or if the time range overlaps
// with the time range of the history used to compute the current floc.
if (!deletion_info.IsAllHistory() && !deletion_info.time_range().IsValid()) {
return;
}
if (deletion_info.time_range().begin() > floc_id_.history_end_time() ||
deletion_info.time_range().end() < floc_id_.history_begin_time()) {
return;
}
ComputeFloc(ComputeFlocTrigger::kHistoryDelete);
// We log the invalidation event although it's technically not a recompute.
// It'd give us a better idea how often the floc is invalidated due to
// history-delete.
LogFlocComputedEvent(ComputeFlocTrigger::kHistoryDelete, ComputeFlocResult());
floc_id_ = FlocId();
}
void FlocIdProviderImpl::OnSortingLshClustersFileReady() {
......@@ -218,21 +231,15 @@ void FlocIdProviderImpl::MaybeTriggerFirstFlocComputation() {
}
void FlocIdProviderImpl::OnComputeFlocScheduledUpdate() {
// It's fine to skip the scheduled update as long as there's one in progress.
// We won't be losing the recomputing frequency, as the in-progress one only
// occurs sooner and when it finishes a new compute-floc task will be
// scheduled.
if (floc_computation_in_progress_)
return;
DCHECK(!pending_recompute_event_);
DCHECK(!floc_computation_in_progress_);
ComputeFloc(ComputeFlocTrigger::kScheduledUpdate);
}
void FlocIdProviderImpl::ComputeFloc(ComputeFlocTrigger trigger) {
DCHECK_NE(trigger == ComputeFlocTrigger::kBrowserStart,
first_floc_computation_triggered_);
DCHECK(trigger == ComputeFlocTrigger::kBrowserStart ||
(trigger == ComputeFlocTrigger::kScheduledUpdate &&
first_floc_computation_triggered_));
DCHECK(!floc_computation_in_progress_);
floc_computation_in_progress_ = true;
......@@ -285,14 +292,6 @@ bool FlocIdProviderImpl::AreThirdPartyCookiesAllowed() const {
void FlocIdProviderImpl::IsSwaaNacAccountEnabled(
CanComputeFlocCallback callback) {
if (!last_swaa_nac_account_enabled_query_time_.is_null() &&
last_swaa_nac_account_enabled_query_time_ +
kSwaaNacAccountEnabledCachePeriod >
base::TimeTicks::Now()) {
std::move(callback).Run(cached_swaa_nac_account_enabled_);
return;
}
net::PartialNetworkTrafficAnnotationTag partial_traffic_annotation =
net::DefinePartialNetworkTrafficAnnotation(
"floc_id_provider_impl", "floc_remote_permission_service",
......@@ -323,17 +322,7 @@ void FlocIdProviderImpl::IsSwaaNacAccountEnabled(
})");
floc_remote_permission_service_->QueryFlocPermission(
base::BindOnce(&FlocIdProviderImpl::OnCheckSwaaNacAccountEnabledCompleted,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)),
partial_traffic_annotation);
}
void FlocIdProviderImpl::OnCheckSwaaNacAccountEnabledCompleted(
CanComputeFlocCallback callback,
bool enabled) {
cached_swaa_nac_account_enabled_ = enabled;
last_swaa_nac_account_enabled_query_time_ = base::TimeTicks::Now();
std::move(callback).Run(enabled);
std::move(callback), partial_traffic_annotation);
}
void FlocIdProviderImpl::GetRecentlyVisitedURLs(
......@@ -350,10 +339,20 @@ void FlocIdProviderImpl::OnGetRecentlyVisitedURLsCompleted(
ComputeFlocCompletedCallback callback,
history::QueryResults results) {
std::unordered_set<std::string> domains;
base::Time history_begin_time = base::Time::Max();
base::Time history_end_time = base::Time::Min();
for (const history::URLResult& url_result : results) {
if (!url_result.publicly_routable())
continue;
if (url_result.visit_time() < history_begin_time)
history_begin_time = url_result.visit_time();
if (url_result.visit_time() > history_end_time)
history_end_time = url_result.visit_time();
domains.insert(net::registry_controlled_domains::GetDomainAndRegistry(
url_result.url(),
net::registry_controlled_domains::INCLUDE_PRIVATE_REGISTRIES));
......@@ -365,16 +364,20 @@ void FlocIdProviderImpl::OnGetRecentlyVisitedURLsCompleted(
}
ApplySortingLshPostProcessing(std::move(callback),
FlocId::SimHashHistory(domains));
FlocId::SimHashHistory(domains),
history_begin_time, history_end_time);
}
void FlocIdProviderImpl::ApplySortingLshPostProcessing(
ComputeFlocCompletedCallback callback,
uint64_t sim_hash) {
uint64_t sim_hash,
base::Time history_begin_time,
base::Time history_end_time) {
if (!base::FeatureList::IsEnabled(
features::kFlocIdSortingLshBasedComputation)) {
std::move(callback).Run(ComputeFlocResult(
sim_hash, FlocId(sim_hash, kSortingLshVersionPlaceholder)));
sim_hash, FlocId(sim_hash, history_begin_time, history_end_time,
kSortingLshVersionPlaceholder)));
return;
}
......@@ -382,14 +385,24 @@ void FlocIdProviderImpl::ApplySortingLshPostProcessing(
sim_hash,
base::BindOnce(&FlocIdProviderImpl::DidApplySortingLshPostProcessing,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
sim_hash));
sim_hash, history_begin_time, history_end_time));
}
void FlocIdProviderImpl::DidApplySortingLshPostProcessing(
ComputeFlocCompletedCallback callback,
uint64_t sim_hash,
FlocId floc_id) {
std::move(callback).Run(ComputeFlocResult(sim_hash, std::move(floc_id)));
base::Time history_begin_time,
base::Time history_end_time,
base::Optional<uint64_t> final_hash,
base::Version version) {
if (!final_hash) {
std::move(callback).Run(ComputeFlocResult(sim_hash, FlocId()));
return;
}
std::move(callback).Run(ComputeFlocResult(
sim_hash, FlocId(final_hash.value(), history_begin_time, history_end_time,
version.components().front())));
}
} // namespace federated_learning
......@@ -42,8 +42,9 @@ class FlocRemotePermissionService;
//
// The floc will be first computed after sync & sync-history are enabled. After
// each computation, another computation will be scheduled 24 hours later. In
// the event of history deletion, the floc will be recomputed immediately and
// reset the timer of any currently scheduled computation to be 24 hours later.
// the event of history deletion, the floc will be invalidated immediately
// if the time range of the deletion overlaps with the time range used to
// compute the existing floc.
class FlocIdProviderImpl : public FlocIdProvider,
public FlocSortingLshClustersService::Observer,
public history::HistoryServiceObserver,
......@@ -134,8 +135,6 @@ class FlocIdProviderImpl : public FlocIdProvider,
bool AreThirdPartyCookiesAllowed() const;
void IsSwaaNacAccountEnabled(CanComputeFlocCallback callback);
void OnCheckSwaaNacAccountEnabledCompleted(CanComputeFlocCallback callback,
bool enabled);
void GetRecentlyVisitedURLs(GetRecentlyVisitedURLsCallback callback);
void OnGetRecentlyVisitedURLsCompleted(ComputeFlocCompletedCallback callback,
......@@ -145,10 +144,15 @@ class FlocIdProviderImpl : public FlocIdProvider,
// The final floc may be invalid if the file is corrupted or the floc end up
// being blocked.
void ApplySortingLshPostProcessing(ComputeFlocCompletedCallback callback,
uint64_t sim_hash);
uint64_t sim_hash,
base::Time history_begin_time,
base::Time history_end_time);
void DidApplySortingLshPostProcessing(ComputeFlocCompletedCallback callback,
uint64_t sim_hash,
FlocId floc_id);
base::Time history_begin_time,
base::Time history_end_time,
base::Optional<uint64_t> final_hash,
base::Version version);
// The id to be exposed to the JS API.
FlocId floc_id_;
......@@ -156,19 +160,17 @@ class FlocIdProviderImpl : public FlocIdProvider,
bool floc_computation_in_progress_ = false;
bool first_floc_computation_triggered_ = false;
// We store a pending event if it arrives during an in-progress computation.
// When the in-progress one finishes, we would disregard the result (no
// loggings, updates, etc.), and compute again.
base::Optional<ComputeFlocTrigger> pending_recompute_event_;
// True if history-delete occurs during an in-progress computation. When the
// in-progress one finishes, we would disregard the result (i.e. no loggings
// or floc update), and compute again. Potentially we could maintain extra
// states to tell if the history-delete would have impact on the in-progress
// result, but since this would only happen in rare race situations, we just
// always recompute to keep things simple.
bool need_recompute_ = false;
bool first_sorting_lsh_file_ready_seen_ = false;
bool first_sync_history_enabled_seen_ = false;
// For the swaa/nac/account_type permission, we will use a cached status to
// avoid querying too often.
bool cached_swaa_nac_account_enabled_ = false;
base::TimeTicks last_swaa_nac_account_enabled_query_time_;
syncer::SyncService* sync_service_;
scoped_refptr<content_settings::CookieSettings> cookie_settings_;
FlocRemotePermissionService* floc_remote_permission_service_;
......
......@@ -26,8 +26,14 @@ uint64_t FlocId::SimHashHistory(
FlocId::FlocId() = default;
FlocId::FlocId(uint64_t id, uint32_t sorting_lsh_version)
: id_(id), sorting_lsh_version_(sorting_lsh_version) {}
FlocId::FlocId(uint64_t id,
base::Time history_begin_time,
base::Time history_end_time,
uint32_t sorting_lsh_version)
: id_(id),
history_begin_time_(history_begin_time),
history_end_time_(history_end_time),
sorting_lsh_version_(sorting_lsh_version) {}
FlocId::FlocId(const FlocId& id) = default;
......@@ -42,19 +48,16 @@ bool FlocId::IsValid() const {
}
bool FlocId::operator==(const FlocId& other) const {
return id_ == other.id_ && sorting_lsh_version_ == other.sorting_lsh_version_;
return id_ == other.id_ && history_begin_time_ == other.history_begin_time_ &&
history_end_time_ == other.history_end_time_ &&
sorting_lsh_version_ == other.sorting_lsh_version_;
}
bool FlocId::operator!=(const FlocId& other) const {
return !(*this == other);
}
uint64_t FlocId::ToUint64() const {
DCHECK(id_.has_value());
return id_.value();
}
std::string FlocId::ToString() const {
std::string FlocId::ToStringForJsApi() const {
DCHECK(id_.has_value());
return base::StrCat({base::NumberToString(id_.value()), ".",
......
......@@ -6,6 +6,7 @@
#define COMPONENTS_FEDERATED_LEARNING_FLOC_ID_H_
#include "base/optional.h"
#include "base/time/time.h"
#include "base/version.h"
#include <stdint.h>
......@@ -24,9 +25,13 @@ class FlocId {
const std::unordered_set<std::string>& domains);
FlocId();
explicit FlocId(uint64_t id, uint32_t sorting_lsh_version);
FlocId(const FlocId& id);
explicit FlocId(uint64_t id,
base::Time history_begin_time,
base::Time history_end_time,
uint32_t sorting_lsh_version);
FlocId(const FlocId& id);
~FlocId();
FlocId& operator=(const FlocId& id);
FlocId& operator=(FlocId&& id);
......@@ -35,16 +40,24 @@ class FlocId {
bool operator!=(const FlocId& other) const;
bool IsValid() const;
uint64_t ToUint64() const;
// The id, followed by the chrome floc version, followed by the async floc
// component versions (i.e. model and sorting-lsh). This is the format to be
// exposed to the JS API. Precondition: |id_| must be valid.
std::string ToString() const;
std::string ToStringForJsApi() const;
base::Time history_begin_time() const { return history_begin_time_; }
base::Time history_end_time() const { return history_end_time_; }
private:
base::Optional<uint64_t> id_;
// The time range of the actual history used to compute the floc. This should
// always be within the time range of each history query.
base::Time history_begin_time_;
base::Time history_end_time_;
// The main version (i.e. 1st int) of the sorting lsh component version.
uint32_t sorting_lsh_version_ = 0;
};
......
......@@ -8,32 +8,33 @@
namespace federated_learning {
const base::Time kTime0 = base::Time();
const base::Time kTime1 = base::Time::FromTimeT(1);
const base::Time kTime2 = base::Time::FromTimeT(2);
TEST(FlocIdTest, IsValid) {
EXPECT_FALSE(FlocId().IsValid());
EXPECT_TRUE(FlocId(0, 0).IsValid());
EXPECT_TRUE(FlocId(0, 1).IsValid());
}
TEST(FlocIdTest, ToUint64) {
EXPECT_EQ(0u, FlocId(0, 0).ToUint64());
EXPECT_EQ(1u, FlocId(1, 0).ToUint64());
EXPECT_EQ(1u, FlocId(1, 1).ToUint64());
EXPECT_TRUE(FlocId(0, kTime0, kTime0, 0).IsValid());
EXPECT_TRUE(FlocId(0, kTime1, kTime2, 1).IsValid());
}
TEST(FlocIdTest, Comparison) {
EXPECT_EQ(FlocId(), FlocId());
EXPECT_EQ(FlocId(0, 0), FlocId(0, 0));
EXPECT_EQ(FlocId(0, 1), FlocId(0, 1));
EXPECT_NE(FlocId(), FlocId(0, 0));
EXPECT_NE(FlocId(0, 0), FlocId(1, 0));
EXPECT_NE(FlocId(0, 0), FlocId(0, 1));
EXPECT_EQ(FlocId(0, kTime0, kTime0, 0), FlocId(0, kTime0, kTime0, 0));
EXPECT_EQ(FlocId(0, kTime1, kTime1, 1), FlocId(0, kTime1, kTime1, 1));
EXPECT_EQ(FlocId(0, kTime1, kTime2, 1), FlocId(0, kTime1, kTime2, 1));
EXPECT_NE(FlocId(), FlocId(0, kTime0, kTime0, 0));
EXPECT_NE(FlocId(0, kTime0, kTime0, 0), FlocId(1, kTime0, kTime0, 0));
EXPECT_NE(FlocId(0, kTime0, kTime1, 0), FlocId(0, kTime1, kTime1, 0));
EXPECT_NE(FlocId(0, kTime0, kTime0, 0), FlocId(0, kTime0, kTime0, 1));
}
TEST(FlocIdTest, ToString) {
EXPECT_EQ("0.1.0", FlocId(0, 0).ToString());
EXPECT_EQ("12345.1.0", FlocId(12345, 0).ToString());
EXPECT_EQ("12345.1.2", FlocId(12345, 2).ToString());
TEST(FlocIdTest, ToStringForJsApi) {
EXPECT_EQ("0.1.0", FlocId(0, kTime0, kTime0, 0).ToStringForJsApi());
EXPECT_EQ("12345.1.0", FlocId(12345, kTime0, kTime0, 0).ToStringForJsApi());
EXPECT_EQ("12345.1.2", FlocId(12345, kTime1, kTime1, 2).ToStringForJsApi());
}
} // namespace federated_learning
......@@ -36,13 +36,13 @@ class CopyingFileInputStream : public google::protobuf::io::CopyingInputStream {
base::File file_;
};
FlocId ApplySortingLshOnBackgroundThread(uint64_t sim_hash,
const base::FilePath& file_path,
const base::Version& version) {
base::Optional<uint64_t> ApplySortingLshOnBackgroundThread(
uint64_t sim_hash,
const base::FilePath& file_path) {
base::File sorting_lsh_clusters_file(
file_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
if (!sorting_lsh_clusters_file.IsValid())
return FlocId();
return base::nullopt;
CopyingFileInputStream copying_stream(std::move(sorting_lsh_clusters_file));
google::protobuf::io::CopyingInputStreamAdaptor zero_copy_stream_adaptor(
......@@ -88,33 +88,33 @@ FlocId ApplySortingLshOnBackgroundThread(uint64_t sim_hash,
for (uint64_t index = 0; input_stream.ReadVarint32(&next_combined); ++index) {
// Sanitizing error: the entry used more than |kSortingLshMaxBits| bits.
if ((next_combined >> kSortingLshMaxBits) > 0)
return FlocId();
return base::nullopt;
bool is_blocked = next_combined & kSortingLshBlockedMask;
uint32_t next = next_combined & kSortingLshSizeMask;
// Sanitizing error
if (next > kMaxNumberOfBitsInFloc)
return FlocId();
return base::nullopt;
cumulative_sum += (1ULL << next);
// Sanitizing error
if (cumulative_sum > kExpectedFinalCumulativeSum)
return FlocId();
return base::nullopt;
// Found the sim-hash upper bound. Use the index as the new floc.
if (cumulative_sum > sim_hash) {
if (is_blocked)
return FlocId();
return base::nullopt;
return FlocId(index, version.components().front());
return index;
}
}
// Sanitizing error: we didn't find a sim-hash upper bound, but we expect to
// always find it after finish iterating through the list.
return FlocId();
return base::nullopt;
}
} // namespace
......@@ -158,9 +158,10 @@ void FlocSortingLshClustersService::ApplySortingLsh(
base::PostTaskAndReplyWithResult(
background_task_runner_.get(), FROM_HERE,
base::BindOnce(&ApplySortingLshOnBackgroundThread, sim_hash,
sorting_lsh_clusters_file_path_,
sorting_lsh_clusters_version_),
std::move(callback));
sorting_lsh_clusters_file_path_),
base::BindOnce(&FlocSortingLshClustersService::DidApplySortingLsh,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
sorting_lsh_clusters_version_));
}
void FlocSortingLshClustersService::SetBackgroundTaskRunnerForTesting(
......@@ -168,4 +169,11 @@ void FlocSortingLshClustersService::SetBackgroundTaskRunnerForTesting(
background_task_runner_ = background_task_runner;
}
void FlocSortingLshClustersService::DidApplySortingLsh(
ApplySortingLshCallback callback,
base::Version version,
base::Optional<uint64_t> final_hash) {
std::move(callback).Run(std::move(final_hash), std::move(version));
}
} // namespace federated_learning
......@@ -29,7 +29,8 @@ namespace federated_learning {
// File reading and parsing is posted to |background_task_runner_|.
class FlocSortingLshClustersService {
public:
using ApplySortingLshCallback = base::OnceCallback<void(FlocId)>;
using ApplySortingLshCallback =
base::OnceCallback<void(base::Optional<uint64_t>, base::Version)>;
class Observer {
public:
......@@ -63,6 +64,10 @@ class FlocSortingLshClustersService {
private:
friend class FlocSortingLshClustersServiceTest;
void DidApplySortingLsh(ApplySortingLshCallback callback,
base::Version version,
base::Optional<uint64_t> final_hash);
// Runner for tasks that do not influence user experience.
scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
......
......@@ -45,6 +45,11 @@ class CopyingFileOutputStream
base::File file_;
};
struct ApplySortingLshResult {
base::Optional<uint64_t> final_hash;
base::Version version;
};
} // namespace
class FlocSortingLshClustersServiceTest : public ::testing::Test {
......@@ -117,14 +122,16 @@ class FlocSortingLshClustersServiceTest : public ::testing::Test {
return service()->sorting_lsh_clusters_file_path_;
}
FlocId ApplySortingLsh(uint64_t sim_hash) {
FlocId result;
ApplySortingLshResult ApplySortingLsh(uint64_t sim_hash) {
ApplySortingLshResult result;
base::RunLoop run_loop;
auto cb = base::BindLambdaForTesting([&](FlocId floc_id) {
result = floc_id;
run_loop.Quit();
});
auto cb = base::BindLambdaForTesting(
[&](base::Optional<uint64_t> final_hash, base::Version version) {
result.final_hash = final_hash;
result.version = version;
run_loop.Quit();
});
service()->ApplySortingLsh(sim_hash, std::move(cb));
background_task_runner_->RunPendingTasks();
......@@ -149,83 +156,86 @@ TEST_F(FlocSortingLshClustersServiceTest, NoFilePath) {
TEST_F(FlocSortingLshClustersServiceTest, EmptyList) {
InitializeSortingLshClustersFile({}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(), ApplySortingLsh(0));
EXPECT_EQ(FlocId(), ApplySortingLsh(1));
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(base::nullopt, ApplySortingLsh(0).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(1).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
}
TEST_F(FlocSortingLshClustersServiceTest, List_0) {
InitializeSortingLshClustersFile({{0, false}}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0));
EXPECT_EQ(FlocId(), ApplySortingLsh(1));
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(base::Version("2.3.4"), ApplySortingLsh(0).version);
EXPECT_EQ(base::nullopt, ApplySortingLsh(1).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
}
TEST_F(FlocSortingLshClustersServiceTest, List_0_Blocked) {
InitializeSortingLshClustersFile({{0, true}}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(), ApplySortingLsh(0));
EXPECT_EQ(FlocId(), ApplySortingLsh(1));
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(base::nullopt, ApplySortingLsh(0).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(1).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
}
TEST_F(FlocSortingLshClustersServiceTest, List_UnexpectedNumber) {
InitializeSortingLshClustersFile({{1 << 8, false}}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(), ApplySortingLsh(0));
EXPECT_EQ(FlocId(), ApplySortingLsh(1));
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(base::nullopt, ApplySortingLsh(0).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(1).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
}
TEST_F(FlocSortingLshClustersServiceTest, List_1) {
InitializeSortingLshClustersFile({{1, false}}, base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(1));
EXPECT_EQ(FlocId(), ApplySortingLsh(2));
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(0u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(base::nullopt, ApplySortingLsh(2).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
}
TEST_F(FlocSortingLshClustersServiceTest, List_0_0) {
InitializeSortingLshClustersFile({{0, false}, {0, false}},
base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0));
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(1));
EXPECT_EQ(FlocId(), ApplySortingLsh(2));
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(1u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(base::nullopt, ApplySortingLsh(2).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
}
TEST_F(FlocSortingLshClustersServiceTest, List_0_1) {
InitializeSortingLshClustersFile({{0, false}, {1, false}},
base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0));
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(1));
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(2));
EXPECT_EQ(FlocId(), ApplySortingLsh(3));
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(1u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(1u, ApplySortingLsh(2).final_hash.value());
EXPECT_EQ(base::nullopt, ApplySortingLsh(3).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
}
TEST_F(FlocSortingLshClustersServiceTest, List_1_0) {
InitializeSortingLshClustersFile({{1, false}, {0, false}},
base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(1));
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(2));
EXPECT_EQ(FlocId(), ApplySortingLsh(3));
EXPECT_EQ(FlocId(), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(0u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(1u, ApplySortingLsh(2).final_hash.value());
EXPECT_EQ(base::nullopt, ApplySortingLsh(3).final_hash);
EXPECT_EQ(base::nullopt, ApplySortingLsh(kMaxSimHash).final_hash);
}
TEST_F(FlocSortingLshClustersServiceTest, List_SingleCluster) {
InitializeSortingLshClustersFile({{kMaxNumberOfBitsInFloc, false}},
base::Version("2.3.4"));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(1));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(12345));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(0u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(0u, ApplySortingLsh(12345).final_hash.value());
EXPECT_EQ(0u, ApplySortingLsh(kMaxSimHash).final_hash.value());
}
TEST_F(FlocSortingLshClustersServiceTest, List_TwoClustersEqualSize) {
......@@ -234,12 +244,12 @@ TEST_F(FlocSortingLshClustersServiceTest, List_TwoClustersEqualSize) {
base::Version("2.3.4"));
uint64_t middle_value = (1ULL << (kMaxNumberOfBitsInFloc - 1));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(0));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(1));
EXPECT_EQ(FlocId(0, 2), ApplySortingLsh(middle_value - 1));
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(middle_value));
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(middle_value + 1));
EXPECT_EQ(FlocId(1, 2), ApplySortingLsh(kMaxSimHash));
EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(0u, ApplySortingLsh(1).final_hash.value());
EXPECT_EQ(0u, ApplySortingLsh(middle_value - 1).final_hash.value());
EXPECT_EQ(1u, ApplySortingLsh(middle_value).final_hash.value());
EXPECT_EQ(1u, ApplySortingLsh(middle_value + 1).final_hash.value());
EXPECT_EQ(1u, ApplySortingLsh(kMaxSimHash).final_hash.value());
}
TEST_F(FlocSortingLshClustersServiceTest,
......@@ -248,11 +258,13 @@ TEST_F(FlocSortingLshClustersServiceTest,
InitializeSortingLshClustersFile({{0, false}}, base::Version("2.3.4"));
base::RunLoop run_loop;
auto cb = base::BindLambdaForTesting([&](FlocId floc_id) {
// Since the file has been deleted, expect an invalid floc id.
EXPECT_EQ(FlocId(), floc_id);
run_loop.Quit();
});
auto cb = base::BindLambdaForTesting(
[&](base::Optional<uint64_t> final_hash, base::Version version) {
// Since the file has been deleted, expect an invalid final_hash.
EXPECT_EQ(base::nullopt, final_hash);
EXPECT_EQ(base::Version("2.3.4"), version);
run_loop.Quit();
});
service()->ApplySortingLsh(/*sim_hash=*/0, std::move(cb));
base::DeleteFile(file_path);
......@@ -264,7 +276,9 @@ TEST_F(FlocSortingLshClustersServiceTest,
TEST_F(FlocSortingLshClustersServiceTest, MultipleUpdate_LatestOneUsed) {
InitializeSortingLshClustersFile({}, base::Version("2.3.4"));
InitializeSortingLshClustersFile({{0, false}}, base::Version("6.7.8.9"));
EXPECT_EQ(FlocId(0, 6), ApplySortingLsh(0));
EXPECT_EQ(0u, ApplySortingLsh(0).final_hash.value());
EXPECT_EQ(base::Version("6.7.8.9"), ApplySortingLsh(0).version);
}
} // namespace federated_learning
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment