diff options
-rw-r--r-- | openstack/common/rpc/amqp.py | 46 | ||||
-rw-r--r-- | openstack/common/rpc/common.py | 4 | ||||
-rw-r--r-- | tests/unit/rpc/amqp.py | 32 |
3 files changed, 82 insertions, 0 deletions
diff --git a/openstack/common/rpc/amqp.py b/openstack/common/rpc/amqp.py index 98cfd9d..2ae7af2 100644 --- a/openstack/common/rpc/amqp.py +++ b/openstack/common/rpc/amqp.py @@ -25,6 +25,7 @@ Specifically, this includes impl_kombu and impl_qpid. impl_carrot also uses AMQP, but is deprecated and predates this code. """ +import collections import inspect import sys import uuid @@ -54,6 +55,7 @@ amqp_opts = [ cfg.CONF.register_opts(amqp_opts) +UNIQUE_ID = '_unique_id' LOG = logging.getLogger(__name__) @@ -236,6 +238,7 @@ def msg_reply(conf, msg_id, reply_q, connection_pool, reply=None, 'failure': failure} if ending: msg['ending'] = True + _add_unique_id(msg) # If a reply_q exists, add the msg_id to the reply and pass the # reply_q to direct_send() to use it as the response queue. # Otherwise use the msg_id for backward compatibilty. @@ -302,6 +305,37 @@ def pack_context(msg, context): msg.update(context_d) +class _MsgIdCache(object): + """This class checks any duplicate messages.""" + + # NOTE: This value is considered can be a configuration item, but + # it is not necessary to change its value in most cases, + # so let this value as static for now. + DUP_MSG_CHECK_SIZE = 16 + + def __init__(self, **kwargs): + self.prev_msgids = collections.deque([], + maxlen=self.DUP_MSG_CHECK_SIZE) + + def check_duplicate_message(self, message_data): + """AMQP consumers may read same message twice when exceptions occur + before ack is returned. This method prevents doing it. + """ + if UNIQUE_ID in message_data: + msg_id = message_data[UNIQUE_ID] + if msg_id not in self.prev_msgids: + self.prev_msgids.append(msg_id) + else: + raise rpc_common.DuplicateMessageError(msg_id=msg_id) + + +def _add_unique_id(msg): + """Add unique_id for checking duplicate messages.""" + unique_id = uuid.uuid4().hex + msg.update({UNIQUE_ID: unique_id}) + LOG.debug(_('UNIQUE_ID is %s.') % (unique_id)) + + class _ThreadPoolWithWait(object): """Base class for a delayed invocation manager used by the Connection class to start up green threads @@ -349,6 +383,7 @@ class ProxyCallback(_ThreadPoolWithWait): connection_pool=connection_pool, ) self.proxy = proxy + self.msg_id_cache = _MsgIdCache() def __call__(self, message_data): """Consumer callback to call a method on a proxy object. @@ -368,6 +403,7 @@ class ProxyCallback(_ThreadPoolWithWait): if hasattr(local.store, 'context'): del local.store.context rpc_common._safe_log(LOG.debug, _('received %s'), message_data) + self.msg_id_cache.check_duplicate_message(message_data) ctxt = unpack_context(self.conf, message_data) method = message_data.get('method') args = message_data.get('args', {}) @@ -422,6 +458,7 @@ class MulticallProxyWaiter(object): self._dataqueue = queue.LightQueue() # Add this caller to the reply proxy's call_waiters self._reply_proxy.add_call_waiter(self, self._msg_id) + self.msg_id_cache = _MsgIdCache() def put(self, data): self._dataqueue.put(data) @@ -435,6 +472,7 @@ class MulticallProxyWaiter(object): def _process_data(self, data): result = None + self.msg_id_cache.check_duplicate_message(data) if data['failure']: failure = data['failure'] result = rpc_common.deserialize_remote_exception(self._conf, @@ -479,6 +517,7 @@ class MulticallWaiter(object): self._done = False self._got_ending = False self._conf = conf + self.msg_id_cache = _MsgIdCache() def done(self): if self._done: @@ -490,6 +529,7 @@ class MulticallWaiter(object): def __call__(self, data): """The consume() callback will call this. Store the result.""" + self.msg_id_cache.check_duplicate_message(data) if data['failure']: failure = data['failure'] self._result = rpc_common.deserialize_remote_exception(self._conf, @@ -542,6 +582,7 @@ def multicall(conf, context, topic, msg, timeout, connection_pool): msg_id = uuid.uuid4().hex msg.update({'_msg_id': msg_id}) LOG.debug(_('MSG_ID is %s') % (msg_id)) + _add_unique_id(msg) pack_context(msg, context) # TODO(pekowski): Remove this flag and the code under the if clause @@ -575,6 +616,7 @@ def call(conf, context, topic, msg, timeout, connection_pool): def cast(conf, context, topic, msg, connection_pool): """Sends a message on a topic without waiting for a response.""" LOG.debug(_('Making asynchronous cast on %s...'), topic) + _add_unique_id(msg) pack_context(msg, context) with ConnectionContext(conf, connection_pool) as conn: conn.topic_send(topic, rpc_common.serialize_msg(msg)) @@ -583,6 +625,7 @@ def cast(conf, context, topic, msg, connection_pool): def fanout_cast(conf, context, topic, msg, connection_pool): """Sends a message on a fanout exchange without waiting for a response.""" LOG.debug(_('Making asynchronous fanout cast...')) + _add_unique_id(msg) pack_context(msg, context) with ConnectionContext(conf, connection_pool) as conn: conn.fanout_send(topic, rpc_common.serialize_msg(msg)) @@ -590,6 +633,7 @@ def fanout_cast(conf, context, topic, msg, connection_pool): def cast_to_server(conf, context, server_params, topic, msg, connection_pool): """Sends a message on a topic to a specific server.""" + _add_unique_id(msg) pack_context(msg, context) with ConnectionContext(conf, connection_pool, pooled=False, server_params=server_params) as conn: @@ -599,6 +643,7 @@ def cast_to_server(conf, context, server_params, topic, msg, connection_pool): def fanout_cast_to_server(conf, context, server_params, topic, msg, connection_pool): """Sends a message on a fanout exchange to a specific server.""" + _add_unique_id(msg) pack_context(msg, context) with ConnectionContext(conf, connection_pool, pooled=False, server_params=server_params) as conn: @@ -610,6 +655,7 @@ def notify(conf, context, topic, msg, connection_pool, envelope): LOG.debug(_('Sending %(event_type)s on %(topic)s'), dict(event_type=msg.get('event_type'), topic=topic)) + _add_unique_id(msg) pack_context(msg, context) with ConnectionContext(conf, connection_pool) as conn: if envelope: diff --git a/openstack/common/rpc/common.py b/openstack/common/rpc/common.py index 6c52bd8..5661bef 100644 --- a/openstack/common/rpc/common.py +++ b/openstack/common/rpc/common.py @@ -125,6 +125,10 @@ class Timeout(RPCException): message = _("Timeout while waiting on RPC response.") +class DuplicateMessageError(RPCException): + message = _("Found duplicate message(%(msg_id)s). Skipping it.") + + class InvalidRPCConnectionReuse(RPCException): message = _("Invalid reuse of an RPC connection.") diff --git a/tests/unit/rpc/amqp.py b/tests/unit/rpc/amqp.py index 1e4733c..4a96ce0 100644 --- a/tests/unit/rpc/amqp.py +++ b/tests/unit/rpc/amqp.py @@ -81,6 +81,11 @@ class BaseRpcAMQPTestCase(common.BaseRpcTestCase): def fake_notify_send(_conn, topic, msg): self.test_msg = msg + def remove_unique_id(msg): + oslo_msg = jsonutils.loads(msg['oslo.message']) + oslo_msg.pop('_unique_id') + msg['oslo.message'] = jsonutils.dumps(oslo_msg) + self.stubs.Set(self.rpc.Connection, 'notify_send', fake_notify_send) self.rpc.notify(FLAGS, self.context, 'notifications.info', raw_msg, @@ -100,6 +105,8 @@ class BaseRpcAMQPTestCase(common.BaseRpcTestCase): } self.rpc.notify(FLAGS, self.context, 'notifications.info', raw_msg, envelope=True) + remove_unique_id(self.test_msg) + remove_unique_id(msg) self.assertEqual(self.test_msg, msg) # Make sure envelopes are still on notifications, even if turned off @@ -107,6 +114,7 @@ class BaseRpcAMQPTestCase(common.BaseRpcTestCase): self.stubs.Set(rpc_common, '_SEND_RPC_ENVELOPE', False) self.rpc.notify(FLAGS, self.context, 'notifications.info', raw_msg, envelope=True) + remove_unique_id(self.test_msg) self.assertEqual(self.test_msg, msg) def test_single_reply_queue_on_has_ids( @@ -233,3 +241,27 @@ class BaseRpcAMQPTestCase(common.BaseRpcTestCase): self.config(amqp_rpc_single_reply_queue=True) self.test_multithreaded_resp_routing() self.config(amqp_rpc_single_reply_queue=False) + + def test_duplicate_message_check(self): + """Test sending *not-dict* to a topic exchange/queue""" + + conn = self.rpc.create_connection(FLAGS) + message = {'args': 'topic test message', '_unique_id': 'aaaabbbbcccc'} + + self.received_message = None + cache = rpc_amqp._MsgIdCache() + self.exc_raised = False + + def _callback(message): + try: + cache.check_duplicate_message(message) + except rpc_common.DuplicateMessageError: + self.exc_raised = True + + conn.declare_topic_consumer('a_topic', _callback) + conn.topic_send('a_topic', rpc_common.serialize_msg(message)) + conn.topic_send('a_topic', rpc_common.serialize_msg(message)) + conn.consume(limit=2) + conn.close() + + self.assertTrue(self.exc_raised) |