Commit 6405d08c authored by Jimmy Gong's avatar Jimmy Gong Committed by Commit Bot

Add discovery and share callback vectors to SmbShareFinder

- Adds vectors for both discovery and share callbacks.
- Solves the race condition in which multiple
  dialogs of "Add Smb Share" would use a shared instance of each
  HostLocator.
- If a SmbShareFinder::GatherSharesInNetwork call comes in when one is
  already running, the callbacks are stored and run when the original
  GatherSharesInNetwork call finishes.
- Adds additional unit test to reflect on the changes.

Bug: chromium:892287
Test: end-to-end
Change-Id: I9551532dc08fd9338803786bdd89f0398f4ce7db
Reviewed-on: https://chromium-review.googlesource.com/c/1336435Reviewed-by: default avatarBailey Berro <baileyberro@chromium.org>
Reviewed-by: default avatarZentaro Kavanagh <zentaro@chromium.org>
Commit-Queue: jimmy gong <jimmyxgong@chromium.org>
Cr-Commit-Position: refs/heads/master@{#611409}
parent 98e5e693
......@@ -11,6 +11,8 @@ namespace chromeos {
namespace smb_client {
InMemoryHostLocator::InMemoryHostLocator() = default;
InMemoryHostLocator::InMemoryHostLocator(bool should_run_synchronously)
: should_run_synchronously_(should_run_synchronously) {}
InMemoryHostLocator::~InMemoryHostLocator() = default;
void InMemoryHostLocator::AddHost(const Hostname& hostname,
......@@ -29,7 +31,17 @@ void InMemoryHostLocator::RemoveHost(const Hostname& hostname) {
}
void InMemoryHostLocator::FindHosts(FindHostsCallback callback) {
if (should_run_synchronously_) {
std::move(callback).Run(true /* success */, host_map_);
} else {
stored_callback_ = std::move(callback);
}
}
void InMemoryHostLocator::RunCallback() {
DCHECK(!should_run_synchronously_);
std::move(stored_callback_).Run(true /* success */, host_map_);
}
} // namespace smb_client
......
......@@ -15,6 +15,8 @@ namespace smb_client {
class InMemoryHostLocator : public HostLocator {
public:
InMemoryHostLocator();
explicit InMemoryHostLocator(bool should_run_synchronously);
~InMemoryHostLocator() override;
// Adds host with |hostname| and |address| to host_map_.
......@@ -29,8 +31,13 @@ class InMemoryHostLocator : public HostLocator {
// HostLocator override.
void FindHosts(FindHostsCallback callback) override;
// Runs the callback, |stored_callback_|.
void RunCallback();
private:
HostMap host_map_;
FindHostsCallback stored_callback_;
bool should_run_synchronously_ = true;
DISALLOW_COPY_AND_ASSIGN(InMemoryHostLocator);
};
......
......@@ -37,12 +37,16 @@ NetworkScanner::NetworkScanner() = default;
NetworkScanner::~NetworkScanner() = default;
void NetworkScanner::FindHostsInNetwork(FindHostsCallback callback) {
DCHECK(!running_);
if (locators_.empty()) {
// Fire the callback immediately if there are no registered HostLocators.
std::move(callback).Run(false /* success */, HostMap());
return;
}
running_ = true;
const uint32_t request_id = AddNewRequest(std::move(callback));
for (const auto& locator : locators_) {
locator->FindHosts(
......@@ -118,6 +122,7 @@ void NetworkScanner::FireCallbackIfFinished(uint32_t request_id) {
found_hosts_ = std::move(info.hosts_found);
find_hosts_returned_ = true;
running_ = false;
std::move(info.callback).Run(true /* success */, found_hosts_);
}
}
......
......@@ -90,6 +90,12 @@ class NetworkScanner : public base::SupportsWeakPtr<NetworkScanner> {
// regardless if any hosts are found.
bool find_hosts_returned_ = false;
// True if FindHostsInNetwork() has been called and is waiting for
// FindHostsCallback to be invoked. This is to prevent multiple calls of
// FindHostsInNetwork() from concurrently executing. Used only for DCHECKing
// if FindHostsInNetwork() is already running.
bool running_ = false;
DISALLOW_COPY_AND_ASSIGN(NetworkScanner);
};
......
......@@ -18,10 +18,23 @@ SmbShareFinder::~SmbShareFinder() = default;
void SmbShareFinder::GatherSharesInNetwork(
HostDiscoveryResponse discovery_callback,
GatherSharesResponse shares_callback) {
scanner_.FindHostsInNetwork(base::BindOnce(
&SmbShareFinder::OnHostsFound, AsWeakPtr(), std::move(discovery_callback),
std::move(shares_callback)));
GatherSharesInNetworkResponse shares_callback) {
const bool should_start_discovery = share_callbacks_.empty();
if (discovery_callbacks_.empty() && !should_start_discovery) {
// Host discovery completed but shares callback is still in progress.
std::move(discovery_callback).Run();
InsertShareCallback(std::move(shares_callback));
} else {
// Either GatherSharesInNetwork has not been called yet or it has and
// discovery has not yet finished.
InsertDiscoveryAndShareCallbacks(std::move(discovery_callback),
std::move(shares_callback));
}
if (should_start_discovery) {
scanner_.FindHostsInNetwork(
base::BindOnce(&SmbShareFinder::OnHostsFound, AsWeakPtr()));
}
}
void SmbShareFinder::DiscoverHostsInNetwork(
......@@ -52,55 +65,89 @@ void SmbShareFinder::OnHostsDiscovered(HostDiscoveryResponse discovery_callback,
std::move(discovery_callback).Run();
}
void SmbShareFinder::OnHostsFound(HostDiscoveryResponse discovery_callback,
GatherSharesResponse shares_callback,
bool success,
const HostMap& hosts) {
std::move(discovery_callback).Run();
void SmbShareFinder::OnHostsFound(bool success, const HostMap& hosts) {
DCHECK_EQ(0u, host_counter_);
RunDiscoveryCallbacks();
if (!success) {
LOG(ERROR) << "SmbShareFinder failed to find hosts";
shares_callback.Run(std::vector<SmbUrl>());
RunEmptySharesCallbacks();
return;
}
if (hosts.empty()) {
shares_callback.Run(std::vector<SmbUrl>());
RunEmptySharesCallbacks();
return;
}
host_counter_ = hosts.size();
for (const auto& host : hosts) {
const std::string& host_name = host.first;
const std::string& resolved_address = host.second;
const base::FilePath server_url(kSmbSchemePrefix + resolved_address);
client_->GetShares(
server_url, base::BindOnce(&SmbShareFinder::OnSharesFound, AsWeakPtr(),
host_name, shares_callback));
server_url,
base::BindOnce(&SmbShareFinder::OnSharesFound, AsWeakPtr(), host_name));
}
}
void SmbShareFinder::OnSharesFound(
const std::string& host_name,
GatherSharesResponse shares_callback,
smbprovider::ErrorType error,
const smbprovider::DirectoryEntryListProto& entries) {
DCHECK_GT(host_counter_, 0u);
--host_counter_;
if (error != smbprovider::ErrorType::ERROR_OK) {
LOG(ERROR) << "Error finding shares: " << error;
shares_callback.Run(std::vector<SmbUrl>());
return;
}
std::vector<SmbUrl> shares;
for (const smbprovider::DirectoryEntryProto& entry : entries.entries()) {
SmbUrl url(kSmbSchemePrefix + host_name + "/" + entry.name());
if (url.IsValid()) {
shares.push_back(std::move(url));
shares_.push_back(std::move(url));
} else {
LOG(WARNING) << "URL found is not valid";
}
}
shares_callback.Run(shares);
if (host_counter_ == 0) {
RunSharesCallbacks(shares_);
}
}
void SmbShareFinder::RunDiscoveryCallbacks() {
for (auto& callback : discovery_callbacks_) {
std::move(callback).Run();
}
discovery_callbacks_.clear();
}
void SmbShareFinder::RunSharesCallbacks(const std::vector<SmbUrl>& shares) {
for (auto& share_callback : share_callbacks_) {
std::move(share_callback).Run(shares);
}
share_callbacks_.clear();
shares_.clear();
}
void SmbShareFinder::RunEmptySharesCallbacks() {
RunSharesCallbacks(std::vector<SmbUrl>());
}
void SmbShareFinder::InsertDiscoveryAndShareCallbacks(
HostDiscoveryResponse discovery_callback,
GatherSharesInNetworkResponse share_callback) {
discovery_callbacks_.push_back(std::move(discovery_callback));
share_callbacks_.push_back(std::move(share_callback));
}
void SmbShareFinder::InsertShareCallback(
GatherSharesInNetworkResponse share_callback) {
share_callbacks_.push_back(std::move(share_callback));
}
} // namespace smb_client
......
......@@ -23,6 +23,11 @@ namespace smb_client {
using GatherSharesResponse =
base::RepeatingCallback<void(const std::vector<SmbUrl>& shares_gathered)>;
// The callback that will be passed to GatherSharesInNetwork. Used to implicitly
// convert GatherSharesResponse to a OnceCallback.
using GatherSharesInNetworkResponse =
base::OnceCallback<void(const std::vector<SmbUrl>& shares_gathered)>;
// The callback run to indicate the scan for hosts on the network is complete.
using HostDiscoveryResponse = base::OnceClosure;
......@@ -35,10 +40,11 @@ class SmbShareFinder : public base::SupportsWeakPtr<SmbShareFinder> {
// Gathers the hosts in the network using |scanner_| and gets the shares for
// each of the hosts found. |discovery_callback| runs once when host
// disovery is complete. |shares_callback| runs once per host and will contain
// the paths to the shares found (e.g. "smb://host/share").
// disovery is complete. |shares_callback| only runs once when all entries
// from hosts are stored to |shares| and will contain the paths to the shares
// found (e.g. "smb://host/share").
void GatherSharesInNetwork(HostDiscoveryResponse discovery_callback,
GatherSharesResponse shares_callback);
GatherSharesInNetworkResponse shares_callback);
// Gathers the hosts in the network using |scanner_|. Runs
// |discovery_callback| upon completion. No data is returned to the caller,
......@@ -59,21 +65,43 @@ class SmbShareFinder : public base::SupportsWeakPtr<SmbShareFinder> {
const HostMap& hosts);
// Handles the response from finding hosts in the network.
void OnHostsFound(HostDiscoveryResponse discovery_callback,
GatherSharesResponse shares_callback,
bool success,
const HostMap& hosts);
void OnHostsFound(bool success, const HostMap& hosts);
// Handles the response from getting shares for a given host.
void OnSharesFound(const std::string& host_name,
GatherSharesResponse shares_callback,
smbprovider::ErrorType error,
const smbprovider::DirectoryEntryListProto& entries);
// Executes all the DiscoveryCallbacks inside |discovery_callbacks_|.
void RunDiscoveryCallbacks();
// Executes all the SharesCallback inside |shares_callback_|.
void RunSharesCallbacks(const std::vector<SmbUrl>& shares);
// Executes all the SharesCallback inside |shares_callback_| with an empty
// vector of SmbUrl.
void RunEmptySharesCallbacks();
// Inserts HostDiscoveryResponse in |discovery_callbacks_| and inserts
// GatherSharesInNetworkResponse in |shares_callbacks_|.
void InsertDiscoveryAndShareCallbacks(
HostDiscoveryResponse discovery_callback,
GatherSharesInNetworkResponse shares_callback);
// Inserts |shares_callback| to |share_callbacks_|.
void InsertShareCallback(GatherSharesInNetworkResponse shares_callback);
NetworkScanner scanner_;
SmbProviderClient* client_; // Not owned.
uint32_t host_counter_ = 0u;
std::vector<HostDiscoveryResponse> discovery_callbacks_;
std::vector<GatherSharesInNetworkResponse> share_callbacks_;
std::vector<SmbUrl> shares_;
DISALLOW_COPY_AND_ASSIGN(SmbShareFinder);
};
......
......@@ -31,17 +31,12 @@ constexpr char kDefaultResolvedUrl[] = "smb://1.2.3.4";
class SmbShareFinderTest : public testing::Test {
public:
SmbShareFinderTest() {
auto host_locator = std::make_unique<InMemoryHostLocator>();
host_locator_ = host_locator.get();
fake_client_ = std::make_unique<FakeSmbProviderClient>();
share_finder_ = std::make_unique<SmbShareFinder>(fake_client_.get());
share_finder_->RegisterHostLocator(std::move(host_locator));
SetupShareFinderTest(true /* should_run_synchronously */);
}
~SmbShareFinderTest() override = default;
protected:
void TearDown() override { fake_client_->ClearShares(); }
// Adds host with |hostname| and |address| as the resolved url.
......@@ -67,23 +62,39 @@ class SmbShareFinderTest : public testing::Test {
}
// Helper function when expecting shares to be found in the network.
void ExpectSharesFound() {
void StartDiscoveryWhileExpectingSharesFound() {
share_finder_->GatherSharesInNetwork(
base::BindOnce(&SmbShareFinderTest::HostsDiscoveredCallback,
base::Unretained(this)),
base::BindRepeating(&SmbShareFinderTest::SharesFoundCallback,
base::BindOnce(&SmbShareFinderTest::SharesFoundCallback,
base::Unretained(this)));
EXPECT_TRUE(discovery_callback_called_);
}
// Helper function when expecting no shares to be found in the network.
void ExpectNoSharesFound() {
StartDiscoveryWhileExpectingEmptyShares();
EXPECT_TRUE(discovery_callback_called_);
}
// Helper function to call SmbShareFinder::GatherSharesInNetwork. Asserts that
// there are no shares discovered from the EmptySharesCallback.
void StartDiscoveryWhileExpectingEmptyShares() {
share_finder_->GatherSharesInNetwork(
base::BindOnce(&SmbShareFinderTest::HostsDiscoveredCallback,
base::Unretained(this)),
base::BindRepeating(&SmbShareFinderTest::EmptySharesCallback,
base::BindOnce(&SmbShareFinderTest::EmptySharesCallback,
base::Unretained(this)));
}
// Helper function to call SmbShareFinder::GatherSharesInNetwork. Asserts that
// shares are found, but does not remove them.
void StartDiscoveryWhileGatheringShares() {
share_finder_->GatherSharesInNetwork(
base::BindOnce(&SmbShareFinderTest::HostsDiscoveredCallback,
base::Unretained(this)),
base::BindOnce(&SmbShareFinderTest::SharesFoundSizeCallback,
base::Unretained(this)));
EXPECT_TRUE(discovery_callback_called_);
}
// Helper function that expects expected_shares_ to be empty.
......@@ -94,8 +105,33 @@ class SmbShareFinderTest : public testing::Test {
EXPECT_EQ(expected, share_finder_->GetResolvedUrl(url));
}
void ExpectDiscoveryCalled(int32_t expected) {
EXPECT_EQ(expected, discovery_callback_counter_);
}
void FinishHostDiscoveryOnHostLocator() { host_locator_->RunCallback(); }
void FinishShareDiscoveryOnSmbProviderClient() {
fake_client_->RunStoredReadDirCallback();
}
void SetupShareFinderTest(bool should_run_synchronously) {
auto host_locator =
std::make_unique<InMemoryHostLocator>(should_run_synchronously);
host_locator_ = host_locator.get();
fake_client_ =
std::make_unique<FakeSmbProviderClient>(should_run_synchronously);
share_finder_ = std::make_unique<SmbShareFinder>(fake_client_.get());
share_finder_->RegisterHostLocator(std::move(host_locator));
}
private:
void HostsDiscoveredCallback() { discovery_callback_called_ = true; }
void HostsDiscoveredCallback() {
discovery_callback_called_ = true;
++discovery_callback_counter_;
}
// Removes shares discovered from |expected_shares_|.
void SharesFoundCallback(const std::vector<SmbUrl>& shares_found) {
......@@ -106,6 +142,10 @@ class SmbShareFinderTest : public testing::Test {
}
}
void SharesFoundSizeCallback(const std::vector<SmbUrl>& shares_found) {
EXPECT_GE(shares_found.size(), 0u);
}
void EmptySharesCallback(const std::vector<SmbUrl>& shares_found) {
EXPECT_EQ(0u, shares_found.size());
}
......@@ -115,6 +155,8 @@ class SmbShareFinderTest : public testing::Test {
// Keeps track of expected shares across multiple hosts.
std::set<std::string> expected_shares_;
int32_t discovery_callback_counter_ = 0;
InMemoryHostLocator* host_locator_;
std::unique_ptr<FakeSmbProviderClient> fake_client_;
std::unique_ptr<SmbShareFinder> share_finder_;
......@@ -142,7 +184,7 @@ TEST_F(SmbShareFinderTest, SharesFoundWithSingleHost) {
AddShareToDefaultHost("share1");
AddShareToDefaultHost("share2");
ExpectSharesFound();
StartDiscoveryWhileExpectingSharesFound();
ExpectAllSharesHaveBeenFound();
}
......@@ -158,7 +200,7 @@ TEST_F(SmbShareFinderTest, SharesFoundWithMultipleHosts) {
AddHost(host2, address2);
AddShare(resolved_server_url2, server_url2, share2);
ExpectSharesFound();
StartDiscoveryWhileExpectingSharesFound();
ExpectAllSharesHaveBeenFound();
}
......@@ -168,7 +210,7 @@ TEST_F(SmbShareFinderTest, SharesFoundOnOneHostWithMultipleHosts) {
AddHost("host2", "4.5.6.7");
ExpectSharesFound();
StartDiscoveryWhileExpectingSharesFound();
ExpectAllSharesHaveBeenFound();
}
......@@ -177,7 +219,7 @@ TEST_F(SmbShareFinderTest, ResolvesHostToOriginalUrlIfNoHostFound) {
SmbUrl smb_url(url);
// Trigger the NetworkScanner to scan the network with its HostLocators.
ExpectSharesFound();
StartDiscoveryWhileExpectingSharesFound();
ExpectResolvedHost(smb_url, url);
}
......@@ -186,7 +228,7 @@ TEST_F(SmbShareFinderTest, ResolvesHost) {
AddDefaultHost();
// Trigger the NetworkScanner to scan the network with its HostLocators.
ExpectSharesFound();
StartDiscoveryWhileExpectingSharesFound();
SmbUrl url(std::string(kDefaultUrl) + "share");
ExpectResolvedHost(url, std::string(kDefaultResolvedUrl) + "/share");
......@@ -197,11 +239,55 @@ TEST_F(SmbShareFinderTest, ResolvesHostWithMultipleHosts) {
AddHost("host2", "4.5.6.7");
// Trigger the NetworkScanner to scan the network with its HostLocators.
ExpectSharesFound();
StartDiscoveryWhileExpectingSharesFound();
SmbUrl url("smb://host2/share");
ExpectResolvedHost(url, "smb://4.5.6.7/share");
}
TEST_F(SmbShareFinderTest, TestNonEmptyDiscoveryWithNonEmptyShareCallback) {
SetupShareFinderTest(false /* should_run_synchronoulsy */);
AddDefaultHost();
// Start discovery twice before host discovery is compeleted.
StartDiscoveryWhileExpectingEmptyShares();
StartDiscoveryWhileExpectingEmptyShares();
// Assert discovery callback has not been called.
ExpectDiscoveryCalled(0 /* expected */);
FinishHostDiscoveryOnHostLocator();
ExpectDiscoveryCalled(2 /* expected */);
}
TEST_F(SmbShareFinderTest, TestEmptyDiscoveryWithNonEmptyShareCallback) {
SetupShareFinderTest(false /* should_run_synchronoulsy */);
AddDefaultHost();
AddShareToDefaultHost("share1");
AddShareToDefaultHost("share2");
// Makes call to start discovery once. Share discovery will not run and be in
// a pending state.
StartDiscoveryWhileGatheringShares();
FinishHostDiscoveryOnHostLocator();
ExpectDiscoveryCalled(1 /* expected */);
// Host discovery will complete immediately while share discoveries will
// remain pending.
StartDiscoveryWhileExpectingSharesFound();
ExpectDiscoveryCalled(2 /* expected */);
// Run shares callback.
FinishShareDiscoveryOnSmbProviderClient();
ExpectAllSharesHaveBeenFound();
}
} // namespace smb_client
} // namespace chromeos
......@@ -30,6 +30,9 @@ void AddDirectoryEntryToList(smbprovider::DirectoryEntryListProto* entry_list,
FakeSmbProviderClient::FakeSmbProviderClient() {}
FakeSmbProviderClient::FakeSmbProviderClient(bool should_run_synchronously)
: should_run_synchronously_(should_run_synchronously) {}
FakeSmbProviderClient::~FakeSmbProviderClient() {}
void FakeSmbProviderClient::AddNetBiosPacketParsingForTesting(
......@@ -184,7 +187,12 @@ void FakeSmbProviderClient::GetShares(const base::FilePath& server_url,
AddDirectoryEntryToList(&entry_list, share);
}
if (should_run_synchronously_) {
std::move(callback).Run(smbprovider::ERROR_OK, entry_list);
} else {
stored_readdir_callback_ =
base::BindOnce(std::move(callback), smbprovider::ERROR_OK, entry_list);
}
}
void FakeSmbProviderClient::SetupKerberos(const std::string& account_id,
......@@ -255,4 +263,8 @@ void FakeSmbProviderClient::ClearShares() {
shares_.clear();
}
void FakeSmbProviderClient::RunStoredReadDirCallback() {
std::move(stored_readdir_callback_).Run();
}
} // namespace chromeos
......@@ -17,6 +17,7 @@ namespace chromeos {
class CHROMEOS_EXPORT FakeSmbProviderClient : public SmbProviderClient {
public:
FakeSmbProviderClient();
explicit FakeSmbProviderClient(bool should_run_synchronously);
~FakeSmbProviderClient() override;
// Adds an entry in the |netbios_parse_results_| map for <packetid,
......@@ -135,7 +136,15 @@ class CHROMEOS_EXPORT FakeSmbProviderClient : public SmbProviderClient {
// Clears |shares_|.
void ClearShares();
// Runs |stored_callback_|.
void RunStoredReadDirCallback();
private:
// Controls whether |stored_readdir_callback_| should run synchronously.
bool should_run_synchronously_ = true;
base::OnceClosure stored_readdir_callback_;
std::map<uint8_t, std::vector<std::string>> netbios_parse_results_;
// Mapping of a server url to its shares.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment