summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--openstack/common/rpc/__init__.py248
-rw-r--r--openstack/common/rpc/amqp.py416
-rw-r--r--openstack/common/rpc/common.py316
-rw-r--r--openstack/common/rpc/dispatcher.py105
-rw-r--r--openstack/common/rpc/impl_fake.py184
-rw-r--r--openstack/common/rpc/impl_kombu.py758
-rw-r--r--openstack/common/rpc/impl_qpid.py580
-rw-r--r--openstack/common/rpc/matchmaker.py257
-rw-r--r--openstack/common/rpc/proxy.py161
-rw-r--r--tests/unit/rpc/__init__.py15
-rw-r--r--tests/unit/rpc/common.py322
-rw-r--r--tests/unit/rpc/test_common.py150
-rw-r--r--tests/unit/rpc/test_dispatcher.py110
-rw-r--r--tests/unit/rpc/test_fake.py32
-rw-r--r--tests/unit/rpc/test_kombu.py414
-rw-r--r--tests/unit/rpc/test_kombu_ssl.py82
-rw-r--r--tests/unit/rpc/test_matchmaker.py60
-rw-r--r--tests/unit/rpc/test_proxy.py128
-rw-r--r--tests/unit/rpc/test_qpid.py377
-rw-r--r--tools/pip-requires1
20 files changed, 4716 insertions, 0 deletions
diff --git a/openstack/common/rpc/__init__.py b/openstack/common/rpc/__init__.py
new file mode 100644
index 0000000..116aa84
--- /dev/null
+++ b/openstack/common/rpc/__init__.py
@@ -0,0 +1,248 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+# Copyright 2011 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+A remote procedure call (rpc) abstraction.
+
+For some wrappers that add message versioning to rpc, see:
+ rpc.dispatcher
+ rpc.proxy
+"""
+
+from openstack.common import cfg
+from openstack.common import importutils
+
+
+rpc_opts = [
+ cfg.StrOpt('rpc_backend',
+ default='nova.rpc.impl_kombu',
+ help="The messaging module to use, defaults to kombu."),
+ cfg.IntOpt('rpc_thread_pool_size',
+ default=64,
+ help='Size of RPC thread pool'),
+ cfg.IntOpt('rpc_conn_pool_size',
+ default=30,
+ help='Size of RPC connection pool'),
+ cfg.IntOpt('rpc_response_timeout',
+ default=60,
+ help='Seconds to wait for a response from call or multicall'),
+ cfg.ListOpt('allowed_rpc_exception_modules',
+ default=['openstack.common.exception', 'nova.exception'],
+ help='Modules of exceptions that are permitted to be recreated'
+ 'upon receiving exception data from an rpc call.'),
+ cfg.StrOpt('control_exchange',
+ default='nova',
+ help='AMQP exchange to connect to if using RabbitMQ or Qpid'),
+ cfg.BoolOpt('fake_rabbit',
+ default=False,
+ help='If passed, use a fake RabbitMQ provider'),
+ ]
+
+cfg.CONF.register_opts(rpc_opts)
+
+
+def create_connection(new=True):
+ """Create a connection to the message bus used for rpc.
+
+ For some example usage of creating a connection and some consumers on that
+ connection, see nova.service.
+
+ :param new: Whether or not to create a new connection. A new connection
+ will be created by default. If new is False, the
+ implementation is free to return an existing connection from a
+ pool.
+
+ :returns: An instance of nova.rpc.common.Connection
+ """
+ return _get_impl().create_connection(cfg.CONF, new=new)
+
+
+def call(context, topic, msg, timeout=None):
+ """Invoke a remote method that returns something.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the rpc message to. This correlates to the
+ topic argument of
+ nova.rpc.common.Connection.create_consumer() and only applies
+ when the consumer was created with fanout=False.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+ :param timeout: int, number of seconds to use for a response timeout.
+ If set, this overrides the rpc_response_timeout option.
+
+ :returns: A dict from the remote method.
+
+ :raises: nova.rpc.common.Timeout if a complete response is not received
+ before the timeout is reached.
+ """
+ return _get_impl().call(cfg.CONF, context, topic, msg, timeout)
+
+
+def cast(context, topic, msg):
+ """Invoke a remote method that does not return anything.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the rpc message to. This correlates to the
+ topic argument of
+ nova.rpc.common.Connection.create_consumer() and only applies
+ when the consumer was created with fanout=False.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+
+ :returns: None
+ """
+ return _get_impl().cast(cfg.CONF, context, topic, msg)
+
+
+def fanout_cast(context, topic, msg):
+ """Broadcast a remote method invocation with no return.
+
+ This method will get invoked on all consumers that were set up with this
+ topic name and fanout=True.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the rpc message to. This correlates to the
+ topic argument of
+ nova.rpc.common.Connection.create_consumer() and only applies
+ when the consumer was created with fanout=True.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+
+ :returns: None
+ """
+ return _get_impl().fanout_cast(cfg.CONF, context, topic, msg)
+
+
+def multicall(context, topic, msg, timeout=None):
+ """Invoke a remote method and get back an iterator.
+
+ In this case, the remote method will be returning multiple values in
+ separate messages, so the return values can be processed as the come in via
+ an iterator.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the rpc message to. This correlates to the
+ topic argument of
+ nova.rpc.common.Connection.create_consumer() and only applies
+ when the consumer was created with fanout=False.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+ :param timeout: int, number of seconds to use for a response timeout.
+ If set, this overrides the rpc_response_timeout option.
+
+ :returns: An iterator. The iterator will yield a tuple (N, X) where N is
+ an index that starts at 0 and increases by one for each value
+ returned and X is the Nth value that was returned by the remote
+ method.
+
+ :raises: nova.rpc.common.Timeout if a complete response is not received
+ before the timeout is reached.
+ """
+ return _get_impl().multicall(cfg.CONF, context, topic, msg, timeout)
+
+
+def notify(context, topic, msg):
+ """Send notification event.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the notification to.
+ :param msg: This is a dict of content of event.
+
+ :returns: None
+ """
+ return _get_impl().notify(cfg.CONF, context, topic, msg)
+
+
+def cleanup():
+ """Clean up resoruces in use by implementation.
+
+ Clean up any resources that have been allocated by the RPC implementation.
+ This is typically open connections to a messaging service. This function
+ would get called before an application using this API exits to allow
+ connections to get torn down cleanly.
+
+ :returns: None
+ """
+ return _get_impl().cleanup()
+
+
+def cast_to_server(context, server_params, topic, msg):
+ """Invoke a remote method that does not return anything.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param server_params: Connection information
+ :param topic: The topic to send the notification to.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+
+ :returns: None
+ """
+ return _get_impl().cast_to_server(cfg.CONF, context, server_params, topic,
+ msg)
+
+
+def fanout_cast_to_server(context, server_params, topic, msg):
+ """Broadcast to a remote method invocation with no return.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param server_params: Connection information
+ :param topic: The topic to send the notification to.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+
+ :returns: None
+ """
+ return _get_impl().fanout_cast_to_server(cfg.CONF, context, server_params,
+ topic, msg)
+
+
+def queue_get_for(context, topic, host):
+ """Get a queue name for a given topic + host.
+
+ This function only works if this naming convention is followed on the
+ consumer side, as well. For example, in nova, every instance of the
+ nova-foo service calls create_consumer() for two topics:
+
+ foo
+ foo.<host>
+
+ Messages sent to the 'foo' topic are distributed to exactly one instance of
+ the nova-foo service. The services are chosen in a round-robin fashion.
+ Messages sent to the 'foo.<host>' topic are sent to the nova-foo service on
+ <host>.
+ """
+ return '%s.%s' % (topic, host)
+
+
+_RPCIMPL = None
+
+
+def _get_impl():
+ """Delay import of rpc_backend until configuration is loaded."""
+ global _RPCIMPL
+ if _RPCIMPL is None:
+ _RPCIMPL = importutils.import_module(cfg.CONF.rpc_backend)
+ return _RPCIMPL
diff --git a/openstack/common/rpc/amqp.py b/openstack/common/rpc/amqp.py
new file mode 100644
index 0000000..a79a3aa
--- /dev/null
+++ b/openstack/common/rpc/amqp.py
@@ -0,0 +1,416 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+# Copyright 2011 - 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Shared code between AMQP based nova.rpc implementations.
+
+The code in this module is shared between the rpc implemenations based on AMQP.
+Specifically, this includes impl_kombu and impl_qpid. impl_carrot also uses
+AMQP, but is deprecated and predates this code.
+"""
+
+import inspect
+import logging
+import sys
+import uuid
+
+from eventlet import greenpool
+from eventlet import pools
+from eventlet import semaphore
+
+from openstack.common import excutils
+from openstack.common import local
+import openstack.common.rpc.common as rpc_common
+
+
+LOG = logging.getLogger(__name__)
+
+
+class Pool(pools.Pool):
+ """Class that implements a Pool of Connections."""
+ def __init__(self, conf, connection_cls, *args, **kwargs):
+ self.connection_cls = connection_cls
+ self.conf = conf
+ kwargs.setdefault("max_size", self.conf.rpc_conn_pool_size)
+ kwargs.setdefault("order_as_stack", True)
+ super(Pool, self).__init__(*args, **kwargs)
+
+ # TODO(comstud): Timeout connections not used in a while
+ def create(self):
+ LOG.debug('Pool creating new connection')
+ return self.connection_cls(self.conf)
+
+ def empty(self):
+ while self.free_items:
+ self.get().close()
+
+
+_pool_create_sem = semaphore.Semaphore()
+
+
+def get_connection_pool(conf, connection_cls):
+ with _pool_create_sem:
+ # Make sure only one thread tries to create the connection pool.
+ if not connection_cls.pool:
+ connection_cls.pool = Pool(conf, connection_cls)
+ return connection_cls.pool
+
+
+class ConnectionContext(rpc_common.Connection):
+ """The class that is actually returned to the caller of
+ create_connection(). This is essentially a wrapper around
+ Connection that supports 'with'. It can also return a new
+ Connection, or one from a pool. The function will also catch
+ when an instance of this class is to be deleted. With that
+ we can return Connections to the pool on exceptions and so
+ forth without making the caller be responsible for catching
+ them. If possible the function makes sure to return a
+ connection to the pool.
+ """
+
+ def __init__(self, conf, connection_pool, pooled=True, server_params=None):
+ """Create a new connection, or get one from the pool"""
+ self.connection = None
+ self.conf = conf
+ self.connection_pool = connection_pool
+ if pooled:
+ self.connection = connection_pool.get()
+ else:
+ self.connection = connection_pool.connection_cls(conf,
+ server_params=server_params)
+ self.pooled = pooled
+
+ def __enter__(self):
+ """When with ConnectionContext() is used, return self"""
+ return self
+
+ def _done(self):
+ """If the connection came from a pool, clean it up and put it back.
+ If it did not come from a pool, close it.
+ """
+ if self.connection:
+ if self.pooled:
+ # Reset the connection so it's ready for the next caller
+ # to grab from the pool
+ self.connection.reset()
+ self.connection_pool.put(self.connection)
+ else:
+ try:
+ self.connection.close()
+ except Exception:
+ pass
+ self.connection = None
+
+ def __exit__(self, exc_type, exc_value, tb):
+ """End of 'with' statement. We're done here."""
+ self._done()
+
+ def __del__(self):
+ """Caller is done with this connection. Make sure we cleaned up."""
+ self._done()
+
+ def close(self):
+ """Caller is done with this connection."""
+ self._done()
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ self.connection.create_consumer(topic, proxy, fanout)
+
+ def create_worker(self, topic, proxy, pool_name):
+ self.connection.create_worker(topic, proxy, pool_name)
+
+ def consume_in_thread(self):
+ self.connection.consume_in_thread()
+
+ def __getattr__(self, key):
+ """Proxy all other calls to the Connection instance"""
+ if self.connection:
+ return getattr(self.connection, key)
+ else:
+ raise rpc_common.InvalidRPCConnectionReuse()
+
+
+def msg_reply(conf, msg_id, connection_pool, reply=None, failure=None,
+ ending=False):
+ """Sends a reply or an error on the channel signified by msg_id.
+
+ Failure should be a sys.exc_info() tuple.
+
+ """
+ with ConnectionContext(conf, connection_pool) as conn:
+ if failure:
+ failure = rpc_common.serialize_remote_exception(failure)
+
+ try:
+ msg = {'result': reply, 'failure': failure}
+ except TypeError:
+ msg = {'result': dict((k, repr(v))
+ for k, v in reply.__dict__.iteritems()),
+ 'failure': failure}
+ if ending:
+ msg['ending'] = True
+ conn.direct_send(msg_id, msg)
+
+
+class RpcContext(rpc_common.CommonRpcContext):
+ """Context that supports replying to a rpc.call"""
+ def __init__(self, **kwargs):
+ self.msg_id = kwargs.pop('msg_id', None)
+ self.conf = kwargs.pop('conf')
+ super(RpcContext, self).__init__(**kwargs)
+
+ def deepcopy(self):
+ values = self.to_dict()
+ values['conf'] = self.conf
+ values['msg_id'] = self.msg_id
+ return self.__class__(**values)
+
+ def reply(self, reply=None, failure=None, ending=False,
+ connection_pool=None):
+ if self.msg_id:
+ msg_reply(self.conf, self.msg_id, connection_pool, reply, failure,
+ ending)
+ if ending:
+ self.msg_id = None
+
+
+def unpack_context(conf, msg):
+ """Unpack context from msg."""
+ context_dict = {}
+ for key in list(msg.keys()):
+ # NOTE(vish): Some versions of python don't like unicode keys
+ # in kwargs.
+ key = str(key)
+ if key.startswith('_context_'):
+ value = msg.pop(key)
+ context_dict[key[9:]] = value
+ context_dict['msg_id'] = msg.pop('_msg_id', None)
+ context_dict['conf'] = conf
+ ctx = RpcContext.from_dict(context_dict)
+ rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict())
+ return ctx
+
+
+def pack_context(msg, context):
+ """Pack context into msg.
+
+ Values for message keys need to be less than 255 chars, so we pull
+ context out into a bunch of separate keys. If we want to support
+ more arguments in rabbit messages, we may want to do the same
+ for args at some point.
+
+ """
+ context_d = dict([('_context_%s' % key, value)
+ for (key, value) in context.to_dict().iteritems()])
+ msg.update(context_d)
+
+
+class ProxyCallback(object):
+ """Calls methods on a proxy object based on method and args."""
+
+ def __init__(self, conf, proxy, connection_pool):
+ self.proxy = proxy
+ self.pool = greenpool.GreenPool(conf.rpc_thread_pool_size)
+ self.connection_pool = connection_pool
+ self.conf = conf
+
+ def __call__(self, message_data):
+ """Consumer callback to call a method on a proxy object.
+
+ Parses the message for validity and fires off a thread to call the
+ proxy object method.
+
+ Message data should be a dictionary with two keys:
+ method: string representing the method to call
+ args: dictionary of arg: value
+
+ Example: {'method': 'echo', 'args': {'value': 42}}
+
+ """
+ # It is important to clear the context here, because at this point
+ # the previous context is stored in local.store.context
+ if hasattr(local.store, 'context'):
+ del local.store.context
+ rpc_common._safe_log(LOG.debug, _('received %s'), message_data)
+ ctxt = unpack_context(self.conf, message_data)
+ method = message_data.get('method')
+ args = message_data.get('args', {})
+ version = message_data.get('version', None)
+ if not method:
+ LOG.warn(_('no method for message: %s') % message_data)
+ ctxt.reply(_('No method for message: %s') % message_data,
+ connection_pool=self.connection_pool)
+ return
+ self.pool.spawn_n(self._process_data, ctxt, version, method, args)
+
+ def _process_data(self, ctxt, version, method, args):
+ """Process a message in a new thread.
+
+ If the proxy object we have has a dispatch method
+ (see rpc.dispatcher.RpcDispatcher), pass it the version,
+ method, and args and let it dispatch as appropriate. If not, use
+ the old behavior of magically calling the specified method on the
+ proxy we have here.
+ """
+ ctxt.update_store()
+ try:
+ rval = self.proxy.dispatch(ctxt, version, method, **args)
+ # Check if the result was a generator
+ if inspect.isgenerator(rval):
+ for x in rval:
+ ctxt.reply(x, None, connection_pool=self.connection_pool)
+ else:
+ ctxt.reply(rval, None, connection_pool=self.connection_pool)
+ # This final None tells multicall that it is done.
+ ctxt.reply(ending=True, connection_pool=self.connection_pool)
+ except Exception as e:
+ LOG.exception('Exception during message handling')
+ ctxt.reply(None, sys.exc_info(),
+ connection_pool=self.connection_pool)
+
+
+class MulticallWaiter(object):
+ def __init__(self, conf, connection, timeout):
+ self._connection = connection
+ self._iterator = connection.iterconsume(
+ timeout=timeout or conf.rpc_response_timeout)
+ self._result = None
+ self._done = False
+ self._got_ending = False
+ self._conf = conf
+
+ def done(self):
+ if self._done:
+ return
+ self._done = True
+ self._iterator.close()
+ self._iterator = None
+ self._connection.close()
+
+ def __call__(self, data):
+ """The consume() callback will call this. Store the result."""
+ if data['failure']:
+ failure = data['failure']
+ self._result = rpc_common.deserialize_remote_exception(self._conf,
+ failure)
+
+ elif data.get('ending', False):
+ self._got_ending = True
+ else:
+ self._result = data['result']
+
+ def __iter__(self):
+ """Return a result until we get a 'None' response from consumer"""
+ if self._done:
+ raise StopIteration
+ while True:
+ try:
+ self._iterator.next()
+ except Exception:
+ with excutils.save_and_reraise_exception():
+ self.done()
+ if self._got_ending:
+ self.done()
+ raise StopIteration
+ result = self._result
+ if isinstance(result, Exception):
+ self.done()
+ raise result
+ yield result
+
+
+def create_connection(conf, new, connection_pool):
+ """Create a connection"""
+ return ConnectionContext(conf, connection_pool, pooled=not new)
+
+
+def multicall(conf, context, topic, msg, timeout, connection_pool):
+ """Make a call that returns multiple times."""
+ # Can't use 'with' for multicall, as it returns an iterator
+ # that will continue to use the connection. When it's done,
+ # connection.close() will get called which will put it back into
+ # the pool
+ LOG.debug(_('Making asynchronous call on %s ...'), topic)
+ msg_id = uuid.uuid4().hex
+ msg.update({'_msg_id': msg_id})
+ LOG.debug(_('MSG_ID is %s') % (msg_id))
+ pack_context(msg, context)
+
+ conn = ConnectionContext(conf, connection_pool)
+ wait_msg = MulticallWaiter(conf, conn, timeout)
+ conn.declare_direct_consumer(msg_id, wait_msg)
+ conn.topic_send(topic, msg)
+ return wait_msg
+
+
+def call(conf, context, topic, msg, timeout, connection_pool):
+ """Sends a message on a topic and wait for a response."""
+ rv = multicall(conf, context, topic, msg, timeout, connection_pool)
+ # NOTE(vish): return the last result from the multicall
+ rv = list(rv)
+ if not rv:
+ return
+ return rv[-1]
+
+
+def cast(conf, context, topic, msg, connection_pool):
+ """Sends a message on a topic without waiting for a response."""
+ LOG.debug(_('Making asynchronous cast on %s...'), topic)
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool) as conn:
+ conn.topic_send(topic, msg)
+
+
+def fanout_cast(conf, context, topic, msg, connection_pool):
+ """Sends a message on a fanout exchange without waiting for a response."""
+ LOG.debug(_('Making asynchronous fanout cast...'))
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool) as conn:
+ conn.fanout_send(topic, msg)
+
+
+def cast_to_server(conf, context, server_params, topic, msg, connection_pool):
+ """Sends a message on a topic to a specific server."""
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool, pooled=False,
+ server_params=server_params) as conn:
+ conn.topic_send(topic, msg)
+
+
+def fanout_cast_to_server(conf, context, server_params, topic, msg,
+ connection_pool):
+ """Sends a message on a fanout exchange to a specific server."""
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool, pooled=False,
+ server_params=server_params) as conn:
+ conn.fanout_send(topic, msg)
+
+
+def notify(conf, context, topic, msg, connection_pool):
+ """Sends a notification event on a topic."""
+ event_type = msg.get('event_type')
+ LOG.debug(_('Sending %(event_type)s on %(topic)s'), locals())
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool) as conn:
+ conn.notify_send(topic, msg)
+
+
+def cleanup(connection_pool):
+ if connection_pool:
+ connection_pool.empty()
diff --git a/openstack/common/rpc/common.py b/openstack/common/rpc/common.py
new file mode 100644
index 0000000..6acd72c
--- /dev/null
+++ b/openstack/common/rpc/common.py
@@ -0,0 +1,316 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+# Copyright 2011 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import copy
+import logging
+import sys
+import traceback
+
+from openstack.common import cfg
+from openstack.common import importutils
+from openstack.common import jsonutils
+from openstack.common import local
+
+
+LOG = logging.getLogger(__name__)
+
+
+class RPCException(Exception):
+ message = _("An unknown RPC related exception occurred.")
+
+ def __init__(self, message=None, **kwargs):
+ self.kwargs = kwargs
+
+ if not message:
+ try:
+ message = self.message % kwargs
+
+ except Exception as e:
+ # kwargs doesn't match a variable in the message
+ # log the issue and the kwargs
+ LOG.exception(_('Exception in string format operation'))
+ for name, value in kwargs.iteritems():
+ LOG.error("%s: %s" % (name, value))
+ # at least get the core message out if something happened
+ message = self.message
+
+ super(RPCException, self).__init__(message)
+
+
+class RemoteError(RPCException):
+ """Signifies that a remote class has raised an exception.
+
+ Contains a string representation of the type of the original exception,
+ the value of the original exception, and the traceback. These are
+ sent to the parent as a joined string so printing the exception
+ contains all of the relevant info.
+
+ """
+ message = _("Remote error: %(exc_type)s %(value)s\n%(traceback)s.")
+
+ def __init__(self, exc_type=None, value=None, traceback=None):
+ self.exc_type = exc_type
+ self.value = value
+ self.traceback = traceback
+ super(RemoteError, self).__init__(exc_type=exc_type,
+ value=value,
+ traceback=traceback)
+
+
+class Timeout(RPCException):
+ """Signifies that a timeout has occurred.
+
+ This exception is raised if the rpc_response_timeout is reached while
+ waiting for a response from the remote side.
+ """
+ message = _("Timeout while waiting on RPC response.")
+
+
+class InvalidRPCConnectionReuse(RPCException):
+ message = _("Invalid reuse of an RPC connection.")
+
+
+class UnsupportedRpcVersion(RPCException):
+ message = _("Specified RPC version, %(version)s, not supported by "
+ "this endpoint.")
+
+
+class Connection(object):
+ """A connection, returned by rpc.create_connection().
+
+ This class represents a connection to the message bus used for rpc.
+ An instance of this class should never be created by users of the rpc API.
+ Use rpc.create_connection() instead.
+ """
+ def close(self):
+ """Close the connection.
+
+ This method must be called when the connection will no longer be used.
+ It will ensure that any resources associated with the connection, such
+ as a network connection, and cleaned up.
+ """
+ raise NotImplementedError()
+
+ def create_consumer(self, conf, topic, proxy, fanout=False):
+ """Create a consumer on this connection.
+
+ A consumer is associated with a message queue on the backend message
+ bus. The consumer will read messages from the queue, unpack them, and
+ dispatch them to the proxy object. The contents of the message pulled
+ off of the queue will determine which method gets called on the proxy
+ object.
+
+ :param conf: An openstack.common.cfg configuration object.
+ :param topic: This is a name associated with what to consume from.
+ Multiple instances of a service may consume from the same
+ topic. For example, all instances of nova-compute consume
+ from a queue called "compute". In that case, the
+ messages will get distributed amongst the consumers in a
+ round-robin fashion if fanout=False. If fanout=True,
+ every consumer associated with this topic will get a
+ copy of every message.
+ :param proxy: The object that will handle all incoming messages.
+ :param fanout: Whether or not this is a fanout topic. See the
+ documentation for the topic parameter for some
+ additional comments on this.
+ """
+ raise NotImplementedError()
+
+ def create_worker(self, conf, topic, proxy, pool_name):
+ """Create a worker on this connection.
+
+ A worker is like a regular consumer of messages directed to a
+ topic, except that it is part of a set of such consumers (the
+ "pool") which may run in parallel. Every pool of workers will
+ receive a given message, but only one worker in the pool will
+ be asked to process it. Load is distributed across the members
+ of the pool in round-robin fashion.
+
+ :param conf: An openstack.common.cfg configuration object.
+ :param topic: This is a name associated with what to consume from.
+ Multiple instances of a service may consume from the same
+ topic.
+ :param proxy: The object that will handle all incoming messages.
+ :param pool_name: String containing the name of the pool of workers
+ """
+ raise NotImplementedError()
+
+ def consume_in_thread(self):
+ """Spawn a thread to handle incoming messages.
+
+ Spawn a thread that will be responsible for handling all incoming
+ messages for consumers that were set up on this connection.
+
+ Message dispatching inside of this is expected to be implemented in a
+ non-blocking manner. An example implementation would be having this
+ thread pull messages in for all of the consumers, but utilize a thread
+ pool for dispatching the messages to the proxy objects.
+ """
+ raise NotImplementedError()
+
+
+def _safe_log(log_func, msg, msg_data):
+ """Sanitizes the msg_data field before logging."""
+ SANITIZE = {
+ 'set_admin_password': ('new_pass',),
+ 'run_instance': ('admin_password',),
+ }
+
+ has_method = 'method' in msg_data and msg_data['method'] in SANITIZE
+ has_context_token = '_context_auth_token' in msg_data
+ has_token = 'auth_token' in msg_data
+
+ if not any([has_method, has_context_token, has_token]):
+ return log_func(msg, msg_data)
+
+ msg_data = copy.deepcopy(msg_data)
+
+ if has_method:
+ method = msg_data['method']
+ if method in SANITIZE:
+ args_to_sanitize = SANITIZE[method]
+ for arg in args_to_sanitize:
+ try:
+ msg_data['args'][arg] = "<SANITIZED>"
+ except KeyError:
+ pass
+
+ if has_context_token:
+ msg_data['_context_auth_token'] = '<SANITIZED>'
+
+ if has_token:
+ msg_data['auth_token'] = '<SANITIZED>'
+
+ return log_func(msg, msg_data)
+
+
+def serialize_remote_exception(failure_info):
+ """Prepares exception data to be sent over rpc.
+
+ Failure_info should be a sys.exc_info() tuple.
+
+ """
+ tb = traceback.format_exception(*failure_info)
+ failure = failure_info[1]
+ LOG.error(_("Returning exception %s to caller"), unicode(failure))
+ LOG.error(tb)
+
+ kwargs = {}
+ if hasattr(failure, 'kwargs'):
+ kwargs = failure.kwargs
+
+ data = {
+ 'class': str(failure.__class__.__name__),
+ 'module': str(failure.__class__.__module__),
+ 'message': unicode(failure),
+ 'tb': tb,
+ 'args': failure.args,
+ 'kwargs': kwargs
+ }
+
+ json_data = jsonutils.dumps(data)
+
+ return json_data
+
+
+def deserialize_remote_exception(conf, data):
+ failure = jsonutils.loads(str(data))
+
+ trace = failure.get('tb', [])
+ message = failure.get('message', "") + "\n" + "\n".join(trace)
+ name = failure.get('class')
+ module = failure.get('module')
+
+ # NOTE(ameade): We DO NOT want to allow just any module to be imported, in
+ # order to prevent arbitrary code execution.
+ if not module in conf.allowed_rpc_exception_modules:
+ return RemoteError(name, failure.get('message'), trace)
+
+ try:
+ mod = importutils.import_module(module)
+ klass = getattr(mod, name)
+ if not issubclass(klass, Exception):
+ raise TypeError("Can only deserialize Exceptions")
+
+ failure = klass(**failure.get('kwargs', {}))
+ except (AttributeError, TypeError, ImportError):
+ return RemoteError(name, failure.get('message'), trace)
+
+ ex_type = type(failure)
+ str_override = lambda self: message
+ new_ex_type = type(ex_type.__name__ + "_Remote", (ex_type,),
+ {'__str__': str_override, '__unicode__': str_override})
+ try:
+ # NOTE(ameade): Dynamically create a new exception type and swap it in
+ # as the new type for the exception. This only works on user defined
+ # Exceptions and not core python exceptions. This is important because
+ # we cannot necessarily change an exception message so we must override
+ # the __str__ method.
+ failure.__class__ = new_ex_type
+ except TypeError as e:
+ # NOTE(ameade): If a core exception then just add the traceback to the
+ # first exception argument.
+ failure.args = (message,) + failure.args[1:]
+ return failure
+
+
+class CommonRpcContext(object):
+ def __init__(self, **kwargs):
+ self.values = kwargs
+
+ def __getattr__(self, key):
+ try:
+ return self.values[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def to_dict(self):
+ return copy.deepcopy(self.values)
+
+ @classmethod
+ def from_dict(cls, values):
+ return cls(**values)
+
+ def deepcopy(self):
+ return self.from_dict(self.to_dict())
+
+ def update_store(self):
+ local.store.context = self
+
+ def elevated(self, read_deleted=None, overwrite=False):
+ """Return a version of this context with admin flag set."""
+ # TODO(russellb) This method is a bit of a nova-ism. It makes
+ # some assumptions about the data in the request context sent
+ # across rpc, while the rest of this class does not. We could get
+ # rid of this if we changed the nova code that uses this to
+ # convert the RpcContext back to its native RequestContext doing
+ # something like nova.context.RequestContext.from_dict(ctxt.to_dict())
+
+ context = self.deepcopy()
+ context.values['is_admin'] = True
+
+ context.values.setdefault('roles', [])
+
+ if 'admin' not in context.values['roles']:
+ context.values['roles'].append('admin')
+
+ if read_deleted is not None:
+ context.values['read_deleted'] = read_deleted
+
+ return context
diff --git a/openstack/common/rpc/dispatcher.py b/openstack/common/rpc/dispatcher.py
new file mode 100644
index 0000000..7319eb2
--- /dev/null
+++ b/openstack/common/rpc/dispatcher.py
@@ -0,0 +1,105 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Code for rpc message dispatching.
+
+Messages that come in have a version number associated with them. RPC API
+version numbers are in the form:
+
+ Major.Minor
+
+For a given message with version X.Y, the receiver must be marked as able to
+handle messages of version A.B, where:
+
+ A = X
+
+ B >= Y
+
+The Major version number would be incremented for an almost completely new API.
+The Minor version number would be incremented for backwards compatible changes
+to an existing API. A backwards compatible change could be something like
+adding a new method, adding an argument to an existing method (but not
+requiring it), or changing the type for an existing argument (but still
+handling the old type as well).
+
+The conversion over to a versioned API must be done on both the client side and
+server side of the API at the same time. However, as the code stands today,
+there can be both versioned and unversioned APIs implemented in the same code
+base.
+"""
+
+from openstack.common.rpc import common as rpc_common
+
+
+class RpcDispatcher(object):
+ """Dispatch rpc messages according to the requested API version.
+
+ This class can be used as the top level 'manager' for a service. It
+ contains a list of underlying managers that have an API_VERSION attribute.
+ """
+
+ def __init__(self, callbacks):
+ """Initialize the rpc dispatcher.
+
+ :param callbacks: List of proxy objects that are an instance
+ of a class with rpc methods exposed. Each proxy
+ object should have an RPC_API_VERSION attribute.
+ """
+ self.callbacks = callbacks
+ super(RpcDispatcher, self).__init__()
+
+ @staticmethod
+ def _is_compatible(mversion, version):
+ """Determine whether versions are compatible.
+
+ :param mversion: The API version implemented by a callback.
+ :param version: The API version requested by an incoming message.
+ """
+ version_parts = version.split('.')
+ mversion_parts = mversion.split('.')
+ if int(version_parts[0]) != int(mversion_parts[0]): # Major
+ return False
+ if int(version_parts[1]) > int(mversion_parts[1]): # Minor
+ return False
+ return True
+
+ def dispatch(self, ctxt, version, method, **kwargs):
+ """Dispatch a message based on a requested version.
+
+ :param ctxt: The request context
+ :param version: The requested API version from the incoming message
+ :param method: The method requested to be called by the incoming
+ message.
+ :param kwargs: A dict of keyword arguments to be passed to the method.
+
+ :returns: Whatever is returned by the underlying method that gets
+ called.
+ """
+ if not version:
+ version = '1.0'
+
+ for proxyobj in self.callbacks:
+ if hasattr(proxyobj, 'RPC_API_VERSION'):
+ rpc_api_version = proxyobj.RPC_API_VERSION
+ else:
+ rpc_api_version = '1.0'
+ if not hasattr(proxyobj, method):
+ continue
+ if self._is_compatible(rpc_api_version, version):
+ return getattr(proxyobj, method)(ctxt, **kwargs)
+
+ raise rpc_common.UnsupportedRpcVersion(version=version)
diff --git a/openstack/common/rpc/impl_fake.py b/openstack/common/rpc/impl_fake.py
new file mode 100644
index 0000000..fba20c9
--- /dev/null
+++ b/openstack/common/rpc/impl_fake.py
@@ -0,0 +1,184 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""Fake RPC implementation which calls proxy methods directly with no
+queues. Casts will block, but this is very useful for tests.
+"""
+
+import inspect
+import json
+import time
+
+import eventlet
+
+from openstack.common.rpc import common as rpc_common
+
+CONSUMERS = {}
+
+
+class RpcContext(rpc_common.CommonRpcContext):
+ def __init__(self, **kwargs):
+ super(RpcContext, self).__init__(**kwargs)
+ self._response = []
+ self._done = False
+
+ def deepcopy(self):
+ values = self.to_dict()
+ new_inst = self.__class__(**values)
+ new_inst._response = self._response
+ new_inst._done = self._done
+ return new_inst
+
+ def reply(self, reply=None, failure=None, ending=False):
+ if ending:
+ self._done = True
+ if not self._done:
+ self._response.append((reply, failure))
+
+
+class Consumer(object):
+ def __init__(self, topic, proxy):
+ self.topic = topic
+ self.proxy = proxy
+
+ def call(self, context, version, method, args, timeout):
+ done = eventlet.event.Event()
+
+ def _inner():
+ ctxt = RpcContext.from_dict(context.to_dict())
+ try:
+ rval = self.proxy.dispatch(context, version, method, **args)
+ res = []
+ # Caller might have called ctxt.reply() manually
+ for (reply, failure) in ctxt._response:
+ if failure:
+ raise failure[0], failure[1], failure[2]
+ res.append(reply)
+ # if ending not 'sent'...we might have more data to
+ # return from the function itself
+ if not ctxt._done:
+ if inspect.isgenerator(rval):
+ for val in rval:
+ res.append(val)
+ else:
+ res.append(rval)
+ done.send(res)
+ except Exception as e:
+ done.send_exception(e)
+
+ thread = eventlet.greenthread.spawn(_inner)
+
+ if timeout:
+ start_time = time.time()
+ while not done.ready():
+ eventlet.greenthread.sleep(1)
+ cur_time = time.time()
+ if (cur_time - start_time) > timeout:
+ thread.kill()
+ raise rpc_common.Timeout()
+
+ return done.wait()
+
+
+class Connection(object):
+ """Connection object."""
+
+ def __init__(self):
+ self.consumers = []
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ consumer = Consumer(topic, proxy)
+ self.consumers.append(consumer)
+ if topic not in CONSUMERS:
+ CONSUMERS[topic] = []
+ CONSUMERS[topic].append(consumer)
+
+ def close(self):
+ for consumer in self.consumers:
+ CONSUMERS[consumer.topic].remove(consumer)
+ self.consumers = []
+
+ def consume_in_thread(self):
+ pass
+
+
+def create_connection(conf, new=True):
+ """Create a connection"""
+ return Connection()
+
+
+def check_serialize(msg):
+ """Make sure a message intended for rpc can be serialized."""
+ json.dumps(msg)
+
+
+def multicall(conf, context, topic, msg, timeout=None):
+ """Make a call that returns multiple times."""
+
+ check_serialize(msg)
+
+ method = msg.get('method')
+ if not method:
+ return
+ args = msg.get('args', {})
+ version = msg.get('version', None)
+
+ try:
+ consumer = CONSUMERS[topic][0]
+ except (KeyError, IndexError):
+ return iter([None])
+ else:
+ return consumer.call(context, version, method, args, timeout)
+
+
+def call(conf, context, topic, msg, timeout=None):
+ """Sends a message on a topic and wait for a response."""
+ rv = multicall(conf, context, topic, msg, timeout)
+ # NOTE(vish): return the last result from the multicall
+ rv = list(rv)
+ if not rv:
+ return
+ return rv[-1]
+
+
+def cast(conf, context, topic, msg):
+ try:
+ call(conf, context, topic, msg)
+ except Exception:
+ pass
+
+
+def notify(conf, context, topic, msg):
+ check_serialize(msg)
+
+
+def cleanup():
+ pass
+
+
+def fanout_cast(conf, context, topic, msg):
+ """Cast to all consumers of a topic"""
+ check_serialize(msg)
+ method = msg.get('method')
+ if not method:
+ return
+ args = msg.get('args', {})
+ version = msg.get('version', None)
+
+ for consumer in CONSUMERS.get(topic, []):
+ try:
+ consumer.call(context, version, method, args, None)
+ except Exception:
+ pass
diff --git a/openstack/common/rpc/impl_kombu.py b/openstack/common/rpc/impl_kombu.py
new file mode 100644
index 0000000..c32497a
--- /dev/null
+++ b/openstack/common/rpc/impl_kombu.py
@@ -0,0 +1,758 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import functools
+import itertools
+import socket
+import ssl
+import sys
+import time
+import uuid
+
+import eventlet
+import greenlet
+import kombu
+import kombu.connection
+import kombu.entity
+import kombu.messaging
+
+from openstack.common import cfg
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common.rpc import common as rpc_common
+
+kombu_opts = [
+ cfg.StrOpt('kombu_ssl_version',
+ default='',
+ help='SSL version to use (valid only if SSL enabled)'),
+ cfg.StrOpt('kombu_ssl_keyfile',
+ default='',
+ help='SSL key file (valid only if SSL enabled)'),
+ cfg.StrOpt('kombu_ssl_certfile',
+ default='',
+ help='SSL cert file (valid only if SSL enabled)'),
+ cfg.StrOpt('kombu_ssl_ca_certs',
+ default='',
+ help=('SSL certification authority file '
+ '(valid only if SSL enabled)')),
+ cfg.StrOpt('rabbit_host',
+ default='localhost',
+ help='the RabbitMQ host'),
+ cfg.IntOpt('rabbit_port',
+ default=5672,
+ help='the RabbitMQ port'),
+ cfg.BoolOpt('rabbit_use_ssl',
+ default=False,
+ help='connect over SSL for RabbitMQ'),
+ cfg.StrOpt('rabbit_userid',
+ default='guest',
+ help='the RabbitMQ userid'),
+ cfg.StrOpt('rabbit_password',
+ default='guest',
+ help='the RabbitMQ password'),
+ cfg.StrOpt('rabbit_virtual_host',
+ default='/',
+ help='the RabbitMQ virtual host'),
+ cfg.IntOpt('rabbit_retry_interval',
+ default=1,
+ help='how frequently to retry connecting with RabbitMQ'),
+ cfg.IntOpt('rabbit_retry_backoff',
+ default=2,
+ help='how long to backoff for between retries when connecting '
+ 'to RabbitMQ'),
+ cfg.IntOpt('rabbit_max_retries',
+ default=0,
+ help='maximum retries with trying to connect to RabbitMQ '
+ '(the default of 0 implies an infinite retry count)'),
+ cfg.BoolOpt('rabbit_durable_queues',
+ default=False,
+ help='use durable queues in RabbitMQ'),
+
+ ]
+
+cfg.CONF.register_opts(kombu_opts)
+
+LOG = rpc_common.LOG
+
+
+class ConsumerBase(object):
+ """Consumer base class."""
+
+ def __init__(self, channel, callback, tag, **kwargs):
+ """Declare a queue on an amqp channel.
+
+ 'channel' is the amqp channel to use
+ 'callback' is the callback to call when messages are received
+ 'tag' is a unique ID for the consumer on the channel
+
+ queue name, exchange name, and other kombu options are
+ passed in here as a dictionary.
+ """
+ self.callback = callback
+ self.tag = str(tag)
+ self.kwargs = kwargs
+ self.queue = None
+ self.reconnect(channel)
+
+ def reconnect(self, channel):
+ """Re-declare the queue after a rabbit reconnect"""
+ self.channel = channel
+ self.kwargs['channel'] = channel
+ self.queue = kombu.entity.Queue(**self.kwargs)
+ self.queue.declare()
+
+ def consume(self, *args, **kwargs):
+ """Actually declare the consumer on the amqp channel. This will
+ start the flow of messages from the queue. Using the
+ Connection.iterconsume() iterator will process the messages,
+ calling the appropriate callback.
+
+ If a callback is specified in kwargs, use that. Otherwise,
+ use the callback passed during __init__()
+
+ If kwargs['nowait'] is True, then this call will block until
+ a message is read.
+
+ Messages will automatically be acked if the callback doesn't
+ raise an exception
+ """
+
+ options = {'consumer_tag': self.tag}
+ options['nowait'] = kwargs.get('nowait', False)
+ callback = kwargs.get('callback', self.callback)
+ if not callback:
+ raise ValueError("No callback defined")
+
+ def _callback(raw_message):
+ message = self.channel.message_to_python(raw_message)
+ try:
+ callback(message.payload)
+ message.ack()
+ except Exception:
+ LOG.exception(_("Failed to process message... skipping it."))
+
+ self.queue.consume(*args, callback=_callback, **options)
+
+ def cancel(self):
+ """Cancel the consuming from the queue, if it has started"""
+ try:
+ self.queue.cancel(self.tag)
+ except KeyError, e:
+ # NOTE(comstud): Kludge to get around a amqplib bug
+ if str(e) != "u'%s'" % self.tag:
+ raise
+ self.queue = None
+
+
+class DirectConsumer(ConsumerBase):
+ """Queue/consumer class for 'direct'"""
+
+ def __init__(self, conf, channel, msg_id, callback, tag, **kwargs):
+ """Init a 'direct' queue.
+
+ 'channel' is the amqp channel to use
+ 'msg_id' is the msg_id to listen on
+ 'callback' is the callback to call when messages are received
+ 'tag' is a unique ID for the consumer on the channel
+
+ Other kombu options may be passed
+ """
+ # Default options
+ options = {'durable': False,
+ 'auto_delete': True,
+ 'exclusive': True}
+ options.update(kwargs)
+ exchange = kombu.entity.Exchange(
+ name=msg_id,
+ type='direct',
+ durable=options['durable'],
+ auto_delete=options['auto_delete'])
+ super(DirectConsumer, self).__init__(
+ channel,
+ callback,
+ tag,
+ name=msg_id,
+ exchange=exchange,
+ routing_key=msg_id,
+ **options)
+
+
+class TopicConsumer(ConsumerBase):
+ """Consumer class for 'topic'"""
+
+ def __init__(self, conf, channel, topic, callback, tag, name=None,
+ **kwargs):
+ """Init a 'topic' queue.
+
+ :param channel: the amqp channel to use
+ :param topic: the topic to listen on
+ :paramtype topic: str
+ :param callback: the callback to call when messages are received
+ :param tag: a unique ID for the consumer on the channel
+ :param name: optional queue name, defaults to topic
+ :paramtype name: str
+
+ Other kombu options may be passed as keyword arguments
+ """
+ # Default options
+ options = {'durable': conf.rabbit_durable_queues,
+ 'auto_delete': False,
+ 'exclusive': False}
+ options.update(kwargs)
+ exchange = kombu.entity.Exchange(
+ name=conf.control_exchange,
+ type='topic',
+ durable=options['durable'],
+ auto_delete=options['auto_delete'])
+ super(TopicConsumer, self).__init__(
+ channel,
+ callback,
+ tag,
+ name=name or topic,
+ exchange=exchange,
+ routing_key=topic,
+ **options)
+
+
+class FanoutConsumer(ConsumerBase):
+ """Consumer class for 'fanout'"""
+
+ def __init__(self, conf, channel, topic, callback, tag, **kwargs):
+ """Init a 'fanout' queue.
+
+ 'channel' is the amqp channel to use
+ 'topic' is the topic to listen on
+ 'callback' is the callback to call when messages are received
+ 'tag' is a unique ID for the consumer on the channel
+
+ Other kombu options may be passed
+ """
+ unique = uuid.uuid4().hex
+ exchange_name = '%s_fanout' % topic
+ queue_name = '%s_fanout_%s' % (topic, unique)
+
+ # Default options
+ options = {'durable': False,
+ 'auto_delete': True,
+ 'exclusive': True}
+ options.update(kwargs)
+ exchange = kombu.entity.Exchange(
+ name=exchange_name,
+ type='fanout',
+ durable=options['durable'],
+ auto_delete=options['auto_delete'])
+ super(FanoutConsumer, self).__init__(
+ channel,
+ callback,
+ tag,
+ name=queue_name,
+ exchange=exchange,
+ routing_key=topic,
+ **options)
+
+
+class Publisher(object):
+ """Base Publisher class"""
+
+ def __init__(self, channel, exchange_name, routing_key, **kwargs):
+ """Init the Publisher class with the exchange_name, routing_key,
+ and other options
+ """
+ self.exchange_name = exchange_name
+ self.routing_key = routing_key
+ self.kwargs = kwargs
+ self.reconnect(channel)
+
+ def reconnect(self, channel):
+ """Re-establish the Producer after a rabbit reconnection"""
+ self.exchange = kombu.entity.Exchange(name=self.exchange_name,
+ **self.kwargs)
+ self.producer = kombu.messaging.Producer(exchange=self.exchange,
+ channel=channel, routing_key=self.routing_key)
+
+ def send(self, msg):
+ """Send a message"""
+ self.producer.publish(msg)
+
+
+class DirectPublisher(Publisher):
+ """Publisher class for 'direct'"""
+ def __init__(self, conf, channel, msg_id, **kwargs):
+ """init a 'direct' publisher.
+
+ Kombu options may be passed as keyword args to override defaults
+ """
+
+ options = {'durable': False,
+ 'auto_delete': True,
+ 'exclusive': True}
+ options.update(kwargs)
+ super(DirectPublisher, self).__init__(channel,
+ msg_id,
+ msg_id,
+ type='direct',
+ **options)
+
+
+class TopicPublisher(Publisher):
+ """Publisher class for 'topic'"""
+ def __init__(self, conf, channel, topic, **kwargs):
+ """init a 'topic' publisher.
+
+ Kombu options may be passed as keyword args to override defaults
+ """
+ options = {'durable': conf.rabbit_durable_queues,
+ 'auto_delete': False,
+ 'exclusive': False}
+ options.update(kwargs)
+ super(TopicPublisher, self).__init__(channel,
+ conf.control_exchange,
+ topic,
+ type='topic',
+ **options)
+
+
+class FanoutPublisher(Publisher):
+ """Publisher class for 'fanout'"""
+ def __init__(self, conf, channel, topic, **kwargs):
+ """init a 'fanout' publisher.
+
+ Kombu options may be passed as keyword args to override defaults
+ """
+ options = {'durable': False,
+ 'auto_delete': True,
+ 'exclusive': True}
+ options.update(kwargs)
+ super(FanoutPublisher, self).__init__(channel,
+ '%s_fanout' % topic,
+ None,
+ type='fanout',
+ **options)
+
+
+class NotifyPublisher(TopicPublisher):
+ """Publisher class for 'notify'"""
+
+ def __init__(self, conf, channel, topic, **kwargs):
+ self.durable = kwargs.pop('durable', conf.rabbit_durable_queues)
+ super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs)
+
+ def reconnect(self, channel):
+ super(NotifyPublisher, self).reconnect(channel)
+
+ # NOTE(jerdfelt): Normally the consumer would create the queue, but
+ # we do this to ensure that messages don't get dropped if the
+ # consumer is started after we do
+ queue = kombu.entity.Queue(channel=channel,
+ exchange=self.exchange,
+ durable=self.durable,
+ name=self.routing_key,
+ routing_key=self.routing_key)
+ queue.declare()
+
+
+class Connection(object):
+ """Connection object."""
+
+ pool = None
+
+ def __init__(self, conf, server_params=None):
+ self.consumers = []
+ self.consumer_thread = None
+ self.conf = conf
+ self.max_retries = self.conf.rabbit_max_retries
+ # Try forever?
+ if self.max_retries <= 0:
+ self.max_retries = None
+ self.interval_start = self.conf.rabbit_retry_interval
+ self.interval_stepping = self.conf.rabbit_retry_backoff
+ # max retry-interval = 30 seconds
+ self.interval_max = 30
+ self.memory_transport = False
+
+ if server_params is None:
+ server_params = {}
+
+ # Keys to translate from server_params to kombu params
+ server_params_to_kombu_params = {'username': 'userid'}
+
+ params = {}
+ for sp_key, value in server_params.iteritems():
+ p_key = server_params_to_kombu_params.get(sp_key, sp_key)
+ params[p_key] = value
+
+ params.setdefault('hostname', self.conf.rabbit_host)
+ params.setdefault('port', self.conf.rabbit_port)
+ params.setdefault('userid', self.conf.rabbit_userid)
+ params.setdefault('password', self.conf.rabbit_password)
+ params.setdefault('virtual_host', self.conf.rabbit_virtual_host)
+
+ self.params = params
+
+ if self.conf.fake_rabbit:
+ self.params['transport'] = 'memory'
+ self.memory_transport = True
+ else:
+ self.memory_transport = False
+
+ if self.conf.rabbit_use_ssl:
+ self.params['ssl'] = self._fetch_ssl_params()
+
+ self.connection = None
+ self.reconnect()
+
+ def _fetch_ssl_params(self):
+ """Handles fetching what ssl params
+ should be used for the connection (if any)"""
+ ssl_params = dict()
+
+ # http://docs.python.org/library/ssl.html - ssl.wrap_socket
+ if self.conf.kombu_ssl_version:
+ ssl_params['ssl_version'] = self.conf.kombu_ssl_version
+ if self.conf.kombu_ssl_keyfile:
+ ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile
+ if self.conf.kombu_ssl_certfile:
+ ssl_params['certfile'] = self.conf.kombu_ssl_certfile
+ if self.conf.kombu_ssl_ca_certs:
+ ssl_params['ca_certs'] = self.conf.kombu_ssl_ca_certs
+ # We might want to allow variations in the
+ # future with this?
+ ssl_params['cert_reqs'] = ssl.CERT_REQUIRED
+
+ if not ssl_params:
+ # Just have the default behavior
+ return True
+ else:
+ # Return the extended behavior
+ return ssl_params
+
+ def _connect(self):
+ """Connect to rabbit. Re-establish any queues that may have
+ been declared before if we are reconnecting. Exceptions should
+ be handled by the caller.
+ """
+ if self.connection:
+ LOG.info(_("Reconnecting to AMQP server on "
+ "%(hostname)s:%(port)d") % self.params)
+ try:
+ self.connection.close()
+ except self.connection_errors:
+ pass
+ # Setting this in case the next statement fails, though
+ # it shouldn't be doing any network operations, yet.
+ self.connection = None
+ self.connection = kombu.connection.BrokerConnection(
+ **self.params)
+ self.connection_errors = self.connection.connection_errors
+ if self.memory_transport:
+ # Kludge to speed up tests.
+ self.connection.transport.polling_interval = 0.0
+ self.consumer_num = itertools.count(1)
+ self.connection.connect()
+ self.channel = self.connection.channel()
+ # work around 'memory' transport bug in 1.1.3
+ if self.memory_transport:
+ self.channel._new_queue('ae.undeliver')
+ for consumer in self.consumers:
+ consumer.reconnect(self.channel)
+ LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d'),
+ self.params)
+
+ def reconnect(self):
+ """Handles reconnecting and re-establishing queues.
+ Will retry up to self.max_retries number of times.
+ self.max_retries = 0 means to retry forever.
+ Sleep between tries, starting at self.interval_start
+ seconds, backing off self.interval_stepping number of seconds
+ each attempt.
+ """
+
+ attempt = 0
+ while True:
+ attempt += 1
+ try:
+ self._connect()
+ return
+ except (self.connection_errors, IOError), e:
+ pass
+ except Exception, e:
+ # NOTE(comstud): Unfortunately it's possible for amqplib
+ # to return an error not covered by its transport
+ # connection_errors in the case of a timeout waiting for
+ # a protocol response. (See paste link in LP888621)
+ # So, we check all exceptions for 'timeout' in them
+ # and try to reconnect in this case.
+ if 'timeout' not in str(e):
+ raise
+
+ log_info = {}
+ log_info['err_str'] = str(e)
+ log_info['max_retries'] = self.max_retries
+ log_info.update(self.params)
+
+ if self.max_retries and attempt == self.max_retries:
+ LOG.exception(_('Unable to connect to AMQP server on '
+ '%(hostname)s:%(port)d after %(max_retries)d '
+ 'tries: %(err_str)s') % log_info)
+ # NOTE(comstud): Copied from original code. There's
+ # really no better recourse because if this was a queue we
+ # need to consume on, we have no way to consume anymore.
+ sys.exit(1)
+
+ if attempt == 1:
+ sleep_time = self.interval_start or 1
+ elif attempt > 1:
+ sleep_time += self.interval_stepping
+ if self.interval_max:
+ sleep_time = min(sleep_time, self.interval_max)
+
+ log_info['sleep_time'] = sleep_time
+ LOG.exception(_('AMQP server on %(hostname)s:%(port)d is'
+ ' unreachable: %(err_str)s. Trying again in '
+ '%(sleep_time)d seconds.') % log_info)
+ time.sleep(sleep_time)
+
+ def ensure(self, error_callback, method, *args, **kwargs):
+ while True:
+ try:
+ return method(*args, **kwargs)
+ except (self.connection_errors, socket.timeout, IOError), e:
+ pass
+ except Exception, e:
+ # NOTE(comstud): Unfortunately it's possible for amqplib
+ # to return an error not covered by its transport
+ # connection_errors in the case of a timeout waiting for
+ # a protocol response. (See paste link in LP888621)
+ # So, we check all exceptions for 'timeout' in them
+ # and try to reconnect in this case.
+ if 'timeout' not in str(e):
+ raise
+ if error_callback:
+ error_callback(e)
+ self.reconnect()
+
+ def get_channel(self):
+ """Convenience call for bin/clear_rabbit_queues"""
+ return self.channel
+
+ def close(self):
+ """Close/release this connection"""
+ self.cancel_consumer_thread()
+ self.connection.release()
+ self.connection = None
+
+ def reset(self):
+ """Reset a connection so it can be used again"""
+ self.cancel_consumer_thread()
+ self.channel.close()
+ self.channel = self.connection.channel()
+ # work around 'memory' transport bug in 1.1.3
+ if self.memory_transport:
+ self.channel._new_queue('ae.undeliver')
+ self.consumers = []
+
+ def declare_consumer(self, consumer_cls, topic, callback):
+ """Create a Consumer using the class that was passed in and
+ add it to our list of consumers
+ """
+
+ def _connect_error(exc):
+ log_info = {'topic': topic, 'err_str': str(exc)}
+ LOG.error(_("Failed to declare consumer for topic '%(topic)s': "
+ "%(err_str)s") % log_info)
+
+ def _declare_consumer():
+ consumer = consumer_cls(self.conf, self.channel, topic, callback,
+ self.consumer_num.next())
+ self.consumers.append(consumer)
+ return consumer
+
+ return self.ensure(_connect_error, _declare_consumer)
+
+ def iterconsume(self, limit=None, timeout=None):
+ """Return an iterator that will consume from all queues/consumers"""
+
+ info = {'do_consume': True}
+
+ def _error_callback(exc):
+ if isinstance(exc, socket.timeout):
+ LOG.exception(_('Timed out waiting for RPC response: %s') %
+ str(exc))
+ raise rpc_common.Timeout()
+ else:
+ LOG.exception(_('Failed to consume message from queue: %s') %
+ str(exc))
+ info['do_consume'] = True
+
+ def _consume():
+ if info['do_consume']:
+ queues_head = self.consumers[:-1]
+ queues_tail = self.consumers[-1]
+ for queue in queues_head:
+ queue.consume(nowait=True)
+ queues_tail.consume(nowait=False)
+ info['do_consume'] = False
+ return self.connection.drain_events(timeout=timeout)
+
+ for iteration in itertools.count(0):
+ if limit and iteration >= limit:
+ raise StopIteration
+ yield self.ensure(_error_callback, _consume)
+
+ def cancel_consumer_thread(self):
+ """Cancel a consumer thread"""
+ if self.consumer_thread is not None:
+ self.consumer_thread.kill()
+ try:
+ self.consumer_thread.wait()
+ except greenlet.GreenletExit:
+ pass
+ self.consumer_thread = None
+
+ def publisher_send(self, cls, topic, msg, **kwargs):
+ """Send to a publisher based on the publisher class"""
+
+ def _error_callback(exc):
+ log_info = {'topic': topic, 'err_str': str(exc)}
+ LOG.exception(_("Failed to publish message to topic "
+ "'%(topic)s': %(err_str)s") % log_info)
+
+ def _publish():
+ publisher = cls(self.conf, self.channel, topic, **kwargs)
+ publisher.send(msg)
+
+ self.ensure(_error_callback, _publish)
+
+ def declare_direct_consumer(self, topic, callback):
+ """Create a 'direct' queue.
+ In nova's use, this is generally a msg_id queue used for
+ responses for call/multicall
+ """
+ self.declare_consumer(DirectConsumer, topic, callback)
+
+ def declare_topic_consumer(self, topic, callback=None, queue_name=None):
+ """Create a 'topic' consumer."""
+ self.declare_consumer(functools.partial(TopicConsumer,
+ name=queue_name,
+ ),
+ topic, callback)
+
+ def declare_fanout_consumer(self, topic, callback):
+ """Create a 'fanout' consumer"""
+ self.declare_consumer(FanoutConsumer, topic, callback)
+
+ def direct_send(self, msg_id, msg):
+ """Send a 'direct' message"""
+ self.publisher_send(DirectPublisher, msg_id, msg)
+
+ def topic_send(self, topic, msg):
+ """Send a 'topic' message"""
+ self.publisher_send(TopicPublisher, topic, msg)
+
+ def fanout_send(self, topic, msg):
+ """Send a 'fanout' message"""
+ self.publisher_send(FanoutPublisher, topic, msg)
+
+ def notify_send(self, topic, msg, **kwargs):
+ """Send a notify message on a topic"""
+ self.publisher_send(NotifyPublisher, topic, msg, **kwargs)
+
+ def consume(self, limit=None):
+ """Consume from all queues/consumers"""
+ it = self.iterconsume(limit=limit)
+ while True:
+ try:
+ it.next()
+ except StopIteration:
+ return
+
+ def consume_in_thread(self):
+ """Consumer from all queues/consumers in a greenthread"""
+ def _consumer_thread():
+ try:
+ self.consume()
+ except greenlet.GreenletExit:
+ return
+ if self.consumer_thread is None:
+ self.consumer_thread = eventlet.spawn(_consumer_thread)
+ return self.consumer_thread
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ """Create a consumer that calls a method in a proxy object"""
+ proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy,
+ rpc_amqp.get_connection_pool(self.conf, Connection))
+
+ if fanout:
+ self.declare_fanout_consumer(topic, proxy_cb)
+ else:
+ self.declare_topic_consumer(topic, proxy_cb)
+
+ def create_worker(self, topic, proxy, pool_name):
+ """Create a worker that calls a method in a proxy object"""
+ proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy,
+ rpc_amqp.get_connection_pool(self.conf, Connection))
+ self.declare_topic_consumer(topic, proxy_cb, pool_name)
+
+
+def create_connection(conf, new=True):
+ """Create a connection"""
+ return rpc_amqp.create_connection(conf, new,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def multicall(conf, context, topic, msg, timeout=None):
+ """Make a call that returns multiple times."""
+ return rpc_amqp.multicall(conf, context, topic, msg, timeout,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def call(conf, context, topic, msg, timeout=None):
+ """Sends a message on a topic and wait for a response."""
+ return rpc_amqp.call(conf, context, topic, msg, timeout,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cast(conf, context, topic, msg):
+ """Sends a message on a topic without waiting for a response."""
+ return rpc_amqp.cast(conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def fanout_cast(conf, context, topic, msg):
+ """Sends a message on a fanout exchange without waiting for a response."""
+ return rpc_amqp.fanout_cast(conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cast_to_server(conf, context, server_params, topic, msg):
+ """Sends a message on a topic to a specific server."""
+ return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def fanout_cast_to_server(conf, context, server_params, topic, msg):
+ """Sends a message on a fanout exchange to a specific server."""
+ return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def notify(conf, context, topic, msg):
+ """Sends a notification event on a topic."""
+ return rpc_amqp.notify(conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cleanup():
+ return rpc_amqp.cleanup(Connection.pool)
diff --git a/openstack/common/rpc/impl_qpid.py b/openstack/common/rpc/impl_qpid.py
new file mode 100644
index 0000000..3c46309
--- /dev/null
+++ b/openstack/common/rpc/impl_qpid.py
@@ -0,0 +1,580 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack LLC
+# Copyright 2011 - 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import functools
+import itertools
+import json
+import logging
+import time
+import uuid
+
+import eventlet
+import greenlet
+import qpid.messaging
+import qpid.messaging.exceptions
+
+from openstack.common import cfg
+from openstack.common.gettextutils import _
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common.rpc import common as rpc_common
+
+LOG = logging.getLogger(__name__)
+
+qpid_opts = [
+ cfg.StrOpt('qpid_hostname',
+ default='localhost',
+ help='Qpid broker hostname'),
+ cfg.StrOpt('qpid_port',
+ default='5672',
+ help='Qpid broker port'),
+ cfg.StrOpt('qpid_username',
+ default='',
+ help='Username for qpid connection'),
+ cfg.StrOpt('qpid_password',
+ default='',
+ help='Password for qpid connection'),
+ cfg.StrOpt('qpid_sasl_mechanisms',
+ default='',
+ help='Space separated list of SASL mechanisms to use for auth'),
+ cfg.BoolOpt('qpid_reconnect',
+ default=True,
+ help='Automatically reconnect'),
+ cfg.IntOpt('qpid_reconnect_timeout',
+ default=0,
+ help='Reconnection timeout in seconds'),
+ cfg.IntOpt('qpid_reconnect_limit',
+ default=0,
+ help='Max reconnections before giving up'),
+ cfg.IntOpt('qpid_reconnect_interval_min',
+ default=0,
+ help='Minimum seconds between reconnection attempts'),
+ cfg.IntOpt('qpid_reconnect_interval_max',
+ default=0,
+ help='Maximum seconds between reconnection attempts'),
+ cfg.IntOpt('qpid_reconnect_interval',
+ default=0,
+ help='Equivalent to setting max and min to the same value'),
+ cfg.IntOpt('qpid_heartbeat',
+ default=5,
+ help='Seconds between connection keepalive heartbeats'),
+ cfg.StrOpt('qpid_protocol',
+ default='tcp',
+ help="Transport to use, either 'tcp' or 'ssl'"),
+ cfg.BoolOpt('qpid_tcp_nodelay',
+ default=True,
+ help='Disable Nagle algorithm'),
+ ]
+
+cfg.CONF.register_opts(qpid_opts)
+
+
+class ConsumerBase(object):
+ """Consumer base class."""
+
+ def __init__(self, session, callback, node_name, node_opts,
+ link_name, link_opts):
+ """Declare a queue on an amqp session.
+
+ 'session' is the amqp session to use
+ 'callback' is the callback to call when messages are received
+ 'node_name' is the first part of the Qpid address string, before ';'
+ 'node_opts' will be applied to the "x-declare" section of "node"
+ in the address string.
+ 'link_name' goes into the "name" field of the "link" in the address
+ string
+ 'link_opts' will be applied to the "x-declare" section of "link"
+ in the address string.
+ """
+ self.callback = callback
+ self.receiver = None
+ self.session = None
+
+ addr_opts = {
+ "create": "always",
+ "node": {
+ "type": "topic",
+ "x-declare": {
+ "durable": True,
+ "auto-delete": True,
+ },
+ },
+ "link": {
+ "name": link_name,
+ "durable": True,
+ "x-declare": {
+ "durable": False,
+ "auto-delete": True,
+ "exclusive": False,
+ },
+ },
+ }
+ addr_opts["node"]["x-declare"].update(node_opts)
+ addr_opts["link"]["x-declare"].update(link_opts)
+
+ self.address = "%s ; %s" % (node_name, json.dumps(addr_opts))
+
+ self.reconnect(session)
+
+ def reconnect(self, session):
+ """Re-declare the receiver after a qpid reconnect"""
+ self.session = session
+ self.receiver = session.receiver(self.address)
+ self.receiver.capacity = 1
+
+ def consume(self):
+ """Fetch the message and pass it to the callback object"""
+ message = self.receiver.fetch()
+ self.callback(message.content)
+
+ def get_receiver(self):
+ return self.receiver
+
+
+class DirectConsumer(ConsumerBase):
+ """Queue/consumer class for 'direct'"""
+
+ def __init__(self, conf, session, msg_id, callback):
+ """Init a 'direct' queue.
+
+ 'session' is the amqp session to use
+ 'msg_id' is the msg_id to listen on
+ 'callback' is the callback to call when messages are received
+ """
+
+ super(DirectConsumer, self).__init__(session, callback,
+ "%s/%s" % (msg_id, msg_id),
+ {"type": "direct"},
+ msg_id,
+ {"exclusive": True})
+
+
+class TopicConsumer(ConsumerBase):
+ """Consumer class for 'topic'"""
+
+ def __init__(self, conf, session, topic, callback, name=None):
+ """Init a 'topic' queue.
+
+ :param session: the amqp session to use
+ :param topic: is the topic to listen on
+ :paramtype topic: str
+ :param callback: the callback to call when messages are received
+ :param name: optional queue name, defaults to topic
+ """
+
+ super(TopicConsumer, self).__init__(session, callback,
+ "%s/%s" % (conf.control_exchange, topic), {},
+ name or topic, {})
+
+
+class FanoutConsumer(ConsumerBase):
+ """Consumer class for 'fanout'"""
+
+ def __init__(self, conf, session, topic, callback):
+ """Init a 'fanout' queue.
+
+ 'session' is the amqp session to use
+ 'topic' is the topic to listen on
+ 'callback' is the callback to call when messages are received
+ """
+
+ super(FanoutConsumer, self).__init__(session, callback,
+ "%s_fanout" % topic,
+ {"durable": False, "type": "fanout"},
+ "%s_fanout_%s" % (topic, uuid.uuid4().hex),
+ {"exclusive": True})
+
+
+class Publisher(object):
+ """Base Publisher class"""
+
+ def __init__(self, session, node_name, node_opts=None):
+ """Init the Publisher class with the exchange_name, routing_key,
+ and other options
+ """
+ self.sender = None
+ self.session = session
+
+ addr_opts = {
+ "create": "always",
+ "node": {
+ "type": "topic",
+ "x-declare": {
+ "durable": False,
+ # auto-delete isn't implemented for exchanges in qpid,
+ # but put in here anyway
+ "auto-delete": True,
+ },
+ },
+ }
+ if node_opts:
+ addr_opts["node"]["x-declare"].update(node_opts)
+
+ self.address = "%s ; %s" % (node_name, json.dumps(addr_opts))
+
+ self.reconnect(session)
+
+ def reconnect(self, session):
+ """Re-establish the Sender after a reconnection"""
+ self.sender = session.sender(self.address)
+
+ def send(self, msg):
+ """Send a message"""
+ self.sender.send(msg)
+
+
+class DirectPublisher(Publisher):
+ """Publisher class for 'direct'"""
+ def __init__(self, conf, session, msg_id):
+ """Init a 'direct' publisher."""
+ super(DirectPublisher, self).__init__(session, msg_id,
+ {"type": "Direct"})
+
+
+class TopicPublisher(Publisher):
+ """Publisher class for 'topic'"""
+ def __init__(self, conf, session, topic):
+ """init a 'topic' publisher.
+ """
+ super(TopicPublisher, self).__init__(session,
+ "%s/%s" % (conf.control_exchange, topic))
+
+
+class FanoutPublisher(Publisher):
+ """Publisher class for 'fanout'"""
+ def __init__(self, conf, session, topic):
+ """init a 'fanout' publisher.
+ """
+ super(FanoutPublisher, self).__init__(session,
+ "%s_fanout" % topic, {"type": "fanout"})
+
+
+class NotifyPublisher(Publisher):
+ """Publisher class for notifications"""
+ def __init__(self, conf, session, topic):
+ """init a 'topic' publisher.
+ """
+ super(NotifyPublisher, self).__init__(session,
+ "%s/%s" % (conf.control_exchange, topic),
+ {"durable": True})
+
+
+class Connection(object):
+ """Connection object."""
+
+ pool = None
+
+ def __init__(self, conf, server_params=None):
+ self.session = None
+ self.consumers = {}
+ self.consumer_thread = None
+ self.conf = conf
+
+ if server_params is None:
+ server_params = {}
+
+ default_params = dict(hostname=self.conf.qpid_hostname,
+ port=self.conf.qpid_port,
+ username=self.conf.qpid_username,
+ password=self.conf.qpid_password)
+
+ params = server_params
+ for key in default_params.keys():
+ params.setdefault(key, default_params[key])
+
+ self.broker = params['hostname'] + ":" + str(params['port'])
+ # Create the connection - this does not open the connection
+ self.connection = qpid.messaging.Connection(self.broker)
+
+ # Check if flags are set and if so set them for the connection
+ # before we call open
+ self.connection.username = params['username']
+ self.connection.password = params['password']
+ self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms
+ self.connection.reconnect = self.conf.qpid_reconnect
+ if self.conf.qpid_reconnect_timeout:
+ self.connection.reconnect_timeout = (
+ self.conf.qpid_reconnect_timeout)
+ if self.conf.qpid_reconnect_limit:
+ self.connection.reconnect_limit = self.conf.qpid_reconnect_limit
+ if self.conf.qpid_reconnect_interval_max:
+ self.connection.reconnect_interval_max = (
+ self.conf.qpid_reconnect_interval_max)
+ if self.conf.qpid_reconnect_interval_min:
+ self.connection.reconnect_interval_min = (
+ self.conf.qpid_reconnect_interval_min)
+ if self.conf.qpid_reconnect_interval:
+ self.connection.reconnect_interval = (
+ self.conf.qpid_reconnect_interval)
+ self.connection.hearbeat = self.conf.qpid_heartbeat
+ self.connection.protocol = self.conf.qpid_protocol
+ self.connection.tcp_nodelay = self.conf.qpid_tcp_nodelay
+
+ # Open is part of reconnect -
+ # NOTE(WGH) not sure we need this with the reconnect flags
+ self.reconnect()
+
+ def _register_consumer(self, consumer):
+ self.consumers[str(consumer.get_receiver())] = consumer
+
+ def _lookup_consumer(self, receiver):
+ return self.consumers[str(receiver)]
+
+ def reconnect(self):
+ """Handles reconnecting and re-establishing sessions and queues"""
+ if self.connection.opened():
+ try:
+ self.connection.close()
+ except qpid.messaging.exceptions.ConnectionError:
+ pass
+
+ while True:
+ try:
+ self.connection.open()
+ except qpid.messaging.exceptions.ConnectionError, e:
+ LOG.error(_('Unable to connect to AMQP server: %s'), e)
+ time.sleep(self.conf.qpid_reconnect_interval or 1)
+ else:
+ break
+
+ LOG.info(_('Connected to AMQP server on %s'), self.broker)
+
+ self.session = self.connection.session()
+
+ for consumer in self.consumers.itervalues():
+ consumer.reconnect(self.session)
+
+ if self.consumers:
+ LOG.debug(_("Re-established AMQP queues"))
+
+ def ensure(self, error_callback, method, *args, **kwargs):
+ while True:
+ try:
+ return method(*args, **kwargs)
+ except (qpid.messaging.exceptions.Empty,
+ qpid.messaging.exceptions.ConnectionError), e:
+ if error_callback:
+ error_callback(e)
+ self.reconnect()
+
+ def close(self):
+ """Close/release this connection"""
+ self.cancel_consumer_thread()
+ self.connection.close()
+ self.connection = None
+
+ def reset(self):
+ """Reset a connection so it can be used again"""
+ self.cancel_consumer_thread()
+ self.session.close()
+ self.session = self.connection.session()
+ self.consumers = {}
+
+ def declare_consumer(self, consumer_cls, topic, callback):
+ """Create a Consumer using the class that was passed in and
+ add it to our list of consumers
+ """
+ def _connect_error(exc):
+ log_info = {'topic': topic, 'err_str': str(exc)}
+ LOG.error(_("Failed to declare consumer for topic '%(topic)s': "
+ "%(err_str)s") % log_info)
+
+ def _declare_consumer():
+ consumer = consumer_cls(self.conf, self.session, topic, callback)
+ self._register_consumer(consumer)
+ return consumer
+
+ return self.ensure(_connect_error, _declare_consumer)
+
+ def iterconsume(self, limit=None, timeout=None):
+ """Return an iterator that will consume from all queues/consumers"""
+
+ def _error_callback(exc):
+ if isinstance(exc, qpid.messaging.exceptions.Empty):
+ LOG.exception(_('Timed out waiting for RPC response: %s') %
+ str(exc))
+ raise rpc_common.Timeout()
+ else:
+ LOG.exception(_('Failed to consume message from queue: %s') %
+ str(exc))
+
+ def _consume():
+ nxt_receiver = self.session.next_receiver(timeout=timeout)
+ try:
+ self._lookup_consumer(nxt_receiver).consume()
+ except Exception:
+ LOG.exception(_("Error processing message. Skipping it."))
+
+ for iteration in itertools.count(0):
+ if limit and iteration >= limit:
+ raise StopIteration
+ yield self.ensure(_error_callback, _consume)
+
+ def cancel_consumer_thread(self):
+ """Cancel a consumer thread"""
+ if self.consumer_thread is not None:
+ self.consumer_thread.kill()
+ try:
+ self.consumer_thread.wait()
+ except greenlet.GreenletExit:
+ pass
+ self.consumer_thread = None
+
+ def publisher_send(self, cls, topic, msg):
+ """Send to a publisher based on the publisher class"""
+
+ def _connect_error(exc):
+ log_info = {'topic': topic, 'err_str': str(exc)}
+ LOG.exception(_("Failed to publish message to topic "
+ "'%(topic)s': %(err_str)s") % log_info)
+
+ def _publisher_send():
+ publisher = cls(self.conf, self.session, topic)
+ publisher.send(msg)
+
+ return self.ensure(_connect_error, _publisher_send)
+
+ def declare_direct_consumer(self, topic, callback):
+ """Create a 'direct' queue.
+ In nova's use, this is generally a msg_id queue used for
+ responses for call/multicall
+ """
+ self.declare_consumer(DirectConsumer, topic, callback)
+
+ def declare_topic_consumer(self, topic, callback=None, queue_name=None):
+ """Create a 'topic' consumer."""
+ self.declare_consumer(functools.partial(TopicConsumer,
+ name=queue_name,
+ ),
+ topic, callback)
+
+ def declare_fanout_consumer(self, topic, callback):
+ """Create a 'fanout' consumer"""
+ self.declare_consumer(FanoutConsumer, topic, callback)
+
+ def direct_send(self, msg_id, msg):
+ """Send a 'direct' message"""
+ self.publisher_send(DirectPublisher, msg_id, msg)
+
+ def topic_send(self, topic, msg):
+ """Send a 'topic' message"""
+ self.publisher_send(TopicPublisher, topic, msg)
+
+ def fanout_send(self, topic, msg):
+ """Send a 'fanout' message"""
+ self.publisher_send(FanoutPublisher, topic, msg)
+
+ def notify_send(self, topic, msg, **kwargs):
+ """Send a notify message on a topic"""
+ self.publisher_send(NotifyPublisher, topic, msg)
+
+ def consume(self, limit=None):
+ """Consume from all queues/consumers"""
+ it = self.iterconsume(limit=limit)
+ while True:
+ try:
+ it.next()
+ except StopIteration:
+ return
+
+ def consume_in_thread(self):
+ """Consumer from all queues/consumers in a greenthread"""
+ def _consumer_thread():
+ try:
+ self.consume()
+ except greenlet.GreenletExit:
+ return
+ if self.consumer_thread is None:
+ self.consumer_thread = eventlet.spawn(_consumer_thread)
+ return self.consumer_thread
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ """Create a consumer that calls a method in a proxy object"""
+ proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy,
+ rpc_amqp.get_connection_pool(self.conf, Connection))
+
+ if fanout:
+ consumer = FanoutConsumer(self.conf, self.session, topic, proxy_cb)
+ else:
+ consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb)
+
+ self._register_consumer(consumer)
+
+ return consumer
+
+ def create_worker(self, topic, proxy, pool_name):
+ """Create a worker that calls a method in a proxy object"""
+ proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy,
+ rpc_amqp.get_connection_pool(self.conf, Connection))
+
+ consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb,
+ name=pool_name)
+
+ self._register_consumer(consumer)
+
+ return consumer
+
+
+def create_connection(conf, new=True):
+ """Create a connection"""
+ return rpc_amqp.create_connection(conf, new,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def multicall(conf, context, topic, msg, timeout=None):
+ """Make a call that returns multiple times."""
+ return rpc_amqp.multicall(conf, context, topic, msg, timeout,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def call(conf, context, topic, msg, timeout=None):
+ """Sends a message on a topic and wait for a response."""
+ return rpc_amqp.call(conf, context, topic, msg, timeout,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cast(conf, context, topic, msg):
+ """Sends a message on a topic without waiting for a response."""
+ return rpc_amqp.cast(conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def fanout_cast(conf, context, topic, msg):
+ """Sends a message on a fanout exchange without waiting for a response."""
+ return rpc_amqp.fanout_cast(conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cast_to_server(conf, context, server_params, topic, msg):
+ """Sends a message on a topic to a specific server."""
+ return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def fanout_cast_to_server(conf, context, server_params, topic, msg):
+ """Sends a message on a fanout exchange to a specific server."""
+ return rpc_amqp.fanout_cast_to_server(conf, context, server_params, topic,
+ msg, rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def notify(conf, context, topic, msg):
+ """Sends a notification event on a topic."""
+ return rpc_amqp.notify(conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cleanup():
+ return rpc_amqp.cleanup(Connection.pool)
diff --git a/openstack/common/rpc/matchmaker.py b/openstack/common/rpc/matchmaker.py
new file mode 100644
index 0000000..4da1dcd
--- /dev/null
+++ b/openstack/common/rpc/matchmaker.py
@@ -0,0 +1,257 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 Cloudscaling Group, Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+The MatchMaker classes should except a Topic or Fanout exchange key and
+return keys for direct exchanges, per (approximate) AMQP parlance.
+"""
+
+import contextlib
+import itertools
+import json
+import logging
+
+from openstack.common import cfg
+
+
+matchmaker_opts = [
+ # Matchmaker ring file
+ cfg.StrOpt('matchmaker_ringfile',
+ default='/etc/nova/matchmaker_ring.json',
+ help='Matchmaker ring file (JSON)'),
+]
+
+CONF = cfg.CONF
+CONF.register_opts(matchmaker_opts)
+LOG = logging.getLogger(__name__)
+contextmanager = contextlib.contextmanager
+
+
+class MatchMakerException(Exception):
+ """Signified a match could not be found."""
+ message = _("Match not found by MatchMaker.")
+
+
+class Exchange(object):
+ """
+ Implements lookups.
+ Subclass this to support hashtables, dns, etc.
+ """
+ def __init__(self):
+ pass
+
+ def run(self, key):
+ raise NotImplementedError()
+
+
+class Binding(object):
+ """
+ A binding on which to perform a lookup.
+ """
+ def __init__(self):
+ pass
+
+ def test(self, key):
+ raise NotImplementedError()
+
+
+class MatchMakerBase(object):
+ """Match Maker Base Class."""
+
+ def __init__(self):
+ # Array of tuples. Index [2] toggles negation, [3] is last-if-true
+ self.bindings = []
+
+ def add_binding(self, binding, rule, last=True):
+ self.bindings.append((binding, rule, False, last))
+
+ #NOTE(ewindisch): kept the following method in case we implement the
+ # underlying support.
+ #def add_negate_binding(self, binding, rule, last=True):
+ # self.bindings.append((binding, rule, True, last))
+
+ def queues(self, key):
+ workers = []
+
+ # bit is for negate bindings - if we choose to implement it.
+ # last stops processing rules if this matches.
+ for (binding, exchange, bit, last) in self.bindings:
+ if binding.test(key):
+ workers.extend(exchange.run(key))
+
+ # Support last.
+ if last:
+ return workers
+ return workers
+
+
+class DirectBinding(Binding):
+ """
+ Specifies a host in the key via a '.' character
+ Although dots are used in the key, the behavior here is
+ that it maps directly to a host, thus direct.
+ """
+ def test(self, key):
+ if '.' in key:
+ return True
+ return False
+
+
+class TopicBinding(Binding):
+ """
+ Where a 'bare' key without dots.
+ AMQP generally considers topic exchanges to be those *with* dots,
+ but we deviate here in terminology as the behavior here matches
+ that of a topic exchange (whereas where there are dots, behavior
+ matches that of a direct exchange.
+ """
+ def test(self, key):
+ if '.' not in key:
+ return True
+ return False
+
+
+class FanoutBinding(Binding):
+ """Match on fanout keys, where key starts with 'fanout.' string."""
+ def test(self, key):
+ if key.startswith('fanout~'):
+ return True
+ return False
+
+
+class StubExchange(Exchange):
+ """Exchange that does nothing."""
+ def run(self, key):
+ return [(key, None)]
+
+
+class RingExchange(Exchange):
+ """
+ Match Maker where hosts are loaded from a static file containing
+ a hashmap (JSON formatted).
+
+ __init__ takes optional ring dictionary argument, otherwise
+ loads the ringfile from CONF.mathcmaker_ringfile.
+ """
+ def __init__(self, ring=None):
+ super(RingExchange, self).__init__()
+
+ if ring:
+ self.ring = ring
+ else:
+ fh = open(CONF.matchmaker_ringfile, 'r')
+ self.ring = json.load(fh)
+ fh.close()
+
+ self.ring0 = {}
+ for k in self.ring.keys():
+ self.ring0[k] = itertools.cycle(self.ring[k])
+
+ def _ring_has(self, key):
+ if key in self.ring0:
+ return True
+ return False
+
+
+class RoundRobinRingExchange(RingExchange):
+ """A Topic Exchange based on a hashmap."""
+ def __init__(self, ring=None):
+ super(RoundRobinRingExchange, self).__init__(ring)
+
+ def run(self, key):
+ if not self._ring_has(key):
+ LOG.warn(
+ _("No key defining hosts for topic '%s', "
+ "see ringfile") % (key, )
+ )
+ return []
+ host = next(self.ring0[key])
+ return [(key + '.' + host, host)]
+
+
+class FanoutRingExchange(RingExchange):
+ """Fanout Exchange based on a hashmap."""
+ def __init__(self, ring=None):
+ super(FanoutRingExchange, self).__init__(ring)
+
+ def run(self, key):
+ # Assume starts with "fanout~", strip it for lookup.
+ nkey = key.split('fanout~')[1:][0]
+ if not self._ring_has(nkey):
+ LOG.warn(
+ _("No key defining hosts for topic '%s', "
+ "see ringfile") % (nkey, )
+ )
+ return []
+ return map(lambda x: (key + '.' + x, x), self.ring[nkey])
+
+
+class LocalhostExchange(Exchange):
+ """Exchange where all direct topics are local."""
+ def __init__(self):
+ super(Exchange, self).__init__()
+
+ def run(self, key):
+ return [(key.split('.')[0] + '.localhost', 'localhost')]
+
+
+class DirectExchange(Exchange):
+ """
+ Exchange where all topic keys are split, sending to second half.
+ i.e. "compute.host" sends a message to "compute" running on "host"
+ """
+ def __init__(self):
+ super(Exchange, self).__init__()
+
+ def run(self, key):
+ b, e = key.split('.', 1)
+ return [(b, e)]
+
+
+class MatchMakerRing(MatchMakerBase):
+ """
+ Match Maker where hosts are loaded from a static hashmap.
+ """
+ def __init__(self, ring=None):
+ super(MatchMakerRing, self).__init__()
+ self.add_binding(FanoutBinding(), FanoutRingExchange(ring))
+ self.add_binding(DirectBinding(), DirectExchange())
+ self.add_binding(TopicBinding(), RoundRobinRingExchange(ring))
+
+
+class MatchMakerLocalhost(MatchMakerBase):
+ """
+ Match Maker where all bare topics resolve to localhost.
+ Useful for testing.
+ """
+ def __init__(self):
+ super(MatchMakerLocalhost, self).__init__()
+ self.add_binding(FanoutBinding(), LocalhostExchange())
+ self.add_binding(DirectBinding(), DirectExchange())
+ self.add_binding(TopicBinding(), LocalhostExchange())
+
+
+class MatchMakerStub(MatchMakerBase):
+ """
+ Match Maker where topics are untouched.
+ Useful for testing, or for AMQP/brokered queues.
+ Will not work where knowledge of hosts is known (i.e. zeromq)
+ """
+ def __init__(self):
+ super(MatchMakerLocalhost, self).__init__()
+
+ self.add_binding(FanoutBinding(), StubExchange())
+ self.add_binding(DirectBinding(), StubExchange())
+ self.add_binding(TopicBinding(), StubExchange())
diff --git a/openstack/common/rpc/proxy.py b/openstack/common/rpc/proxy.py
new file mode 100644
index 0000000..4f4dff5
--- /dev/null
+++ b/openstack/common/rpc/proxy.py
@@ -0,0 +1,161 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+A helper class for proxy objects to remote APIs.
+
+For more information about rpc API version numbers, see:
+ rpc/dispatcher.py
+"""
+
+
+from openstack.common import rpc
+
+
+class RpcProxy(object):
+ """A helper class for rpc clients.
+
+ This class is a wrapper around the RPC client API. It allows you to
+ specify the topic and API version in a single place. This is intended to
+ be used as a base class for a class that implements the client side of an
+ rpc API.
+ """
+
+ def __init__(self, topic, default_version):
+ """Initialize an RpcProxy.
+
+ :param topic: The topic to use for all messages.
+ :param default_version: The default API version to request in all
+ outgoing messages. This can be overridden on a per-message
+ basis.
+ """
+ self.topic = topic
+ self.default_version = default_version
+ super(RpcProxy, self).__init__()
+
+ def _set_version(self, msg, vers):
+ """Helper method to set the version in a message.
+
+ :param msg: The message having a version added to it.
+ :param vers: The version number to add to the message.
+ """
+ msg['version'] = vers if vers else self.default_version
+
+ def _get_topic(self, topic):
+ """Return the topic to use for a message."""
+ return topic if topic else self.topic
+
+ @staticmethod
+ def make_msg(method, **kwargs):
+ return {'method': method, 'args': kwargs}
+
+ def call(self, context, msg, topic=None, version=None, timeout=None):
+ """rpc.call() a remote method.
+
+ :param context: The request context
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param timeout: (Optional) A timeout to use when waiting for the
+ response. If no timeout is specified, a default timeout will be
+ used that is usually sufficient.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: The return value from the remote method.
+ """
+ self._set_version(msg, version)
+ return rpc.call(context, self._get_topic(topic), msg, timeout)
+
+ def multicall(self, context, msg, topic=None, version=None, timeout=None):
+ """rpc.multicall() a remote method.
+
+ :param context: The request context
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param timeout: (Optional) A timeout to use when waiting for the
+ response. If no timeout is specified, a default timeout will be
+ used that is usually sufficient.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: An iterator that lets you process each of the returned values
+ from the remote method as they arrive.
+ """
+ self._set_version(msg, version)
+ return rpc.multicall(context, self._get_topic(topic), msg, timeout)
+
+ def cast(self, context, msg, topic=None, version=None):
+ """rpc.cast() a remote method.
+
+ :param context: The request context
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: None. rpc.cast() does not wait on any return value from the
+ remote method.
+ """
+ self._set_version(msg, version)
+ rpc.cast(context, self._get_topic(topic), msg)
+
+ def fanout_cast(self, context, msg, version=None):
+ """rpc.fanout_cast() a remote method.
+
+ :param context: The request context
+ :param msg: The message to send, including the method and args.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: None. rpc.fanout_cast() does not wait on any return value
+ from the remote method.
+ """
+ self._set_version(msg, version)
+ rpc.fanout_cast(context, self.topic, msg)
+
+ def cast_to_server(self, context, server_params, msg, topic=None,
+ version=None):
+ """rpc.cast_to_server() a remote method.
+
+ :param context: The request context
+ :param server_params: Server parameters. See rpc.cast_to_server() for
+ details.
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: None. rpc.cast_to_server() does not wait on any
+ return values.
+ """
+ self._set_version(msg, version)
+ rpc.cast_to_server(context, server_params, self._get_topic(topic), msg)
+
+ def fanout_cast_to_server(self, context, server_params, msg, version=None):
+ """rpc.fanout_cast_to_server() a remote method.
+
+ :param context: The request context
+ :param server_params: Server parameters. See rpc.cast_to_server() for
+ details.
+ :param msg: The message to send, including the method and args.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: None. rpc.fanout_cast_to_server() does not wait on any
+ return values.
+ """
+ self._set_version(msg, version)
+ rpc.fanout_cast_to_server(context, server_params, self.topic, msg)
diff --git a/tests/unit/rpc/__init__.py b/tests/unit/rpc/__init__.py
new file mode 100644
index 0000000..848908a
--- /dev/null
+++ b/tests/unit/rpc/__init__.py
@@ -0,0 +1,15 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit/rpc/common.py b/tests/unit/rpc/common.py
new file mode 100644
index 0000000..013418d
--- /dev/null
+++ b/tests/unit/rpc/common.py
@@ -0,0 +1,322 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls shared between all implementations
+"""
+
+import logging
+import time
+import unittest
+
+import eventlet
+from eventlet import greenthread
+import nose
+
+from openstack.common import cfg
+from openstack.common import exception
+from openstack.common.gettextutils import _
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common.rpc import common as rpc_common
+from openstack.common.rpc import dispatcher as rpc_dispatcher
+
+
+FLAGS = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+class BaseRpcTestCase(unittest.TestCase):
+ def setUp(self, supports_timeouts=True, topic='test',
+ topic_nested='nested'):
+ super(BaseRpcTestCase, self).setUp()
+ self.topic = topic or self.topic
+ self.topic_nested = topic_nested or self.topic_nested
+ self.supports_timeouts = supports_timeouts
+ self.context = rpc_common.CommonRpcContext(user='fake_user',
+ pw='fake_pw')
+
+ if self.rpc:
+ receiver = TestReceiver()
+ self.conn = self._create_consumer(receiver, self.topic)
+
+ def tearDown(self):
+ if self.rpc:
+ self.conn.close()
+ super(BaseRpcTestCase, self).tearDown()
+
+ def _create_consumer(self, proxy, topic, fanout=False):
+ dispatcher = rpc_dispatcher.RpcDispatcher([proxy])
+ conn = self.rpc.create_connection(FLAGS, True)
+ conn.create_consumer(topic, dispatcher, fanout)
+ conn.consume_in_thread()
+ return conn
+
+ def test_call_succeed(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.call(FLAGS, self.context, self.topic,
+ {"method": "echo", "args": {"value": value}})
+ self.assertEqual(value, result)
+
+ def test_call_succeed_despite_multiple_returns_yield(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.call(FLAGS, self.context, self.topic,
+ {"method": "echo_three_times_yield",
+ "args": {"value": value}})
+ self.assertEqual(value + 2, result)
+
+ def test_multicall_succeed_once(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.multicall(FLAGS, self.context,
+ self.topic,
+ {"method": "echo",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ if i > 0:
+ self.fail('should only receive one response')
+ self.assertEqual(value + i, x)
+
+ def test_multicall_three_nones(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.multicall(FLAGS, self.context,
+ self.topic,
+ {"method": "multicall_three_nones",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ self.assertEqual(x, None)
+ # i should have been 0, 1, and finally 2:
+ self.assertEqual(i, 2)
+
+ def test_multicall_succeed_three_times_yield(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.multicall(FLAGS, self.context,
+ self.topic,
+ {"method": "echo_three_times_yield",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ self.assertEqual(value + i, x)
+
+ def test_context_passed(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ """Makes sure a context is passed through rpc call."""
+ value = 42
+ result = self.rpc.call(FLAGS, self.context,
+ self.topic, {"method": "context",
+ "args": {"value": value}})
+ self.assertEqual(self.context.to_dict(), result)
+
+ def _test_cast(self, fanout=False):
+ """Test casts by pushing items through a channeled queue."""
+
+ # Not a true global, but capitalized so
+ # it is clear it is leaking scope into Nested()
+ QUEUE = eventlet.queue.Queue()
+
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ # We use the nested topic so we don't need QUEUE to be a proper
+ # global, and do not keep state outside this test.
+ class Nested(object):
+ @staticmethod
+ def put_queue(context, value):
+ LOG.debug("Got value in put_queue: %s", value)
+ QUEUE.put(value)
+
+ nested = Nested()
+ conn = self._create_consumer(nested, self.topic_nested, fanout)
+ value = 42
+
+ method = (self.rpc.cast, self.rpc.fanout_cast)[fanout]
+ method(FLAGS, self.context,
+ self.topic_nested,
+ {"method": "put_queue",
+ "args": {"value": value}})
+
+ try:
+ # If it does not succeed in 2 seconds, give up and assume
+ # failure.
+ result = QUEUE.get(True, 2)
+ except Exception:
+ self.assertEqual(value, None)
+
+ conn.close()
+ self.assertEqual(value, result)
+
+ def test_cast_success(self):
+ self._test_cast(False)
+
+ def test_fanout_success(self):
+ self._test_cast(True)
+
+ def test_nested_calls(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ """Test that we can do an rpc.call inside another call."""
+ class Nested(object):
+ @staticmethod
+ def echo(context, queue, value):
+ """Calls echo in the passed queue."""
+ LOG.debug(_("Nested received %(queue)s, %(value)s")
+ % locals())
+ # TODO(comstud):
+ # so, it will replay the context and use the same REQID?
+ # that's bizarre.
+ ret = self.rpc.call(FLAGS, context,
+ queue,
+ {"method": "echo",
+ "args": {"value": value}})
+ LOG.debug(_("Nested return %s"), ret)
+ return value
+
+ nested = Nested()
+ conn = self._create_consumer(nested, self.topic_nested)
+
+ value = 42
+ result = self.rpc.call(FLAGS, self.context,
+ self.topic_nested,
+ {"method": "echo",
+ "args": {"queue": "test", "value": value}})
+ conn.close()
+ self.assertEqual(value, result)
+
+ def test_call_timeout(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ """Make sure rpc.call will time out."""
+ if not self.supports_timeouts:
+ raise nose.SkipTest(_("RPC backend does not support timeouts"))
+
+ value = 42
+ self.assertRaises(rpc_common.Timeout,
+ self.rpc.call,
+ FLAGS, self.context,
+ self.topic,
+ {"method": "block",
+ "args": {"value": value}}, timeout=1)
+ try:
+ self.rpc.call(FLAGS, self.context,
+ self.topic,
+ {"method": "block",
+ "args": {"value": value}},
+ timeout=1)
+ self.fail("should have thrown Timeout")
+ except rpc_common.Timeout as exc:
+ pass
+
+
+class BaseRpcAMQPTestCase(BaseRpcTestCase):
+ """Base test class for all AMQP-based RPC tests."""
+ def test_proxycallback_handles_exceptions(self):
+ """Make sure exceptions unpacking messages don't cause hangs."""
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ orig_unpack = rpc_amqp.unpack_context
+
+ info = {'unpacked': False}
+
+ def fake_unpack_context(*args, **kwargs):
+ info['unpacked'] = True
+ raise test.TestingException('moo')
+
+ self.stubs.Set(rpc_amqp, 'unpack_context', fake_unpack_context)
+
+ value = 41
+ self.rpc.cast(FLAGS, self.context, self.topic,
+ {"method": "echo", "args": {"value": value}})
+
+ # Wait for the cast to complete.
+ for x in xrange(50):
+ if info['unpacked']:
+ break
+ greenthread.sleep(0.1)
+ else:
+ self.fail("Timeout waiting for message to be consumed")
+
+ # Now see if we get a response even though we raised an
+ # exception for the cast above.
+ self.stubs.Set(rpc_amqp, 'unpack_context', orig_unpack)
+
+ value = 42
+ result = self.rpc.call(FLAGS, self.context, self.topic,
+ {"method": "echo",
+ "args": {"value": value}})
+ self.assertEqual(value, result)
+
+
+class TestReceiver(object):
+ """Simple Proxy class so the consumer has methods to call.
+
+ Uses static methods because we aren't actually storing any state.
+
+ """
+ @staticmethod
+ def echo(context, value):
+ """Simply returns whatever value is sent in."""
+ LOG.debug(_("Received %s"), value)
+ return value
+
+ @staticmethod
+ def context(context, value):
+ """Returns dictionary version of context."""
+ LOG.debug(_("Received %s"), context)
+ return context.to_dict()
+
+ @staticmethod
+ def multicall_three_nones(context, value):
+ yield None
+ yield None
+ yield None
+
+ @staticmethod
+ def echo_three_times_yield(context, value):
+ yield value
+ yield value + 1
+ yield value + 2
+
+ @staticmethod
+ def fail(context, value):
+ """Raises an exception with the value sent in."""
+ raise NotImplementedError(value)
+
+ @staticmethod
+ def fail_converted(context, value):
+ """Raises an exception with the value sent in."""
+ raise exception.ApiError(message=value, code='500')
+
+ @staticmethod
+ def block(context, value):
+ time.sleep(2)
diff --git a/tests/unit/rpc/test_common.py b/tests/unit/rpc/test_common.py
new file mode 100644
index 0000000..73fb733
--- /dev/null
+++ b/tests/unit/rpc/test_common.py
@@ -0,0 +1,150 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for 'common' functons used through rpc code.
+"""
+
+import json
+import logging
+import sys
+import unittest
+
+from openstack.common import cfg
+from openstack.common import context
+from openstack.common import exception
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common.rpc import common as rpc_common
+from tests.unit.rpc import common
+
+
+FLAGS = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+def raise_exception():
+ raise Exception("test")
+
+
+class FakeUserDefinedException(Exception):
+ def __init__(self):
+ Exception.__init__(self, "Test Message")
+
+
+class RpcCommonTestCase(unittest.TestCase):
+ def test_serialize_remote_exception(self):
+ expected = {
+ 'class': 'Exception',
+ 'module': 'exceptions',
+ 'message': 'test',
+ }
+
+ try:
+ raise_exception()
+ except Exception as exc:
+ failure = rpc_common.serialize_remote_exception(sys.exc_info())
+
+ failure = json.loads(failure)
+ #assure the traceback was added
+ self.assertEqual(expected['class'], failure['class'])
+ self.assertEqual(expected['module'], failure['module'])
+ self.assertEqual(expected['message'], failure['message'])
+
+ def test_serialize_remote_custom_exception(self):
+ def raise_custom_exception():
+ raise exception.OpenstackException()
+
+ expected = {
+ 'class': 'OpenstackException',
+ 'module': 'openstack.common.exception',
+ 'message': exception.OpenstackException.message,
+ }
+
+ try:
+ raise_custom_exception()
+ except Exception as exc:
+ failure = rpc_common.serialize_remote_exception(sys.exc_info())
+
+ failure = json.loads(failure)
+ #assure the traceback was added
+ self.assertEqual(expected['class'], failure['class'])
+ self.assertEqual(expected['module'], failure['module'])
+ self.assertEqual(expected['message'], failure['message'])
+
+ def test_deserialize_remote_exception(self):
+ failure = {
+ 'class': 'OpenstackException',
+ 'module': 'openstack.common.exception',
+ 'message': exception.OpenstackException.message,
+ 'tb': ['raise OpenstackException'],
+ }
+ serialized = json.dumps(failure)
+
+ after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
+ self.assertTrue(isinstance(after_exc, exception.OpenstackException))
+ self.assertTrue('An unknown' in unicode(after_exc))
+ #assure the traceback was added
+ self.assertTrue('raise OpenstackException' in unicode(after_exc))
+
+ def test_deserialize_remote_exception_bad_module(self):
+ failure = {
+ 'class': 'popen2',
+ 'module': 'os',
+ 'kwargs': {'cmd': '/bin/echo failed'},
+ 'message': 'foo',
+ }
+ serialized = json.dumps(failure)
+
+ after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
+ self.assertTrue(isinstance(after_exc, rpc_common.RemoteError))
+
+ def test_deserialize_remote_exception_user_defined_exception(self):
+ """Ensure a user defined exception can be deserialized."""
+ FLAGS.set_override('allowed_rpc_exception_modules',
+ [self.__class__.__module__])
+ failure = {
+ 'class': 'FakeUserDefinedException',
+ 'module': self.__class__.__module__,
+ 'tb': ['raise FakeUserDefinedException'],
+ }
+ serialized = json.dumps(failure)
+
+ after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
+ self.assertTrue(isinstance(after_exc, FakeUserDefinedException))
+ #assure the traceback was added
+ self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc))
+ FLAGS.reset()
+
+ def test_deserialize_remote_exception_cannot_recreate(self):
+ """Ensure a RemoteError is returned on initialization failure.
+
+ If an exception cannot be recreated with it's original class then a
+ RemoteError with the exception informations should still be returned.
+
+ """
+ FLAGS.set_override('allowed_rpc_exception_modules',
+ [self.__class__.__module__])
+ failure = {
+ 'class': 'FakeIDontExistException',
+ 'module': self.__class__.__module__,
+ 'tb': ['raise FakeIDontExistException'],
+ }
+ serialized = json.dumps(failure)
+
+ after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
+ self.assertTrue(isinstance(after_exc, rpc_common.RemoteError))
+ #assure the traceback was added
+ self.assertTrue('raise FakeIDontExistException' in unicode(after_exc))
+ FLAGS.reset()
diff --git a/tests/unit/rpc/test_dispatcher.py b/tests/unit/rpc/test_dispatcher.py
new file mode 100644
index 0000000..a085567
--- /dev/null
+++ b/tests/unit/rpc/test_dispatcher.py
@@ -0,0 +1,110 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Unit Tests for rpc.dispatcher
+"""
+
+import unittest
+
+from openstack.common import context
+from openstack.common.rpc import common as rpc_common
+from openstack.common.rpc import dispatcher
+
+
+class RpcDispatcherTestCase(unittest.TestCase):
+ class API1(object):
+ RPC_API_VERSION = '1.0'
+
+ def __init__(self):
+ self.test_method_ctxt = None
+ self.test_method_arg1 = None
+
+ def test_method(self, ctxt, arg1):
+ self.test_method_ctxt = ctxt
+ self.test_method_arg1 = arg1
+
+ class API2(object):
+ RPC_API_VERSION = '2.1'
+
+ def __init__(self):
+ self.test_method_ctxt = None
+ self.test_method_arg1 = None
+
+ def test_method(self, ctxt, arg1):
+ self.test_method_ctxt = ctxt
+ self.test_method_arg1 = arg1
+
+ class API3(object):
+ RPC_API_VERSION = '3.5'
+
+ def __init__(self):
+ self.test_method_ctxt = None
+ self.test_method_arg1 = None
+
+ def test_method(self, ctxt, arg1):
+ self.test_method_ctxt = ctxt
+ self.test_method_arg1 = arg1
+
+ def setUp(self):
+ self.ctxt = context.RequestContext('fake_user', 'fake_project')
+ super(RpcDispatcherTestCase, self).setUp()
+
+ def tearDown(self):
+ super(RpcDispatcherTestCase, self).tearDown()
+
+ def _test_dispatch(self, version, expectations):
+ v2 = self.API2()
+ v3 = self.API3()
+ disp = dispatcher.RpcDispatcher([v2, v3])
+
+ disp.dispatch(self.ctxt, version, 'test_method', arg1=1)
+
+ self.assertEqual(v2.test_method_ctxt, expectations[0])
+ self.assertEqual(v2.test_method_arg1, expectations[1])
+ self.assertEqual(v3.test_method_ctxt, expectations[2])
+ self.assertEqual(v3.test_method_arg1, expectations[3])
+
+ def test_dispatch(self):
+ self._test_dispatch('2.1', (self.ctxt, 1, None, None))
+ self._test_dispatch('3.5', (None, None, self.ctxt, 1))
+
+ def test_dispatch_lower_minor_version(self):
+ self._test_dispatch('2.0', (self.ctxt, 1, None, None))
+ self._test_dispatch('3.1', (None, None, self.ctxt, 1))
+
+ def test_dispatch_higher_minor_version(self):
+ self.assertRaises(rpc_common.UnsupportedRpcVersion,
+ self._test_dispatch, '2.6', (None, None, None, None))
+ self.assertRaises(rpc_common.UnsupportedRpcVersion,
+ self._test_dispatch, '3.6', (None, None, None, None))
+
+ def test_dispatch_lower_major_version(self):
+ self.assertRaises(rpc_common.UnsupportedRpcVersion,
+ self._test_dispatch, '1.0', (None, None, None, None))
+
+ def test_dispatch_higher_major_version(self):
+ self.assertRaises(rpc_common.UnsupportedRpcVersion,
+ self._test_dispatch, '4.0', (None, None, None, None))
+
+ def test_dispatch_no_version_uses_v1(self):
+ v1 = self.API1()
+ disp = dispatcher.RpcDispatcher([v1])
+
+ disp.dispatch(self.ctxt, None, 'test_method', arg1=1)
+
+ self.assertEqual(v1.test_method_ctxt, self.ctxt)
+ self.assertEqual(v1.test_method_arg1, 1)
diff --git a/tests/unit/rpc/test_fake.py b/tests/unit/rpc/test_fake.py
new file mode 100644
index 0000000..8ceac47
--- /dev/null
+++ b/tests/unit/rpc/test_fake.py
@@ -0,0 +1,32 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls using fake_impl
+"""
+
+import eventlet
+eventlet.monkey_patch()
+
+from openstack.common.rpc import impl_fake
+from tests.unit.rpc import common
+
+
+class RpcFakeTestCase(common.BaseRpcTestCase):
+ def setUp(self):
+ self.rpc = impl_fake
+ super(RpcFakeTestCase, self).setUp()
diff --git a/tests/unit/rpc/test_kombu.py b/tests/unit/rpc/test_kombu.py
new file mode 100644
index 0000000..ccab4a2
--- /dev/null
+++ b/tests/unit/rpc/test_kombu.py
@@ -0,0 +1,414 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls using kombu
+"""
+
+import eventlet
+eventlet.monkey_patch()
+
+import logging
+import unittest
+
+import stubout
+
+from openstack.common import cfg
+from openstack.common import exception
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common.rpc import common as rpc_common
+from openstack.common import testutils
+from tests.unit.rpc import common
+
+try:
+ import kombu
+ from openstack.common.rpc import impl_kombu
+except ImportError:
+ kombu = None
+ impl_kombu = None
+
+
+FLAGS = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+class MyException(Exception):
+ pass
+
+
+def _raise_exc_stub(stubs, times, obj, method, exc_msg,
+ exc_class=MyException):
+ info = {'called': 0}
+ orig_method = getattr(obj, method)
+
+ def _raise_stub(*args, **kwargs):
+ info['called'] += 1
+ if info['called'] <= times:
+ raise exc_class(exc_msg)
+ orig_method(*args, **kwargs)
+ stubs.Set(obj, method, _raise_stub)
+ return info
+
+
+class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ if kombu:
+ FLAGS.set_override('fake_rabbit', True)
+ FLAGS.set_override('rpc_response_timeout', 5)
+ self.rpc = impl_kombu
+ else:
+ self.rpc = None
+ super(RpcKombuTestCase, self).setUp()
+
+ def tearDown(self):
+ self.stubs.UnsetAll()
+ self.stubs.SmartUnsetAll()
+ if kombu:
+ impl_kombu.cleanup()
+ FLAGS.reset()
+ super(RpcKombuTestCase, self).tearDown()
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_reusing_connection(self):
+ """Test that reusing a connection returns same one."""
+ conn_context = self.rpc.create_connection(FLAGS, new=False)
+ conn1 = conn_context.connection
+ conn_context.close()
+ conn_context = self.rpc.create_connection(FLAGS, new=False)
+ conn2 = conn_context.connection
+ conn_context.close()
+ self.assertEqual(conn1, conn2)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_topic_send_receive(self):
+ """Test sending to a topic exchange/queue"""
+
+ conn = self.rpc.create_connection(FLAGS)
+ message = 'topic test message'
+
+ self.received_message = None
+
+ def _callback(message):
+ self.received_message = message
+
+ conn.declare_topic_consumer('a_topic', _callback)
+ conn.topic_send('a_topic', message)
+ conn.consume(limit=1)
+ conn.close()
+
+ self.assertEqual(self.received_message, message)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_topic_multiple_queues(self):
+ """Test sending to a topic exchange with multiple queues"""
+
+ conn = self.rpc.create_connection(FLAGS)
+ message = 'topic test message'
+
+ self.received_message_1 = None
+ self.received_message_2 = None
+
+ def _callback1(message):
+ self.received_message_1 = message
+
+ def _callback2(message):
+ self.received_message_2 = message
+
+ conn.declare_topic_consumer('a_topic', _callback1, queue_name='queue1')
+ conn.declare_topic_consumer('a_topic', _callback2, queue_name='queue2')
+ conn.topic_send('a_topic', message)
+ conn.consume(limit=2)
+ conn.close()
+
+ self.assertEqual(self.received_message_1, message)
+ self.assertEqual(self.received_message_2, message)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_direct_send_receive(self):
+ """Test sending to a direct exchange/queue"""
+ conn = self.rpc.create_connection(FLAGS)
+ message = 'direct test message'
+
+ self.received_message = None
+
+ def _callback(message):
+ self.received_message = message
+
+ conn.declare_direct_consumer('a_direct', _callback)
+ conn.direct_send('a_direct', message)
+ conn.consume(limit=1)
+ conn.close()
+
+ self.assertEqual(self.received_message, message)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_cast_interface_uses_default_options(self):
+ """Test kombu rpc.cast"""
+
+ ctxt = rpc_common.CommonRpcContext(user='fake_user',
+ project='fake_project')
+
+ class MyConnection(impl_kombu.Connection):
+ def __init__(myself, *args, **kwargs):
+ super(MyConnection, myself).__init__(*args, **kwargs)
+ self.assertEqual(myself.params,
+ {'hostname': FLAGS.rabbit_host,
+ 'userid': FLAGS.rabbit_userid,
+ 'password': FLAGS.rabbit_password,
+ 'port': FLAGS.rabbit_port,
+ 'virtual_host': FLAGS.rabbit_virtual_host,
+ 'transport': 'memory'})
+
+ def topic_send(_context, topic, msg):
+ pass
+
+ MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
+ self.stubs.Set(impl_kombu, 'Connection', MyConnection)
+
+ impl_kombu.cast(FLAGS, ctxt, 'fake_topic', {'msg': 'fake'})
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_cast_to_server_uses_server_params(self):
+ """Test kombu rpc.cast"""
+
+ ctxt = rpc_common.CommonRpcContext(user='fake_user',
+ project='fake_project')
+
+ server_params = {'username': 'fake_username',
+ 'password': 'fake_password',
+ 'hostname': 'fake_hostname',
+ 'port': 31337,
+ 'virtual_host': 'fake_virtual_host'}
+
+ class MyConnection(impl_kombu.Connection):
+ def __init__(myself, *args, **kwargs):
+ super(MyConnection, myself).__init__(*args, **kwargs)
+ self.assertEqual(myself.params,
+ {'hostname': server_params['hostname'],
+ 'userid': server_params['username'],
+ 'password': server_params['password'],
+ 'port': server_params['port'],
+ 'virtual_host': server_params['virtual_host'],
+ 'transport': 'memory'})
+
+ def topic_send(_context, topic, msg):
+ pass
+
+ MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
+ self.stubs.Set(impl_kombu, 'Connection', MyConnection)
+
+ impl_kombu.cast_to_server(FLAGS, ctxt, server_params,
+ 'fake_topic', {'msg': 'fake'})
+
+ @testutils.skip_test("kombu memory transport seems buggy with "
+ "fanout queues as this test passes when "
+ "you use rabbit (fake_rabbit=False)")
+ def test_fanout_send_receive(self):
+ """Test sending to a fanout exchange and consuming from 2 queues"""
+
+ conn = self.rpc.create_connection()
+ conn2 = self.rpc.create_connection()
+ message = 'fanout test message'
+
+ self.received_message = None
+
+ def _callback(message):
+ self.received_message = message
+
+ conn.declare_fanout_consumer('a_fanout', _callback)
+ conn2.declare_fanout_consumer('a_fanout', _callback)
+ conn.fanout_send('a_fanout', message)
+
+ conn.consume(limit=1)
+ conn.close()
+ self.assertEqual(self.received_message, message)
+
+ self.received_message = None
+ conn2.consume(limit=1)
+ conn2.close()
+ self.assertEqual(self.received_message, message)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_declare_consumer_errors_will_reconnect(self):
+ # Test that any exception with 'timeout' in it causes a
+ # reconnection
+ info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer,
+ '__init__', 'foo timeout foo')
+
+ conn = self.rpc.Connection(FLAGS)
+ result = conn.declare_consumer(self.rpc.DirectConsumer,
+ 'test_topic', None)
+
+ self.assertEqual(info['called'], 3)
+ self.assertTrue(isinstance(result, self.rpc.DirectConsumer))
+
+ # Test that any exception in transport.connection_errors causes
+ # a reconnection
+ self.stubs.UnsetAll()
+
+ info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectConsumer,
+ '__init__', 'meow')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.connection_errors = (MyException, )
+
+ result = conn.declare_consumer(self.rpc.DirectConsumer,
+ 'test_topic', None)
+
+ self.assertEqual(info['called'], 2)
+ self.assertTrue(isinstance(result, self.rpc.DirectConsumer))
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_declare_consumer_ioerrors_will_reconnect(self):
+ """Test that an IOError exception causes a reconnection"""
+ info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer,
+ '__init__', 'Socket closed', exc_class=IOError)
+
+ conn = self.rpc.Connection(FLAGS)
+ result = conn.declare_consumer(self.rpc.DirectConsumer,
+ 'test_topic', None)
+
+ self.assertEqual(info['called'], 3)
+ self.assertTrue(isinstance(result, self.rpc.DirectConsumer))
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_publishing_errors_will_reconnect(self):
+ # Test that any exception with 'timeout' in it causes a
+ # reconnection when declaring the publisher class and when
+ # calling send()
+ info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher,
+ '__init__', 'foo timeout foo')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
+
+ self.assertEqual(info['called'], 3)
+ self.stubs.UnsetAll()
+
+ info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher,
+ 'send', 'foo timeout foo')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
+
+ self.assertEqual(info['called'], 3)
+
+ # Test that any exception in transport.connection_errors causes
+ # a reconnection when declaring the publisher class and when
+ # calling send()
+ self.stubs.UnsetAll()
+
+ info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher,
+ '__init__', 'meow')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.connection_errors = (MyException, )
+
+ conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
+
+ self.assertEqual(info['called'], 2)
+ self.stubs.UnsetAll()
+
+ info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher,
+ 'send', 'meow')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.connection_errors = (MyException, )
+
+ conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
+
+ self.assertEqual(info['called'], 2)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_iterconsume_errors_will_reconnect(self):
+ conn = self.rpc.Connection(FLAGS)
+ message = 'reconnect test message'
+
+ self.received_message = None
+
+ def _callback(message):
+ self.received_message = message
+
+ conn.declare_direct_consumer('a_direct', _callback)
+ conn.direct_send('a_direct', message)
+
+ info = _raise_exc_stub(self.stubs, 1, conn.connection,
+ 'drain_events', 'foo timeout foo')
+ conn.consume(limit=1)
+ conn.close()
+
+ self.assertEqual(self.received_message, message)
+ # Only called once, because our stub goes away during reconnection
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_call_exception(self):
+ """Test that exception gets passed back properly.
+
+ rpc.call returns an Exception object. The value of the
+ exception is converted to a string.
+
+ """
+ FLAGS.set_override('allowed_rpc_exception_modules', ['exceptions'])
+ value = "This is the exception message"
+ self.assertRaises(NotImplementedError,
+ self.rpc.call,
+ FLAGS,
+ self.context,
+ 'test',
+ {"method": "fail",
+ "args": {"value": value}})
+ try:
+ self.rpc.call(FLAGS, self.context,
+ 'test',
+ {"method": "fail",
+ "args": {"value": value}})
+ self.fail("should have thrown Exception")
+ except NotImplementedError as exc:
+ self.assertTrue(value in unicode(exc))
+ #Traceback should be included in exception message
+ self.assertTrue('raise NotImplementedError(value)' in unicode(exc))
+
+ FLAGS.reset()
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_call_converted_exception(self):
+ """Test that exception gets passed back properly.
+
+ rpc.call returns an Exception object. The value of the
+ exception is converted to a string.
+
+ """
+ value = "This is the exception message"
+ # The use of ApiError is an arbitrary choice here ...
+ self.assertRaises(exception.ApiError,
+ self.rpc.call,
+ FLAGS,
+ self.context,
+ 'test',
+ {"method": "fail_converted",
+ "args": {"value": value}})
+ try:
+ self.rpc.call(FLAGS, self.context,
+ 'test',
+ {"method": "fail_converted",
+ "args": {"value": value}})
+ self.fail("should have thrown Exception")
+ except exception.ApiError as exc:
+ self.assertTrue(value in unicode(exc))
+ #Traceback should be included in exception message
+ self.assertTrue('exception.ApiError' in unicode(exc))
diff --git a/tests/unit/rpc/test_kombu_ssl.py b/tests/unit/rpc/test_kombu_ssl.py
new file mode 100644
index 0000000..2aecc3f
--- /dev/null
+++ b/tests/unit/rpc/test_kombu_ssl.py
@@ -0,0 +1,82 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls using kombu + ssl
+"""
+
+import eventlet
+eventlet.monkey_patch()
+
+import unittest
+
+from openstack.common import cfg
+from openstack.common import testutils
+
+try:
+ import kombu
+ from openstack.common.rpc import impl_kombu
+except ImportError:
+ kombu = None
+ impl_kombu = None
+
+
+# Flag settings we will ensure get passed to amqplib
+SSL_VERSION = "SSLv2"
+SSL_CERT = "/tmp/cert.blah.blah"
+SSL_CA_CERT = "/tmp/cert.ca.blah.blah"
+SSL_KEYFILE = "/tmp/keyfile.blah.blah"
+
+FLAGS = cfg.CONF
+
+
+class RpcKombuSslTestCase(unittest.TestCase):
+
+ def setUp(self):
+ super(RpcKombuSslTestCase, self).setUp()
+ override = {
+ 'kombu_ssl_keyfile': SSL_KEYFILE,
+ 'kombu_ssl_ca_certs': SSL_CA_CERT,
+ 'kombu_ssl_certfile': SSL_CERT,
+ 'kombu_ssl_version': SSL_VERSION,
+ 'rabbit_use_ssl': True,
+ 'fake_rabbit': True,
+ }
+
+ if kombu:
+ for k, v in override.iteritems():
+ FLAGS.set_override(k, v)
+
+ def tearDown(self):
+ super(RpcKombuSslTestCase, self).tearDown()
+ if kombu:
+ FLAGS.reset()
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_ssl_on_extended(self):
+ rpc = impl_kombu
+ conn = rpc.create_connection(FLAGS, True)
+ c = conn.connection
+ #This might be kombu version dependent...
+ #Since we are now peaking into the internals of kombu...
+ self.assertTrue(isinstance(c.connection.ssl, dict))
+ self.assertEqual(SSL_VERSION, c.connection.ssl.get("ssl_version"))
+ self.assertEqual(SSL_CERT, c.connection.ssl.get("certfile"))
+ self.assertEqual(SSL_CA_CERT, c.connection.ssl.get("ca_certs"))
+ self.assertEqual(SSL_KEYFILE, c.connection.ssl.get("keyfile"))
+ #That hash then goes into amqplib which then goes
+ #Into python ssl creation...
diff --git a/tests/unit/rpc/test_matchmaker.py b/tests/unit/rpc/test_matchmaker.py
new file mode 100644
index 0000000..a38b59c
--- /dev/null
+++ b/tests/unit/rpc/test_matchmaker.py
@@ -0,0 +1,60 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 Cloudscaling Group, Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import logging
+import unittest
+
+from openstack.common.rpc import matchmaker
+
+
+LOG = logging.getLogger(__name__)
+
+
+class _MatchMakerTestCase(unittest.TestCase):
+ def test_valid_host_matches(self):
+ queues = self.driver.queues(self.topic)
+ matched_hosts = map(lambda x: x[1], queues)
+
+ for host in matched_hosts:
+ self.assertIn(host, self.hosts)
+
+ def test_fanout_host_matches(self):
+ """For known hosts, see if they're in fanout."""
+ queues = self.driver.queues("fanout~" + self.topic)
+ matched_hosts = map(lambda x: x[1], queues)
+
+ LOG.info("Received result from matchmaker: %s", queues)
+ for host in self.hosts:
+ self.assertIn(host, matched_hosts)
+
+
+class MatchMakerFileTestCase(_MatchMakerTestCase):
+ def setUp(self):
+ self.topic = "test"
+ self.hosts = ['hello', 'world', 'foo', 'bar', 'baz']
+ ring = {
+ self.topic: self.hosts
+ }
+ self.driver = matchmaker.MatchMakerRing(ring)
+ super(MatchMakerFileTestCase, self).setUp()
+
+
+class MatchMakerLocalhostTestCase(_MatchMakerTestCase):
+ def setUp(self):
+ self.driver = matchmaker.MatchMakerLocalhost()
+ self.topic = "test"
+ self.hosts = ['localhost']
+ super(MatchMakerLocalhostTestCase, self).setUp()
diff --git a/tests/unit/rpc/test_proxy.py b/tests/unit/rpc/test_proxy.py
new file mode 100644
index 0000000..1af37c7
--- /dev/null
+++ b/tests/unit/rpc/test_proxy.py
@@ -0,0 +1,128 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Unit Tests for rpc.proxy
+"""
+
+import copy
+import stubout
+import unittest
+
+from openstack.common import context
+from openstack.common import rpc
+from openstack.common.rpc import proxy
+
+
+class RpcProxyTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ super(RpcProxyTestCase, self).setUp()
+
+ def tearDown(self):
+ self.stubs.UnsetAll()
+ self.stubs.SmartUnsetAll()
+ super(RpcProxyTestCase, self).tearDown()
+
+ def _test_rpc_method(self, rpc_method, has_timeout=False, has_retval=False,
+ server_params=None, supports_topic_override=True):
+ topic = 'fake_topic'
+ timeout = 123
+ rpc_proxy = proxy.RpcProxy(topic, '1.0')
+ ctxt = context.RequestContext('fake_user', 'fake_project')
+ msg = {'method': 'fake_method', 'args': {'x': 'y'}}
+ expected_msg = {'method': 'fake_method', 'args': {'x': 'y'},
+ 'version': '1.0'}
+
+ expected_retval = 'hi' if has_retval else None
+
+ self.fake_args = None
+ self.fake_kwargs = None
+
+ def _fake_rpc_method(*args, **kwargs):
+ self.fake_args = args
+ self.fake_kwargs = kwargs
+ if has_retval:
+ return expected_retval
+
+ self.stubs.Set(rpc, rpc_method, _fake_rpc_method)
+
+ args = [ctxt, msg]
+ if server_params:
+ args.insert(1, server_params)
+
+ # Base method usage
+ retval = getattr(rpc_proxy, rpc_method)(*args)
+ self.assertEqual(retval, expected_retval)
+ expected_args = [ctxt, topic, expected_msg]
+ if server_params:
+ expected_args.insert(1, server_params)
+ for arg, expected_arg in zip(self.fake_args, expected_args):
+ self.assertEqual(arg, expected_arg)
+
+ # overriding the version
+ retval = getattr(rpc_proxy, rpc_method)(*args, version='1.1')
+ self.assertEqual(retval, expected_retval)
+ new_msg = copy.deepcopy(expected_msg)
+ new_msg['version'] = '1.1'
+ expected_args = [ctxt, topic, new_msg]
+ if server_params:
+ expected_args.insert(1, server_params)
+ for arg, expected_arg in zip(self.fake_args, expected_args):
+ self.assertEqual(arg, expected_arg)
+
+ if has_timeout:
+ # set a timeout
+ retval = getattr(rpc_proxy, rpc_method)(ctxt, msg, timeout=timeout)
+ self.assertEqual(retval, expected_retval)
+ expected_args = [ctxt, topic, expected_msg, timeout]
+ for arg, expected_arg in zip(self.fake_args, expected_args):
+ self.assertEqual(arg, expected_arg)
+
+ if supports_topic_override:
+ # set a topic
+ new_topic = 'foo.bar'
+ retval = getattr(rpc_proxy, rpc_method)(*args, topic=new_topic)
+ self.assertEqual(retval, expected_retval)
+ expected_args = [ctxt, new_topic, expected_msg]
+ if server_params:
+ expected_args.insert(1, server_params)
+ for arg, expected_arg in zip(self.fake_args, expected_args):
+ self.assertEqual(arg, expected_arg)
+
+ def test_call(self):
+ self._test_rpc_method('call', has_timeout=True, has_retval=True)
+
+ def test_multicall(self):
+ self._test_rpc_method('multicall', has_timeout=True, has_retval=True)
+
+ def test_cast(self):
+ self._test_rpc_method('cast')
+
+ def test_fanout_cast(self):
+ self._test_rpc_method('fanout_cast', supports_topic_override=False)
+
+ def test_cast_to_server(self):
+ self._test_rpc_method('cast_to_server', server_params={'blah': 1})
+
+ def test_fanout_cast_to_server(self):
+ self._test_rpc_method('fanout_cast_to_server',
+ server_params={'blah': 1}, supports_topic_override=False)
+
+ def test_make_msg(self):
+ self.assertEqual(proxy.RpcProxy.make_msg('test_method', a=1, b=2),
+ {'method': 'test_method', 'args': {'a': 1, 'b': 2}})
diff --git a/tests/unit/rpc/test_qpid.py b/tests/unit/rpc/test_qpid.py
new file mode 100644
index 0000000..b753c22
--- /dev/null
+++ b/tests/unit/rpc/test_qpid.py
@@ -0,0 +1,377 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+# Copyright 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls using qpid
+"""
+
+import eventlet
+eventlet.monkey_patch()
+
+import logging
+import mox
+import stubout
+import unittest
+
+from openstack.common import cfg
+from openstack.common import context
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common import testutils
+
+try:
+ from openstack.common.rpc import impl_qpid
+ import qpid
+except ImportError:
+ qpid = None
+ impl_qpid = None
+
+
+FLAGS = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+class RpcQpidTestCase(unittest.TestCase):
+ """
+ Exercise the public API of impl_qpid utilizing mox.
+
+ This set of tests utilizes mox to replace the Qpid objects and ensures
+ that the right operations happen on them when the various public rpc API
+ calls are exercised. The API calls tested here include:
+
+ nova.rpc.create_connection()
+ nova.rpc.common.Connection.create_consumer()
+ nova.rpc.common.Connection.close()
+ nova.rpc.cast()
+ nova.rpc.fanout_cast()
+ nova.rpc.call()
+ nova.rpc.multicall()
+ """
+
+ def setUp(self):
+ super(RpcQpidTestCase, self).setUp()
+
+ self.stubs = stubout.StubOutForTesting()
+ self.mox = mox.Mox()
+
+ self.mock_connection = None
+ self.mock_session = None
+ self.mock_sender = None
+ self.mock_receiver = None
+
+ if qpid:
+ self.orig_connection = qpid.messaging.Connection
+ self.orig_session = qpid.messaging.Session
+ self.orig_sender = qpid.messaging.Sender
+ self.orig_receiver = qpid.messaging.Receiver
+ qpid.messaging.Connection = lambda *_x, **_y: self.mock_connection
+ qpid.messaging.Session = lambda *_x, **_y: self.mock_session
+ qpid.messaging.Sender = lambda *_x, **_y: self.mock_sender
+ qpid.messaging.Receiver = lambda *_x, **_y: self.mock_receiver
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.stubs.UnsetAll()
+ self.stubs.SmartUnsetAll()
+ if qpid:
+ qpid.messaging.Connection = self.orig_connection
+ qpid.messaging.Session = self.orig_session
+ qpid.messaging.Sender = self.orig_sender
+ qpid.messaging.Receiver = self.orig_receiver
+ if impl_qpid:
+ # Need to reset this in case we changed the connection_cls
+ # in self._setup_to_server_tests()
+ impl_qpid.Connection.pool.connection_cls = impl_qpid.Connection
+
+ super(RpcQpidTestCase, self).tearDown()
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_create_connection(self):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+ self.mock_connection.session().AndReturn(self.mock_session)
+ self.mock_connection.close()
+
+ self.mox.ReplayAll()
+
+ connection = impl_qpid.create_connection(FLAGS)
+ connection.close()
+
+ def _test_create_consumer(self, fanout):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+ self.mock_receiver = self.mox.CreateMock(self.orig_receiver)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+ self.mock_connection.session().AndReturn(self.mock_session)
+ if fanout:
+ # The link name includes a UUID, so match it with a regex.
+ expected_address = mox.Regex(r'^impl_qpid_test_fanout ; '
+ '{"node": {"x-declare": {"auto-delete": true, "durable": '
+ 'false, "type": "fanout"}, "type": "topic"}, "create": '
+ '"always", "link": {"x-declare": {"auto-delete": true, '
+ '"exclusive": true, "durable": false}, "durable": true, '
+ '"name": "impl_qpid_test_fanout_.*"}}$')
+ else:
+ expected_address = ('nova/impl_qpid_test ; {"node": {"x-declare": '
+ '{"auto-delete": true, "durable": true}, "type": "topic"}, '
+ '"create": "always", "link": {"x-declare": {"auto-delete": '
+ 'true, "exclusive": false, "durable": false}, "durable": '
+ 'true, "name": "impl_qpid_test"}}')
+ self.mock_session.receiver(expected_address).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.capacity = 1
+ self.mock_connection.close()
+
+ self.mox.ReplayAll()
+
+ connection = impl_qpid.create_connection(FLAGS)
+ connection.create_consumer("impl_qpid_test",
+ lambda *_x, **_y: None,
+ fanout)
+ connection.close()
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_create_consumer(self):
+ self._test_create_consumer(fanout=False)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_create_consumer_fanout(self):
+ self._test_create_consumer(fanout=True)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_create_worker(self):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+ self.mock_receiver = self.mox.CreateMock(self.orig_receiver)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+ self.mock_connection.session().AndReturn(self.mock_session)
+ expected_address = (
+ 'nova/impl_qpid_test ; {"node": {"x-declare": '
+ '{"auto-delete": true, "durable": true}, "type": "topic"}, '
+ '"create": "always", "link": {"x-declare": {"auto-delete": '
+ 'true, "exclusive": false, "durable": false}, "durable": '
+ 'true, "name": "impl.qpid.test.workers"}}')
+ self.mock_session.receiver(expected_address).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.capacity = 1
+ self.mock_connection.close()
+
+ self.mox.ReplayAll()
+
+ connection = impl_qpid.create_connection(FLAGS)
+ connection.create_worker("impl_qpid_test",
+ lambda *_x, **_y: None,
+ 'impl.qpid.test.workers',
+ )
+ connection.close()
+
+ def _test_cast(self, fanout, server_params=None):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+ self.mock_sender = self.mox.CreateMock(self.orig_sender)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+
+ self.mock_connection.session().AndReturn(self.mock_session)
+ if fanout:
+ expected_address = ('impl_qpid_test_fanout ; '
+ '{"node": {"x-declare": {"auto-delete": true, '
+ '"durable": false, "type": "fanout"}, '
+ '"type": "topic"}, "create": "always"}')
+ else:
+ expected_address = ('nova/impl_qpid_test ; {"node": {"x-declare": '
+ '{"auto-delete": true, "durable": false}, "type": "topic"}, '
+ '"create": "always"}')
+ self.mock_session.sender(expected_address).AndReturn(self.mock_sender)
+ self.mock_sender.send(mox.IgnoreArg())
+ if not server_params:
+ # This is a pooled connection, so instead of closing it, it
+ # gets reset, which is just creating a new session on the
+ # connection.
+ self.mock_session.close()
+ self.mock_connection.session().AndReturn(self.mock_session)
+
+ self.mox.ReplayAll()
+
+ try:
+ ctx = context.RequestContext("user", "project")
+
+ args = [FLAGS, ctx, "impl_qpid_test",
+ {"method": "test_method", "args": {}}]
+
+ if server_params:
+ args.insert(2, server_params)
+ if fanout:
+ method = impl_qpid.fanout_cast_to_server
+ else:
+ method = impl_qpid.cast_to_server
+ else:
+ if fanout:
+ method = impl_qpid.fanout_cast
+ else:
+ method = impl_qpid.cast
+
+ method(*args)
+ finally:
+ while impl_qpid.Connection.pool.free_items:
+ # Pull the mock connection object out of the connection pool so
+ # that it doesn't mess up other test cases.
+ impl_qpid.Connection.pool.get()
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_cast(self):
+ self._test_cast(fanout=False)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_fanout_cast(self):
+ self._test_cast(fanout=True)
+
+ def _setup_to_server_tests(self, server_params):
+ class MyConnection(impl_qpid.Connection):
+ def __init__(myself, *args, **kwargs):
+ super(MyConnection, myself).__init__(*args, **kwargs)
+ self.assertEqual(myself.connection.username,
+ server_params['username'])
+ self.assertEqual(myself.connection.password,
+ server_params['password'])
+ self.assertEqual(myself.broker,
+ server_params['hostname'] + ':' +
+ str(server_params['port']))
+
+ MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
+ self.stubs.Set(impl_qpid, 'Connection', MyConnection)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_cast_to_server(self):
+ server_params = {'username': 'fake_username',
+ 'password': 'fake_password',
+ 'hostname': 'fake_hostname',
+ 'port': 31337}
+ self._setup_to_server_tests(server_params)
+ self._test_cast(fanout=False, server_params=server_params)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_fanout_cast_to_server(self):
+ server_params = {'username': 'fake_username',
+ 'password': 'fake_password',
+ 'hostname': 'fake_hostname',
+ 'port': 31337}
+ self._setup_to_server_tests(server_params)
+ self._test_cast(fanout=True, server_params=server_params)
+
+ def _test_call(self, multi):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+ self.mock_sender = self.mox.CreateMock(self.orig_sender)
+ self.mock_receiver = self.mox.CreateMock(self.orig_receiver)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+ self.mock_connection.session().AndReturn(self.mock_session)
+ rcv_addr = mox.Regex(r'^.*/.* ; {"node": {"x-declare": {"auto-delete":'
+ ' true, "durable": true, "type": "direct"}, "type": '
+ '"topic"}, "create": "always", "link": {"x-declare": '
+ '{"auto-delete": true, "exclusive": true, "durable": '
+ 'false}, "durable": true, "name": ".*"}}')
+ self.mock_session.receiver(rcv_addr).AndReturn(self.mock_receiver)
+ self.mock_receiver.capacity = 1
+ send_addr = ('nova/impl_qpid_test ; {"node": {"x-declare": '
+ '{"auto-delete": true, "durable": false}, "type": "topic"}, '
+ '"create": "always"}')
+ self.mock_session.sender(send_addr).AndReturn(self.mock_sender)
+ self.mock_sender.send(mox.IgnoreArg())
+
+ self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.fetch().AndReturn(qpid.messaging.Message(
+ {"result": "foo", "failure": False, "ending": False}))
+ if multi:
+ self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.fetch().AndReturn(
+ qpid.messaging.Message(
+ {"result": "bar", "failure": False,
+ "ending": False}))
+ self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.fetch().AndReturn(
+ qpid.messaging.Message(
+ {"result": "baz", "failure": False,
+ "ending": False}))
+ self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.fetch().AndReturn(qpid.messaging.Message(
+ {"failure": False, "ending": True}))
+ self.mock_session.close()
+ self.mock_connection.session().AndReturn(self.mock_session)
+
+ self.mox.ReplayAll()
+
+ try:
+ ctx = context.RequestContext("user", "project")
+
+ if multi:
+ method = impl_qpid.multicall
+ else:
+ method = impl_qpid.call
+
+ res = method(FLAGS, ctx, "impl_qpid_test",
+ {"method": "test_method", "args": {}})
+
+ if multi:
+ self.assertEquals(list(res), ["foo", "bar", "baz"])
+ else:
+ self.assertEquals(res, "foo")
+ finally:
+ while impl_qpid.Connection.pool.free_items:
+ # Pull the mock connection object out of the connection pool so
+ # that it doesn't mess up other test cases.
+ impl_qpid.Connection.pool.get()
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_call(self):
+ self._test_call(multi=False)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_multicall(self):
+ self._test_call(multi=True)
+
+
+#
+#from nova.tests.rpc import common
+#
+# Qpid does not have a handy in-memory transport like kombu, so it's not
+# terribly straight forward to take advantage of the common unit tests.
+# However, at least at the time of this writing, the common unit tests all pass
+# with qpidd running.
+#
+# class RpcQpidCommonTestCase(common._BaseRpcTestCase):
+# def setUp(self):
+# self.rpc = impl_qpid
+# super(RpcQpidCommonTestCase, self).setUp()
+#
+# def tearDown(self):
+# super(RpcQpidCommonTestCase, self).tearDown()
+#
diff --git a/tools/pip-requires b/tools/pip-requires
index adb2400..3f2d088 100644
--- a/tools/pip-requires
+++ b/tools/pip-requires
@@ -9,3 +9,4 @@ routes==1.12.3
webtest
iso8601>=0.1.4
anyjson==0.2.4
+kombu==1.0.4