Commit f180cd87 authored by Brian Geffon's avatar Brian Geffon Committed by Commit Bot

Mojo: Add support for versioned NodeChannel messages

To prepare for adding new channel types (ie. fast posix channels) we
will need a way to extend existing Mojo NodeChannel messages.

Bug: b:173022729
Change-Id: I56b80ed58f28d058927d94ef9421319fe547c719
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2538166
Commit-Queue: Brian Geffon <bgeffon@chromium.org>
Reviewed-by: default avatarKen Rockot <rockot@google.com>
Cr-Commit-Position: refs/heads/master@{#827995}
parent 7c9776b5
...@@ -47,39 +47,45 @@ enum class MessageType : uint32_t { ...@@ -47,39 +47,45 @@ enum class MessageType : uint32_t {
#pragma pack(push, 1) #pragma pack(push, 1)
struct Header { struct alignas(8) Header {
MessageType type; MessageType type;
uint32_t padding;
}; };
static_assert(IsAlignedForChannelMessage(sizeof(Header)), static_assert(IsAlignedForChannelMessage(sizeof(Header)),
"Invalid header size."); "Invalid header size.");
struct AcceptInviteeData { struct alignas(8) AcceptInviteeDataV0 {
ports::NodeName inviter_name; ports::NodeName inviter_name;
ports::NodeName token; ports::NodeName token;
}; };
struct AcceptInvitationData { using AcceptInviteeData = AcceptInviteeDataV0;
struct alignas(8) AcceptInvitationDataV0 {
ports::NodeName token; ports::NodeName token;
ports::NodeName invitee_name; ports::NodeName invitee_name;
}; };
struct AcceptPeerData { using AcceptInvitationData = AcceptInvitationDataV0;
struct alignas(8) AcceptPeerDataV0 {
ports::NodeName token; ports::NodeName token;
ports::NodeName peer_name; ports::NodeName peer_name;
ports::PortName port_name; ports::PortName port_name;
}; };
// This message may include a process handle on plaforms that require it. using AcceptPeerData = AcceptPeerDataV0;
struct AddBrokerClientData {
// This message may include a process handle on platforms that require it.
struct alignas(8) AddBrokerClientDataV0 {
ports::NodeName client_name; ports::NodeName client_name;
#if !defined(OS_WIN) #if !defined(OS_WIN)
uint32_t process_handle; uint32_t process_handle = 0;
uint32_t padding;
#endif #endif
}; };
using AddBrokerClientData = AddBrokerClientDataV0;
#if !defined(OS_WIN) #if !defined(OS_WIN)
static_assert(sizeof(base::ProcessHandle) == sizeof(uint32_t), static_assert(sizeof(base::ProcessHandle) == sizeof(uint32_t),
"Unexpected pid size"); "Unexpected pid size");
...@@ -88,19 +94,24 @@ static_assert(sizeof(AddBrokerClientData) % kChannelMessageAlignment == 0, ...@@ -88,19 +94,24 @@ static_assert(sizeof(AddBrokerClientData) % kChannelMessageAlignment == 0,
#endif #endif
// This data is followed by a platform channel handle to the broker. // This data is followed by a platform channel handle to the broker.
struct BrokerClientAddedData { struct alignas(8) BrokerClientAddedDataV0 {
ports::NodeName client_name; ports::NodeName client_name;
}; };
using BrokerClientAddedData = BrokerClientAddedDataV0;
// This data may be followed by a platform channel handle to the broker. If not, // This data may be followed by a platform channel handle to the broker. If not,
// then the inviter is the broker and its channel should be used as such. // then the inviter is the broker and its channel should be used as such.
struct AcceptBrokerClientData { struct alignas(8) AcceptBrokerClientDataV0 {
ports::NodeName broker_name; ports::NodeName broker_name;
}; };
using AcceptBrokerClientData = AcceptBrokerClientDataV0;
// This is followed by arbitrary payload data which is interpreted as a token // This is followed by arbitrary payload data which is interpreted as a token
// string for port location. // string for port location.
struct RequestPortMergeData { // NOTE: Because this field is variable length it cannot be versioned.
struct alignas(8) RequestPortMergeData {
ports::PortName connector_port_name; ports::PortName connector_port_name;
}; };
...@@ -110,35 +121,41 @@ struct RequestPortMergeData { ...@@ -110,35 +121,41 @@ struct RequestPortMergeData {
// the receiver may use to communicate with the named node directly, or an // the receiver may use to communicate with the named node directly, or an
// invalid platform handle if the node is unknown to the sender or otherwise // invalid platform handle if the node is unknown to the sender or otherwise
// cannot be introduced. // cannot be introduced.
struct IntroductionData { struct alignas(8) IntroductionDataV0 {
ports::NodeName name; ports::NodeName name;
}; };
// This message is just a PlatformHandle. The data struct here has only a using IntroductionData = IntroductionDataV0;
// padding field to ensure an aligned, non-zero-length payload.
struct BindBrokerHostData { // This message is just a PlatformHandle. The data struct alignas(8) here has
uint64_t padding; // only a padding field to ensure an aligned, non-zero-length payload.
}; struct alignas(8) BindBrokerHostDataV0 {};
using BindBrokerHostData = BindBrokerHostDataV0;
#if defined(OS_WIN) #if defined(OS_WIN)
// This struct is followed by the full payload of a message to be relayed. // This struct alignas(8) is followed by the full payload of a message to be
struct RelayEventMessageData { // relayed.
// NOTE: Because this field is variable length it cannot be versioned.
struct alignas(8) RelayEventMessageData {
ports::NodeName destination; ports::NodeName destination;
}; };
// This struct is followed by the full payload of a relayed message. // This struct alignas(8) is followed by the full payload of a relayed message.
struct EventMessageFromRelayData { struct alignas(8) EventMessageFromRelayDataV0 {
ports::NodeName source; ports::NodeName source;
}; };
using EventMessageFromRelayData = EventMessageFromRelayDataV0;
#endif #endif
#pragma pack(pop) #pragma pack(pop)
template <typename DataType>
Channel::MessagePtr CreateMessage(MessageType type, Channel::MessagePtr CreateMessage(MessageType type,
size_t payload_size, size_t payload_size,
size_t num_handles, size_t num_handles,
DataType** out_data, void** out_data,
size_t capacity = 0) { size_t capacity = 0) {
const size_t total_size = payload_size + sizeof(Header); const size_t total_size = payload_size + sizeof(Header);
if (capacity == 0) if (capacity == 0)
...@@ -148,21 +165,49 @@ Channel::MessagePtr CreateMessage(MessageType type, ...@@ -148,21 +165,49 @@ Channel::MessagePtr CreateMessage(MessageType type,
auto message = auto message =
std::make_unique<Channel::Message>(capacity, total_size, num_handles); std::make_unique<Channel::Message>(capacity, total_size, num_handles);
Header* header = reinterpret_cast<Header*>(message->mutable_payload()); Header* header = reinterpret_cast<Header*>(message->mutable_payload());
// Make sure any header padding gets zeroed.
memset(header, 0, sizeof(Header));
header->type = type; header->type = type;
header->padding = 0;
*out_data = reinterpret_cast<DataType*>(&header[1]); // The out_data starts beyond the header.
*out_data = reinterpret_cast<void*>(header + 1);
return message; return message;
} }
template <typename DataType>
Channel::MessagePtr CreateMessage(MessageType type,
size_t payload_size,
size_t num_handles,
DataType** out_data,
size_t capacity = 0) {
auto msg_ptr = CreateMessage(type, payload_size, num_handles,
reinterpret_cast<void**>(out_data), capacity);
// Since we know the type let's make sure any padding areas are zeroed.
memset(*out_data, 0, sizeof(DataType));
return msg_ptr;
}
template <typename DataType> template <typename DataType>
bool GetMessagePayload(const void* bytes, bool GetMessagePayload(const void* bytes,
size_t num_bytes, size_t num_bytes,
DataType** out_data) { DataType* out_data) {
static_assert(sizeof(DataType) > 0, "DataType must have non-zero size."); static_assert(sizeof(DataType) > 0, "DataType must have non-zero size.");
if (num_bytes < sizeof(Header) + sizeof(DataType)) // We should have at least 1 byte to contribute towards DataType.
if (num_bytes <= sizeof(Header))
return false; return false;
*out_data = reinterpret_cast<const DataType*>(
static_cast<const char*>(bytes) + sizeof(Header)); // Always make sure that the full object is zeored and default constructed as
// we may not have the complete type. The default construction allows fields
// to be default initialized to be resilient to older message versions.
memset(out_data, 0, sizeof(*out_data));
new (out_data) DataType;
// Overwrite any fields we received.
memcpy(out_data, static_cast<const uint8_t*>(bytes) + sizeof(Header),
std::min(sizeof(DataType), num_bytes - sizeof(Header)));
return true; return true;
} }
...@@ -307,7 +352,6 @@ void NodeChannel::AddBrokerClient(const ports::NodeName& client_name, ...@@ -307,7 +352,6 @@ void NodeChannel::AddBrokerClient(const ports::NodeName& client_name,
data->client_name = client_name; data->client_name = client_name;
#if !defined(OS_WIN) #if !defined(OS_WIN)
data->process_handle = process_handle.get(); data->process_handle = process_handle.get();
data->padding = 0;
#endif #endif
WriteChannelMessage(std::move(message)); WriteChannelMessage(std::move(message));
} }
...@@ -394,7 +438,6 @@ void NodeChannel::BindBrokerHost(PlatformHandle broker_host_handle) { ...@@ -394,7 +438,6 @@ void NodeChannel::BindBrokerHost(PlatformHandle broker_host_handle) {
Channel::MessagePtr message = Channel::MessagePtr message =
CreateMessage(MessageType::BIND_BROKER_HOST, sizeof(BindBrokerHostData), CreateMessage(MessageType::BIND_BROKER_HOST, sizeof(BindBrokerHostData),
handles.size(), &data); handles.size(), &data);
data->padding = 0;
message->SetHandles(std::move(handles)); message->SetHandles(std::move(handles));
WriteChannelMessage(std::move(message)); WriteChannelMessage(std::move(message));
#endif #endif
...@@ -403,7 +446,6 @@ void NodeChannel::BindBrokerHost(PlatformHandle broker_host_handle) { ...@@ -403,7 +446,6 @@ void NodeChannel::BindBrokerHost(PlatformHandle broker_host_handle) {
#if defined(OS_WIN) #if defined(OS_WIN)
void NodeChannel::RelayEventMessage(const ports::NodeName& destination, void NodeChannel::RelayEventMessage(const ports::NodeName& destination,
Channel::MessagePtr message) { Channel::MessagePtr message) {
#if defined(OS_WIN)
DCHECK(message->has_handles()); DCHECK(message->has_handles());
// Note that this is only used on Windows, and on Windows all platform // Note that this is only used on Windows, and on Windows all platform
...@@ -429,24 +471,6 @@ void NodeChannel::RelayEventMessage(const ports::NodeName& destination, ...@@ -429,24 +471,6 @@ void NodeChannel::RelayEventMessage(const ports::NodeName& destination,
for (auto& handle : handles) for (auto& handle : handles)
handle.TakeHandle().release(); handle.TakeHandle().release();
#else
DCHECK(message->has_mach_ports());
// On OSX, the handles are extracted from the relayed message and attached to
// the wrapper. The broker then takes the handles attached to the wrapper and
// moves them back to the relayed message. This is necessary because the
// message may contain fds which need to be attached to the outer message so
// that they can be transferred to the broker.
std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
size_t num_bytes = sizeof(RelayEventMessageData) + message->data_num_bytes();
RelayEventMessageData* data;
Channel::MessagePtr relay_message = CreateMessage(
MessageType::RELAY_EVENT_MESSAGE, num_bytes, handles.size(), &data);
data->destination = destination;
memcpy(data + 1, message->data(), message->data_num_bytes());
relay_message->SetHandles(std::move(handles));
#endif // defined(OS_WIN)
WriteChannelMessage(std::move(relay_message)); WriteChannelMessage(std::move(relay_message));
} }
...@@ -515,42 +539,42 @@ void NodeChannel::OnChannelMessage(const void* payload, ...@@ -515,42 +539,42 @@ void NodeChannel::OnChannelMessage(const void* payload,
const Header* header = static_cast<const Header*>(payload); const Header* header = static_cast<const Header*>(payload);
switch (header->type) { switch (header->type) {
case MessageType::ACCEPT_INVITEE: { case MessageType::ACCEPT_INVITEE: {
const AcceptInviteeData* data; AcceptInviteeData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
delegate_->OnAcceptInvitee(remote_node_name_, data->inviter_name, delegate_->OnAcceptInvitee(remote_node_name_, data.inviter_name,
data->token); data.token);
return; return;
} }
break; break;
} }
case MessageType::ACCEPT_INVITATION: { case MessageType::ACCEPT_INVITATION: {
const AcceptInvitationData* data; AcceptInvitationData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
delegate_->OnAcceptInvitation(remote_node_name_, data->token, delegate_->OnAcceptInvitation(remote_node_name_, data.token,
data->invitee_name); data.invitee_name);
return; return;
} }
break; break;
} }
case MessageType::ADD_BROKER_CLIENT: { case MessageType::ADD_BROKER_CLIENT: {
const AddBrokerClientData* data; AddBrokerClientData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
#if defined(OS_WIN) #if defined(OS_WIN)
if (handles.size() != 1) { if (handles.size() != 1) {
DLOG(ERROR) << "Dropping invalid AddBrokerClient message."; DLOG(ERROR) << "Dropping invalid AddBrokerClient message.";
break; break;
} }
delegate_->OnAddBrokerClient(remote_node_name_, data->client_name, delegate_->OnAddBrokerClient(remote_node_name_, data.client_name,
handles[0].ReleaseHandle()); handles[0].ReleaseHandle());
#else #else
if (!handles.empty()) { if (!handles.empty()) {
DLOG(ERROR) << "Dropping invalid AddBrokerClient message."; DLOG(ERROR) << "Dropping invalid AddBrokerClient message.";
break; break;
} }
delegate_->OnAddBrokerClient(remote_node_name_, data->client_name, delegate_->OnAddBrokerClient(remote_node_name_, data.client_name,
data->process_handle); data.process_handle);
#endif #endif
return; return;
} }
...@@ -558,13 +582,13 @@ void NodeChannel::OnChannelMessage(const void* payload, ...@@ -558,13 +582,13 @@ void NodeChannel::OnChannelMessage(const void* payload,
} }
case MessageType::BROKER_CLIENT_ADDED: { case MessageType::BROKER_CLIENT_ADDED: {
const BrokerClientAddedData* data; BrokerClientAddedData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
if (handles.size() != 1) { if (handles.size() != 1) {
DLOG(ERROR) << "Dropping invalid BrokerClientAdded message."; DLOG(ERROR) << "Dropping invalid BrokerClientAdded message.";
break; break;
} }
delegate_->OnBrokerClientAdded(remote_node_name_, data->client_name, delegate_->OnBrokerClientAdded(remote_node_name_, data.client_name,
std::move(handles[0])); std::move(handles[0]));
return; return;
} }
...@@ -572,7 +596,7 @@ void NodeChannel::OnChannelMessage(const void* payload, ...@@ -572,7 +596,7 @@ void NodeChannel::OnChannelMessage(const void* payload,
} }
case MessageType::ACCEPT_BROKER_CLIENT: { case MessageType::ACCEPT_BROKER_CLIENT: {
const AcceptBrokerClientData* data; AcceptBrokerClientData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
PlatformHandle broker_channel; PlatformHandle broker_channel;
if (handles.size() > 1) { if (handles.size() > 1) {
...@@ -582,7 +606,7 @@ void NodeChannel::OnChannelMessage(const void* payload, ...@@ -582,7 +606,7 @@ void NodeChannel::OnChannelMessage(const void* payload,
if (handles.size() == 1) if (handles.size() == 1)
broker_channel = std::move(handles[0]); broker_channel = std::move(handles[0]);
delegate_->OnAcceptBrokerClient(remote_node_name_, data->broker_name, delegate_->OnAcceptBrokerClient(remote_node_name_, data.broker_name,
std::move(broker_channel)); std::move(broker_channel));
return; return;
} }
...@@ -599,31 +623,33 @@ void NodeChannel::OnChannelMessage(const void* payload, ...@@ -599,31 +623,33 @@ void NodeChannel::OnChannelMessage(const void* payload,
} }
case MessageType::REQUEST_PORT_MERGE: { case MessageType::REQUEST_PORT_MERGE: {
const RequestPortMergeData* data; RequestPortMergeData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
// Don't accept an empty token. // Don't accept an empty token.
size_t token_size = payload_size - sizeof(*data) - sizeof(Header); size_t token_size = payload_size - sizeof(data) - sizeof(Header);
if (token_size == 0) if (token_size == 0)
break; break;
std::string token(reinterpret_cast<const char*>(data + 1), token_size); std::string token(reinterpret_cast<const char*>(payload) +
sizeof(Header) + sizeof(data),
token_size);
delegate_->OnRequestPortMerge(remote_node_name_, delegate_->OnRequestPortMerge(remote_node_name_,
data->connector_port_name, token); data.connector_port_name, token);
return; return;
} }
break; break;
} }
case MessageType::REQUEST_INTRODUCTION: { case MessageType::REQUEST_INTRODUCTION: {
const IntroductionData* data; IntroductionData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
delegate_->OnRequestIntroduction(remote_node_name_, data->name); delegate_->OnRequestIntroduction(remote_node_name_, data.name);
return; return;
} }
break; break;
} }
case MessageType::INTRODUCE: { case MessageType::INTRODUCE: {
const IntroductionData* data; IntroductionData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
if (handles.size() > 1) { if (handles.size() > 1) {
DLOG(ERROR) << "Dropping invalid introduction message."; DLOG(ERROR) << "Dropping invalid introduction message.";
...@@ -633,7 +659,7 @@ void NodeChannel::OnChannelMessage(const void* payload, ...@@ -633,7 +659,7 @@ void NodeChannel::OnChannelMessage(const void* payload,
if (handles.size() == 1) if (handles.size() == 1)
channel_handle = std::move(handles[0]); channel_handle = std::move(handles[0]);
delegate_->OnIntroduce(remote_node_name_, data->name, delegate_->OnIntroduce(remote_node_name_, data.name,
std::move(channel_handle)); std::move(channel_handle));
return; return;
} }
...@@ -650,22 +676,23 @@ void NodeChannel::OnChannelMessage(const void* payload, ...@@ -650,22 +676,23 @@ void NodeChannel::OnChannelMessage(const void* payload,
// |remote_process_handle_| is never reset once set. // |remote_process_handle_| is never reset once set.
from_process = remote_process_handle_.get(); from_process = remote_process_handle_.get();
} }
const RelayEventMessageData* data; RelayEventMessageData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
// Don't try to relay an empty message. // Don't try to relay an empty message.
if (payload_size <= sizeof(Header) + sizeof(RelayEventMessageData)) if (payload_size <= sizeof(Header) + sizeof(data))
break; break;
const void* message_start = data + 1; const void* message_start = reinterpret_cast<const uint8_t*>(payload) +
sizeof(Header) + sizeof(data);
Channel::MessagePtr message = Channel::Message::Deserialize( Channel::MessagePtr message = Channel::Message::Deserialize(
message_start, payload_size - sizeof(Header) - sizeof(*data), message_start, payload_size - sizeof(Header) - sizeof(data),
from_process); from_process);
if (!message) { if (!message) {
DLOG(ERROR) << "Dropping invalid relay message."; DLOG(ERROR) << "Dropping invalid relay message.";
break; break;
} }
delegate_->OnRelayEventMessage(remote_node_name_, from_process, delegate_->OnRelayEventMessage(remote_node_name_, from_process,
data->destination, std::move(message)); data.destination, std::move(message));
return; return;
} }
break; break;
...@@ -689,19 +716,22 @@ void NodeChannel::OnChannelMessage(const void* payload, ...@@ -689,19 +716,22 @@ void NodeChannel::OnChannelMessage(const void* payload,
#if defined(OS_WIN) #if defined(OS_WIN)
case MessageType::EVENT_MESSAGE_FROM_RELAY: { case MessageType::EVENT_MESSAGE_FROM_RELAY: {
const EventMessageFromRelayData* data; EventMessageFromRelayData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
size_t num_bytes = payload_size - sizeof(*data); if (payload_size < (sizeof(Header) + sizeof(data)))
if (num_bytes < sizeof(Header))
break; break;
num_bytes -= sizeof(Header);
size_t num_bytes = payload_size - sizeof(data) - sizeof(Header);
Channel::MessagePtr message( Channel::MessagePtr message(
new Channel::Message(num_bytes, handles.size())); new Channel::Message(num_bytes, handles.size()));
message->SetHandles(std::move(handles)); message->SetHandles(std::move(handles));
if (num_bytes) if (num_bytes)
memcpy(message->mutable_payload(), data + 1, num_bytes); memcpy(message->mutable_payload(),
delegate_->OnEventMessageFromRelay(remote_node_name_, data->source, static_cast<const uint8_t*>(payload) + sizeof(Header) +
sizeof(data),
num_bytes);
delegate_->OnEventMessageFromRelay(remote_node_name_, data.source,
std::move(message)); std::move(message));
return; return;
} }
...@@ -710,10 +740,10 @@ void NodeChannel::OnChannelMessage(const void* payload, ...@@ -710,10 +740,10 @@ void NodeChannel::OnChannelMessage(const void* payload,
#endif // defined(OS_WIN) #endif // defined(OS_WIN)
case MessageType::ACCEPT_PEER: { case MessageType::ACCEPT_PEER: {
const AcceptPeerData* data; AcceptPeerData data;
if (GetMessagePayload(payload, payload_size, &data)) { if (GetMessagePayload(payload, payload_size, &data)) {
delegate_->OnAcceptPeer(remote_node_name_, data->token, data->peer_name, delegate_->OnAcceptPeer(remote_node_name_, data.token, data.peer_name,
data->port_name); data.port_name);
return; return;
} }
break; break;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment