Commit 96693856 authored by morrita's avatar morrita Committed by Commit bot

IPC: Get rid of FileDescriptor usage from FileDescriptorSet and Message

This is a step toward to killing FileDescriptor.
This change lets FiileDescriptorSet have both Files (for owning fds)
and PlatformFiles (for non-owning fds). Doing this, we no longer
need FileDescriptor which provides |auto_close| flag.

BUG=415294
TEST=ipc_tests, ipc_mojo_unittests
R=agl@chromium.org, jam@hcromium.org, viettrungluu@chromium.org

Review URL: https://codereview.chromium.org/583473002

Cr-Commit-Position: refs/heads/master@{#296498}
parent 33169d9f
...@@ -16,56 +16,54 @@ FileDescriptorSet::FileDescriptorSet() ...@@ -16,56 +16,54 @@ FileDescriptorSet::FileDescriptorSet()
} }
FileDescriptorSet::~FileDescriptorSet() { FileDescriptorSet::~FileDescriptorSet() {
if (consumed_descriptor_highwater_ == descriptors_.size()) if (consumed_descriptor_highwater_ == size())
return; return;
LOG(WARNING) << "FileDescriptorSet destroyed with unconsumed descriptors"; // We close all the owning descriptors. If this message should have
// We close all the descriptors where the close flag is set. If this // been transmitted, then closing those with close flags set mirrors
// message should have been transmitted, then closing those with close // the expected behaviour.
// flags set mirrors the expected behaviour.
// //
// If this message was received with more descriptors than expected // If this message was received with more descriptors than expected
// (which could a DOS against the browser by a rogue renderer) then all // (which could a DOS against the browser by a rogue renderer) then all
// the descriptors have their close flag set and we free all the extra // the descriptors have their close flag set and we free all the extra
// kernel resources. // kernel resources.
for (unsigned i = consumed_descriptor_highwater_; LOG(WARNING) << "FileDescriptorSet destroyed with unconsumed descriptors: "
i < descriptors_.size(); ++i) { << consumed_descriptor_highwater_ << "/" << size();
if (descriptors_[i].auto_close)
if (IGNORE_EINTR(close(descriptors_[i].fd)) < 0)
PLOG(ERROR) << "close";
}
} }
bool FileDescriptorSet::Add(int fd) { bool FileDescriptorSet::AddToBorrow(base::PlatformFile fd) {
if (descriptors_.size() == kMaxDescriptorsPerMessage) { DCHECK_EQ(consumed_descriptor_highwater_, 0u);
if (size() == kMaxDescriptorsPerMessage) {
DLOG(WARNING) << "Cannot add file descriptor. FileDescriptorSet full."; DLOG(WARNING) << "Cannot add file descriptor. FileDescriptorSet full.";
return false; return false;
} }
struct base::FileDescriptor sd; descriptors_.push_back(fd);
sd.fd = fd;
sd.auto_close = false;
descriptors_.push_back(sd);
return true; return true;
} }
bool FileDescriptorSet::AddAndAutoClose(int fd) { bool FileDescriptorSet::AddToOwn(base::ScopedFD fd) {
if (descriptors_.size() == kMaxDescriptorsPerMessage) { DCHECK_EQ(consumed_descriptor_highwater_, 0u);
if (size() == kMaxDescriptorsPerMessage) {
DLOG(WARNING) << "Cannot add file descriptor. FileDescriptorSet full."; DLOG(WARNING) << "Cannot add file descriptor. FileDescriptorSet full.";
return false; return false;
} }
struct base::FileDescriptor sd; descriptors_.push_back(fd.get());
sd.fd = fd; owned_descriptors_.push_back(new base::ScopedFD(fd.Pass()));
sd.auto_close = true; DCHECK(size() <= kMaxDescriptorsPerMessage);
descriptors_.push_back(sd);
DCHECK(descriptors_.size() <= kMaxDescriptorsPerMessage);
return true; return true;
} }
int FileDescriptorSet::GetDescriptorAt(unsigned index) const { base::PlatformFile FileDescriptorSet::TakeDescriptorAt(unsigned index) {
if (index >= descriptors_.size()) if (index >= size()) {
DLOG(WARNING) << "Accessing out of bound index:"
<< index << "/" << size();
return -1; return -1;
}
// We should always walk the descriptors in order, so it's reasonable to // We should always walk the descriptors in order, so it's reasonable to
// enforce this. Consider the case where a compromised renderer sends us // enforce this. Consider the case where a compromised renderer sends us
...@@ -86,6 +84,8 @@ int FileDescriptorSet::GetDescriptorAt(unsigned index) const { ...@@ -86,6 +84,8 @@ int FileDescriptorSet::GetDescriptorAt(unsigned index) const {
// There's one more wrinkle: When logging messages, we may reparse them. So // There's one more wrinkle: When logging messages, we may reparse them. So
// we have an exception: When the consumed_descriptor_highwater_ is at the // we have an exception: When the consumed_descriptor_highwater_ is at the
// end of the array and index 0 is requested, we reset the highwater value. // end of the array and index 0 is requested, we reset the highwater value.
// TODO(morrita): This is absurd. This "wringle" disallow to introduce clearer
// ownership model. Only client is NaclIPCAdapter. See crbug.com/415294
if (index == 0 && consumed_descriptor_highwater_ == descriptors_.size()) if (index == 0 && consumed_descriptor_highwater_ == descriptors_.size())
consumed_descriptor_highwater_ = 0; consumed_descriptor_highwater_ = 0;
...@@ -93,22 +93,37 @@ int FileDescriptorSet::GetDescriptorAt(unsigned index) const { ...@@ -93,22 +93,37 @@ int FileDescriptorSet::GetDescriptorAt(unsigned index) const {
return -1; return -1;
consumed_descriptor_highwater_ = index + 1; consumed_descriptor_highwater_ = index + 1;
return descriptors_[index].fd;
}
void FileDescriptorSet::GetDescriptors(int* buffer) const { base::PlatformFile file = descriptors_[index];
for (std::vector<base::FileDescriptor>::const_iterator
i = descriptors_.begin(); i != descriptors_.end(); ++i) { // TODO(morrita): In production, descriptors_.size() should be same as
*(buffer++) = i->fd; // owned_descriptors_.size() as all read descriptors are owned by Message.
// We have to do this because unit test breaks this assumption. It should be
// changed to exercise with own-able descriptors.
for (ScopedVector<base::ScopedFD>::const_iterator i =
owned_descriptors_.begin();
i != owned_descriptors_.end();
++i) {
if ((*i)->get() == file) {
ignore_result((*i)->release());
break;
}
} }
return file;
}
void FileDescriptorSet::PeekDescriptors(base::PlatformFile* buffer) const {
std::copy(descriptors_.begin(), descriptors_.end(), buffer);
} }
bool FileDescriptorSet::ContainsDirectoryDescriptor() const { bool FileDescriptorSet::ContainsDirectoryDescriptor() const {
struct stat st; struct stat st;
for (std::vector<base::FileDescriptor>::const_iterator for (std::vector<base::PlatformFile>::const_iterator i = descriptors_.begin();
i = descriptors_.begin(); i != descriptors_.end(); ++i) { i != descriptors_.end();
if (fstat(i->fd, &st) == 0 && S_ISDIR(st.st_mode)) ++i) {
if (fstat(*i, &st) == 0 && S_ISDIR(st.st_mode))
return true; return true;
} }
...@@ -116,36 +131,32 @@ bool FileDescriptorSet::ContainsDirectoryDescriptor() const { ...@@ -116,36 +131,32 @@ bool FileDescriptorSet::ContainsDirectoryDescriptor() const {
} }
void FileDescriptorSet::CommitAll() { void FileDescriptorSet::CommitAll() {
for (std::vector<base::FileDescriptor>::iterator
i = descriptors_.begin(); i != descriptors_.end(); ++i) {
if (i->auto_close)
if (IGNORE_EINTR(close(i->fd)) < 0)
PLOG(ERROR) << "close";
}
descriptors_.clear(); descriptors_.clear();
owned_descriptors_.clear();
consumed_descriptor_highwater_ = 0; consumed_descriptor_highwater_ = 0;
} }
void FileDescriptorSet::ReleaseFDsToClose(std::vector<int>* fds) { void FileDescriptorSet::ReleaseFDsToClose(
for (std::vector<base::FileDescriptor>::iterator std::vector<base::PlatformFile>* fds) {
i = descriptors_.begin(); i != descriptors_.end(); ++i) { for (ScopedVector<base::ScopedFD>::iterator i = owned_descriptors_.begin();
if (i->auto_close) i != owned_descriptors_.end();
fds->push_back(i->fd); ++i) {
fds->push_back((*i)->release());
} }
descriptors_.clear();
consumed_descriptor_highwater_ = 0; CommitAll();
} }
void FileDescriptorSet::SetDescriptors(const int* buffer, unsigned count) { void FileDescriptorSet::AddDescriptorsToOwn(const base::PlatformFile* buffer,
unsigned count) {
DCHECK(count <= kMaxDescriptorsPerMessage); DCHECK(count <= kMaxDescriptorsPerMessage);
DCHECK_EQ(descriptors_.size(), 0u); DCHECK_EQ(size(), 0u);
DCHECK_EQ(consumed_descriptor_highwater_, 0u); DCHECK_EQ(consumed_descriptor_highwater_, 0u);
descriptors_.reserve(count); descriptors_.reserve(count);
owned_descriptors_.reserve(count);
for (unsigned i = 0; i < count; ++i) { for (unsigned i = 0; i < count; ++i) {
struct base::FileDescriptor sd; descriptors_.push_back(buffer[i]);
sd.fd = buffer[i]; owned_descriptors_.push_back(new base::ScopedFD(buffer[i]));
sd.auto_close = true;
descriptors_.push_back(sd);
} }
} }
...@@ -8,8 +8,9 @@ ...@@ -8,8 +8,9 @@
#include <vector> #include <vector>
#include "base/basictypes.h" #include "base/basictypes.h"
#include "base/file_descriptor_posix.h" #include "base/files/file.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/memory/scoped_vector.h"
#include "ipc/ipc_export.h" #include "ipc/ipc_export.h"
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
...@@ -36,10 +37,10 @@ class IPC_EXPORT FileDescriptorSet ...@@ -36,10 +37,10 @@ class IPC_EXPORT FileDescriptorSet
// Interfaces for building during message serialisation... // Interfaces for building during message serialisation...
// Add a descriptor to the end of the set. Returns false iff the set is full. // Add a descriptor to the end of the set. Returns false iff the set is full.
bool Add(int fd); bool AddToBorrow(base::PlatformFile fd);
// Add a descriptor to the end of the set and automatically close it after // Add a descriptor to the end of the set and automatically close it after
// transmission. Returns false iff the set is full. // transmission. Returns false iff the set is full.
bool AddAndAutoClose(int fd); bool AddToOwn(base::ScopedFD fd);
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
...@@ -50,15 +51,15 @@ class IPC_EXPORT FileDescriptorSet ...@@ -50,15 +51,15 @@ class IPC_EXPORT FileDescriptorSet
// Return the number of descriptors // Return the number of descriptors
unsigned size() const { return descriptors_.size(); } unsigned size() const { return descriptors_.size(); }
// Return true if no unconsumed descriptors remain // Return true if no unconsumed descriptors remain
bool empty() const { return descriptors_.empty(); } bool empty() const { return 0 == size(); }
// Fetch the nth descriptor from the beginning of the set. Code using this // Take the nth descriptor from the beginning of the set,
// /must/ access the descriptors in order, except that it may wrap from the // transferring the ownership of the descriptor taken. Code using this
// end to index 0 again. // /must/ access the descriptors in order, and must do it at most once.
// //
// This interface is designed for the deserialising code as it doesn't // This interface is designed for the deserialising code as it doesn't
// support close flags. // support close flags.
// returns: file descriptor, or -1 on error // returns: file descriptor, or -1 on error
int GetDescriptorAt(unsigned n) const; base::PlatformFile TakeDescriptorAt(unsigned n);
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
...@@ -69,9 +70,9 @@ class IPC_EXPORT FileDescriptorSet ...@@ -69,9 +70,9 @@ class IPC_EXPORT FileDescriptorSet
// Fill an array with file descriptors without 'consuming' them. CommitAll // Fill an array with file descriptors without 'consuming' them. CommitAll
// must be called after these descriptors have been transmitted. // must be called after these descriptors have been transmitted.
// buffer: (output) a buffer of, at least, size() integers. // buffer: (output) a buffer of, at least, size() integers.
void GetDescriptors(int* buffer) const; void PeekDescriptors(base::PlatformFile* buffer) const;
// This must be called after transmitting the descriptors returned by // This must be called after transmitting the descriptors returned by
// GetDescriptors. It marks all the descriptors as consumed and closes those // PeekDescriptors. It marks all the descriptors as consumed and closes those
// which are auto-close. // which are auto-close.
void CommitAll(); void CommitAll();
// Returns true if any contained file descriptors appear to be handles to a // Returns true if any contained file descriptors appear to be handles to a
...@@ -79,7 +80,7 @@ class IPC_EXPORT FileDescriptorSet ...@@ -79,7 +80,7 @@ class IPC_EXPORT FileDescriptorSet
bool ContainsDirectoryDescriptor() const; bool ContainsDirectoryDescriptor() const;
// Fetch all filedescriptors with the "auto close" property. // Fetch all filedescriptors with the "auto close" property.
// Used instead of CommitAll() when closing must be handled manually. // Used instead of CommitAll() when closing must be handled manually.
void ReleaseFDsToClose(std::vector<int>* fds); void ReleaseFDsToClose(std::vector<base::PlatformFile>* fds);
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
...@@ -90,7 +91,7 @@ class IPC_EXPORT FileDescriptorSet ...@@ -90,7 +91,7 @@ class IPC_EXPORT FileDescriptorSet
// Set the contents of the set from the given buffer. This set must be empty // Set the contents of the set from the given buffer. This set must be empty
// before calling. The auto-close flag is set on all the descriptors so that // before calling. The auto-close flag is set on all the descriptors so that
// unconsumed descriptors are closed on destruction. // unconsumed descriptors are closed on destruction.
void SetDescriptors(const int* buffer, unsigned count); void AddDescriptorsToOwn(const base::PlatformFile* buffer, unsigned count);
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
...@@ -103,7 +104,8 @@ class IPC_EXPORT FileDescriptorSet ...@@ -103,7 +104,8 @@ class IPC_EXPORT FileDescriptorSet
// these descriptors are sent as control data. After sending, any descriptors // these descriptors are sent as control data. After sending, any descriptors
// with a true flag are closed. If this message has been received, then these // with a true flag are closed. If this message has been received, then these
// are the descriptors which were received and all close flags are true. // are the descriptors which were received and all close flags are true.
std::vector<base::FileDescriptor> descriptors_; std::vector<base::PlatformFile> descriptors_;
ScopedVector<base::ScopedFD> owned_descriptors_;
// This contains the index of the next descriptor which should be consumed. // This contains the index of the next descriptor which should be consumed.
// It's used in a couple of ways. Firstly, at destruction we can check that // It's used in a couple of ways. Firstly, at destruction we can check that
......
...@@ -41,7 +41,7 @@ TEST(FileDescriptorSet, BasicAdd) { ...@@ -41,7 +41,7 @@ TEST(FileDescriptorSet, BasicAdd) {
ASSERT_EQ(set->size(), 0u); ASSERT_EQ(set->size(), 0u);
ASSERT_TRUE(set->empty()); ASSERT_TRUE(set->empty());
ASSERT_TRUE(set->Add(kFDBase)); ASSERT_TRUE(set->AddToBorrow(kFDBase));
ASSERT_EQ(set->size(), 1u); ASSERT_EQ(set->size(), 1u);
ASSERT_TRUE(!set->empty()); ASSERT_TRUE(!set->empty());
...@@ -56,7 +56,7 @@ TEST(FileDescriptorSet, BasicAddAndClose) { ...@@ -56,7 +56,7 @@ TEST(FileDescriptorSet, BasicAddAndClose) {
ASSERT_EQ(set->size(), 0u); ASSERT_EQ(set->size(), 0u);
ASSERT_TRUE(set->empty()); ASSERT_TRUE(set->empty());
const int fd = GetSafeFd(); const int fd = GetSafeFd();
ASSERT_TRUE(set->AddAndAutoClose(fd)); ASSERT_TRUE(set->AddToOwn(base::ScopedFD(fd)));
ASSERT_EQ(set->size(), 1u); ASSERT_EQ(set->size(), 1u);
ASSERT_TRUE(!set->empty()); ASSERT_TRUE(!set->empty());
...@@ -68,9 +68,9 @@ TEST(FileDescriptorSet, MaxSize) { ...@@ -68,9 +68,9 @@ TEST(FileDescriptorSet, MaxSize) {
scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet); scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet);
for (size_t i = 0; i < FileDescriptorSet::kMaxDescriptorsPerMessage; ++i) for (size_t i = 0; i < FileDescriptorSet::kMaxDescriptorsPerMessage; ++i)
ASSERT_TRUE(set->Add(kFDBase + 1 + i)); ASSERT_TRUE(set->AddToBorrow(kFDBase + 1 + i));
ASSERT_TRUE(!set->Add(kFDBase)); ASSERT_TRUE(!set->AddToBorrow(kFDBase));
set->CommitAll(); set->CommitAll();
} }
...@@ -79,12 +79,12 @@ TEST(FileDescriptorSet, SetDescriptors) { ...@@ -79,12 +79,12 @@ TEST(FileDescriptorSet, SetDescriptors) {
scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet); scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet);
ASSERT_TRUE(set->empty()); ASSERT_TRUE(set->empty());
set->SetDescriptors(NULL, 0); set->AddDescriptorsToOwn(NULL, 0);
ASSERT_TRUE(set->empty()); ASSERT_TRUE(set->empty());
const int fd = GetSafeFd(); const int fd = GetSafeFd();
static const int fds[] = {fd}; static const int fds[] = {fd};
set->SetDescriptors(fds, 1); set->AddDescriptorsToOwn(fds, 1);
ASSERT_TRUE(!set->empty()); ASSERT_TRUE(!set->empty());
ASSERT_EQ(set->size(), 1u); ASSERT_EQ(set->size(), 1u);
...@@ -93,15 +93,15 @@ TEST(FileDescriptorSet, SetDescriptors) { ...@@ -93,15 +93,15 @@ TEST(FileDescriptorSet, SetDescriptors) {
ASSERT_TRUE(VerifyClosed(fd)); ASSERT_TRUE(VerifyClosed(fd));
} }
TEST(FileDescriptorSet, GetDescriptors) { TEST(FileDescriptorSet, PeekDescriptors) {
scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet); scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet);
set->GetDescriptors(NULL); set->PeekDescriptors(NULL);
ASSERT_TRUE(set->Add(kFDBase)); ASSERT_TRUE(set->AddToBorrow(kFDBase));
int fds[1]; int fds[1];
fds[0] = 0; fds[0] = 0;
set->GetDescriptors(fds); set->PeekDescriptors(fds);
ASSERT_EQ(fds[0], kFDBase); ASSERT_EQ(fds[0], kFDBase);
set->CommitAll(); set->CommitAll();
ASSERT_TRUE(set->empty()); ASSERT_TRUE(set->empty());
...@@ -110,13 +110,15 @@ TEST(FileDescriptorSet, GetDescriptors) { ...@@ -110,13 +110,15 @@ TEST(FileDescriptorSet, GetDescriptors) {
TEST(FileDescriptorSet, WalkInOrder) { TEST(FileDescriptorSet, WalkInOrder) {
scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet); scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet);
ASSERT_TRUE(set->Add(kFDBase)); // TODO(morrita): This test is wrong. TakeDescriptorAt() shouldn't be
ASSERT_TRUE(set->Add(kFDBase + 1)); // used to retrieve borrowed descriptors. That never happens in production.
ASSERT_TRUE(set->Add(kFDBase + 2)); ASSERT_TRUE(set->AddToBorrow(kFDBase));
ASSERT_TRUE(set->AddToBorrow(kFDBase + 1));
ASSERT_TRUE(set->AddToBorrow(kFDBase + 2));
ASSERT_EQ(set->GetDescriptorAt(0), kFDBase); ASSERT_EQ(set->TakeDescriptorAt(0), kFDBase);
ASSERT_EQ(set->GetDescriptorAt(1), kFDBase + 1); ASSERT_EQ(set->TakeDescriptorAt(1), kFDBase + 1);
ASSERT_EQ(set->GetDescriptorAt(2), kFDBase + 2); ASSERT_EQ(set->TakeDescriptorAt(2), kFDBase + 2);
set->CommitAll(); set->CommitAll();
} }
...@@ -124,12 +126,14 @@ TEST(FileDescriptorSet, WalkInOrder) { ...@@ -124,12 +126,14 @@ TEST(FileDescriptorSet, WalkInOrder) {
TEST(FileDescriptorSet, WalkWrongOrder) { TEST(FileDescriptorSet, WalkWrongOrder) {
scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet); scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet);
ASSERT_TRUE(set->Add(kFDBase)); // TODO(morrita): This test is wrong. TakeDescriptorAt() shouldn't be
ASSERT_TRUE(set->Add(kFDBase + 1)); // used to retrieve borrowed descriptors. That never happens in production.
ASSERT_TRUE(set->Add(kFDBase + 2)); ASSERT_TRUE(set->AddToBorrow(kFDBase));
ASSERT_TRUE(set->AddToBorrow(kFDBase + 1));
ASSERT_TRUE(set->AddToBorrow(kFDBase + 2));
ASSERT_EQ(set->GetDescriptorAt(0), kFDBase); ASSERT_EQ(set->TakeDescriptorAt(0), kFDBase);
ASSERT_EQ(set->GetDescriptorAt(2), -1); ASSERT_EQ(set->TakeDescriptorAt(2), -1);
set->CommitAll(); set->CommitAll();
} }
...@@ -137,19 +141,21 @@ TEST(FileDescriptorSet, WalkWrongOrder) { ...@@ -137,19 +141,21 @@ TEST(FileDescriptorSet, WalkWrongOrder) {
TEST(FileDescriptorSet, WalkCycle) { TEST(FileDescriptorSet, WalkCycle) {
scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet); scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet);
ASSERT_TRUE(set->Add(kFDBase)); // TODO(morrita): This test is wrong. TakeDescriptorAt() shouldn't be
ASSERT_TRUE(set->Add(kFDBase + 1)); // used to retrieve borrowed descriptors. That never happens in production.
ASSERT_TRUE(set->Add(kFDBase + 2)); ASSERT_TRUE(set->AddToBorrow(kFDBase));
ASSERT_TRUE(set->AddToBorrow(kFDBase + 1));
ASSERT_EQ(set->GetDescriptorAt(0), kFDBase); ASSERT_TRUE(set->AddToBorrow(kFDBase + 2));
ASSERT_EQ(set->GetDescriptorAt(1), kFDBase + 1);
ASSERT_EQ(set->GetDescriptorAt(2), kFDBase + 2); ASSERT_EQ(set->TakeDescriptorAt(0), kFDBase);
ASSERT_EQ(set->GetDescriptorAt(0), kFDBase); ASSERT_EQ(set->TakeDescriptorAt(1), kFDBase + 1);
ASSERT_EQ(set->GetDescriptorAt(1), kFDBase + 1); ASSERT_EQ(set->TakeDescriptorAt(2), kFDBase + 2);
ASSERT_EQ(set->GetDescriptorAt(2), kFDBase + 2); ASSERT_EQ(set->TakeDescriptorAt(0), kFDBase);
ASSERT_EQ(set->GetDescriptorAt(0), kFDBase); ASSERT_EQ(set->TakeDescriptorAt(1), kFDBase + 1);
ASSERT_EQ(set->GetDescriptorAt(1), kFDBase + 1); ASSERT_EQ(set->TakeDescriptorAt(2), kFDBase + 2);
ASSERT_EQ(set->GetDescriptorAt(2), kFDBase + 2); ASSERT_EQ(set->TakeDescriptorAt(0), kFDBase);
ASSERT_EQ(set->TakeDescriptorAt(1), kFDBase + 1);
ASSERT_EQ(set->TakeDescriptorAt(2), kFDBase + 2);
set->CommitAll(); set->CommitAll();
} }
...@@ -158,7 +164,7 @@ TEST(FileDescriptorSet, DontClose) { ...@@ -158,7 +164,7 @@ TEST(FileDescriptorSet, DontClose) {
scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet); scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet);
const int fd = GetSafeFd(); const int fd = GetSafeFd();
ASSERT_TRUE(set->Add(fd)); ASSERT_TRUE(set->AddToBorrow(fd));
set->CommitAll(); set->CommitAll();
ASSERT_FALSE(VerifyClosed(fd)); ASSERT_FALSE(VerifyClosed(fd));
...@@ -168,7 +174,7 @@ TEST(FileDescriptorSet, DoClose) { ...@@ -168,7 +174,7 @@ TEST(FileDescriptorSet, DoClose) {
scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet); scoped_refptr<FileDescriptorSet> set(new FileDescriptorSet);
const int fd = GetSafeFd(); const int fd = GetSafeFd();
ASSERT_TRUE(set->AddAndAutoClose(fd)); ASSERT_TRUE(set->AddToOwn(base::ScopedFD(fd)));
set->CommitAll(); set->CommitAll();
ASSERT_TRUE(VerifyClosed(fd)); ASSERT_TRUE(VerifyClosed(fd));
......
...@@ -283,7 +283,7 @@ bool ChannelNacl::ProcessOutgoingMessages() { ...@@ -283,7 +283,7 @@ bool ChannelNacl::ProcessOutgoingMessages() {
int fds[FileDescriptorSet::kMaxDescriptorsPerMessage]; int fds[FileDescriptorSet::kMaxDescriptorsPerMessage];
const size_t num_fds = msg->file_descriptor_set()->size(); const size_t num_fds = msg->file_descriptor_set()->size();
DCHECK(num_fds <= FileDescriptorSet::kMaxDescriptorsPerMessage); DCHECK(num_fds <= FileDescriptorSet::kMaxDescriptorsPerMessage);
msg->file_descriptor_set()->GetDescriptors(fds); msg->file_descriptor_set()->PeekDescriptors(fds);
NaClAbiNaClImcMsgIoVec iov = { NaClAbiNaClImcMsgIoVec iov = {
const_cast<void*>(msg->data()), msg->size() const_cast<void*>(msg->data()), msg->size()
...@@ -357,8 +357,8 @@ bool ChannelNacl::WillDispatchInputMessage(Message* msg) { ...@@ -357,8 +357,8 @@ bool ChannelNacl::WillDispatchInputMessage(Message* msg) {
// The shenaniganery below with &foo.front() requires input_fds_ to have // The shenaniganery below with &foo.front() requires input_fds_ to have
// contiguous underlying storage (such as a simple array or a std::vector). // contiguous underlying storage (such as a simple array or a std::vector).
// This is why the header warns not to make input_fds_ a deque<>. // This is why the header warns not to make input_fds_ a deque<>.
msg->file_descriptor_set()->SetDescriptors(&input_fds_.front(), msg->file_descriptor_set()->AddDescriptorsToOwn(&input_fds_.front(),
header_fds); header_fds);
input_fds_.clear(); input_fds_.clear();
return true; return true;
} }
......
...@@ -432,7 +432,7 @@ bool ChannelPosix::ProcessOutgoingMessages() { ...@@ -432,7 +432,7 @@ bool ChannelPosix::ProcessOutgoingMessages() {
cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS; cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(sizeof(int) * num_fds); cmsg->cmsg_len = CMSG_LEN(sizeof(int) * num_fds);
msg->file_descriptor_set()->GetDescriptors( msg->file_descriptor_set()->PeekDescriptors(
reinterpret_cast<int*>(CMSG_DATA(cmsg))); reinterpret_cast<int*>(CMSG_DATA(cmsg)));
msgh.msg_controllen = cmsg->cmsg_len; msgh.msg_controllen = cmsg->cmsg_len;
...@@ -769,8 +769,7 @@ void ChannelPosix::QueueHelloMessage() { ...@@ -769,8 +769,7 @@ void ChannelPosix::QueueHelloMessage() {
#if defined(IPC_USES_READWRITE) #if defined(IPC_USES_READWRITE)
scoped_ptr<Message> hello; scoped_ptr<Message> hello;
if (remote_fd_pipe_ != -1) { if (remote_fd_pipe_ != -1) {
if (!msg->WriteFileDescriptor(base::FileDescriptor(remote_fd_pipe_, if (!msg->WriteBorrowingFile(remote_fd_pipe_)) {
false))) {
NOTREACHED() << "Unable to pickle hello message file descriptors"; NOTREACHED() << "Unable to pickle hello message file descriptors";
} }
DCHECK_EQ(msg->file_descriptor_set()->size(), 1U); DCHECK_EQ(msg->file_descriptor_set()->size(), 1U);
...@@ -896,8 +895,8 @@ bool ChannelPosix::WillDispatchInputMessage(Message* msg) { ...@@ -896,8 +895,8 @@ bool ChannelPosix::WillDispatchInputMessage(Message* msg) {
// The shenaniganery below with &foo.front() requires input_fds_ to have // The shenaniganery below with &foo.front() requires input_fds_ to have
// contiguous underlying storage (such as a simple array or a std::vector). // contiguous underlying storage (such as a simple array or a std::vector).
// This is why the header warns not to make input_fds_ a deque<>. // This is why the header warns not to make input_fds_ a deque<>.
msg->file_descriptor_set()->SetDescriptors(&input_fds_.front(), msg->file_descriptor_set()->AddDescriptorsToOwn(&input_fds_.front(),
header_fds); header_fds);
input_fds_.erase(input_fds_.begin(), input_fds_.begin() + header_fds); input_fds_.erase(input_fds_.begin(), input_fds_.begin() + header_fds);
return true; return true;
} }
...@@ -991,12 +990,11 @@ void ChannelPosix::HandleInternalMessage(const Message& msg) { ...@@ -991,12 +990,11 @@ void ChannelPosix::HandleInternalMessage(const Message& msg) {
// server also contains the fd_pipe_, which will be used for all // server also contains the fd_pipe_, which will be used for all
// subsequent file descriptor passing. // subsequent file descriptor passing.
DCHECK_EQ(msg.file_descriptor_set()->size(), 1U); DCHECK_EQ(msg.file_descriptor_set()->size(), 1U);
base::FileDescriptor descriptor; base::ScopedFD descriptor;
if (!msg.ReadFileDescriptor(&iter, &descriptor)) { if (!msg.ReadFile(&iter, &descriptor)) {
NOTREACHED(); NOTREACHED();
} }
fd_pipe_ = descriptor.fd; fd_pipe_ = descriptor.release();
CHECK(descriptor.auto_close);
} }
#endif // IPC_USES_READWRITE #endif // IPC_USES_READWRITE
peer_pid_ = pid; peer_pid_ = pid;
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "build/build_config.h" #include "build/build_config.h"
#if defined(OS_POSIX) #if defined(OS_POSIX)
#include "base/file_descriptor_posix.h"
#include "ipc/file_descriptor_set_posix.h" #include "ipc/file_descriptor_set_posix.h"
#endif #endif
...@@ -122,19 +123,21 @@ void Message::set_received_time(int64 time) const { ...@@ -122,19 +123,21 @@ void Message::set_received_time(int64 time) const {
#endif #endif
#if defined(OS_POSIX) #if defined(OS_POSIX)
bool Message::WriteFileDescriptor(const base::FileDescriptor& descriptor) { bool Message::WriteFile(base::ScopedFD descriptor) {
// We write the index of the descriptor so that we don't have to // We write the index of the descriptor so that we don't have to
// keep the current descriptor as extra decoding state when deserialising. // keep the current descriptor as extra decoding state when deserialising.
WriteInt(file_descriptor_set()->size()); WriteInt(file_descriptor_set()->size());
if (descriptor.auto_close) { return file_descriptor_set()->AddToOwn(descriptor.Pass());
return file_descriptor_set()->AddAndAutoClose(descriptor.fd);
} else {
return file_descriptor_set()->Add(descriptor.fd);
}
} }
bool Message::ReadFileDescriptor(PickleIterator* iter, bool Message::WriteBorrowingFile(const base::PlatformFile& descriptor) {
base::FileDescriptor* descriptor) const { // We write the index of the descriptor so that we don't have to
// keep the current descriptor as extra decoding state when deserialising.
WriteInt(file_descriptor_set()->size());
return file_descriptor_set()->AddToBorrow(descriptor);
}
bool Message::ReadFile(PickleIterator* iter, base::ScopedFD* descriptor) const {
int descriptor_index; int descriptor_index;
if (!ReadInt(iter, &descriptor_index)) if (!ReadInt(iter, &descriptor_index))
return false; return false;
...@@ -143,10 +146,13 @@ bool Message::ReadFileDescriptor(PickleIterator* iter, ...@@ -143,10 +146,13 @@ bool Message::ReadFileDescriptor(PickleIterator* iter,
if (!file_descriptor_set) if (!file_descriptor_set)
return false; return false;
descriptor->fd = file_descriptor_set->GetDescriptorAt(descriptor_index); base::PlatformFile file =
descriptor->auto_close = true; file_descriptor_set->TakeDescriptorAt(descriptor_index);
if (file < 0)
return false;
return descriptor->fd >= 0; descriptor->reset(file);
return true;
} }
bool Message::HasFileDescriptors() const { bool Message::HasFileDescriptors() const {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "base/basictypes.h" #include "base/basictypes.h"
#include "base/debug/trace_event.h" #include "base/debug/trace_event.h"
#include "base/files/file.h"
#include "base/pickle.h" #include "base/pickle.h"
#include "ipc/ipc_export.h" #include "ipc/ipc_export.h"
...@@ -20,10 +21,6 @@ ...@@ -20,10 +21,6 @@
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#endif #endif
namespace base {
struct FileDescriptor;
}
class FileDescriptorSet; class FileDescriptorSet;
namespace IPC { namespace IPC {
...@@ -178,12 +175,12 @@ class IPC_EXPORT Message : public Pickle { ...@@ -178,12 +175,12 @@ class IPC_EXPORT Message : public Pickle {
// This is used to pass a file descriptor to the peer of an IPC channel. // This is used to pass a file descriptor to the peer of an IPC channel.
// Add a descriptor to the end of the set. Returns false if the set is full. // Add a descriptor to the end of the set. Returns false if the set is full.
bool WriteFileDescriptor(const base::FileDescriptor& descriptor); bool WriteFile(base::ScopedFD descriptor);
bool WriteBorrowingFile(const base::PlatformFile& descriptor);
// Get a file descriptor from the message. Returns false on error. // Get a file descriptor from the message. Returns false on error.
// iter: a Pickle iterator to the current location in the message. // iter: a Pickle iterator to the current location in the message.
bool ReadFileDescriptor(PickleIterator* iter, bool ReadFile(PickleIterator* iter, base::ScopedFD* file) const;
base::FileDescriptor* descriptor) const;
// Returns true if there are any file descriptors in this message. // Returns true if there are any file descriptors in this message.
bool HasFileDescriptors() const; bool HasFileDescriptors() const;
......
...@@ -462,8 +462,14 @@ void ParamTraits<base::FileDescriptor>::Write(Message* m, const param_type& p) { ...@@ -462,8 +462,14 @@ void ParamTraits<base::FileDescriptor>::Write(Message* m, const param_type& p) {
const bool valid = p.fd >= 0; const bool valid = p.fd >= 0;
WriteParam(m, valid); WriteParam(m, valid);
if (valid) { if (!valid)
if (!m->WriteFileDescriptor(p)) return;
if (p.auto_close) {
if (!m->WriteFile(base::ScopedFD(p.fd)))
NOTREACHED();
} else {
if (!m->WriteBorrowingFile(p.fd))
NOTREACHED(); NOTREACHED();
} }
} }
...@@ -471,17 +477,22 @@ void ParamTraits<base::FileDescriptor>::Write(Message* m, const param_type& p) { ...@@ -471,17 +477,22 @@ void ParamTraits<base::FileDescriptor>::Write(Message* m, const param_type& p) {
bool ParamTraits<base::FileDescriptor>::Read(const Message* m, bool ParamTraits<base::FileDescriptor>::Read(const Message* m,
PickleIterator* iter, PickleIterator* iter,
param_type* r) { param_type* r) {
*r = base::FileDescriptor();
bool valid; bool valid;
if (!ReadParam(m, iter, &valid)) if (!ReadParam(m, iter, &valid))
return false; return false;
if (!valid) { // TODO(morrita): Seems like this should return false.
r->fd = -1; if (!valid)
r->auto_close = false;
return true; return true;
}
return m->ReadFileDescriptor(iter, r); base::ScopedFD fd;
if (!m->ReadFile(iter, &fd))
return false;
*r = base::FileDescriptor(fd.release(), true);
return true;
} }
void ParamTraits<base::FileDescriptor>::Log(const param_type& p, void ParamTraits<base::FileDescriptor>::Log(const param_type& p,
......
...@@ -229,7 +229,8 @@ MojoResult ChannelMojo::WriteToFileDescriptorSet( ...@@ -229,7 +229,8 @@ MojoResult ChannelMojo::WriteToFileDescriptorSet(
return unwrap_result; return unwrap_result;
} }
bool ok = message->file_descriptor_set()->Add(platform_handle.release().fd); bool ok = message->file_descriptor_set()->AddToOwn(
base::ScopedFD(platform_handle.release().fd));
DCHECK(ok); DCHECK(ok);
} }
...@@ -238,17 +239,20 @@ MojoResult ChannelMojo::WriteToFileDescriptorSet( ...@@ -238,17 +239,20 @@ MojoResult ChannelMojo::WriteToFileDescriptorSet(
// static // static
MojoResult ChannelMojo::ReadFromFileDescriptorSet( MojoResult ChannelMojo::ReadFromFileDescriptorSet(
const Message& message, Message* message,
std::vector<MojoHandle>* handles) { std::vector<MojoHandle>* handles) {
// We dup() the handles in IPC::Message to transmit. // We dup() the handles in IPC::Message to transmit.
// IPC::FileDescriptorSet has intricate lifecycle semantics // IPC::FileDescriptorSet has intricate lifecycle semantics
// of FDs, so just to dup()-and-own them is the safest option. // of FDs, so just to dup()-and-own them is the safest option.
if (message.HasFileDescriptors()) { if (message->HasFileDescriptors()) {
const FileDescriptorSet* fdset = message.file_descriptor_set(); FileDescriptorSet* fdset = message->file_descriptor_set();
for (size_t i = 0; i < fdset->size(); ++i) { std::vector<base::PlatformFile> fds_to_send(fdset->size());
int fd_to_send = dup(fdset->GetDescriptorAt(i)); fdset->PeekDescriptors(&fds_to_send[0]);
for (size_t i = 0; i < fds_to_send.size(); ++i) {
int fd_to_send = dup(fds_to_send[i]);
if (-1 == fd_to_send) { if (-1 == fd_to_send) {
DPLOG(WARNING) << "Failed to dup FD to transmit."; DPLOG(WARNING) << "Failed to dup FD to transmit.";
fdset->CommitAll();
return MOJO_RESULT_UNKNOWN; return MOJO_RESULT_UNKNOWN;
} }
...@@ -260,11 +264,14 @@ MojoResult ChannelMojo::ReadFromFileDescriptorSet( ...@@ -260,11 +264,14 @@ MojoResult ChannelMojo::ReadFromFileDescriptorSet(
if (MOJO_RESULT_OK != wrap_result) { if (MOJO_RESULT_OK != wrap_result) {
DLOG(WARNING) << "Pipe failed to wrap handles. Closing: " DLOG(WARNING) << "Pipe failed to wrap handles. Closing: "
<< wrap_result; << wrap_result;
fdset->CommitAll();
return wrap_result; return wrap_result;
} }
handles->push_back(wrapped_handle); handles->push_back(wrapped_handle);
} }
fdset->CommitAll();
} }
return MOJO_RESULT_OK; return MOJO_RESULT_OK;
......
...@@ -100,7 +100,7 @@ class IPC_MOJO_EXPORT ChannelMojo : public Channel, ...@@ -100,7 +100,7 @@ class IPC_MOJO_EXPORT ChannelMojo : public Channel,
static MojoResult WriteToFileDescriptorSet( static MojoResult WriteToFileDescriptorSet(
const std::vector<MojoHandle>& handle_buffer, const std::vector<MojoHandle>& handle_buffer,
Message* message); Message* message);
static MojoResult ReadFromFileDescriptorSet(const Message& message, static MojoResult ReadFromFileDescriptorSet(Message* message,
std::vector<MojoHandle>* handles); std::vector<MojoHandle>* handles);
#endif // defined(OS_POSIX) && !defined(OS_NACL) #endif // defined(OS_POSIX) && !defined(OS_NACL)
......
...@@ -138,7 +138,7 @@ bool MessageReader::Send(scoped_ptr<Message> message) { ...@@ -138,7 +138,7 @@ bool MessageReader::Send(scoped_ptr<Message> message) {
std::vector<MojoHandle> handles; std::vector<MojoHandle> handles;
#if defined(OS_POSIX) && !defined(OS_NACL) #if defined(OS_POSIX) && !defined(OS_NACL)
MojoResult read_result = MojoResult read_result =
ChannelMojo::ReadFromFileDescriptorSet(*message, &handles); ChannelMojo::ReadFromFileDescriptorSet(message.get(), &handles);
if (read_result != MOJO_RESULT_OK) { if (read_result != MOJO_RESULT_OK) {
std::for_each(handles.begin(), handles.end(), &MojoClose); std::for_each(handles.begin(), handles.end(), &MojoClose);
CloseWithError(read_result); CloseWithError(read_result);
......
...@@ -302,10 +302,11 @@ class ListenerThatExpectsFile : public IPC::Listener { ...@@ -302,10 +302,11 @@ class ListenerThatExpectsFile : public IPC::Listener {
virtual bool OnMessageReceived(const IPC::Message& message) OVERRIDE { virtual bool OnMessageReceived(const IPC::Message& message) OVERRIDE {
PickleIterator iter(message); PickleIterator iter(message);
base::FileDescriptor desc;
EXPECT_TRUE(message.ReadFileDescriptor(&iter, &desc)); base::ScopedFD fd;
EXPECT_TRUE(message.ReadFile(&iter, &fd));
base::File file(fd.release());
std::string content(GetSendingFileContent().size(), ' '); std::string content(GetSendingFileContent().size(), ' ');
base::File file(desc.fd);
file.Read(0, &content[0], content.size()); file.Read(0, &content[0], content.size());
EXPECT_EQ(content, GetSendingFileContent()); EXPECT_EQ(content, GetSendingFileContent());
base::MessageLoop::current()->Quit(); base::MessageLoop::current()->Quit();
...@@ -334,8 +335,7 @@ class ListenerThatExpectsFile : public IPC::Listener { ...@@ -334,8 +335,7 @@ class ListenerThatExpectsFile : public IPC::Listener {
file.Flush(); file.Flush();
IPC::Message* message = new IPC::Message( IPC::Message* message = new IPC::Message(
0, 2, IPC::Message::PRIORITY_NORMAL); 0, 2, IPC::Message::PRIORITY_NORMAL);
message->WriteFileDescriptor( message->WriteFile(base::ScopedFD(file.TakePlatformFile()));
base::FileDescriptor(file.TakePlatformFile(), false));
ASSERT_TRUE(sender->Send(message)); ASSERT_TRUE(sender->Send(message));
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment