Commit 5d5c9a3a authored by Troy Hildebrandt's avatar Troy Hildebrandt Committed by Commit Bot

Fix SharedProtoDatabase WeakPtr dereferences/multiple init issues.

Introduces the concept of pending client initializations so that we
don't fail in odd ways if multiple initializations are in flight.

Also fixes some of the initialization PostTasks to ensure we're always
dereferencing our WeakPtrs on the same task runnerand ensuring that we
actually post our callbacks on the calling task runner.

Removes the WeakPtrFactory from ProtoLevelDBWrapper as well, which
involves a substantial refactoring to get DB init status back. A new
InitStatusCallback has been created so the 2 param Init and
InitWithDatabase calls give me an InitStatus. IsCorrupt was removed,
since setting the corruption state was the only thing that required
using a WeakPtr in the first place. This gives us the freedom to make
calls to the wrapper from any sequence regardless of what it was
created on, and not have the WeakPtrFactory cause problems when it's
destructed.

Bug: 912117,870813
Change-Id: Ic7931a543b4d3d09714184dfb335311130bc7667
Reviewed-on: https://chromium-review.googlesource.com/c/1364074
Commit-Queue: Troy Hildebrandt <thildebr@chromium.org>
Reviewed-by: default avatarTommy Nyquist <nyquist@chromium.org>
Cr-Commit-Position: refs/heads/master@{#616072}
parent ad0bd35f
......@@ -121,7 +121,8 @@ void OnDBInitComplete(
FeatureVector feature_filter,
PersistentAvailabilityStore::OnLoadedCallback on_loaded_callback,
uint32_t current_day,
bool success) {
leveldb_proto::Enums::InitStatus status) {
bool success = status == leveldb_proto::Enums::InitStatus::kOK;
stats::RecordDbInitEvent(success, stats::StoreType::AVAILABILITY_STORE);
if (!success) {
......
......@@ -99,7 +99,7 @@ TEST_F(PersistentAvailabilityStoreTest, InitFail) {
storage_dir_, CreateDB(), FeatureVector(), std::move(load_callback_),
14u);
db_->InitCallback(false);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kError);
EXPECT_TRUE(load_successful_.has_value());
EXPECT_FALSE(load_successful_.value());
......@@ -112,7 +112,7 @@ TEST_F(PersistentAvailabilityStoreTest, LoadFail) {
storage_dir_, CreateDB(), FeatureVector(), std::move(load_callback_),
14u);
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(load_successful_.has_value());
db_->LoadCallback(false);
......@@ -128,7 +128,7 @@ TEST_F(PersistentAvailabilityStoreTest, EmptyDBEmptyFeatureFilterUpdateFailed) {
storage_dir_, CreateDB(), FeatureVector(), std::move(load_callback_),
14u);
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(load_successful_.has_value());
db_->LoadCallback(true);
......@@ -147,7 +147,7 @@ TEST_F(PersistentAvailabilityStoreTest, EmptyDBEmptyFeatureFilterUpdateOK) {
storage_dir_, CreateDB(), FeatureVector(), std::move(load_callback_),
14u);
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(load_successful_.has_value());
db_->LoadCallback(true);
......@@ -174,7 +174,7 @@ TEST_F(PersistentAvailabilityStoreTest, AllNewFeatures) {
PersistentAvailabilityStore::LoadAndUpdateStore(
storage_dir_, CreateDB(), feature_filter, std::move(load_callback_), 14u);
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(load_successful_.has_value());
db_->LoadCallback(true);
......@@ -221,7 +221,7 @@ TEST_F(PersistentAvailabilityStoreTest, TestAllFilterCombinations) {
PersistentAvailabilityStore::LoadAndUpdateStore(
storage_dir_, CreateDB(), feature_filter, std::move(load_callback_), 14u);
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(load_successful_.has_value());
db_->LoadCallback(true);
......@@ -269,7 +269,7 @@ TEST_F(PersistentAvailabilityStoreTest, TestAllCombinationsEmptyFilter) {
storage_dir_, CreateDB(), FeatureVector(), std::move(load_callback_),
14u);
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(load_successful_.has_value());
db_->LoadCallback(true);
......
......@@ -8,6 +8,7 @@
#include "base/bind.h"
#include "components/feature_engagement/internal/stats.h"
#include "components/leveldb_proto/proto_database.h"
namespace feature_engagement {
namespace {
......@@ -62,8 +63,10 @@ void PersistentEventStore::DeleteEvent(const std::string& event_name) {
base::BindOnce(&NoopUpdateCallback));
}
void PersistentEventStore::OnInitComplete(const OnLoadedCallback& callback,
bool success) {
void PersistentEventStore::OnInitComplete(
const OnLoadedCallback& callback,
leveldb_proto::Enums::InitStatus status) {
bool success = status == leveldb_proto::Enums::InitStatus::kOK;
stats::RecordDbInitEvent(success, stats::StoreType::EVENTS_STORE);
if (!success) {
......
......@@ -34,7 +34,8 @@ class PersistentEventStore : public EventStore {
void DeleteEvent(const std::string& event_name) override;
private:
void OnInitComplete(const OnLoadedCallback& callback, bool success);
void OnInitComplete(const OnLoadedCallback& callback,
leveldb_proto::Enums::InitStatus status);
void OnLoadComplete(const OnLoadedCallback& callback,
bool success,
std::unique_ptr<std::vector<Event>> entries);
......
......@@ -78,7 +78,7 @@ TEST_F(PersistentEventStoreTest, SuccessfulInitAndLoadEmptyStore) {
store_->Load(load_callback_);
// The initialize should not trigger a response to the callback.
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(load_successful_.has_value());
// The load should trigger a response to the callback.
......@@ -117,7 +117,7 @@ TEST_F(PersistentEventStoreTest, SuccessfulInitAndLoadWithEvents) {
// The initialize should not trigger a response to the callback.
store_->Load(load_callback_);
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(load_successful_.has_value());
// The load should trigger a response to the callback.
......@@ -143,7 +143,7 @@ TEST_F(PersistentEventStoreTest, SuccessfulInitBadLoad) {
store_->Load(load_callback_);
// The initialize should not trigger a response to the callback.
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(load_successful_.has_value());
// The load will fail and should trigger the callback.
......@@ -166,7 +166,7 @@ TEST_F(PersistentEventStoreTest, BadInit) {
store_->Load(load_callback_);
// The initialize will fail and should trigger the callback.
db_->InitCallback(false);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kError);
EXPECT_FALSE(load_successful_.value());
EXPECT_FALSE(store_->IsReady());
......@@ -185,7 +185,7 @@ TEST_F(PersistentEventStoreTest, IsReady) {
store_->Load(load_callback_);
EXPECT_FALSE(store_->IsReady());
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(store_->IsReady());
db_->LoadCallback(true);
......@@ -196,7 +196,7 @@ TEST_F(PersistentEventStoreTest, WriteEvent) {
SetUpDB();
store_->Load(load_callback_);
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
db_->LoadCallback(true);
Event event;
......@@ -217,7 +217,7 @@ TEST_F(PersistentEventStoreTest, WriteAndDeleteEvent) {
SetUpDB();
store_->Load(load_callback_);
db_->InitCallback(true);
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
db_->LoadCallback(true);
Event event;
......
......@@ -24,4 +24,16 @@ leveldb_env::Options CreateSimpleOptions() {
return options;
}
// static
Enums::InitStatus Util::ConvertLevelDBStatusToInitStatus(
const leveldb::Status& status) {
if (status.ok())
return Enums::InitStatus::kOK;
if (status.IsCorruption())
return Enums::InitStatus::kCorrupt;
if (status.IsNotSupportedError() || status.IsInvalidArgument())
return Enums::InitStatus::kInvalidOperation;
return Enums::InitStatus::kError;
}
} // namespace leveldb_proto
......@@ -11,111 +11,152 @@
#include "base/sequenced_task_runner.h"
#include "base/threading/thread_checker.h"
#include "components/leveldb_proto/leveldb_database.h"
#include "components/leveldb_proto/proto_leveldb_wrapper.h"
namespace leveldb_proto {
class ProtoLevelDBWrapper;
class Enums {
public:
enum InitStatus {
kError = -1,
kNotInitialized = 0,
kOK = 1,
kCorrupt = 2,
kInvalidOperation = 3,
};
};
class Callbacks {
public:
using InitCallback = base::OnceCallback<void(bool)>;
using InitStatusCallback = base::OnceCallback<void(Enums::InitStatus)>;
using UpdateCallback = base::OnceCallback<void(bool)>;
using LoadKeysCallback =
base::OnceCallback<void(bool, std::unique_ptr<std::vector<std::string>>)>;
using DestroyCallback = base::OnceCallback<void(bool)>;
using OnCreateCallback = base::OnceCallback<void(ProtoLevelDBWrapper*)>;
template <typename T>
class Internal {
public:
using LoadCallback =
base::OnceCallback<void(bool, std::unique_ptr<std::vector<T>>)>;
using GetCallback = base::OnceCallback<void(bool, std::unique_ptr<T>)>;
using LoadKeysAndEntriesCallback =
base::OnceCallback<void(bool,
std::unique_ptr<std::map<std::string, T>>)>;
};
};
class Util {
public:
static Enums::InitStatus ConvertLevelDBStatusToInitStatus(
const leveldb::Status& status);
template <typename T>
class Internal {
public:
// A list of key-value (string, T) tuples.
using KeyEntryVector = std::vector<std::pair<std::string, T>>;
};
};
// Interface for classes providing persistent storage of Protocol Buffer
// entries (T must be a Proto type extending MessageLite).
template <typename T>
class ProtoDatabase {
public:
using InitCallback = base::OnceCallback<void(bool success)>;
using UpdateCallback = base::OnceCallback<void(bool success)>;
using LoadCallback =
base::OnceCallback<void(bool success, std::unique_ptr<std::vector<T>>)>;
using LoadKeysCallback =
base::OnceCallback<void(bool success,
std::unique_ptr<std::vector<std::string>>)>;
using LoadKeysAndEntriesCallback =
base::OnceCallback<void(bool success,
std::unique_ptr<std::map<std::string, T>>)>;
using GetCallback =
base::OnceCallback<void(bool success, std::unique_ptr<T>)>;
using DestroyCallback = base::OnceCallback<void(bool success)>;
// A list of key-value (string, T) tuples.
using KeyEntryVector = std::vector<std::pair<std::string, T>>;
// For compatibility:
using KeyEntryVector = typename Util::Internal<T>::KeyEntryVector;
virtual ~ProtoDatabase() = default;
// Asynchronously initializes the object with the specified |options|.
// |callback| will be invoked on the calling thread when complete.
virtual void Init(const std::string& client_name, InitCallback callback) = 0;
virtual void Init(const std::string& client_name,
Callbacks::InitStatusCallback callback) = 0;
// This version of Init is for compatibility, since many of the current
// proto database clients still use this.
virtual void Init(const char* client_name,
const base::FilePath& database_dir,
const leveldb_env::Options& options,
InitCallback callback) = 0;
Callbacks::InitCallback callback) = 0;
virtual void InitWithDatabase(LevelDB* database,
const base::FilePath& database_dir,
const leveldb_env::Options& options,
InitCallback callback) = 0;
Callbacks::InitStatusCallback callback) = 0;
// Asynchronously saves |entries_to_save| and deletes entries from
// |keys_to_remove| from the database. |callback| will be invoked on the
// calling thread when complete.
virtual void UpdateEntries(
std::unique_ptr<KeyEntryVector> entries_to_save,
std::unique_ptr<typename Util::Internal<T>::KeyEntryVector>
entries_to_save,
std::unique_ptr<std::vector<std::string>> keys_to_remove,
UpdateCallback callback) = 0;
Callbacks::UpdateCallback callback) = 0;
// Asynchronously saves |entries_to_save| and deletes entries that satisfies
// the |delete_key_filter| from the database. |callback| will be invoked on
// the calling thread when complete. The filter will be called on
// ProtoDatabase's taskrunner.
virtual void UpdateEntriesWithRemoveFilter(
std::unique_ptr<KeyEntryVector> entries_to_save,
std::unique_ptr<typename Util::Internal<T>::KeyEntryVector>
entries_to_save,
const LevelDB::KeyFilter& delete_key_filter,
UpdateCallback callback) = 0;
Callbacks::UpdateCallback callback) = 0;
virtual void UpdateEntriesWithRemoveFilter(
std::unique_ptr<KeyEntryVector> entries_to_save,
std::unique_ptr<typename Util::Internal<T>::KeyEntryVector>
entries_to_save,
const LevelDB::KeyFilter& delete_key_filter,
const std::string& target_prefix,
UpdateCallback callback) = 0;
Callbacks::UpdateCallback callback) = 0;
// Asynchronously loads all entries from the database and invokes |callback|
// when complete.
virtual void LoadEntries(LoadCallback callback) = 0;
virtual void LoadEntries(
typename Callbacks::Internal<T>::LoadCallback callback) = 0;
// Asynchronously loads entries that satisfies the |filter| from the database
// and invokes |callback| when complete. The filter will be called on
// ProtoDatabase's taskrunner.
virtual void LoadEntriesWithFilter(const LevelDB::KeyFilter& filter,
LoadCallback callback) = 0;
virtual void LoadEntriesWithFilter(const LevelDB::KeyFilter& key_filter,
const leveldb::ReadOptions& options,
const std::string& target_prefix,
LoadCallback callback) = 0;
virtual void LoadEntriesWithFilter(
const LevelDB::KeyFilter& filter,
typename Callbacks::Internal<T>::LoadCallback callback) = 0;
virtual void LoadEntriesWithFilter(
const LevelDB::KeyFilter& key_filter,
const leveldb::ReadOptions& options,
const std::string& target_prefix,
typename Callbacks::Internal<T>::LoadCallback callback) = 0;
virtual void LoadKeysAndEntries(LoadKeysAndEntriesCallback callback) = 0;
virtual void LoadKeysAndEntries(
typename Callbacks::Internal<T>::LoadKeysAndEntriesCallback callback) = 0;
virtual void LoadKeysAndEntriesWithFilter(
const LevelDB::KeyFilter& filter,
LoadKeysAndEntriesCallback callback) = 0;
typename Callbacks::Internal<T>::LoadKeysAndEntriesCallback callback) = 0;
virtual void LoadKeysAndEntriesWithFilter(
const LevelDB::KeyFilter& filter,
const leveldb::ReadOptions& options,
const std::string& target_prefix,
LoadKeysAndEntriesCallback callback) = 0;
typename Callbacks::Internal<T>::LoadKeysAndEntriesCallback callback) = 0;
// Asynchronously loads all keys from the database and invokes |callback| with
// those keys when complete.
virtual void LoadKeys(LoadKeysCallback callback) = 0;
virtual void LoadKeys(typename Callbacks::LoadKeysCallback callback) = 0;
virtual void LoadKeys(const std::string& target_prefix,
LoadKeysCallback callback) = 0;
typename Callbacks::LoadKeysCallback callback) = 0;
// Asynchronously loads a single entry, identified by |key|, from the database
// and invokes |callback| when complete. If no entry with |key| is found,
// a nullptr is passed to the callback, but the success flag is still true.
virtual void GetEntry(const std::string& key, GetCallback callback) = 0;
virtual void GetEntry(
const std::string& key,
typename Callbacks::Internal<T>::GetCallback callback) = 0;
// Asynchronously destroys the database.
virtual void Destroy(DestroyCallback callback) = 0;
virtual bool IsCorrupt() = 0;
virtual void Destroy(Callbacks::DestroyCallback callback) = 0;
protected:
ProtoDatabase() = default;
......
......@@ -7,8 +7,11 @@
#include "base/files/file_path.h"
#include "base/sequenced_task_runner.h"
#include "base/synchronization/lock.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "components/leveldb_proto/shared_proto_database.h"
namespace leveldb_proto {
namespace {
const char kSharedProtoDatabaseClientName[] = "SharedProtoDB";
......@@ -16,8 +19,6 @@ const char kSharedProtoDatabaseDirectory[] = "shared_proto_db";
} // namespace
namespace leveldb_proto {
ProtoDatabaseProvider::ProtoDatabaseProvider(const base::FilePath& profile_dir)
: profile_dir_(profile_dir),
task_runner_(base::CreateSequencedTaskRunnerWithTraits(
......@@ -35,30 +36,19 @@ ProtoDatabaseProvider* ProtoDatabaseProvider::Create(
void ProtoDatabaseProvider::GetSharedDBInstance(
GetSharedDBInstanceCallback callback) {
task_runner_->PostTaskAndReply(
FROM_HERE,
base::BindOnce(
&ProtoDatabaseProvider::PrepareSharedDBInstanceOnTaskRunner,
weak_factory_.GetWeakPtr()),
base::BindOnce(&ProtoDatabaseProvider::RunGetSharedDBInstanceCallback,
weak_factory_.GetWeakPtr(), std::move(callback)));
}
void ProtoDatabaseProvider::PrepareSharedDBInstanceOnTaskRunner() {
if (db_)
return;
db_ = base::WrapRefCounted(new SharedProtoDatabase(
base::CreateSequencedTaskRunnerWithTraits(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT,
base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}),
kSharedProtoDatabaseClientName,
profile_dir_.AppendASCII(std::string(kSharedProtoDatabaseDirectory))));
}
void ProtoDatabaseProvider::RunGetSharedDBInstanceCallback(
GetSharedDBInstanceCallback callback) {
std::move(callback).Run(db_);
DCHECK(base::SequencedTaskRunnerHandle::IsSet());
auto callback_task_runner = base::SequencedTaskRunnerHandle::Get();
{
base::AutoLock lock(get_db_lock_);
if (!db_) {
db_ = base::WrapRefCounted(new SharedProtoDatabase(
kSharedProtoDatabaseClientName, profile_dir_.AppendASCII(std::string(
kSharedProtoDatabaseDirectory))));
}
}
callback_task_runner->PostTask(FROM_HERE,
base::BindOnce(std::move(callback), db_));
}
} // namespace leveldb_proto
......@@ -57,6 +57,7 @@ class ProtoDatabaseProvider : public KeyedService {
base::FilePath profile_dir_;
scoped_refptr<SharedProtoDatabase> db_;
base::Lock get_db_lock_;
// The SequencedTaskRunner used to ensure thread-safe behaviour for
// GetSharedDBInstance.
scoped_refptr<base::SequencedTaskRunner> task_runner_;
......
......@@ -4,137 +4,106 @@
#include "components/leveldb_proto/proto_leveldb_wrapper.h"
#include "base/sequenced_task_runner.h"
#include "base/task/post_task.h"
#include "base/task/task_traits.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "components/leveldb_proto/proto_leveldb_wrapper_metrics.h"
namespace leveldb_proto {
namespace {
inline void InitFromTaskRunner(LevelDB* database,
const base::FilePath& database_dir,
const leveldb_env::Options& options,
bool destroy_on_corruption,
leveldb::Status* status,
const std::string& client_id) {
DCHECK(status);
Enums::InitStatus InitFromTaskRunner(LevelDB* database,
const base::FilePath& database_dir,
const leveldb_env::Options& options,
bool destroy_on_corruption,
const std::string& client_id) {
// TODO(cjhopman): Histogram for database size.
*status = database->Init(database_dir, options, destroy_on_corruption);
ProtoLevelDBWrapperMetrics::RecordInit(client_id, *status);
}
auto status = database->Init(database_dir, options, destroy_on_corruption);
ProtoLevelDBWrapperMetrics::RecordInit(client_id, status);
void RunDestroyCallback(typename ProtoLevelDBWrapper::DestroyCallback callback,
const bool* success) {
std::move(callback).Run(*success);
return Util::ConvertLevelDBStatusToInitStatus(status);
}
inline void DestroyFromTaskRunner(LevelDB* database,
bool* success,
const std::string& client_id) {
CHECK(success);
bool DestroyFromTaskRunner(LevelDB* database, const std::string& client_id) {
auto status = database->Destroy();
*success = status.ok();
ProtoLevelDBWrapperMetrics::RecordDestroy(client_id, *success);
}
bool success = status.ok();
ProtoLevelDBWrapperMetrics::RecordDestroy(client_id, success);
void RunLoadKeysCallback(
typename ProtoLevelDBWrapper::LoadKeysCallback callback,
std::unique_ptr<bool> success,
std::unique_ptr<std::vector<std::string>> keys) {
std::move(callback).Run(*success, std::move(keys));
return success;
}
inline void LoadKeysFromTaskRunner(LevelDB* database,
const std::string& target_prefix,
std::vector<std::string>* keys,
bool* success,
const std::string& client_id) {
DCHECK(success);
DCHECK(keys);
keys->clear();
*success = database->LoadKeys(target_prefix, keys);
ProtoLevelDBWrapperMetrics::RecordLoadKeys(client_id, *success);
void LoadKeysFromTaskRunner(
LevelDB* database,
const std::string& target_prefix,
const std::string& client_id,
Callbacks::LoadKeysCallback callback,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner) {
auto keys = std::make_unique<std::vector<std::string>>();
bool success = database->LoadKeys(target_prefix, keys.get());
ProtoLevelDBWrapperMetrics::RecordLoadKeys(client_id, success);
callback_task_runner->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), success, std::move(keys)));
}
} // namespace
ProtoLevelDBWrapper::ProtoLevelDBWrapper(
const scoped_refptr<base::SequencedTaskRunner>& task_runner)
: task_runner_(task_runner), weak_ptr_factory_(this) {
: task_runner_(task_runner) {
DETACH_FROM_SEQUENCE(sequence_checker_);
}
ProtoLevelDBWrapper::ProtoLevelDBWrapper(
const scoped_refptr<base::SequencedTaskRunner>& task_runner,
LevelDB* db)
: task_runner_(task_runner), db_(db), weak_ptr_factory_(this) {
: task_runner_(task_runner), db_(db) {
DETACH_FROM_SEQUENCE(sequence_checker_);
}
ProtoLevelDBWrapper::~ProtoLevelDBWrapper() = default;
void ProtoLevelDBWrapper::RunInitCallback(
typename ProtoLevelDBWrapper::InitCallback callback,
const leveldb::Status* status) {
is_corrupt_ = status->IsCorruption();
std::move(callback).Run(status->ok());
}
void ProtoLevelDBWrapper::InitWithDatabase(
LevelDB* database,
const base::FilePath& database_dir,
const leveldb_env::Options& options,
bool destroy_on_corruption,
typename ProtoLevelDBWrapper::InitCallback callback) {
Callbacks::InitStatusCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(!db_);
DCHECK(database);
db_ = database;
leveldb::Status* status = new leveldb::Status();
task_runner_->PostTaskAndReply(
FROM_HERE,
base::PostTaskAndReplyWithResult(
task_runner_.get(), FROM_HERE,
base::BindOnce(InitFromTaskRunner, base::Unretained(db_), database_dir,
options, destroy_on_corruption, status, metrics_id_),
base::BindOnce(&ProtoLevelDBWrapper::RunInitCallback,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
base::Owned(status)));
options, destroy_on_corruption, metrics_id_),
std::move(callback));
}
void ProtoLevelDBWrapper::Destroy(
typename ProtoLevelDBWrapper::DestroyCallback callback) {
void ProtoLevelDBWrapper::Destroy(Callbacks::DestroyCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(db_);
bool* success = new bool(false);
task_runner_->PostTaskAndReply(
FROM_HERE,
base::BindOnce(DestroyFromTaskRunner, base::Unretained(db_), success,
metrics_id_),
base::BindOnce(RunDestroyCallback, std::move(callback),
base::Owned(success)));
base::PostTaskAndReplyWithResult(
task_runner_.get(), FROM_HERE,
base::BindOnce(DestroyFromTaskRunner, base::Unretained(db_), metrics_id_),
std::move(callback));
}
void ProtoLevelDBWrapper::LoadKeys(
typename ProtoLevelDBWrapper::LoadKeysCallback callback) {
typename Callbacks::LoadKeysCallback callback) {
LoadKeys(std::string(), std::move(callback));
}
void ProtoLevelDBWrapper::LoadKeys(
const std::string& target_prefix,
typename ProtoLevelDBWrapper::LoadKeysCallback callback) {
typename Callbacks::LoadKeysCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto success = std::make_unique<bool>(false);
auto keys = std::make_unique<std::vector<std::string>>();
bool* success_ptr = success.get();
std::vector<std::string>* keys_ptr = keys.get();
task_runner_->PostTaskAndReply(
FROM_HERE,
base::BindOnce(LoadKeysFromTaskRunner, base::Unretained(db_),
target_prefix, base::Unretained(keys_ptr),
base::Unretained(success_ptr), metrics_id_),
base::BindOnce(RunLoadKeysCallback, std::move(callback),
std::move(success), std::move(keys)));
task_runner_->PostTask(
FROM_HERE, base::BindOnce(LoadKeysFromTaskRunner, base::Unretained(db_),
target_prefix, metrics_id_, std::move(callback),
base::SequencedTaskRunnerHandle::Get()));
}
void ProtoLevelDBWrapper::SetMetricsId(const std::string& id) {
......@@ -148,10 +117,6 @@ bool ProtoLevelDBWrapper::GetApproximateMemoryUse(uint64_t* approx_mem_use) {
return db_->GetApproximateMemoryUse(approx_mem_use);
}
bool ProtoLevelDBWrapper::IsCorrupt() {
return is_corrupt_;
}
const scoped_refptr<base::SequencedTaskRunner>&
ProtoLevelDBWrapper::task_runner() {
return task_runner_;
......
......@@ -14,11 +14,17 @@
namespace leveldb_proto {
SharedProtoDatabase::SharedProtoDatabase(
const scoped_refptr<base::SequencedTaskRunner>& task_runner,
const std::string& client_name,
const base::FilePath& db_dir)
: task_runner_(task_runner),
inline void RunCallbackOnCallingSequence(
Callbacks::InitCallback callback,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
bool success) {
callback_task_runner->PostTask(FROM_HERE,
base::BindOnce(std::move(callback), success));
}
SharedProtoDatabase::SharedProtoDatabase(const std::string& client_name,
const base::FilePath& db_dir)
: task_runner_(base::SequencedTaskRunnerHandle::Get()),
db_dir_(db_dir),
db_wrapper_(std::make_unique<ProtoLevelDBWrapper>(task_runner_)),
db_(std::make_unique<LevelDB>(client_name.c_str())),
......@@ -30,17 +36,21 @@ SharedProtoDatabase::SharedProtoDatabase(
// this after a database Init will receive the correct status of the database.
// PostTaskAndReply is used to ensure that we call the Init callback on its
// original calling thread.
void SharedProtoDatabase::GetDatabaseInitStateAsync(
ProtoLevelDBWrapper::InitCallback callback) {
task_runner_->PostTaskAndReply(
FROM_HERE, base::DoNothing(),
base::BindOnce(&SharedProtoDatabase::RunInitCallback,
weak_factory_.GetWeakPtr(), std::move(callback)));
void SharedProtoDatabase::GetDatabaseInitStatusAsync(
Callbacks::InitStatusCallback callback) {
DCHECK(base::SequencedTaskRunnerHandle::IsSet());
auto current_task_runner = base::SequencedTaskRunnerHandle::Get();
task_runner_->PostTask(
FROM_HERE, base::BindOnce(&SharedProtoDatabase::RunInitCallback,
weak_factory_.GetWeakPtr(), std::move(callback),
std::move(current_task_runner)));
}
void SharedProtoDatabase::RunInitCallback(
ProtoLevelDBWrapper::InitCallback callback) {
std::move(callback).Run(init_state_ == InitState::kSuccess);
Callbacks::InitStatusCallback callback,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner) {
callback_task_runner->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), init_status_));
}
// Setting |create_if_missing| to false allows us to test whether or not the
......@@ -52,7 +62,7 @@ void SharedProtoDatabase::RunInitCallback(
// with this set to true, and others false.
void SharedProtoDatabase::Init(
bool create_if_missing,
ProtoLevelDBWrapper::InitCallback callback,
Callbacks::InitStatusCallback callback,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner) {
DCHECK_CALLED_ON_VALID_SEQUENCE(on_task_runner_);
......@@ -60,7 +70,14 @@ void SharedProtoDatabase::Init(
// continue to try initialization for every new request.
if (init_state_ == InitState::kSuccess) {
callback_task_runner->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), true /* success */));
FROM_HERE, base::BindOnce(std::move(callback),
Enums::InitStatus::kOK /* status */));
return;
}
if (init_state_ == InitState::kInProgress) {
outstanding_init_requests_.emplace(
std::make_pair(std::move(callback), std::move(callback_task_runner)));
return;
}
......@@ -78,23 +95,31 @@ void SharedProtoDatabase::Init(
callback_task_runner));
}
void SharedProtoDatabase::ProcessInitRequests(Enums::InitStatus status) {
// The pairs are stored as (callback, callback_task_runner).
while (!outstanding_init_requests_.empty()) {
auto request = std::move(outstanding_init_requests_.front());
auto task_runner = std::move(request.second);
task_runner->PostTask(FROM_HERE,
base::BindOnce(std::move(request.first), status));
outstanding_init_requests_.pop();
}
}
void SharedProtoDatabase::OnDatabaseInit(
ProtoLevelDBWrapper::InitCallback callback,
Callbacks::InitStatusCallback callback,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
bool success) {
Enums::InitStatus status) {
DCHECK_CALLED_ON_VALID_SEQUENCE(on_task_runner_);
init_state_ = success ? InitState::kSuccess : InitState::kFailure;
// TODO(thildebr): Check the db_wrapper_->IsCorrupt() and store corruption
// information to inform clients they may have lost data.
init_state_ = status == Enums::InitStatus::kOK ? InitState::kSuccess
: InitState::kFailure;
ProcessInitRequests(status);
callback_task_runner->PostTask(FROM_HERE,
base::BindOnce(std::move(callback), success));
base::BindOnce(std::move(callback), status));
}
SharedProtoDatabase::~SharedProtoDatabase() {
DCHECK_CALLED_ON_VALID_SEQUENCE(on_creation_sequence_);
}
SharedProtoDatabase::~SharedProtoDatabase() = default;
LevelDB* SharedProtoDatabase::GetLevelDBForTesting() const {
return db_.get();
......
......@@ -20,12 +20,20 @@
namespace leveldb_proto {
template <typename T>
void GetClientInitCallback(
base::OnceCallback<void(std::unique_ptr<SharedProtoDatabaseClient<T>>)>
callback,
std::unique_ptr<SharedProtoDatabaseClient<T>> client,
Enums::InitStatus status);
// Controls a single LevelDB database to be used by many clients, and provides
// a way to get SharedProtoDatabaseClients that allow shared access to the
// underlying single database.
class SharedProtoDatabase : public base::RefCounted<SharedProtoDatabase> {
class SharedProtoDatabase
: public base::RefCountedThreadSafe<SharedProtoDatabase> {
public:
void GetDatabaseInitStateAsync(ProtoLevelDBWrapper::InitCallback callback);
void GetDatabaseInitStatusAsync(Callbacks::InitStatusCallback callback);
// Always returns a SharedProtoDatabaseClient pointer, but that should ONLY
// be used if the callback returns success.
......@@ -34,12 +42,13 @@ class SharedProtoDatabase : public base::RefCounted<SharedProtoDatabase> {
const std::string& client_namespace,
const std::string& type_prefix,
bool create_if_missing,
ProtoLevelDBWrapper::InitCallback callback);
Callbacks::InitStatusCallback callback);
private:
friend class base::RefCounted<SharedProtoDatabase>;
friend class base::RefCountedThreadSafe<SharedProtoDatabase>;
friend class ProtoDatabaseProvider;
friend class ProtoDatabaseWrapperTest;
friend class SharedProtoDatabaseTest;
friend class SharedProtoDatabaseClientTest;
......@@ -52,26 +61,33 @@ class SharedProtoDatabase : public base::RefCounted<SharedProtoDatabase> {
// Private since we only want to create a singleton of it.
SharedProtoDatabase(
const scoped_refptr<base::SequencedTaskRunner>& task_runner,
const std::string& client_name,
const base::FilePath& db_dir);
virtual ~SharedProtoDatabase();
void ProcessInitRequests(Enums::InitStatus status);
template <typename T>
std::unique_ptr<SharedProtoDatabaseClient<T>> GetClientInternal(
const std::string& client_namespace,
const std::string& type_prefix);
// |callback_task_runner| should be the same sequence that Init was called
// from.
void Init(bool create_if_missing,
ProtoLevelDBWrapper::InitCallback callback,
Callbacks::InitStatusCallback callback,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner);
void OnDatabaseInit(
ProtoLevelDBWrapper::InitCallback callback,
Callbacks::InitStatusCallback callback,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
bool success);
void RunInitCallback(ProtoLevelDBWrapper::InitCallback callback);
Enums::InitStatus status);
void RunInitCallback(
Callbacks::InitStatusCallback callback,
scoped_refptr<base::SequencedTaskRunner> callback_task_runner);
LevelDB* GetLevelDBForTesting() const;
SEQUENCE_CHECKER(on_creation_sequence_);
SEQUENCE_CHECKER(on_task_runner_);
InitState init_state_ = InitState::kNone;
......@@ -84,6 +100,14 @@ class SharedProtoDatabase : public base::RefCounted<SharedProtoDatabase> {
std::unique_ptr<ProtoLevelDBWrapper> db_wrapper_;
std::unique_ptr<LevelDB> db_;
// Used to return to the Init callback in the case of an error, so we can
// report corruptions.
Enums::InitStatus init_status_ = Enums::InitStatus::kNotInitialized;
std::queue<std::pair<Callbacks::InitStatusCallback,
scoped_refptr<base::SequencedTaskRunner>>>
outstanding_init_requests_;
base::WeakPtrFactory<SharedProtoDatabase> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(SharedProtoDatabase);
......@@ -94,7 +118,7 @@ std::unique_ptr<SharedProtoDatabaseClient<T>> SharedProtoDatabase::GetClient(
const std::string& client_namespace,
const std::string& type_prefix,
bool create_if_missing,
ProtoLevelDBWrapper::InitCallback callback) {
Callbacks::InitStatusCallback callback) {
DCHECK(base::SequencedTaskRunnerHandle::IsSet());
auto current_task_runner = base::SequencedTaskRunnerHandle::Get();
task_runner_->PostTask(
......@@ -102,6 +126,13 @@ std::unique_ptr<SharedProtoDatabaseClient<T>> SharedProtoDatabase::GetClient(
base::BindOnce(&SharedProtoDatabase::Init, weak_factory_.GetWeakPtr(),
create_if_missing, std::move(callback),
std::move(current_task_runner)));
return GetClientInternal<T>(client_namespace, type_prefix);
}
template <typename T>
std::unique_ptr<SharedProtoDatabaseClient<T>>
SharedProtoDatabase::GetClientInternal(const std::string& client_namespace,
const std::string& type_prefix) {
return base::WrapUnique(new SharedProtoDatabaseClient<T>(
std::make_unique<ProtoLevelDBWrapper>(task_runner_, db_.get()),
client_namespace, type_prefix, this));
......
......@@ -32,10 +32,10 @@ bool KeyFilterStripPrefix(const LevelDB::KeyFilter& key_filter,
return key_filter.Run(StripPrefix(key, prefix));
}
void GetSharedDatabaseInitStateAsync(
void GetSharedDatabaseInitStatusAsync(
const scoped_refptr<SharedProtoDatabase>& shared_db,
ProtoLevelDBWrapper::InitCallback callback) {
shared_db->GetDatabaseInitStateAsync(std::move(callback));
Callbacks::InitStatusCallback callback) {
shared_db->GetDatabaseInitStatusAsync(std::move(callback));
}
} // namespace leveldb_proto
......@@ -31,9 +31,8 @@ class SharedProtoDatabaseClientTest : public testing::Test {
void SetUp() override {
temp_dir_.reset(new base::ScopedTempDir());
ASSERT_TRUE(temp_dir_->CreateUniqueTempDir());
db_ = base::WrapRefCounted(new SharedProtoDatabase(
scoped_task_environment_.GetMainThreadTaskRunner(), "client",
temp_dir_->GetPath()));
db_ = base::WrapRefCounted(
new SharedProtoDatabase("client", temp_dir_->GetPath()));
}
void TearDown() override {
......@@ -51,12 +50,13 @@ class SharedProtoDatabaseClientTest : public testing::Test {
const std::string& client_namespace,
const std::string& type_prefix,
bool create_if_missing,
ProtoLevelDBWrapper::InitCallback callback) {
Callbacks::InitStatusCallback callback) {
return db_->GetClient<T>(
client_namespace, type_prefix, create_if_missing,
base::BindOnce([](ProtoLevelDBWrapper::InitCallback callback,
bool success) { std::move(callback).Run(success); },
std::move(callback)));
base::BindOnce(
[](Callbacks::InitStatusCallback callback,
Enums::InitStatus status) { std::move(callback).Run(status); },
std::move(callback)));
}
template <typename T>
......@@ -64,16 +64,17 @@ class SharedProtoDatabaseClientTest : public testing::Test {
const std::string& client_namespace,
const std::string& type_prefix,
bool create_if_missing,
bool* success) {
Enums::InitStatus* status) {
base::RunLoop loop;
auto client = GetClient<T>(
client_namespace, type_prefix, create_if_missing,
base::BindOnce(
[](bool* success_out, base::OnceClosure closure, bool success) {
*success_out = success;
[](Enums::InitStatus* status_out, base::OnceClosure closure,
Enums::InitStatus status) {
*status_out = status;
std::move(closure).Run();
},
success, loop.QuitClosure()));
status, loop.QuitClosure()));
loop.Run();
return client;
}
......@@ -242,35 +243,37 @@ class SharedProtoDatabaseClientTest : public testing::Test {
};
TEST_F(SharedProtoDatabaseClientTest, InitSuccess) {
bool success = false;
auto status = Enums::InitStatus::kError;
auto client =
GetClientAndWait<TestProto>(kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
client->Init("client",
base::BindOnce([](bool success) { ASSERT_TRUE(success); }));
client->Init("client", base::BindOnce([](Enums::InitStatus status) {
ASSERT_EQ(status, Enums::InitStatus::kOK);
}));
}
TEST_F(SharedProtoDatabaseClientTest, InitFail) {
bool success = false;
auto status = Enums::InitStatus::kError;
auto client =
GetClientAndWait<TestProto>(kDefaultNamespace, kDefaultTypePrefix,
false /* create_if_missing */, &success);
ASSERT_FALSE(success);
false /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kInvalidOperation);
client->Init("client",
base::BindOnce([](bool success) { ASSERT_FALSE(success); }));
client->Init("client", base::BindOnce([](Enums::InitStatus status) {
ASSERT_EQ(status, Enums::InitStatus::kError);
}));
}
// Ensure that our LevelDB contains the properly prefixed entries and also
// removes prefixed entries correctly.
TEST_F(SharedProtoDatabaseClientTest, UpdateEntriesAppropriatePrefix) {
bool success = false;
auto status = Enums::InitStatus::kError;
auto client =
GetClientAndWait<TestProto>(kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
std::vector<std::string> key_list = {"entry1", "entry2", "entry3"};
UpdateEntries(client.get(), key_list, leveldb_proto::KeyVector(), true);
......@@ -296,15 +299,15 @@ TEST_F(SharedProtoDatabaseClientTest, UpdateEntriesAppropriatePrefix) {
TEST_F(SharedProtoDatabaseClientTest,
UpdateEntries_DeletesCorrectClientEntries) {
bool success = false;
auto status = Enums::InitStatus::kError;
auto client_a =
GetClientAndWait<TestProto>(kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
auto client_b =
GetClientAndWait<TestProto>(kDefaultNamespace2, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
std::vector<std::string> key_list = {"entry1", "entry2", "entry3"};
UpdateEntries(client_a.get(), key_list, leveldb_proto::KeyVector(), true);
......@@ -330,11 +333,11 @@ TEST_F(SharedProtoDatabaseClientTest,
TEST_F(SharedProtoDatabaseClientTest,
UpdateEntriesWithRemoveFilter_DeletesCorrectEntries) {
bool success = false;
auto status = Enums::InitStatus::kError;
auto client =
GetClientAndWait<TestProto>(kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
std::vector<std::string> key_list = {"entry1", "entry2", "testentry3"};
UpdateEntries(client.get(), key_list, leveldb_proto::KeyVector(), true);
......@@ -361,15 +364,15 @@ TEST_F(SharedProtoDatabaseClientTest,
}
TEST_F(SharedProtoDatabaseClientTest, LoadEntriesWithFilter) {
bool success = false;
auto status = Enums::InitStatus::kError;
auto client_a =
GetClientAndWait<TestProto>(kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
auto client_b =
GetClientAndWait<TestProto>(kDefaultNamespace2, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
std::vector<std::string> key_list_a = {"entry123", "entry2123", "testentry3"};
UpdateEntries(client_a.get(), key_list_a, leveldb_proto::KeyVector(), true);
......@@ -404,15 +407,15 @@ TEST_F(SharedProtoDatabaseClientTest, LoadEntriesWithFilter) {
}
TEST_F(SharedProtoDatabaseClientTest, LoadKeys) {
bool success = false;
auto status = Enums::InitStatus::kError;
auto client_a =
GetClientAndWait<TestProto>(kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
auto client_b =
GetClientAndWait<TestProto>(kDefaultNamespace2, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
std::vector<std::string> key_list_a = {"entry123", "entry2123", "testentry3",
"testing"};
......@@ -431,15 +434,15 @@ TEST_F(SharedProtoDatabaseClientTest, LoadKeys) {
}
TEST_F(SharedProtoDatabaseClientTest, GetEntry) {
bool success = false;
auto status = Enums::InitStatus::kError;
auto client_a =
GetClientAndWait<TestProto>(kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
auto client_b =
GetClientAndWait<TestProto>(kDefaultNamespace2, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
std::vector<std::string> key_list = {"a", "b", "c"};
// Add the same entries to both because we want to make sure we only get the
......@@ -461,15 +464,15 @@ TEST_F(SharedProtoDatabaseClientTest, GetEntry) {
}
TEST_F(SharedProtoDatabaseClientTest, TestDestroy) {
bool success = false;
auto status = Enums::InitStatus::kError;
auto client_a =
GetClientAndWait<TestProto>(kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
auto client_b =
GetClientAndWait<TestProto>(kDefaultNamespace2, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
std::vector<std::string> key_list = {"a", "b", "c"};
// Add the same entries to both because we want to make sure we only destroy
......
......@@ -22,6 +22,17 @@ const std::string kDefaultNamespace = "ns";
const std::string kDefaultNamespace2 = "ns2";
const std::string kDefaultTypePrefix = "tp";
inline void GetClientFromTaskRunner(SharedProtoDatabase* db,
const std::string& client_namespace,
const std::string& type_prefix,
base::OnceClosure closure) {
db->GetClient<TestProto>(
client_namespace, type_prefix, true /* create_if_missing */,
base::BindOnce([](base::OnceClosure closure,
Enums::InitStatus status) { std::move(closure).Run(); },
std::move(closure)));
}
} // namespace
class SharedProtoDatabaseTest : public testing::Test {
......@@ -29,30 +40,19 @@ class SharedProtoDatabaseTest : public testing::Test {
void SetUp() override {
temp_dir_ = std::make_unique<base::ScopedTempDir>();
ASSERT_TRUE(temp_dir_->CreateUniqueTempDir());
db_thread_ = std::make_unique<base::Thread>("db_thread");
ASSERT_TRUE(db_thread_->Start());
db_ = base::WrapRefCounted(new SharedProtoDatabase(
db_thread_->task_runner(), "client", temp_dir_->GetPath()));
db_ = base::WrapRefCounted(
new SharedProtoDatabase("client", temp_dir_->GetPath()));
}
void TearDown() override {}
void InitDB(bool create_if_missing,
ProtoLevelDBWrapper::InitCallback callback) {
void InitDB(bool create_if_missing, Callbacks::InitStatusCallback callback) {
db_->Init(create_if_missing, std::move(callback),
scoped_task_environment_.GetMainThreadTaskRunner());
}
void KillDB() { db_.reset(); }
scoped_refptr<SharedProtoDatabase> CreateDatabase(
const scoped_refptr<base::SequencedTaskRunner>& task_runner,
const char* client_name,
const base::FilePath& db_dir) {
return base::WrapRefCounted(
new SharedProtoDatabase(task_runner, client_name, db_dir));
}
bool IsDatabaseInitialized(SharedProtoDatabase* db) {
return db->init_state_ == SharedProtoDatabase::InitState::kSuccess;
}
......@@ -63,16 +63,17 @@ class SharedProtoDatabaseTest : public testing::Test {
const std::string& client_namespace,
const std::string& type_prefix,
bool create_if_missing,
bool* success) {
Enums::InitStatus* status) {
base::RunLoop loop;
auto client = db->GetClient<T>(
client_namespace, type_prefix, create_if_missing,
base::BindOnce(
[](bool* success_out, base::OnceClosure closure, bool success) {
*success_out = success;
[](Enums::InitStatus* status_out, base::OnceClosure closure,
Enums::InitStatus status) {
*status_out = status;
std::move(closure).Run();
},
success, loop.QuitClosure()));
status, loop.QuitClosure()));
loop.Run();
return client;
}
......@@ -88,26 +89,14 @@ class SharedProtoDatabaseTest : public testing::Test {
base::test::ScopedTaskEnvironment scoped_task_environment_;
std::unique_ptr<base::ScopedTempDir> temp_dir_;
std::unique_ptr<base::Thread> db_thread_;
scoped_refptr<SharedProtoDatabase> db_;
};
inline void GetClientFromTaskRunner(SharedProtoDatabase* db,
const std::string& client_namespace,
const std::string& type_prefix,
base::OnceClosure closure) {
db->GetClient<TestProto>(
client_namespace, type_prefix, true /* create_if_missing */,
base::BindOnce([](base::OnceClosure closure,
bool success) { std::move(closure).Run(); },
std::move(closure)));
}
TEST_F(SharedProtoDatabaseTest, CreateClient_SucceedsWithCreate) {
bool success = false;
auto status = Enums::InitStatus::kError;
GetClientAndWait<TestProto>(db(), kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
}
// TODO(912117): Fix flaky test!
......@@ -116,28 +105,28 @@ TEST_F(SharedProtoDatabaseTest, DISABLED_CreateClient_FailsWithoutCreate) {
#else
TEST_F(SharedProtoDatabaseTest, CreateClient_FailsWithoutCreate) {
#endif
bool success = false;
auto status = Enums::InitStatus::kError;
GetClientAndWait<TestProto>(db(), kDefaultNamespace, kDefaultTypePrefix,
false /* create_if_missing */, &success);
ASSERT_FALSE(success);
false /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kInvalidOperation);
}
TEST_F(SharedProtoDatabaseTest,
CreateClient_SucceedsWithoutCreateIfAlreadyCreated) {
bool success = false;
auto status = Enums::InitStatus::kError;
GetClientAndWait<TestProto>(db(), kDefaultNamespace2, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
GetClientAndWait<TestProto>(db(), kDefaultNamespace, kDefaultTypePrefix,
false /* create_if_missing */, &success);
ASSERT_TRUE(success);
false /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
}
TEST_F(SharedProtoDatabaseTest, GetClient_DifferentThreads) {
bool success = false;
auto status = Enums::InitStatus::kError;
GetClientAndWait<TestProto>(db(), kDefaultNamespace, kDefaultTypePrefix,
true /* create_if_missing */, &success);
ASSERT_TRUE(success);
true /* create_if_missing */, &status);
ASSERT_EQ(status, Enums::InitStatus::kOK);
base::Thread t("test_thread");
ASSERT_TRUE(t.Start());
......@@ -158,8 +147,8 @@ TEST_F(SharedProtoDatabaseTest, TestDBDestructionAfterInit) {
base::RunLoop run_init_loop;
InitDB(true /* create_if_missing */,
base::BindOnce(
[](base::OnceClosure signal, bool success) {
ASSERT_TRUE(success);
[](base::OnceClosure signal, Enums::InitStatus status) {
ASSERT_EQ(status, Enums::InitStatus::kOK);
std::move(signal).Run();
},
run_init_loop.QuitClosure()));
......
This diff is collapsed.
......@@ -85,6 +85,7 @@ class MockDB : public LevelDB {
class MockDatabaseCaller {
public:
MOCK_METHOD1(InitCallback, void(bool));
MOCK_METHOD1(InitStatusCallback, void(Enums::InitStatus));
MOCK_METHOD1(DestroyCallback, void(bool));
MOCK_METHOD1(SaveCallback, void(bool));
void LoadCallback(bool success,
......@@ -209,10 +210,10 @@ TEST_F(UniqueProtoDatabaseTest, TestDBInitSuccess) {
.WillOnce(Return(leveldb::Status()));
MockDatabaseCaller caller;
EXPECT_CALL(caller, InitCallback(true));
EXPECT_CALL(caller, InitStatusCallback(Enums::InitStatus::kOK));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
base::RunLoop().RunUntilIdle();
......@@ -229,10 +230,10 @@ TEST_F(UniqueProtoDatabaseTest, TestDBInitFailure) {
Return(leveldb::Status::IOError(leveldb::Slice(), leveldb::Slice())));
MockDatabaseCaller caller;
EXPECT_CALL(caller, InitCallback(false));
EXPECT_CALL(caller, InitStatusCallback(Enums::InitStatus::kError));
db_->InitWithDatabase(mock_db.get(), path, options,
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
base::RunLoop().RunUntilIdle();
......@@ -246,10 +247,10 @@ TEST_F(UniqueProtoDatabaseTest, TestDBDestroySuccess) {
.WillOnce(Return(leveldb::Status()));
MockDatabaseCaller caller;
EXPECT_CALL(caller, InitCallback(true));
EXPECT_CALL(caller, InitStatusCallback(Enums::InitStatus::kOK));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
EXPECT_CALL(caller, DestroyCallback(true));
......@@ -268,10 +269,10 @@ TEST_F(UniqueProtoDatabaseTest, TestDBDestroyFailure) {
.WillOnce(Return(leveldb::Status()));
MockDatabaseCaller caller;
EXPECT_CALL(caller, InitCallback(true));
EXPECT_CALL(caller, InitStatusCallback(Enums::InitStatus::kOK));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
EXPECT_CALL(caller, DestroyCallback(false));
......@@ -325,9 +326,9 @@ TEST_F(UniqueProtoDatabaseTest, TestDBLoadSuccess) {
EntryMap model = GetSmallModel();
EXPECT_CALL(*mock_db, Init(_, options_, _));
EXPECT_CALL(caller, InitCallback(_));
EXPECT_CALL(caller, InitStatusCallback(_));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
EXPECT_CALL(*mock_db, LoadKeysAndEntriesWithFilter(_, _, _, _))
......@@ -348,9 +349,9 @@ TEST_F(UniqueProtoDatabaseTest, TestDBLoadFailure) {
MockDatabaseCaller caller;
EXPECT_CALL(*mock_db, Init(_, options_, _));
EXPECT_CALL(caller, InitCallback(_));
EXPECT_CALL(caller, InitStatusCallback(_));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
EXPECT_CALL(*mock_db, LoadWithFilter(_, _, _, _)).WillOnce(Return(false));
......@@ -388,9 +389,9 @@ TEST_F(UniqueProtoDatabaseTest, TestDBGetSuccess) {
EntryMap model = GetSmallModel();
EXPECT_CALL(*mock_db, Init(_, options_, _));
EXPECT_CALL(caller, InitCallback(_));
EXPECT_CALL(caller, InitStatusCallback(_));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
std::string key("1");
......@@ -420,10 +421,12 @@ TEST_F(UniqueProtoDatabaseLevelDBTest, TestDBSaveAndLoadKeys) {
std::unique_ptr<UniqueProtoDatabase<TestProto>> db(
new UniqueProtoDatabase<TestProto>(db_thread.task_runner()));
auto expect_init_success =
base::BindOnce([](bool success) { EXPECT_TRUE(success); });
MockDatabaseCaller caller;
EXPECT_CALL(caller, InitCallback(true));
db->Init(kTestLevelDBClientName, temp_dir.GetPath(), CreateSimpleOptions(),
std::move(expect_init_success));
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::Unretained(&caller)));
base::RunLoop run_update_entries;
auto expect_update_success = base::BindOnce(
......@@ -456,10 +459,6 @@ TEST_F(UniqueProtoDatabaseLevelDBTest, TestDBSaveAndLoadKeys) {
// Shutdown database.
db.reset();
base::RunLoop run_destruction;
db_thread.task_runner()->PostTaskAndReply(FROM_HERE, base::DoNothing(),
run_destruction.QuitClosure());
run_destruction.Run();
}
TEST_F(UniqueProtoDatabaseTest, TestDBGetNotFound) {
......@@ -470,9 +469,9 @@ TEST_F(UniqueProtoDatabaseTest, TestDBGetNotFound) {
EntryMap model = GetSmallModel();
EXPECT_CALL(*mock_db, Init(_, options_, _));
EXPECT_CALL(caller, InitCallback(_));
EXPECT_CALL(caller, InitStatusCallback(_));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
std::string key("does_not_exist");
......@@ -493,9 +492,9 @@ TEST_F(UniqueProtoDatabaseTest, TestDBGetFailure) {
EntryMap model = GetSmallModel();
EXPECT_CALL(*mock_db, Init(_, options_, _));
EXPECT_CALL(caller, InitCallback(_));
EXPECT_CALL(caller, InitStatusCallback(_));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
std::string key("does_not_exist");
......@@ -538,9 +537,9 @@ TEST_F(UniqueProtoDatabaseTest, TestDBSaveSuccess) {
EntryMap model = GetSmallModel();
EXPECT_CALL(*mock_db, Init(_, options_, _));
EXPECT_CALL(caller, InitCallback(_));
EXPECT_CALL(caller, InitStatusCallback(_));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
std::unique_ptr<ProtoDatabase<TestProto>::KeyEntryVector> entries(
......@@ -569,9 +568,9 @@ TEST_F(UniqueProtoDatabaseTest, TestDBSaveFailure) {
std::unique_ptr<KeyVector> keys_to_remove(new KeyVector());
EXPECT_CALL(*mock_db, Init(_, options_, _));
EXPECT_CALL(caller, InitCallback(_));
EXPECT_CALL(caller, InitStatusCallback(_));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
EXPECT_CALL(*mock_db, Save(_, _, _)).WillOnce(Return(false));
......@@ -594,9 +593,9 @@ TEST_F(UniqueProtoDatabaseTest, TestDBRemoveSuccess) {
EntryMap model = GetSmallModel();
EXPECT_CALL(*mock_db, Init(_, options_, _));
EXPECT_CALL(caller, InitCallback(_));
EXPECT_CALL(caller, InitStatusCallback(_));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
std::unique_ptr<ProtoDatabase<TestProto>::KeyEntryVector> entries(
......@@ -625,9 +624,9 @@ TEST_F(UniqueProtoDatabaseTest, TestDBRemoveFailure) {
std::unique_ptr<KeyVector> keys_to_remove(new KeyVector());
EXPECT_CALL(*mock_db, Init(_, options_, _));
EXPECT_CALL(caller, InitCallback(_));
EXPECT_CALL(caller, InitStatusCallback(_));
db_->InitWithDatabase(mock_db.get(), path, CreateSimpleOptions(),
base::BindOnce(&MockDatabaseCaller::InitCallback,
base::BindOnce(&MockDatabaseCaller::InitStatusCallback,
base::Unretained(&caller)));
EXPECT_CALL(*mock_db, Save(_, _, _)).WillOnce(Return(false));
......@@ -667,7 +666,6 @@ TEST(UniqueProtoDatabaseThreadingTest, TestDBDestruction) {
init_loop.Run();
db.reset();
base::RunLoop run_loop;
db_thread.task_runner()->PostTaskAndReply(FROM_HERE, base::DoNothing(),
run_loop.QuitClosure());
......@@ -705,8 +703,6 @@ TEST(UniqueProtoDatabaseThreadingTest, TestDBDestroy) {
db->Destroy(base::BindOnce(&MockDatabaseCaller::DestroyCallback,
base::Unretained(&caller)));
db.reset();
base::RunLoop run_loop;
db_thread.task_runner()->PostTaskAndReply(FROM_HERE, base::DoNothing(),
run_loop.QuitClosure());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment