Commit e86ad221 authored by qsr's avatar qsr Committed by Commit bot

mojo: Add router for python bindings.

This is a reland of https://codereview.chromium.org/607513003 with the
fix for the failing test.

BUG=417707
TBR=sdefresne@chromium.org

Review URL: https://codereview.chromium.org/612443002

Cr-Commit-Position: refs/heads/master@{#296958}
parent f7d78211
......@@ -5,18 +5,140 @@
"""Utility classes to handle sending and receiving messages."""
import struct
import weakref
# pylint: disable=F0401
import mojo.bindings.serialization as serialization
import mojo.system as system
# The flag values for a message header.
NO_FLAG = 0
MESSAGE_EXPECTS_RESPONSE_FLAG = 1 << 0
MESSAGE_IS_RESPONSE_FLAG = 1 << 1
class MessageHeader(object):
"""The header of a mojo message."""
_SIMPLE_MESSAGE_NUM_FIELDS = 2
_SIMPLE_MESSAGE_STRUCT = struct.Struct("=IIII")
_REQUEST_ID_STRUCT = struct.Struct("=Q")
_REQUEST_ID_OFFSET = _SIMPLE_MESSAGE_STRUCT.size
_MESSAGE_WITH_REQUEST_ID_NUM_FIELDS = 3
_MESSAGE_WITH_REQUEST_ID_SIZE = (
_SIMPLE_MESSAGE_STRUCT.size + _REQUEST_ID_STRUCT.size)
def __init__(self, message_type, flags, request_id=0, data=None):
self._message_type = message_type
self._flags = flags
self._request_id = request_id
self._data = data
@classmethod
def Deserialize(cls, data):
buf = buffer(data)
if len(data) < cls._SIMPLE_MESSAGE_STRUCT.size:
raise serialization.DeserializationException('Header is too short.')
(size, version, message_type, flags) = (
cls._SIMPLE_MESSAGE_STRUCT.unpack_from(buf))
if (version < cls._SIMPLE_MESSAGE_NUM_FIELDS):
raise serialization.DeserializationException('Incorrect version.')
request_id = 0
if _HasRequestId(flags):
if version < cls._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS:
raise serialization.DeserializationException('Incorrect version.')
if (size < cls._MESSAGE_WITH_REQUEST_ID_SIZE or
len(data) < cls._MESSAGE_WITH_REQUEST_ID_SIZE):
raise serialization.DeserializationException('Header is too short.')
(request_id, ) = cls._REQUEST_ID_STRUCT.unpack_from(
buf, cls._REQUEST_ID_OFFSET)
return MessageHeader(message_type, flags, request_id, data)
@property
def message_type(self):
return self._message_type
# pylint: disable=E0202
@property
def request_id(self):
assert self.has_request_id
return self._request_id
# pylint: disable=E0202
@request_id.setter
def request_id(self, request_id):
assert self.has_request_id
self._request_id = request_id
self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET,
request_id)
@property
def has_request_id(self):
return _HasRequestId(self._flags)
@property
def expects_response(self):
return self._HasFlag(MESSAGE_EXPECTS_RESPONSE_FLAG)
@property
def is_response(self):
return self._HasFlag(MESSAGE_IS_RESPONSE_FLAG)
@property
def size(self):
if self.has_request_id:
return self._MESSAGE_WITH_REQUEST_ID_SIZE
return self._SIMPLE_MESSAGE_STRUCT.size
def Serialize(self):
if not self._data:
self._data = bytearray(self.size)
version = self._SIMPLE_MESSAGE_NUM_FIELDS
size = self._SIMPLE_MESSAGE_STRUCT.size
if self.has_request_id:
version = self._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS
size = self._MESSAGE_WITH_REQUEST_ID_SIZE
self._SIMPLE_MESSAGE_STRUCT.pack_into(self._data, 0, size, version,
self._message_type, self._flags)
if self.has_request_id:
self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET,
self._request_id)
return self._data
def _HasFlag(self, flag):
return self._flags & flag != 0
class Message(object):
"""A message for a message pipe. This contains data and handles."""
def __init__(self, data=None, handles=None):
self.data = data
self.handles = handles
self._header = None
self._payload = None
@property
def header(self):
if self._header is None:
self._header = MessageHeader.Deserialize(self.data)
return self._header
@property
def payload(self):
if self._payload is None:
self._payload = Message(self.data[self.header.size:], self.handles)
return self._payload
def SetRequestId(self, request_id):
header = self.header
header.request_id = request_id
(data, _) = header.Serialize()
self.data[:header.Size] = data[:header.Size]
class MessageReceiver(object):
......@@ -111,6 +233,12 @@ class Connector(MessageReceiver):
result = self._handle.WriteMessage(message.data, message.handles)
return result == system.RESULT_OK
def Close(self):
if self._cancellable:
self._cancellable()
self._cancellable = None
self._handle.Close()
def _OnAsyncWaiterResult(self, result):
self._cancellable = None
if result == system.RESULT_OK:
......@@ -141,6 +269,96 @@ class Connector(MessageReceiver):
self._OnError(result)
class Router(MessageReceiverWithResponder):
"""
A Router will handle mojo message and forward those to a Connector. It deals
with parsing of headers and adding of request ids in order to be able to match
a response to a request.
"""
def __init__(self, handle):
MessageReceiverWithResponder.__init__(self)
self._incoming_message_receiver = None
self._next_request_id = 1
self._responders = {}
self._connector = Connector(handle)
self._connector.SetIncomingMessageReceiver(
ForwardingMessageReceiver(self._HandleIncomingMessage))
def Start(self):
self._connector.Start()
def SetIncomingMessageReceiver(self, message_receiver):
"""
Set the MessageReceiver that will receive message from the owned message
pipe.
"""
self._incoming_message_receiver = message_receiver
def SetErrorHandler(self, error_handler):
"""
Set the ConnectionErrorHandler that will be notified of errors on the owned
message pipe.
"""
self._connector.SetErrorHandler(error_handler)
def Accept(self, message):
# A message without responder is directly forwarded to the connector.
return self._connector.Accept(message)
def AcceptWithResponder(self, message, responder):
# The message must have a header.
header = message.header
assert header.expects_response
request_id = self.NextRequestId()
header.request_id = request_id
if not self._connector.Accept(message):
return False
self._responders[request_id] = responder
return True
def Close(self):
self._connector.Close()
def _HandleIncomingMessage(self, message):
header = message.header
if header.expects_response:
if self._incoming_message_receiver:
return self._incoming_message_receiver.AcceptWithResponder(
message, self)
# If we receive a request expecting a response when the client is not
# listening, then we have no choice but to tear down the pipe.
self.Close()
return False
if header.is_response:
request_id = header.request_id
responder = self._responders.pop(request_id, None)
if responder is None:
return False
return responder.Accept(message)
if self._incoming_message_receiver:
return self._incoming_message_receiver.Accept(message)
# Ok to drop the message
return False
def NextRequestId(self):
request_id = self._next_request_id
while request_id == 0 or request_id in self._responders:
request_id = (request_id + 1) % (1 << 64)
self._next_request_id = (request_id + 1) % (1 << 64)
return request_id
class ForwardingMessageReceiver(MessageReceiver):
"""A MessageReceiver that forward calls to |Accept| to a callable."""
def __init__(self, callback):
MessageReceiver.__init__(self)
self._callback = callback
def Accept(self, message):
return self._callback(message)
def _WeakCallback(callback):
func = callback.im_func
self = callback.im_self
......@@ -165,3 +383,5 @@ def _ReadAndDispatchMessage(handle, message_receiver):
message_receiver.Accept(Message(data[0], data[1]))
return result
def _HasRequestId(flags):
return flags & (MESSAGE_EXPECTS_RESPONSE_FLAG|MESSAGE_IS_RESPONSE_FLAG) != 0
......@@ -10,16 +10,6 @@ from mojo.bindings import messaging
from mojo import system
class _ForwardingMessageReceiver(messaging.MessageReceiver):
def __init__(self, callback):
self._callback = callback
def Accept(self, message):
self._callback(message)
return True
class _ForwardingConnectionErrorHandler(messaging.ConnectionErrorHandler):
def __init__(self, callback):
......@@ -29,7 +19,7 @@ class _ForwardingConnectionErrorHandler(messaging.ConnectionErrorHandler):
self._callback(result)
class MessagingTest(unittest.TestCase):
class ConnectorTest(unittest.TestCase):
def setUp(self):
mojo.embedder.Init()
......@@ -38,12 +28,13 @@ class MessagingTest(unittest.TestCase):
self.received_errors = []
def _OnMessage(message):
self.received_messages.append(message)
return True
def _OnError(result):
self.received_errors.append(result)
handles = system.MessagePipe()
self.connector = messaging.Connector(handles.handle1)
self.connector.SetIncomingMessageReceiver(
_ForwardingMessageReceiver(_OnMessage))
messaging.ForwardingMessageReceiver(_OnMessage))
self.connector.SetErrorHandler(
_ForwardingConnectionErrorHandler(_OnError))
self.connector.Start()
......@@ -79,3 +70,138 @@ class MessagingTest(unittest.TestCase):
self.connector = None
(result, _, _) = self.handle.ReadMessage()
self.assertEquals(result, system.RESULT_FAILED_PRECONDITION)
class HeaderTest(unittest.TestCase):
def testSimpleMessageHeader(self):
header = messaging.MessageHeader(0xdeadbeaf, messaging.NO_FLAG)
self.assertEqual(header.message_type, 0xdeadbeaf)
self.assertFalse(header.has_request_id)
self.assertFalse(header.expects_response)
self.assertFalse(header.is_response)
data = header.Serialize()
other_header = messaging.MessageHeader.Deserialize(data)
self.assertEqual(other_header.message_type, 0xdeadbeaf)
self.assertFalse(other_header.has_request_id)
self.assertFalse(other_header.expects_response)
self.assertFalse(other_header.is_response)
def testMessageHeaderWithRequestID(self):
# Request message.
header = messaging.MessageHeader(0xdeadbeaf,
messaging.MESSAGE_EXPECTS_RESPONSE_FLAG)
self.assertEqual(header.message_type, 0xdeadbeaf)
self.assertTrue(header.has_request_id)
self.assertTrue(header.expects_response)
self.assertFalse(header.is_response)
self.assertEqual(header.request_id, 0)
data = header.Serialize()
other_header = messaging.MessageHeader.Deserialize(data)
self.assertEqual(other_header.message_type, 0xdeadbeaf)
self.assertTrue(other_header.has_request_id)
self.assertTrue(other_header.expects_response)
self.assertFalse(other_header.is_response)
self.assertEqual(other_header.request_id, 0)
header.request_id = 0xdeadbeafdeadbeaf
data = header.Serialize()
other_header = messaging.MessageHeader.Deserialize(data)
self.assertEqual(other_header.request_id, 0xdeadbeafdeadbeaf)
# Response message.
header = messaging.MessageHeader(0xdeadbeaf,
messaging.MESSAGE_IS_RESPONSE_FLAG,
0xdeadbeafdeadbeaf)
self.assertEqual(header.message_type, 0xdeadbeaf)
self.assertTrue(header.has_request_id)
self.assertFalse(header.expects_response)
self.assertTrue(header.is_response)
self.assertEqual(header.request_id, 0xdeadbeafdeadbeaf)
data = header.Serialize()
other_header = messaging.MessageHeader.Deserialize(data)
self.assertEqual(other_header.message_type, 0xdeadbeaf)
self.assertTrue(other_header.has_request_id)
self.assertFalse(other_header.expects_response)
self.assertTrue(other_header.is_response)
self.assertEqual(other_header.request_id, 0xdeadbeafdeadbeaf)
class RouterTest(unittest.TestCase):
def setUp(self):
mojo.embedder.Init()
self.loop = system.RunLoop()
self.received_messages = []
self.received_errors = []
def _OnMessage(message):
self.received_messages.append(message)
return True
def _OnError(result):
self.received_errors.append(result)
handles = system.MessagePipe()
self.router = messaging.Router(handles.handle1)
self.router.SetIncomingMessageReceiver(
messaging.ForwardingMessageReceiver(_OnMessage))
self.router.SetErrorHandler(
_ForwardingConnectionErrorHandler(_OnError))
self.router.Start()
self.handle = handles.handle0
def tearDown(self):
self.router = None
self.handle = None
self.loop = None
def testSimpleMessage(self):
header_data = messaging.MessageHeader(0, messaging.NO_FLAG).Serialize()
message = messaging.Message(header_data)
self.router.Accept(message)
self.loop.RunUntilIdle()
self.assertFalse(self.received_errors)
self.assertFalse(self.received_messages)
(res, data, _) = self.handle.ReadMessage(bytearray(len(header_data)))
self.assertEquals(system.RESULT_OK, res)
self.assertEquals(data[0], header_data)
def testSimpleReception(self):
header_data = messaging.MessageHeader(0, messaging.NO_FLAG).Serialize()
self.handle.WriteMessage(header_data)
self.loop.RunUntilIdle()
self.assertFalse(self.received_errors)
self.assertEquals(len(self.received_messages), 1)
self.assertEquals(self.received_messages[0].data, header_data)
def testRequestResponse(self):
header_data = messaging.MessageHeader(
0, messaging.MESSAGE_EXPECTS_RESPONSE_FLAG).Serialize()
message = messaging.Message(header_data)
back_messages = []
def OnBackMessage(message):
back_messages.append(message)
self.router.AcceptWithResponder(message,
messaging.ForwardingMessageReceiver(
OnBackMessage))
self.loop.RunUntilIdle()
self.assertFalse(self.received_errors)
self.assertFalse(self.received_messages)
(res, data, _) = self.handle.ReadMessage(bytearray(len(header_data)))
self.assertEquals(system.RESULT_OK, res)
message_header = messaging.MessageHeader.Deserialize(data[0])
self.assertNotEquals(message_header.request_id, 0)
response_header_data = messaging.MessageHeader(
0,
messaging.MESSAGE_IS_RESPONSE_FLAG,
message_header.request_id).Serialize()
self.handle.WriteMessage(response_header_data)
self.loop.RunUntilIdle()
self.assertFalse(self.received_errors)
self.assertEquals(len(back_messages), 1)
self.assertEquals(back_messages[0].data, response_header_data)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment