diff options
-rw-r--r-- | openstack/common/rpc/__init__.py | 248 | ||||
-rw-r--r-- | openstack/common/rpc/amqp.py | 416 | ||||
-rw-r--r-- | openstack/common/rpc/common.py | 316 | ||||
-rw-r--r-- | openstack/common/rpc/dispatcher.py | 105 | ||||
-rw-r--r-- | openstack/common/rpc/impl_fake.py | 184 | ||||
-rw-r--r-- | openstack/common/rpc/impl_kombu.py | 758 | ||||
-rw-r--r-- | openstack/common/rpc/impl_qpid.py | 580 | ||||
-rw-r--r-- | openstack/common/rpc/matchmaker.py | 257 | ||||
-rw-r--r-- | openstack/common/rpc/proxy.py | 161 | ||||
-rw-r--r-- | tests/unit/rpc/__init__.py | 15 | ||||
-rw-r--r-- | tests/unit/rpc/common.py | 322 | ||||
-rw-r--r-- | tests/unit/rpc/test_common.py | 150 | ||||
-rw-r--r-- | tests/unit/rpc/test_dispatcher.py | 110 | ||||
-rw-r--r-- | tests/unit/rpc/test_fake.py | 32 | ||||
-rw-r--r-- | tests/unit/rpc/test_kombu.py | 414 | ||||
-rw-r--r-- | tests/unit/rpc/test_kombu_ssl.py | 82 | ||||
-rw-r--r-- | tests/unit/rpc/test_matchmaker.py | 60 | ||||
-rw-r--r-- | tests/unit/rpc/test_proxy.py | 128 | ||||
-rw-r--r-- | tests/unit/rpc/test_qpid.py | 377 | ||||
-rw-r--r-- | tools/pip-requires | 1 |
20 files changed, 4716 insertions, 0 deletions
diff --git a/openstack/common/rpc/__init__.py b/openstack/common/rpc/__init__.py new file mode 100644 index 0000000..116aa84 --- /dev/null +++ b/openstack/common/rpc/__init__.py @@ -0,0 +1,248 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A remote procedure call (rpc) abstraction. + +For some wrappers that add message versioning to rpc, see: + rpc.dispatcher + rpc.proxy +""" + +from openstack.common import cfg +from openstack.common import importutils + + +rpc_opts = [ + cfg.StrOpt('rpc_backend', + default='nova.rpc.impl_kombu', + help="The messaging module to use, defaults to kombu."), + cfg.IntOpt('rpc_thread_pool_size', + default=64, + help='Size of RPC thread pool'), + cfg.IntOpt('rpc_conn_pool_size', + default=30, + help='Size of RPC connection pool'), + cfg.IntOpt('rpc_response_timeout', + default=60, + help='Seconds to wait for a response from call or multicall'), + cfg.ListOpt('allowed_rpc_exception_modules', + default=['openstack.common.exception', 'nova.exception'], + help='Modules of exceptions that are permitted to be recreated' + 'upon receiving exception data from an rpc call.'), + cfg.StrOpt('control_exchange', + default='nova', + help='AMQP exchange to connect to if using RabbitMQ or Qpid'), + cfg.BoolOpt('fake_rabbit', + default=False, + help='If passed, use a fake RabbitMQ provider'), + ] + +cfg.CONF.register_opts(rpc_opts) + + +def create_connection(new=True): + """Create a connection to the message bus used for rpc. + + For some example usage of creating a connection and some consumers on that + connection, see nova.service. + + :param new: Whether or not to create a new connection. A new connection + will be created by default. If new is False, the + implementation is free to return an existing connection from a + pool. + + :returns: An instance of nova.rpc.common.Connection + """ + return _get_impl().create_connection(cfg.CONF, new=new) + + +def call(context, topic, msg, timeout=None): + """Invoke a remote method that returns something. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + nova.rpc.common.Connection.create_consumer() and only applies + when the consumer was created with fanout=False. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + :param timeout: int, number of seconds to use for a response timeout. + If set, this overrides the rpc_response_timeout option. + + :returns: A dict from the remote method. + + :raises: nova.rpc.common.Timeout if a complete response is not received + before the timeout is reached. + """ + return _get_impl().call(cfg.CONF, context, topic, msg, timeout) + + +def cast(context, topic, msg): + """Invoke a remote method that does not return anything. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + nova.rpc.common.Connection.create_consumer() and only applies + when the consumer was created with fanout=False. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().cast(cfg.CONF, context, topic, msg) + + +def fanout_cast(context, topic, msg): + """Broadcast a remote method invocation with no return. + + This method will get invoked on all consumers that were set up with this + topic name and fanout=True. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + nova.rpc.common.Connection.create_consumer() and only applies + when the consumer was created with fanout=True. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().fanout_cast(cfg.CONF, context, topic, msg) + + +def multicall(context, topic, msg, timeout=None): + """Invoke a remote method and get back an iterator. + + In this case, the remote method will be returning multiple values in + separate messages, so the return values can be processed as the come in via + an iterator. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + nova.rpc.common.Connection.create_consumer() and only applies + when the consumer was created with fanout=False. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + :param timeout: int, number of seconds to use for a response timeout. + If set, this overrides the rpc_response_timeout option. + + :returns: An iterator. The iterator will yield a tuple (N, X) where N is + an index that starts at 0 and increases by one for each value + returned and X is the Nth value that was returned by the remote + method. + + :raises: nova.rpc.common.Timeout if a complete response is not received + before the timeout is reached. + """ + return _get_impl().multicall(cfg.CONF, context, topic, msg, timeout) + + +def notify(context, topic, msg): + """Send notification event. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the notification to. + :param msg: This is a dict of content of event. + + :returns: None + """ + return _get_impl().notify(cfg.CONF, context, topic, msg) + + +def cleanup(): + """Clean up resoruces in use by implementation. + + Clean up any resources that have been allocated by the RPC implementation. + This is typically open connections to a messaging service. This function + would get called before an application using this API exits to allow + connections to get torn down cleanly. + + :returns: None + """ + return _get_impl().cleanup() + + +def cast_to_server(context, server_params, topic, msg): + """Invoke a remote method that does not return anything. + + :param context: Information that identifies the user that has made this + request. + :param server_params: Connection information + :param topic: The topic to send the notification to. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().cast_to_server(cfg.CONF, context, server_params, topic, + msg) + + +def fanout_cast_to_server(context, server_params, topic, msg): + """Broadcast to a remote method invocation with no return. + + :param context: Information that identifies the user that has made this + request. + :param server_params: Connection information + :param topic: The topic to send the notification to. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().fanout_cast_to_server(cfg.CONF, context, server_params, + topic, msg) + + +def queue_get_for(context, topic, host): + """Get a queue name for a given topic + host. + + This function only works if this naming convention is followed on the + consumer side, as well. For example, in nova, every instance of the + nova-foo service calls create_consumer() for two topics: + + foo + foo.<host> + + Messages sent to the 'foo' topic are distributed to exactly one instance of + the nova-foo service. The services are chosen in a round-robin fashion. + Messages sent to the 'foo.<host>' topic are sent to the nova-foo service on + <host>. + """ + return '%s.%s' % (topic, host) + + +_RPCIMPL = None + + +def _get_impl(): + """Delay import of rpc_backend until configuration is loaded.""" + global _RPCIMPL + if _RPCIMPL is None: + _RPCIMPL = importutils.import_module(cfg.CONF.rpc_backend) + return _RPCIMPL diff --git a/openstack/common/rpc/amqp.py b/openstack/common/rpc/amqp.py new file mode 100644 index 0000000..a79a3aa --- /dev/null +++ b/openstack/common/rpc/amqp.py @@ -0,0 +1,416 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 - 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Shared code between AMQP based nova.rpc implementations. + +The code in this module is shared between the rpc implemenations based on AMQP. +Specifically, this includes impl_kombu and impl_qpid. impl_carrot also uses +AMQP, but is deprecated and predates this code. +""" + +import inspect +import logging +import sys +import uuid + +from eventlet import greenpool +from eventlet import pools +from eventlet import semaphore + +from openstack.common import excutils +from openstack.common import local +import openstack.common.rpc.common as rpc_common + + +LOG = logging.getLogger(__name__) + + +class Pool(pools.Pool): + """Class that implements a Pool of Connections.""" + def __init__(self, conf, connection_cls, *args, **kwargs): + self.connection_cls = connection_cls + self.conf = conf + kwargs.setdefault("max_size", self.conf.rpc_conn_pool_size) + kwargs.setdefault("order_as_stack", True) + super(Pool, self).__init__(*args, **kwargs) + + # TODO(comstud): Timeout connections not used in a while + def create(self): + LOG.debug('Pool creating new connection') + return self.connection_cls(self.conf) + + def empty(self): + while self.free_items: + self.get().close() + + +_pool_create_sem = semaphore.Semaphore() + + +def get_connection_pool(conf, connection_cls): + with _pool_create_sem: + # Make sure only one thread tries to create the connection pool. + if not connection_cls.pool: + connection_cls.pool = Pool(conf, connection_cls) + return connection_cls.pool + + +class ConnectionContext(rpc_common.Connection): + """The class that is actually returned to the caller of + create_connection(). This is essentially a wrapper around + Connection that supports 'with'. It can also return a new + Connection, or one from a pool. The function will also catch + when an instance of this class is to be deleted. With that + we can return Connections to the pool on exceptions and so + forth without making the caller be responsible for catching + them. If possible the function makes sure to return a + connection to the pool. + """ + + def __init__(self, conf, connection_pool, pooled=True, server_params=None): + """Create a new connection, or get one from the pool""" + self.connection = None + self.conf = conf + self.connection_pool = connection_pool + if pooled: + self.connection = connection_pool.get() + else: + self.connection = connection_pool.connection_cls(conf, + server_params=server_params) + self.pooled = pooled + + def __enter__(self): + """When with ConnectionContext() is used, return self""" + return self + + def _done(self): + """If the connection came from a pool, clean it up and put it back. + If it did not come from a pool, close it. + """ + if self.connection: + if self.pooled: + # Reset the connection so it's ready for the next caller + # to grab from the pool + self.connection.reset() + self.connection_pool.put(self.connection) + else: + try: + self.connection.close() + except Exception: + pass + self.connection = None + + def __exit__(self, exc_type, exc_value, tb): + """End of 'with' statement. We're done here.""" + self._done() + + def __del__(self): + """Caller is done with this connection. Make sure we cleaned up.""" + self._done() + + def close(self): + """Caller is done with this connection.""" + self._done() + + def create_consumer(self, topic, proxy, fanout=False): + self.connection.create_consumer(topic, proxy, fanout) + + def create_worker(self, topic, proxy, pool_name): + self.connection.create_worker(topic, proxy, pool_name) + + def consume_in_thread(self): + self.connection.consume_in_thread() + + def __getattr__(self, key): + """Proxy all other calls to the Connection instance""" + if self.connection: + return getattr(self.connection, key) + else: + raise rpc_common.InvalidRPCConnectionReuse() + + +def msg_reply(conf, msg_id, connection_pool, reply=None, failure=None, + ending=False): + """Sends a reply or an error on the channel signified by msg_id. + + Failure should be a sys.exc_info() tuple. + + """ + with ConnectionContext(conf, connection_pool) as conn: + if failure: + failure = rpc_common.serialize_remote_exception(failure) + + try: + msg = {'result': reply, 'failure': failure} + except TypeError: + msg = {'result': dict((k, repr(v)) + for k, v in reply.__dict__.iteritems()), + 'failure': failure} + if ending: + msg['ending'] = True + conn.direct_send(msg_id, msg) + + +class RpcContext(rpc_common.CommonRpcContext): + """Context that supports replying to a rpc.call""" + def __init__(self, **kwargs): + self.msg_id = kwargs.pop('msg_id', None) + self.conf = kwargs.pop('conf') + super(RpcContext, self).__init__(**kwargs) + + def deepcopy(self): + values = self.to_dict() + values['conf'] = self.conf + values['msg_id'] = self.msg_id + return self.__class__(**values) + + def reply(self, reply=None, failure=None, ending=False, + connection_pool=None): + if self.msg_id: + msg_reply(self.conf, self.msg_id, connection_pool, reply, failure, + ending) + if ending: + self.msg_id = None + + +def unpack_context(conf, msg): + """Unpack context from msg.""" + context_dict = {} + for key in list(msg.keys()): + # NOTE(vish): Some versions of python don't like unicode keys + # in kwargs. + key = str(key) + if key.startswith('_context_'): + value = msg.pop(key) + context_dict[key[9:]] = value + context_dict['msg_id'] = msg.pop('_msg_id', None) + context_dict['conf'] = conf + ctx = RpcContext.from_dict(context_dict) + rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict()) + return ctx + + +def pack_context(msg, context): + """Pack context into msg. + + Values for message keys need to be less than 255 chars, so we pull + context out into a bunch of separate keys. If we want to support + more arguments in rabbit messages, we may want to do the same + for args at some point. + + """ + context_d = dict([('_context_%s' % key, value) + for (key, value) in context.to_dict().iteritems()]) + msg.update(context_d) + + +class ProxyCallback(object): + """Calls methods on a proxy object based on method and args.""" + + def __init__(self, conf, proxy, connection_pool): + self.proxy = proxy + self.pool = greenpool.GreenPool(conf.rpc_thread_pool_size) + self.connection_pool = connection_pool + self.conf = conf + + def __call__(self, message_data): + """Consumer callback to call a method on a proxy object. + + Parses the message for validity and fires off a thread to call the + proxy object method. + + Message data should be a dictionary with two keys: + method: string representing the method to call + args: dictionary of arg: value + + Example: {'method': 'echo', 'args': {'value': 42}} + + """ + # It is important to clear the context here, because at this point + # the previous context is stored in local.store.context + if hasattr(local.store, 'context'): + del local.store.context + rpc_common._safe_log(LOG.debug, _('received %s'), message_data) + ctxt = unpack_context(self.conf, message_data) + method = message_data.get('method') + args = message_data.get('args', {}) + version = message_data.get('version', None) + if not method: + LOG.warn(_('no method for message: %s') % message_data) + ctxt.reply(_('No method for message: %s') % message_data, + connection_pool=self.connection_pool) + return + self.pool.spawn_n(self._process_data, ctxt, version, method, args) + + def _process_data(self, ctxt, version, method, args): + """Process a message in a new thread. + + If the proxy object we have has a dispatch method + (see rpc.dispatcher.RpcDispatcher), pass it the version, + method, and args and let it dispatch as appropriate. If not, use + the old behavior of magically calling the specified method on the + proxy we have here. + """ + ctxt.update_store() + try: + rval = self.proxy.dispatch(ctxt, version, method, **args) + # Check if the result was a generator + if inspect.isgenerator(rval): + for x in rval: + ctxt.reply(x, None, connection_pool=self.connection_pool) + else: + ctxt.reply(rval, None, connection_pool=self.connection_pool) + # This final None tells multicall that it is done. + ctxt.reply(ending=True, connection_pool=self.connection_pool) + except Exception as e: + LOG.exception('Exception during message handling') + ctxt.reply(None, sys.exc_info(), + connection_pool=self.connection_pool) + + +class MulticallWaiter(object): + def __init__(self, conf, connection, timeout): + self._connection = connection + self._iterator = connection.iterconsume( + timeout=timeout or conf.rpc_response_timeout) + self._result = None + self._done = False + self._got_ending = False + self._conf = conf + + def done(self): + if self._done: + return + self._done = True + self._iterator.close() + self._iterator = None + self._connection.close() + + def __call__(self, data): + """The consume() callback will call this. Store the result.""" + if data['failure']: + failure = data['failure'] + self._result = rpc_common.deserialize_remote_exception(self._conf, + failure) + + elif data.get('ending', False): + self._got_ending = True + else: + self._result = data['result'] + + def __iter__(self): + """Return a result until we get a 'None' response from consumer""" + if self._done: + raise StopIteration + while True: + try: + self._iterator.next() + except Exception: + with excutils.save_and_reraise_exception(): + self.done() + if self._got_ending: + self.done() + raise StopIteration + result = self._result + if isinstance(result, Exception): + self.done() + raise result + yield result + + +def create_connection(conf, new, connection_pool): + """Create a connection""" + return ConnectionContext(conf, connection_pool, pooled=not new) + + +def multicall(conf, context, topic, msg, timeout, connection_pool): + """Make a call that returns multiple times.""" + # Can't use 'with' for multicall, as it returns an iterator + # that will continue to use the connection. When it's done, + # connection.close() will get called which will put it back into + # the pool + LOG.debug(_('Making asynchronous call on %s ...'), topic) + msg_id = uuid.uuid4().hex + msg.update({'_msg_id': msg_id}) + LOG.debug(_('MSG_ID is %s') % (msg_id)) + pack_context(msg, context) + + conn = ConnectionContext(conf, connection_pool) + wait_msg = MulticallWaiter(conf, conn, timeout) + conn.declare_direct_consumer(msg_id, wait_msg) + conn.topic_send(topic, msg) + return wait_msg + + +def call(conf, context, topic, msg, timeout, connection_pool): + """Sends a message on a topic and wait for a response.""" + rv = multicall(conf, context, topic, msg, timeout, connection_pool) + # NOTE(vish): return the last result from the multicall + rv = list(rv) + if not rv: + return + return rv[-1] + + +def cast(conf, context, topic, msg, connection_pool): + """Sends a message on a topic without waiting for a response.""" + LOG.debug(_('Making asynchronous cast on %s...'), topic) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool) as conn: + conn.topic_send(topic, msg) + + +def fanout_cast(conf, context, topic, msg, connection_pool): + """Sends a message on a fanout exchange without waiting for a response.""" + LOG.debug(_('Making asynchronous fanout cast...')) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool) as conn: + conn.fanout_send(topic, msg) + + +def cast_to_server(conf, context, server_params, topic, msg, connection_pool): + """Sends a message on a topic to a specific server.""" + pack_context(msg, context) + with ConnectionContext(conf, connection_pool, pooled=False, + server_params=server_params) as conn: + conn.topic_send(topic, msg) + + +def fanout_cast_to_server(conf, context, server_params, topic, msg, + connection_pool): + """Sends a message on a fanout exchange to a specific server.""" + pack_context(msg, context) + with ConnectionContext(conf, connection_pool, pooled=False, + server_params=server_params) as conn: + conn.fanout_send(topic, msg) + + +def notify(conf, context, topic, msg, connection_pool): + """Sends a notification event on a topic.""" + event_type = msg.get('event_type') + LOG.debug(_('Sending %(event_type)s on %(topic)s'), locals()) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool) as conn: + conn.notify_send(topic, msg) + + +def cleanup(connection_pool): + if connection_pool: + connection_pool.empty() diff --git a/openstack/common/rpc/common.py b/openstack/common/rpc/common.py new file mode 100644 index 0000000..6acd72c --- /dev/null +++ b/openstack/common/rpc/common.py @@ -0,0 +1,316 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +import logging +import sys +import traceback + +from openstack.common import cfg +from openstack.common import importutils +from openstack.common import jsonutils +from openstack.common import local + + +LOG = logging.getLogger(__name__) + + +class RPCException(Exception): + message = _("An unknown RPC related exception occurred.") + + def __init__(self, message=None, **kwargs): + self.kwargs = kwargs + + if not message: + try: + message = self.message % kwargs + + except Exception as e: + # kwargs doesn't match a variable in the message + # log the issue and the kwargs + LOG.exception(_('Exception in string format operation')) + for name, value in kwargs.iteritems(): + LOG.error("%s: %s" % (name, value)) + # at least get the core message out if something happened + message = self.message + + super(RPCException, self).__init__(message) + + +class RemoteError(RPCException): + """Signifies that a remote class has raised an exception. + + Contains a string representation of the type of the original exception, + the value of the original exception, and the traceback. These are + sent to the parent as a joined string so printing the exception + contains all of the relevant info. + + """ + message = _("Remote error: %(exc_type)s %(value)s\n%(traceback)s.") + + def __init__(self, exc_type=None, value=None, traceback=None): + self.exc_type = exc_type + self.value = value + self.traceback = traceback + super(RemoteError, self).__init__(exc_type=exc_type, + value=value, + traceback=traceback) + + +class Timeout(RPCException): + """Signifies that a timeout has occurred. + + This exception is raised if the rpc_response_timeout is reached while + waiting for a response from the remote side. + """ + message = _("Timeout while waiting on RPC response.") + + +class InvalidRPCConnectionReuse(RPCException): + message = _("Invalid reuse of an RPC connection.") + + +class UnsupportedRpcVersion(RPCException): + message = _("Specified RPC version, %(version)s, not supported by " + "this endpoint.") + + +class Connection(object): + """A connection, returned by rpc.create_connection(). + + This class represents a connection to the message bus used for rpc. + An instance of this class should never be created by users of the rpc API. + Use rpc.create_connection() instead. + """ + def close(self): + """Close the connection. + + This method must be called when the connection will no longer be used. + It will ensure that any resources associated with the connection, such + as a network connection, and cleaned up. + """ + raise NotImplementedError() + + def create_consumer(self, conf, topic, proxy, fanout=False): + """Create a consumer on this connection. + + A consumer is associated with a message queue on the backend message + bus. The consumer will read messages from the queue, unpack them, and + dispatch them to the proxy object. The contents of the message pulled + off of the queue will determine which method gets called on the proxy + object. + + :param conf: An openstack.common.cfg configuration object. + :param topic: This is a name associated with what to consume from. + Multiple instances of a service may consume from the same + topic. For example, all instances of nova-compute consume + from a queue called "compute". In that case, the + messages will get distributed amongst the consumers in a + round-robin fashion if fanout=False. If fanout=True, + every consumer associated with this topic will get a + copy of every message. + :param proxy: The object that will handle all incoming messages. + :param fanout: Whether or not this is a fanout topic. See the + documentation for the topic parameter for some + additional comments on this. + """ + raise NotImplementedError() + + def create_worker(self, conf, topic, proxy, pool_name): + """Create a worker on this connection. + + A worker is like a regular consumer of messages directed to a + topic, except that it is part of a set of such consumers (the + "pool") which may run in parallel. Every pool of workers will + receive a given message, but only one worker in the pool will + be asked to process it. Load is distributed across the members + of the pool in round-robin fashion. + + :param conf: An openstack.common.cfg configuration object. + :param topic: This is a name associated with what to consume from. + Multiple instances of a service may consume from the same + topic. + :param proxy: The object that will handle all incoming messages. + :param pool_name: String containing the name of the pool of workers + """ + raise NotImplementedError() + + def consume_in_thread(self): + """Spawn a thread to handle incoming messages. + + Spawn a thread that will be responsible for handling all incoming + messages for consumers that were set up on this connection. + + Message dispatching inside of this is expected to be implemented in a + non-blocking manner. An example implementation would be having this + thread pull messages in for all of the consumers, but utilize a thread + pool for dispatching the messages to the proxy objects. + """ + raise NotImplementedError() + + +def _safe_log(log_func, msg, msg_data): + """Sanitizes the msg_data field before logging.""" + SANITIZE = { + 'set_admin_password': ('new_pass',), + 'run_instance': ('admin_password',), + } + + has_method = 'method' in msg_data and msg_data['method'] in SANITIZE + has_context_token = '_context_auth_token' in msg_data + has_token = 'auth_token' in msg_data + + if not any([has_method, has_context_token, has_token]): + return log_func(msg, msg_data) + + msg_data = copy.deepcopy(msg_data) + + if has_method: + method = msg_data['method'] + if method in SANITIZE: + args_to_sanitize = SANITIZE[method] + for arg in args_to_sanitize: + try: + msg_data['args'][arg] = "<SANITIZED>" + except KeyError: + pass + + if has_context_token: + msg_data['_context_auth_token'] = '<SANITIZED>' + + if has_token: + msg_data['auth_token'] = '<SANITIZED>' + + return log_func(msg, msg_data) + + +def serialize_remote_exception(failure_info): + """Prepares exception data to be sent over rpc. + + Failure_info should be a sys.exc_info() tuple. + + """ + tb = traceback.format_exception(*failure_info) + failure = failure_info[1] + LOG.error(_("Returning exception %s to caller"), unicode(failure)) + LOG.error(tb) + + kwargs = {} + if hasattr(failure, 'kwargs'): + kwargs = failure.kwargs + + data = { + 'class': str(failure.__class__.__name__), + 'module': str(failure.__class__.__module__), + 'message': unicode(failure), + 'tb': tb, + 'args': failure.args, + 'kwargs': kwargs + } + + json_data = jsonutils.dumps(data) + + return json_data + + +def deserialize_remote_exception(conf, data): + failure = jsonutils.loads(str(data)) + + trace = failure.get('tb', []) + message = failure.get('message', "") + "\n" + "\n".join(trace) + name = failure.get('class') + module = failure.get('module') + + # NOTE(ameade): We DO NOT want to allow just any module to be imported, in + # order to prevent arbitrary code execution. + if not module in conf.allowed_rpc_exception_modules: + return RemoteError(name, failure.get('message'), trace) + + try: + mod = importutils.import_module(module) + klass = getattr(mod, name) + if not issubclass(klass, Exception): + raise TypeError("Can only deserialize Exceptions") + + failure = klass(**failure.get('kwargs', {})) + except (AttributeError, TypeError, ImportError): + return RemoteError(name, failure.get('message'), trace) + + ex_type = type(failure) + str_override = lambda self: message + new_ex_type = type(ex_type.__name__ + "_Remote", (ex_type,), + {'__str__': str_override, '__unicode__': str_override}) + try: + # NOTE(ameade): Dynamically create a new exception type and swap it in + # as the new type for the exception. This only works on user defined + # Exceptions and not core python exceptions. This is important because + # we cannot necessarily change an exception message so we must override + # the __str__ method. + failure.__class__ = new_ex_type + except TypeError as e: + # NOTE(ameade): If a core exception then just add the traceback to the + # first exception argument. + failure.args = (message,) + failure.args[1:] + return failure + + +class CommonRpcContext(object): + def __init__(self, **kwargs): + self.values = kwargs + + def __getattr__(self, key): + try: + return self.values[key] + except KeyError: + raise AttributeError(key) + + def to_dict(self): + return copy.deepcopy(self.values) + + @classmethod + def from_dict(cls, values): + return cls(**values) + + def deepcopy(self): + return self.from_dict(self.to_dict()) + + def update_store(self): + local.store.context = self + + def elevated(self, read_deleted=None, overwrite=False): + """Return a version of this context with admin flag set.""" + # TODO(russellb) This method is a bit of a nova-ism. It makes + # some assumptions about the data in the request context sent + # across rpc, while the rest of this class does not. We could get + # rid of this if we changed the nova code that uses this to + # convert the RpcContext back to its native RequestContext doing + # something like nova.context.RequestContext.from_dict(ctxt.to_dict()) + + context = self.deepcopy() + context.values['is_admin'] = True + + context.values.setdefault('roles', []) + + if 'admin' not in context.values['roles']: + context.values['roles'].append('admin') + + if read_deleted is not None: + context.values['read_deleted'] = read_deleted + + return context diff --git a/openstack/common/rpc/dispatcher.py b/openstack/common/rpc/dispatcher.py new file mode 100644 index 0000000..7319eb2 --- /dev/null +++ b/openstack/common/rpc/dispatcher.py @@ -0,0 +1,105 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Code for rpc message dispatching. + +Messages that come in have a version number associated with them. RPC API +version numbers are in the form: + + Major.Minor + +For a given message with version X.Y, the receiver must be marked as able to +handle messages of version A.B, where: + + A = X + + B >= Y + +The Major version number would be incremented for an almost completely new API. +The Minor version number would be incremented for backwards compatible changes +to an existing API. A backwards compatible change could be something like +adding a new method, adding an argument to an existing method (but not +requiring it), or changing the type for an existing argument (but still +handling the old type as well). + +The conversion over to a versioned API must be done on both the client side and +server side of the API at the same time. However, as the code stands today, +there can be both versioned and unversioned APIs implemented in the same code +base. +""" + +from openstack.common.rpc import common as rpc_common + + +class RpcDispatcher(object): + """Dispatch rpc messages according to the requested API version. + + This class can be used as the top level 'manager' for a service. It + contains a list of underlying managers that have an API_VERSION attribute. + """ + + def __init__(self, callbacks): + """Initialize the rpc dispatcher. + + :param callbacks: List of proxy objects that are an instance + of a class with rpc methods exposed. Each proxy + object should have an RPC_API_VERSION attribute. + """ + self.callbacks = callbacks + super(RpcDispatcher, self).__init__() + + @staticmethod + def _is_compatible(mversion, version): + """Determine whether versions are compatible. + + :param mversion: The API version implemented by a callback. + :param version: The API version requested by an incoming message. + """ + version_parts = version.split('.') + mversion_parts = mversion.split('.') + if int(version_parts[0]) != int(mversion_parts[0]): # Major + return False + if int(version_parts[1]) > int(mversion_parts[1]): # Minor + return False + return True + + def dispatch(self, ctxt, version, method, **kwargs): + """Dispatch a message based on a requested version. + + :param ctxt: The request context + :param version: The requested API version from the incoming message + :param method: The method requested to be called by the incoming + message. + :param kwargs: A dict of keyword arguments to be passed to the method. + + :returns: Whatever is returned by the underlying method that gets + called. + """ + if not version: + version = '1.0' + + for proxyobj in self.callbacks: + if hasattr(proxyobj, 'RPC_API_VERSION'): + rpc_api_version = proxyobj.RPC_API_VERSION + else: + rpc_api_version = '1.0' + if not hasattr(proxyobj, method): + continue + if self._is_compatible(rpc_api_version, version): + return getattr(proxyobj, method)(ctxt, **kwargs) + + raise rpc_common.UnsupportedRpcVersion(version=version) diff --git a/openstack/common/rpc/impl_fake.py b/openstack/common/rpc/impl_fake.py new file mode 100644 index 0000000..fba20c9 --- /dev/null +++ b/openstack/common/rpc/impl_fake.py @@ -0,0 +1,184 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""Fake RPC implementation which calls proxy methods directly with no +queues. Casts will block, but this is very useful for tests. +""" + +import inspect +import json +import time + +import eventlet + +from openstack.common.rpc import common as rpc_common + +CONSUMERS = {} + + +class RpcContext(rpc_common.CommonRpcContext): + def __init__(self, **kwargs): + super(RpcContext, self).__init__(**kwargs) + self._response = [] + self._done = False + + def deepcopy(self): + values = self.to_dict() + new_inst = self.__class__(**values) + new_inst._response = self._response + new_inst._done = self._done + return new_inst + + def reply(self, reply=None, failure=None, ending=False): + if ending: + self._done = True + if not self._done: + self._response.append((reply, failure)) + + +class Consumer(object): + def __init__(self, topic, proxy): + self.topic = topic + self.proxy = proxy + + def call(self, context, version, method, args, timeout): + done = eventlet.event.Event() + + def _inner(): + ctxt = RpcContext.from_dict(context.to_dict()) + try: + rval = self.proxy.dispatch(context, version, method, **args) + res = [] + # Caller might have called ctxt.reply() manually + for (reply, failure) in ctxt._response: + if failure: + raise failure[0], failure[1], failure[2] + res.append(reply) + # if ending not 'sent'...we might have more data to + # return from the function itself + if not ctxt._done: + if inspect.isgenerator(rval): + for val in rval: + res.append(val) + else: + res.append(rval) + done.send(res) + except Exception as e: + done.send_exception(e) + + thread = eventlet.greenthread.spawn(_inner) + + if timeout: + start_time = time.time() + while not done.ready(): + eventlet.greenthread.sleep(1) + cur_time = time.time() + if (cur_time - start_time) > timeout: + thread.kill() + raise rpc_common.Timeout() + + return done.wait() + + +class Connection(object): + """Connection object.""" + + def __init__(self): + self.consumers = [] + + def create_consumer(self, topic, proxy, fanout=False): + consumer = Consumer(topic, proxy) + self.consumers.append(consumer) + if topic not in CONSUMERS: + CONSUMERS[topic] = [] + CONSUMERS[topic].append(consumer) + + def close(self): + for consumer in self.consumers: + CONSUMERS[consumer.topic].remove(consumer) + self.consumers = [] + + def consume_in_thread(self): + pass + + +def create_connection(conf, new=True): + """Create a connection""" + return Connection() + + +def check_serialize(msg): + """Make sure a message intended for rpc can be serialized.""" + json.dumps(msg) + + +def multicall(conf, context, topic, msg, timeout=None): + """Make a call that returns multiple times.""" + + check_serialize(msg) + + method = msg.get('method') + if not method: + return + args = msg.get('args', {}) + version = msg.get('version', None) + + try: + consumer = CONSUMERS[topic][0] + except (KeyError, IndexError): + return iter([None]) + else: + return consumer.call(context, version, method, args, timeout) + + +def call(conf, context, topic, msg, timeout=None): + """Sends a message on a topic and wait for a response.""" + rv = multicall(conf, context, topic, msg, timeout) + # NOTE(vish): return the last result from the multicall + rv = list(rv) + if not rv: + return + return rv[-1] + + +def cast(conf, context, topic, msg): + try: + call(conf, context, topic, msg) + except Exception: + pass + + +def notify(conf, context, topic, msg): + check_serialize(msg) + + +def cleanup(): + pass + + +def fanout_cast(conf, context, topic, msg): + """Cast to all consumers of a topic""" + check_serialize(msg) + method = msg.get('method') + if not method: + return + args = msg.get('args', {}) + version = msg.get('version', None) + + for consumer in CONSUMERS.get(topic, []): + try: + consumer.call(context, version, method, args, None) + except Exception: + pass diff --git a/openstack/common/rpc/impl_kombu.py b/openstack/common/rpc/impl_kombu.py new file mode 100644 index 0000000..c32497a --- /dev/null +++ b/openstack/common/rpc/impl_kombu.py @@ -0,0 +1,758 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import itertools +import socket +import ssl +import sys +import time +import uuid + +import eventlet +import greenlet +import kombu +import kombu.connection +import kombu.entity +import kombu.messaging + +from openstack.common import cfg +from openstack.common.rpc import amqp as rpc_amqp +from openstack.common.rpc import common as rpc_common + +kombu_opts = [ + cfg.StrOpt('kombu_ssl_version', + default='', + help='SSL version to use (valid only if SSL enabled)'), + cfg.StrOpt('kombu_ssl_keyfile', + default='', + help='SSL key file (valid only if SSL enabled)'), + cfg.StrOpt('kombu_ssl_certfile', + default='', + help='SSL cert file (valid only if SSL enabled)'), + cfg.StrOpt('kombu_ssl_ca_certs', + default='', + help=('SSL certification authority file ' + '(valid only if SSL enabled)')), + cfg.StrOpt('rabbit_host', + default='localhost', + help='the RabbitMQ host'), + cfg.IntOpt('rabbit_port', + default=5672, + help='the RabbitMQ port'), + cfg.BoolOpt('rabbit_use_ssl', + default=False, + help='connect over SSL for RabbitMQ'), + cfg.StrOpt('rabbit_userid', + default='guest', + help='the RabbitMQ userid'), + cfg.StrOpt('rabbit_password', + default='guest', + help='the RabbitMQ password'), + cfg.StrOpt('rabbit_virtual_host', + default='/', + help='the RabbitMQ virtual host'), + cfg.IntOpt('rabbit_retry_interval', + default=1, + help='how frequently to retry connecting with RabbitMQ'), + cfg.IntOpt('rabbit_retry_backoff', + default=2, + help='how long to backoff for between retries when connecting ' + 'to RabbitMQ'), + cfg.IntOpt('rabbit_max_retries', + default=0, + help='maximum retries with trying to connect to RabbitMQ ' + '(the default of 0 implies an infinite retry count)'), + cfg.BoolOpt('rabbit_durable_queues', + default=False, + help='use durable queues in RabbitMQ'), + + ] + +cfg.CONF.register_opts(kombu_opts) + +LOG = rpc_common.LOG + + +class ConsumerBase(object): + """Consumer base class.""" + + def __init__(self, channel, callback, tag, **kwargs): + """Declare a queue on an amqp channel. + + 'channel' is the amqp channel to use + 'callback' is the callback to call when messages are received + 'tag' is a unique ID for the consumer on the channel + + queue name, exchange name, and other kombu options are + passed in here as a dictionary. + """ + self.callback = callback + self.tag = str(tag) + self.kwargs = kwargs + self.queue = None + self.reconnect(channel) + + def reconnect(self, channel): + """Re-declare the queue after a rabbit reconnect""" + self.channel = channel + self.kwargs['channel'] = channel + self.queue = kombu.entity.Queue(**self.kwargs) + self.queue.declare() + + def consume(self, *args, **kwargs): + """Actually declare the consumer on the amqp channel. This will + start the flow of messages from the queue. Using the + Connection.iterconsume() iterator will process the messages, + calling the appropriate callback. + + If a callback is specified in kwargs, use that. Otherwise, + use the callback passed during __init__() + + If kwargs['nowait'] is True, then this call will block until + a message is read. + + Messages will automatically be acked if the callback doesn't + raise an exception + """ + + options = {'consumer_tag': self.tag} + options['nowait'] = kwargs.get('nowait', False) + callback = kwargs.get('callback', self.callback) + if not callback: + raise ValueError("No callback defined") + + def _callback(raw_message): + message = self.channel.message_to_python(raw_message) + try: + callback(message.payload) + message.ack() + except Exception: + LOG.exception(_("Failed to process message... skipping it.")) + + self.queue.consume(*args, callback=_callback, **options) + + def cancel(self): + """Cancel the consuming from the queue, if it has started""" + try: + self.queue.cancel(self.tag) + except KeyError, e: + # NOTE(comstud): Kludge to get around a amqplib bug + if str(e) != "u'%s'" % self.tag: + raise + self.queue = None + + +class DirectConsumer(ConsumerBase): + """Queue/consumer class for 'direct'""" + + def __init__(self, conf, channel, msg_id, callback, tag, **kwargs): + """Init a 'direct' queue. + + 'channel' is the amqp channel to use + 'msg_id' is the msg_id to listen on + 'callback' is the callback to call when messages are received + 'tag' is a unique ID for the consumer on the channel + + Other kombu options may be passed + """ + # Default options + options = {'durable': False, + 'auto_delete': True, + 'exclusive': True} + options.update(kwargs) + exchange = kombu.entity.Exchange( + name=msg_id, + type='direct', + durable=options['durable'], + auto_delete=options['auto_delete']) + super(DirectConsumer, self).__init__( + channel, + callback, + tag, + name=msg_id, + exchange=exchange, + routing_key=msg_id, + **options) + + +class TopicConsumer(ConsumerBase): + """Consumer class for 'topic'""" + + def __init__(self, conf, channel, topic, callback, tag, name=None, + **kwargs): + """Init a 'topic' queue. + + :param channel: the amqp channel to use + :param topic: the topic to listen on + :paramtype topic: str + :param callback: the callback to call when messages are received + :param tag: a unique ID for the consumer on the channel + :param name: optional queue name, defaults to topic + :paramtype name: str + + Other kombu options may be passed as keyword arguments + """ + # Default options + options = {'durable': conf.rabbit_durable_queues, + 'auto_delete': False, + 'exclusive': False} + options.update(kwargs) + exchange = kombu.entity.Exchange( + name=conf.control_exchange, + type='topic', + durable=options['durable'], + auto_delete=options['auto_delete']) + super(TopicConsumer, self).__init__( + channel, + callback, + tag, + name=name or topic, + exchange=exchange, + routing_key=topic, + **options) + + +class FanoutConsumer(ConsumerBase): + """Consumer class for 'fanout'""" + + def __init__(self, conf, channel, topic, callback, tag, **kwargs): + """Init a 'fanout' queue. + + 'channel' is the amqp channel to use + 'topic' is the topic to listen on + 'callback' is the callback to call when messages are received + 'tag' is a unique ID for the consumer on the channel + + Other kombu options may be passed + """ + unique = uuid.uuid4().hex + exchange_name = '%s_fanout' % topic + queue_name = '%s_fanout_%s' % (topic, unique) + + # Default options + options = {'durable': False, + 'auto_delete': True, + 'exclusive': True} + options.update(kwargs) + exchange = kombu.entity.Exchange( + name=exchange_name, + type='fanout', + durable=options['durable'], + auto_delete=options['auto_delete']) + super(FanoutConsumer, self).__init__( + channel, + callback, + tag, + name=queue_name, + exchange=exchange, + routing_key=topic, + **options) + + +class Publisher(object): + """Base Publisher class""" + + def __init__(self, channel, exchange_name, routing_key, **kwargs): + """Init the Publisher class with the exchange_name, routing_key, + and other options + """ + self.exchange_name = exchange_name + self.routing_key = routing_key + self.kwargs = kwargs + self.reconnect(channel) + + def reconnect(self, channel): + """Re-establish the Producer after a rabbit reconnection""" + self.exchange = kombu.entity.Exchange(name=self.exchange_name, + **self.kwargs) + self.producer = kombu.messaging.Producer(exchange=self.exchange, + channel=channel, routing_key=self.routing_key) + + def send(self, msg): + """Send a message""" + self.producer.publish(msg) + + +class DirectPublisher(Publisher): + """Publisher class for 'direct'""" + def __init__(self, conf, channel, msg_id, **kwargs): + """init a 'direct' publisher. + + Kombu options may be passed as keyword args to override defaults + """ + + options = {'durable': False, + 'auto_delete': True, + 'exclusive': True} + options.update(kwargs) + super(DirectPublisher, self).__init__(channel, + msg_id, + msg_id, + type='direct', + **options) + + +class TopicPublisher(Publisher): + """Publisher class for 'topic'""" + def __init__(self, conf, channel, topic, **kwargs): + """init a 'topic' publisher. + + Kombu options may be passed as keyword args to override defaults + """ + options = {'durable': conf.rabbit_durable_queues, + 'auto_delete': False, + 'exclusive': False} + options.update(kwargs) + super(TopicPublisher, self).__init__(channel, + conf.control_exchange, + topic, + type='topic', + **options) + + +class FanoutPublisher(Publisher): + """Publisher class for 'fanout'""" + def __init__(self, conf, channel, topic, **kwargs): + """init a 'fanout' publisher. + + Kombu options may be passed as keyword args to override defaults + """ + options = {'durable': False, + 'auto_delete': True, + 'exclusive': True} + options.update(kwargs) + super(FanoutPublisher, self).__init__(channel, + '%s_fanout' % topic, + None, + type='fanout', + **options) + + +class NotifyPublisher(TopicPublisher): + """Publisher class for 'notify'""" + + def __init__(self, conf, channel, topic, **kwargs): + self.durable = kwargs.pop('durable', conf.rabbit_durable_queues) + super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs) + + def reconnect(self, channel): + super(NotifyPublisher, self).reconnect(channel) + + # NOTE(jerdfelt): Normally the consumer would create the queue, but + # we do this to ensure that messages don't get dropped if the + # consumer is started after we do + queue = kombu.entity.Queue(channel=channel, + exchange=self.exchange, + durable=self.durable, + name=self.routing_key, + routing_key=self.routing_key) + queue.declare() + + +class Connection(object): + """Connection object.""" + + pool = None + + def __init__(self, conf, server_params=None): + self.consumers = [] + self.consumer_thread = None + self.conf = conf + self.max_retries = self.conf.rabbit_max_retries + # Try forever? + if self.max_retries <= 0: + self.max_retries = None + self.interval_start = self.conf.rabbit_retry_interval + self.interval_stepping = self.conf.rabbit_retry_backoff + # max retry-interval = 30 seconds + self.interval_max = 30 + self.memory_transport = False + + if server_params is None: + server_params = {} + + # Keys to translate from server_params to kombu params + server_params_to_kombu_params = {'username': 'userid'} + + params = {} + for sp_key, value in server_params.iteritems(): + p_key = server_params_to_kombu_params.get(sp_key, sp_key) + params[p_key] = value + + params.setdefault('hostname', self.conf.rabbit_host) + params.setdefault('port', self.conf.rabbit_port) + params.setdefault('userid', self.conf.rabbit_userid) + params.setdefault('password', self.conf.rabbit_password) + params.setdefault('virtual_host', self.conf.rabbit_virtual_host) + + self.params = params + + if self.conf.fake_rabbit: + self.params['transport'] = 'memory' + self.memory_transport = True + else: + self.memory_transport = False + + if self.conf.rabbit_use_ssl: + self.params['ssl'] = self._fetch_ssl_params() + + self.connection = None + self.reconnect() + + def _fetch_ssl_params(self): + """Handles fetching what ssl params + should be used for the connection (if any)""" + ssl_params = dict() + + # http://docs.python.org/library/ssl.html - ssl.wrap_socket + if self.conf.kombu_ssl_version: + ssl_params['ssl_version'] = self.conf.kombu_ssl_version + if self.conf.kombu_ssl_keyfile: + ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile + if self.conf.kombu_ssl_certfile: + ssl_params['certfile'] = self.conf.kombu_ssl_certfile + if self.conf.kombu_ssl_ca_certs: + ssl_params['ca_certs'] = self.conf.kombu_ssl_ca_certs + # We might want to allow variations in the + # future with this? + ssl_params['cert_reqs'] = ssl.CERT_REQUIRED + + if not ssl_params: + # Just have the default behavior + return True + else: + # Return the extended behavior + return ssl_params + + def _connect(self): + """Connect to rabbit. Re-establish any queues that may have + been declared before if we are reconnecting. Exceptions should + be handled by the caller. + """ + if self.connection: + LOG.info(_("Reconnecting to AMQP server on " + "%(hostname)s:%(port)d") % self.params) + try: + self.connection.close() + except self.connection_errors: + pass + # Setting this in case the next statement fails, though + # it shouldn't be doing any network operations, yet. + self.connection = None + self.connection = kombu.connection.BrokerConnection( + **self.params) + self.connection_errors = self.connection.connection_errors + if self.memory_transport: + # Kludge to speed up tests. + self.connection.transport.polling_interval = 0.0 + self.consumer_num = itertools.count(1) + self.connection.connect() + self.channel = self.connection.channel() + # work around 'memory' transport bug in 1.1.3 + if self.memory_transport: + self.channel._new_queue('ae.undeliver') + for consumer in self.consumers: + consumer.reconnect(self.channel) + LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d'), + self.params) + + def reconnect(self): + """Handles reconnecting and re-establishing queues. + Will retry up to self.max_retries number of times. + self.max_retries = 0 means to retry forever. + Sleep between tries, starting at self.interval_start + seconds, backing off self.interval_stepping number of seconds + each attempt. + """ + + attempt = 0 + while True: + attempt += 1 + try: + self._connect() + return + except (self.connection_errors, IOError), e: + pass + except Exception, e: + # NOTE(comstud): Unfortunately it's possible for amqplib + # to return an error not covered by its transport + # connection_errors in the case of a timeout waiting for + # a protocol response. (See paste link in LP888621) + # So, we check all exceptions for 'timeout' in them + # and try to reconnect in this case. + if 'timeout' not in str(e): + raise + + log_info = {} + log_info['err_str'] = str(e) + log_info['max_retries'] = self.max_retries + log_info.update(self.params) + + if self.max_retries and attempt == self.max_retries: + LOG.exception(_('Unable to connect to AMQP server on ' + '%(hostname)s:%(port)d after %(max_retries)d ' + 'tries: %(err_str)s') % log_info) + # NOTE(comstud): Copied from original code. There's + # really no better recourse because if this was a queue we + # need to consume on, we have no way to consume anymore. + sys.exit(1) + + if attempt == 1: + sleep_time = self.interval_start or 1 + elif attempt > 1: + sleep_time += self.interval_stepping + if self.interval_max: + sleep_time = min(sleep_time, self.interval_max) + + log_info['sleep_time'] = sleep_time + LOG.exception(_('AMQP server on %(hostname)s:%(port)d is' + ' unreachable: %(err_str)s. Trying again in ' + '%(sleep_time)d seconds.') % log_info) + time.sleep(sleep_time) + + def ensure(self, error_callback, method, *args, **kwargs): + while True: + try: + return method(*args, **kwargs) + except (self.connection_errors, socket.timeout, IOError), e: + pass + except Exception, e: + # NOTE(comstud): Unfortunately it's possible for amqplib + # to return an error not covered by its transport + # connection_errors in the case of a timeout waiting for + # a protocol response. (See paste link in LP888621) + # So, we check all exceptions for 'timeout' in them + # and try to reconnect in this case. + if 'timeout' not in str(e): + raise + if error_callback: + error_callback(e) + self.reconnect() + + def get_channel(self): + """Convenience call for bin/clear_rabbit_queues""" + return self.channel + + def close(self): + """Close/release this connection""" + self.cancel_consumer_thread() + self.connection.release() + self.connection = None + + def reset(self): + """Reset a connection so it can be used again""" + self.cancel_consumer_thread() + self.channel.close() + self.channel = self.connection.channel() + # work around 'memory' transport bug in 1.1.3 + if self.memory_transport: + self.channel._new_queue('ae.undeliver') + self.consumers = [] + + def declare_consumer(self, consumer_cls, topic, callback): + """Create a Consumer using the class that was passed in and + add it to our list of consumers + """ + + def _connect_error(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.error(_("Failed to declare consumer for topic '%(topic)s': " + "%(err_str)s") % log_info) + + def _declare_consumer(): + consumer = consumer_cls(self.conf, self.channel, topic, callback, + self.consumer_num.next()) + self.consumers.append(consumer) + return consumer + + return self.ensure(_connect_error, _declare_consumer) + + def iterconsume(self, limit=None, timeout=None): + """Return an iterator that will consume from all queues/consumers""" + + info = {'do_consume': True} + + def _error_callback(exc): + if isinstance(exc, socket.timeout): + LOG.exception(_('Timed out waiting for RPC response: %s') % + str(exc)) + raise rpc_common.Timeout() + else: + LOG.exception(_('Failed to consume message from queue: %s') % + str(exc)) + info['do_consume'] = True + + def _consume(): + if info['do_consume']: + queues_head = self.consumers[:-1] + queues_tail = self.consumers[-1] + for queue in queues_head: + queue.consume(nowait=True) + queues_tail.consume(nowait=False) + info['do_consume'] = False + return self.connection.drain_events(timeout=timeout) + + for iteration in itertools.count(0): + if limit and iteration >= limit: + raise StopIteration + yield self.ensure(_error_callback, _consume) + + def cancel_consumer_thread(self): + """Cancel a consumer thread""" + if self.consumer_thread is not None: + self.consumer_thread.kill() + try: + self.consumer_thread.wait() + except greenlet.GreenletExit: + pass + self.consumer_thread = None + + def publisher_send(self, cls, topic, msg, **kwargs): + """Send to a publisher based on the publisher class""" + + def _error_callback(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.exception(_("Failed to publish message to topic " + "'%(topic)s': %(err_str)s") % log_info) + + def _publish(): + publisher = cls(self.conf, self.channel, topic, **kwargs) + publisher.send(msg) + + self.ensure(_error_callback, _publish) + + def declare_direct_consumer(self, topic, callback): + """Create a 'direct' queue. + In nova's use, this is generally a msg_id queue used for + responses for call/multicall + """ + self.declare_consumer(DirectConsumer, topic, callback) + + def declare_topic_consumer(self, topic, callback=None, queue_name=None): + """Create a 'topic' consumer.""" + self.declare_consumer(functools.partial(TopicConsumer, + name=queue_name, + ), + topic, callback) + + def declare_fanout_consumer(self, topic, callback): + """Create a 'fanout' consumer""" + self.declare_consumer(FanoutConsumer, topic, callback) + + def direct_send(self, msg_id, msg): + """Send a 'direct' message""" + self.publisher_send(DirectPublisher, msg_id, msg) + + def topic_send(self, topic, msg): + """Send a 'topic' message""" + self.publisher_send(TopicPublisher, topic, msg) + + def fanout_send(self, topic, msg): + """Send a 'fanout' message""" + self.publisher_send(FanoutPublisher, topic, msg) + + def notify_send(self, topic, msg, **kwargs): + """Send a notify message on a topic""" + self.publisher_send(NotifyPublisher, topic, msg, **kwargs) + + def consume(self, limit=None): + """Consume from all queues/consumers""" + it = self.iterconsume(limit=limit) + while True: + try: + it.next() + except StopIteration: + return + + def consume_in_thread(self): + """Consumer from all queues/consumers in a greenthread""" + def _consumer_thread(): + try: + self.consume() + except greenlet.GreenletExit: + return + if self.consumer_thread is None: + self.consumer_thread = eventlet.spawn(_consumer_thread) + return self.consumer_thread + + def create_consumer(self, topic, proxy, fanout=False): + """Create a consumer that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + + if fanout: + self.declare_fanout_consumer(topic, proxy_cb) + else: + self.declare_topic_consumer(topic, proxy_cb) + + def create_worker(self, topic, proxy, pool_name): + """Create a worker that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + self.declare_topic_consumer(topic, proxy_cb, pool_name) + + +def create_connection(conf, new=True): + """Create a connection""" + return rpc_amqp.create_connection(conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def multicall(conf, context, topic, msg, timeout=None): + """Make a call that returns multiple times.""" + return rpc_amqp.multicall(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def call(conf, context, topic, msg, timeout=None): + """Sends a message on a topic and wait for a response.""" + return rpc_amqp.call(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast(conf, context, topic, msg): + """Sends a message on a topic without waiting for a response.""" + return rpc_amqp.cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast(conf, context, topic, msg): + """Sends a message on a fanout exchange without waiting for a response.""" + return rpc_amqp.fanout_cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a topic to a specific server.""" + return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a fanout exchange to a specific server.""" + return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def notify(conf, context, topic, msg): + """Sends a notification event on a topic.""" + return rpc_amqp.notify(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cleanup(): + return rpc_amqp.cleanup(Connection.pool) diff --git a/openstack/common/rpc/impl_qpid.py b/openstack/common/rpc/impl_qpid.py new file mode 100644 index 0000000..3c46309 --- /dev/null +++ b/openstack/common/rpc/impl_qpid.py @@ -0,0 +1,580 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack LLC +# Copyright 2011 - 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import itertools +import json +import logging +import time +import uuid + +import eventlet +import greenlet +import qpid.messaging +import qpid.messaging.exceptions + +from openstack.common import cfg +from openstack.common.gettextutils import _ +from openstack.common.rpc import amqp as rpc_amqp +from openstack.common.rpc import common as rpc_common + +LOG = logging.getLogger(__name__) + +qpid_opts = [ + cfg.StrOpt('qpid_hostname', + default='localhost', + help='Qpid broker hostname'), + cfg.StrOpt('qpid_port', + default='5672', + help='Qpid broker port'), + cfg.StrOpt('qpid_username', + default='', + help='Username for qpid connection'), + cfg.StrOpt('qpid_password', + default='', + help='Password for qpid connection'), + cfg.StrOpt('qpid_sasl_mechanisms', + default='', + help='Space separated list of SASL mechanisms to use for auth'), + cfg.BoolOpt('qpid_reconnect', + default=True, + help='Automatically reconnect'), + cfg.IntOpt('qpid_reconnect_timeout', + default=0, + help='Reconnection timeout in seconds'), + cfg.IntOpt('qpid_reconnect_limit', + default=0, + help='Max reconnections before giving up'), + cfg.IntOpt('qpid_reconnect_interval_min', + default=0, + help='Minimum seconds between reconnection attempts'), + cfg.IntOpt('qpid_reconnect_interval_max', + default=0, + help='Maximum seconds between reconnection attempts'), + cfg.IntOpt('qpid_reconnect_interval', + default=0, + help='Equivalent to setting max and min to the same value'), + cfg.IntOpt('qpid_heartbeat', + default=5, + help='Seconds between connection keepalive heartbeats'), + cfg.StrOpt('qpid_protocol', + default='tcp', + help="Transport to use, either 'tcp' or 'ssl'"), + cfg.BoolOpt('qpid_tcp_nodelay', + default=True, + help='Disable Nagle algorithm'), + ] + +cfg.CONF.register_opts(qpid_opts) + + +class ConsumerBase(object): + """Consumer base class.""" + + def __init__(self, session, callback, node_name, node_opts, + link_name, link_opts): + """Declare a queue on an amqp session. + + 'session' is the amqp session to use + 'callback' is the callback to call when messages are received + 'node_name' is the first part of the Qpid address string, before ';' + 'node_opts' will be applied to the "x-declare" section of "node" + in the address string. + 'link_name' goes into the "name" field of the "link" in the address + string + 'link_opts' will be applied to the "x-declare" section of "link" + in the address string. + """ + self.callback = callback + self.receiver = None + self.session = None + + addr_opts = { + "create": "always", + "node": { + "type": "topic", + "x-declare": { + "durable": True, + "auto-delete": True, + }, + }, + "link": { + "name": link_name, + "durable": True, + "x-declare": { + "durable": False, + "auto-delete": True, + "exclusive": False, + }, + }, + } + addr_opts["node"]["x-declare"].update(node_opts) + addr_opts["link"]["x-declare"].update(link_opts) + + self.address = "%s ; %s" % (node_name, json.dumps(addr_opts)) + + self.reconnect(session) + + def reconnect(self, session): + """Re-declare the receiver after a qpid reconnect""" + self.session = session + self.receiver = session.receiver(self.address) + self.receiver.capacity = 1 + + def consume(self): + """Fetch the message and pass it to the callback object""" + message = self.receiver.fetch() + self.callback(message.content) + + def get_receiver(self): + return self.receiver + + +class DirectConsumer(ConsumerBase): + """Queue/consumer class for 'direct'""" + + def __init__(self, conf, session, msg_id, callback): + """Init a 'direct' queue. + + 'session' is the amqp session to use + 'msg_id' is the msg_id to listen on + 'callback' is the callback to call when messages are received + """ + + super(DirectConsumer, self).__init__(session, callback, + "%s/%s" % (msg_id, msg_id), + {"type": "direct"}, + msg_id, + {"exclusive": True}) + + +class TopicConsumer(ConsumerBase): + """Consumer class for 'topic'""" + + def __init__(self, conf, session, topic, callback, name=None): + """Init a 'topic' queue. + + :param session: the amqp session to use + :param topic: is the topic to listen on + :paramtype topic: str + :param callback: the callback to call when messages are received + :param name: optional queue name, defaults to topic + """ + + super(TopicConsumer, self).__init__(session, callback, + "%s/%s" % (conf.control_exchange, topic), {}, + name or topic, {}) + + +class FanoutConsumer(ConsumerBase): + """Consumer class for 'fanout'""" + + def __init__(self, conf, session, topic, callback): + """Init a 'fanout' queue. + + 'session' is the amqp session to use + 'topic' is the topic to listen on + 'callback' is the callback to call when messages are received + """ + + super(FanoutConsumer, self).__init__(session, callback, + "%s_fanout" % topic, + {"durable": False, "type": "fanout"}, + "%s_fanout_%s" % (topic, uuid.uuid4().hex), + {"exclusive": True}) + + +class Publisher(object): + """Base Publisher class""" + + def __init__(self, session, node_name, node_opts=None): + """Init the Publisher class with the exchange_name, routing_key, + and other options + """ + self.sender = None + self.session = session + + addr_opts = { + "create": "always", + "node": { + "type": "topic", + "x-declare": { + "durable": False, + # auto-delete isn't implemented for exchanges in qpid, + # but put in here anyway + "auto-delete": True, + }, + }, + } + if node_opts: + addr_opts["node"]["x-declare"].update(node_opts) + + self.address = "%s ; %s" % (node_name, json.dumps(addr_opts)) + + self.reconnect(session) + + def reconnect(self, session): + """Re-establish the Sender after a reconnection""" + self.sender = session.sender(self.address) + + def send(self, msg): + """Send a message""" + self.sender.send(msg) + + +class DirectPublisher(Publisher): + """Publisher class for 'direct'""" + def __init__(self, conf, session, msg_id): + """Init a 'direct' publisher.""" + super(DirectPublisher, self).__init__(session, msg_id, + {"type": "Direct"}) + + +class TopicPublisher(Publisher): + """Publisher class for 'topic'""" + def __init__(self, conf, session, topic): + """init a 'topic' publisher. + """ + super(TopicPublisher, self).__init__(session, + "%s/%s" % (conf.control_exchange, topic)) + + +class FanoutPublisher(Publisher): + """Publisher class for 'fanout'""" + def __init__(self, conf, session, topic): + """init a 'fanout' publisher. + """ + super(FanoutPublisher, self).__init__(session, + "%s_fanout" % topic, {"type": "fanout"}) + + +class NotifyPublisher(Publisher): + """Publisher class for notifications""" + def __init__(self, conf, session, topic): + """init a 'topic' publisher. + """ + super(NotifyPublisher, self).__init__(session, + "%s/%s" % (conf.control_exchange, topic), + {"durable": True}) + + +class Connection(object): + """Connection object.""" + + pool = None + + def __init__(self, conf, server_params=None): + self.session = None + self.consumers = {} + self.consumer_thread = None + self.conf = conf + + if server_params is None: + server_params = {} + + default_params = dict(hostname=self.conf.qpid_hostname, + port=self.conf.qpid_port, + username=self.conf.qpid_username, + password=self.conf.qpid_password) + + params = server_params + for key in default_params.keys(): + params.setdefault(key, default_params[key]) + + self.broker = params['hostname'] + ":" + str(params['port']) + # Create the connection - this does not open the connection + self.connection = qpid.messaging.Connection(self.broker) + + # Check if flags are set and if so set them for the connection + # before we call open + self.connection.username = params['username'] + self.connection.password = params['password'] + self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms + self.connection.reconnect = self.conf.qpid_reconnect + if self.conf.qpid_reconnect_timeout: + self.connection.reconnect_timeout = ( + self.conf.qpid_reconnect_timeout) + if self.conf.qpid_reconnect_limit: + self.connection.reconnect_limit = self.conf.qpid_reconnect_limit + if self.conf.qpid_reconnect_interval_max: + self.connection.reconnect_interval_max = ( + self.conf.qpid_reconnect_interval_max) + if self.conf.qpid_reconnect_interval_min: + self.connection.reconnect_interval_min = ( + self.conf.qpid_reconnect_interval_min) + if self.conf.qpid_reconnect_interval: + self.connection.reconnect_interval = ( + self.conf.qpid_reconnect_interval) + self.connection.hearbeat = self.conf.qpid_heartbeat + self.connection.protocol = self.conf.qpid_protocol + self.connection.tcp_nodelay = self.conf.qpid_tcp_nodelay + + # Open is part of reconnect - + # NOTE(WGH) not sure we need this with the reconnect flags + self.reconnect() + + def _register_consumer(self, consumer): + self.consumers[str(consumer.get_receiver())] = consumer + + def _lookup_consumer(self, receiver): + return self.consumers[str(receiver)] + + def reconnect(self): + """Handles reconnecting and re-establishing sessions and queues""" + if self.connection.opened(): + try: + self.connection.close() + except qpid.messaging.exceptions.ConnectionError: + pass + + while True: + try: + self.connection.open() + except qpid.messaging.exceptions.ConnectionError, e: + LOG.error(_('Unable to connect to AMQP server: %s'), e) + time.sleep(self.conf.qpid_reconnect_interval or 1) + else: + break + + LOG.info(_('Connected to AMQP server on %s'), self.broker) + + self.session = self.connection.session() + + for consumer in self.consumers.itervalues(): + consumer.reconnect(self.session) + + if self.consumers: + LOG.debug(_("Re-established AMQP queues")) + + def ensure(self, error_callback, method, *args, **kwargs): + while True: + try: + return method(*args, **kwargs) + except (qpid.messaging.exceptions.Empty, + qpid.messaging.exceptions.ConnectionError), e: + if error_callback: + error_callback(e) + self.reconnect() + + def close(self): + """Close/release this connection""" + self.cancel_consumer_thread() + self.connection.close() + self.connection = None + + def reset(self): + """Reset a connection so it can be used again""" + self.cancel_consumer_thread() + self.session.close() + self.session = self.connection.session() + self.consumers = {} + + def declare_consumer(self, consumer_cls, topic, callback): + """Create a Consumer using the class that was passed in and + add it to our list of consumers + """ + def _connect_error(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.error(_("Failed to declare consumer for topic '%(topic)s': " + "%(err_str)s") % log_info) + + def _declare_consumer(): + consumer = consumer_cls(self.conf, self.session, topic, callback) + self._register_consumer(consumer) + return consumer + + return self.ensure(_connect_error, _declare_consumer) + + def iterconsume(self, limit=None, timeout=None): + """Return an iterator that will consume from all queues/consumers""" + + def _error_callback(exc): + if isinstance(exc, qpid.messaging.exceptions.Empty): + LOG.exception(_('Timed out waiting for RPC response: %s') % + str(exc)) + raise rpc_common.Timeout() + else: + LOG.exception(_('Failed to consume message from queue: %s') % + str(exc)) + + def _consume(): + nxt_receiver = self.session.next_receiver(timeout=timeout) + try: + self._lookup_consumer(nxt_receiver).consume() + except Exception: + LOG.exception(_("Error processing message. Skipping it.")) + + for iteration in itertools.count(0): + if limit and iteration >= limit: + raise StopIteration + yield self.ensure(_error_callback, _consume) + + def cancel_consumer_thread(self): + """Cancel a consumer thread""" + if self.consumer_thread is not None: + self.consumer_thread.kill() + try: + self.consumer_thread.wait() + except greenlet.GreenletExit: + pass + self.consumer_thread = None + + def publisher_send(self, cls, topic, msg): + """Send to a publisher based on the publisher class""" + + def _connect_error(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.exception(_("Failed to publish message to topic " + "'%(topic)s': %(err_str)s") % log_info) + + def _publisher_send(): + publisher = cls(self.conf, self.session, topic) + publisher.send(msg) + + return self.ensure(_connect_error, _publisher_send) + + def declare_direct_consumer(self, topic, callback): + """Create a 'direct' queue. + In nova's use, this is generally a msg_id queue used for + responses for call/multicall + """ + self.declare_consumer(DirectConsumer, topic, callback) + + def declare_topic_consumer(self, topic, callback=None, queue_name=None): + """Create a 'topic' consumer.""" + self.declare_consumer(functools.partial(TopicConsumer, + name=queue_name, + ), + topic, callback) + + def declare_fanout_consumer(self, topic, callback): + """Create a 'fanout' consumer""" + self.declare_consumer(FanoutConsumer, topic, callback) + + def direct_send(self, msg_id, msg): + """Send a 'direct' message""" + self.publisher_send(DirectPublisher, msg_id, msg) + + def topic_send(self, topic, msg): + """Send a 'topic' message""" + self.publisher_send(TopicPublisher, topic, msg) + + def fanout_send(self, topic, msg): + """Send a 'fanout' message""" + self.publisher_send(FanoutPublisher, topic, msg) + + def notify_send(self, topic, msg, **kwargs): + """Send a notify message on a topic""" + self.publisher_send(NotifyPublisher, topic, msg) + + def consume(self, limit=None): + """Consume from all queues/consumers""" + it = self.iterconsume(limit=limit) + while True: + try: + it.next() + except StopIteration: + return + + def consume_in_thread(self): + """Consumer from all queues/consumers in a greenthread""" + def _consumer_thread(): + try: + self.consume() + except greenlet.GreenletExit: + return + if self.consumer_thread is None: + self.consumer_thread = eventlet.spawn(_consumer_thread) + return self.consumer_thread + + def create_consumer(self, topic, proxy, fanout=False): + """Create a consumer that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + + if fanout: + consumer = FanoutConsumer(self.conf, self.session, topic, proxy_cb) + else: + consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb) + + self._register_consumer(consumer) + + return consumer + + def create_worker(self, topic, proxy, pool_name): + """Create a worker that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + + consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb, + name=pool_name) + + self._register_consumer(consumer) + + return consumer + + +def create_connection(conf, new=True): + """Create a connection""" + return rpc_amqp.create_connection(conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def multicall(conf, context, topic, msg, timeout=None): + """Make a call that returns multiple times.""" + return rpc_amqp.multicall(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def call(conf, context, topic, msg, timeout=None): + """Sends a message on a topic and wait for a response.""" + return rpc_amqp.call(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast(conf, context, topic, msg): + """Sends a message on a topic without waiting for a response.""" + return rpc_amqp.cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast(conf, context, topic, msg): + """Sends a message on a fanout exchange without waiting for a response.""" + return rpc_amqp.fanout_cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a topic to a specific server.""" + return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a fanout exchange to a specific server.""" + return rpc_amqp.fanout_cast_to_server(conf, context, server_params, topic, + msg, rpc_amqp.get_connection_pool(conf, Connection)) + + +def notify(conf, context, topic, msg): + """Sends a notification event on a topic.""" + return rpc_amqp.notify(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cleanup(): + return rpc_amqp.cleanup(Connection.pool) diff --git a/openstack/common/rpc/matchmaker.py b/openstack/common/rpc/matchmaker.py new file mode 100644 index 0000000..4da1dcd --- /dev/null +++ b/openstack/common/rpc/matchmaker.py @@ -0,0 +1,257 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +The MatchMaker classes should except a Topic or Fanout exchange key and +return keys for direct exchanges, per (approximate) AMQP parlance. +""" + +import contextlib +import itertools +import json +import logging + +from openstack.common import cfg + + +matchmaker_opts = [ + # Matchmaker ring file + cfg.StrOpt('matchmaker_ringfile', + default='/etc/nova/matchmaker_ring.json', + help='Matchmaker ring file (JSON)'), +] + +CONF = cfg.CONF +CONF.register_opts(matchmaker_opts) +LOG = logging.getLogger(__name__) +contextmanager = contextlib.contextmanager + + +class MatchMakerException(Exception): + """Signified a match could not be found.""" + message = _("Match not found by MatchMaker.") + + +class Exchange(object): + """ + Implements lookups. + Subclass this to support hashtables, dns, etc. + """ + def __init__(self): + pass + + def run(self, key): + raise NotImplementedError() + + +class Binding(object): + """ + A binding on which to perform a lookup. + """ + def __init__(self): + pass + + def test(self, key): + raise NotImplementedError() + + +class MatchMakerBase(object): + """Match Maker Base Class.""" + + def __init__(self): + # Array of tuples. Index [2] toggles negation, [3] is last-if-true + self.bindings = [] + + def add_binding(self, binding, rule, last=True): + self.bindings.append((binding, rule, False, last)) + + #NOTE(ewindisch): kept the following method in case we implement the + # underlying support. + #def add_negate_binding(self, binding, rule, last=True): + # self.bindings.append((binding, rule, True, last)) + + def queues(self, key): + workers = [] + + # bit is for negate bindings - if we choose to implement it. + # last stops processing rules if this matches. + for (binding, exchange, bit, last) in self.bindings: + if binding.test(key): + workers.extend(exchange.run(key)) + + # Support last. + if last: + return workers + return workers + + +class DirectBinding(Binding): + """ + Specifies a host in the key via a '.' character + Although dots are used in the key, the behavior here is + that it maps directly to a host, thus direct. + """ + def test(self, key): + if '.' in key: + return True + return False + + +class TopicBinding(Binding): + """ + Where a 'bare' key without dots. + AMQP generally considers topic exchanges to be those *with* dots, + but we deviate here in terminology as the behavior here matches + that of a topic exchange (whereas where there are dots, behavior + matches that of a direct exchange. + """ + def test(self, key): + if '.' not in key: + return True + return False + + +class FanoutBinding(Binding): + """Match on fanout keys, where key starts with 'fanout.' string.""" + def test(self, key): + if key.startswith('fanout~'): + return True + return False + + +class StubExchange(Exchange): + """Exchange that does nothing.""" + def run(self, key): + return [(key, None)] + + +class RingExchange(Exchange): + """ + Match Maker where hosts are loaded from a static file containing + a hashmap (JSON formatted). + + __init__ takes optional ring dictionary argument, otherwise + loads the ringfile from CONF.mathcmaker_ringfile. + """ + def __init__(self, ring=None): + super(RingExchange, self).__init__() + + if ring: + self.ring = ring + else: + fh = open(CONF.matchmaker_ringfile, 'r') + self.ring = json.load(fh) + fh.close() + + self.ring0 = {} + for k in self.ring.keys(): + self.ring0[k] = itertools.cycle(self.ring[k]) + + def _ring_has(self, key): + if key in self.ring0: + return True + return False + + +class RoundRobinRingExchange(RingExchange): + """A Topic Exchange based on a hashmap.""" + def __init__(self, ring=None): + super(RoundRobinRingExchange, self).__init__(ring) + + def run(self, key): + if not self._ring_has(key): + LOG.warn( + _("No key defining hosts for topic '%s', " + "see ringfile") % (key, ) + ) + return [] + host = next(self.ring0[key]) + return [(key + '.' + host, host)] + + +class FanoutRingExchange(RingExchange): + """Fanout Exchange based on a hashmap.""" + def __init__(self, ring=None): + super(FanoutRingExchange, self).__init__(ring) + + def run(self, key): + # Assume starts with "fanout~", strip it for lookup. + nkey = key.split('fanout~')[1:][0] + if not self._ring_has(nkey): + LOG.warn( + _("No key defining hosts for topic '%s', " + "see ringfile") % (nkey, ) + ) + return [] + return map(lambda x: (key + '.' + x, x), self.ring[nkey]) + + +class LocalhostExchange(Exchange): + """Exchange where all direct topics are local.""" + def __init__(self): + super(Exchange, self).__init__() + + def run(self, key): + return [(key.split('.')[0] + '.localhost', 'localhost')] + + +class DirectExchange(Exchange): + """ + Exchange where all topic keys are split, sending to second half. + i.e. "compute.host" sends a message to "compute" running on "host" + """ + def __init__(self): + super(Exchange, self).__init__() + + def run(self, key): + b, e = key.split('.', 1) + return [(b, e)] + + +class MatchMakerRing(MatchMakerBase): + """ + Match Maker where hosts are loaded from a static hashmap. + """ + def __init__(self, ring=None): + super(MatchMakerRing, self).__init__() + self.add_binding(FanoutBinding(), FanoutRingExchange(ring)) + self.add_binding(DirectBinding(), DirectExchange()) + self.add_binding(TopicBinding(), RoundRobinRingExchange(ring)) + + +class MatchMakerLocalhost(MatchMakerBase): + """ + Match Maker where all bare topics resolve to localhost. + Useful for testing. + """ + def __init__(self): + super(MatchMakerLocalhost, self).__init__() + self.add_binding(FanoutBinding(), LocalhostExchange()) + self.add_binding(DirectBinding(), DirectExchange()) + self.add_binding(TopicBinding(), LocalhostExchange()) + + +class MatchMakerStub(MatchMakerBase): + """ + Match Maker where topics are untouched. + Useful for testing, or for AMQP/brokered queues. + Will not work where knowledge of hosts is known (i.e. zeromq) + """ + def __init__(self): + super(MatchMakerLocalhost, self).__init__() + + self.add_binding(FanoutBinding(), StubExchange()) + self.add_binding(DirectBinding(), StubExchange()) + self.add_binding(TopicBinding(), StubExchange()) diff --git a/openstack/common/rpc/proxy.py b/openstack/common/rpc/proxy.py new file mode 100644 index 0000000..4f4dff5 --- /dev/null +++ b/openstack/common/rpc/proxy.py @@ -0,0 +1,161 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A helper class for proxy objects to remote APIs. + +For more information about rpc API version numbers, see: + rpc/dispatcher.py +""" + + +from openstack.common import rpc + + +class RpcProxy(object): + """A helper class for rpc clients. + + This class is a wrapper around the RPC client API. It allows you to + specify the topic and API version in a single place. This is intended to + be used as a base class for a class that implements the client side of an + rpc API. + """ + + def __init__(self, topic, default_version): + """Initialize an RpcProxy. + + :param topic: The topic to use for all messages. + :param default_version: The default API version to request in all + outgoing messages. This can be overridden on a per-message + basis. + """ + self.topic = topic + self.default_version = default_version + super(RpcProxy, self).__init__() + + def _set_version(self, msg, vers): + """Helper method to set the version in a message. + + :param msg: The message having a version added to it. + :param vers: The version number to add to the message. + """ + msg['version'] = vers if vers else self.default_version + + def _get_topic(self, topic): + """Return the topic to use for a message.""" + return topic if topic else self.topic + + @staticmethod + def make_msg(method, **kwargs): + return {'method': method, 'args': kwargs} + + def call(self, context, msg, topic=None, version=None, timeout=None): + """rpc.call() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param timeout: (Optional) A timeout to use when waiting for the + response. If no timeout is specified, a default timeout will be + used that is usually sufficient. + :param version: (Optional) Override the requested API version in this + message. + + :returns: The return value from the remote method. + """ + self._set_version(msg, version) + return rpc.call(context, self._get_topic(topic), msg, timeout) + + def multicall(self, context, msg, topic=None, version=None, timeout=None): + """rpc.multicall() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param timeout: (Optional) A timeout to use when waiting for the + response. If no timeout is specified, a default timeout will be + used that is usually sufficient. + :param version: (Optional) Override the requested API version in this + message. + + :returns: An iterator that lets you process each of the returned values + from the remote method as they arrive. + """ + self._set_version(msg, version) + return rpc.multicall(context, self._get_topic(topic), msg, timeout) + + def cast(self, context, msg, topic=None, version=None): + """rpc.cast() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.cast() does not wait on any return value from the + remote method. + """ + self._set_version(msg, version) + rpc.cast(context, self._get_topic(topic), msg) + + def fanout_cast(self, context, msg, version=None): + """rpc.fanout_cast() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.fanout_cast() does not wait on any return value + from the remote method. + """ + self._set_version(msg, version) + rpc.fanout_cast(context, self.topic, msg) + + def cast_to_server(self, context, server_params, msg, topic=None, + version=None): + """rpc.cast_to_server() a remote method. + + :param context: The request context + :param server_params: Server parameters. See rpc.cast_to_server() for + details. + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.cast_to_server() does not wait on any + return values. + """ + self._set_version(msg, version) + rpc.cast_to_server(context, server_params, self._get_topic(topic), msg) + + def fanout_cast_to_server(self, context, server_params, msg, version=None): + """rpc.fanout_cast_to_server() a remote method. + + :param context: The request context + :param server_params: Server parameters. See rpc.cast_to_server() for + details. + :param msg: The message to send, including the method and args. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.fanout_cast_to_server() does not wait on any + return values. + """ + self._set_version(msg, version) + rpc.fanout_cast_to_server(context, server_params, self.topic, msg) diff --git a/tests/unit/rpc/__init__.py b/tests/unit/rpc/__init__.py new file mode 100644 index 0000000..848908a --- /dev/null +++ b/tests/unit/rpc/__init__.py @@ -0,0 +1,15 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/tests/unit/rpc/common.py b/tests/unit/rpc/common.py new file mode 100644 index 0000000..013418d --- /dev/null +++ b/tests/unit/rpc/common.py @@ -0,0 +1,322 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for remote procedure calls shared between all implementations +""" + +import logging +import time +import unittest + +import eventlet +from eventlet import greenthread +import nose + +from openstack.common import cfg +from openstack.common import exception +from openstack.common.gettextutils import _ +from openstack.common.rpc import amqp as rpc_amqp +from openstack.common.rpc import common as rpc_common +from openstack.common.rpc import dispatcher as rpc_dispatcher + + +FLAGS = cfg.CONF +LOG = logging.getLogger(__name__) + + +class BaseRpcTestCase(unittest.TestCase): + def setUp(self, supports_timeouts=True, topic='test', + topic_nested='nested'): + super(BaseRpcTestCase, self).setUp() + self.topic = topic or self.topic + self.topic_nested = topic_nested or self.topic_nested + self.supports_timeouts = supports_timeouts + self.context = rpc_common.CommonRpcContext(user='fake_user', + pw='fake_pw') + + if self.rpc: + receiver = TestReceiver() + self.conn = self._create_consumer(receiver, self.topic) + + def tearDown(self): + if self.rpc: + self.conn.close() + super(BaseRpcTestCase, self).tearDown() + + def _create_consumer(self, proxy, topic, fanout=False): + dispatcher = rpc_dispatcher.RpcDispatcher([proxy]) + conn = self.rpc.create_connection(FLAGS, True) + conn.create_consumer(topic, dispatcher, fanout) + conn.consume_in_thread() + return conn + + def test_call_succeed(self): + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + value = 42 + result = self.rpc.call(FLAGS, self.context, self.topic, + {"method": "echo", "args": {"value": value}}) + self.assertEqual(value, result) + + def test_call_succeed_despite_multiple_returns_yield(self): + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + value = 42 + result = self.rpc.call(FLAGS, self.context, self.topic, + {"method": "echo_three_times_yield", + "args": {"value": value}}) + self.assertEqual(value + 2, result) + + def test_multicall_succeed_once(self): + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + value = 42 + result = self.rpc.multicall(FLAGS, self.context, + self.topic, + {"method": "echo", + "args": {"value": value}}) + for i, x in enumerate(result): + if i > 0: + self.fail('should only receive one response') + self.assertEqual(value + i, x) + + def test_multicall_three_nones(self): + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + value = 42 + result = self.rpc.multicall(FLAGS, self.context, + self.topic, + {"method": "multicall_three_nones", + "args": {"value": value}}) + for i, x in enumerate(result): + self.assertEqual(x, None) + # i should have been 0, 1, and finally 2: + self.assertEqual(i, 2) + + def test_multicall_succeed_three_times_yield(self): + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + value = 42 + result = self.rpc.multicall(FLAGS, self.context, + self.topic, + {"method": "echo_three_times_yield", + "args": {"value": value}}) + for i, x in enumerate(result): + self.assertEqual(value + i, x) + + def test_context_passed(self): + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + """Makes sure a context is passed through rpc call.""" + value = 42 + result = self.rpc.call(FLAGS, self.context, + self.topic, {"method": "context", + "args": {"value": value}}) + self.assertEqual(self.context.to_dict(), result) + + def _test_cast(self, fanout=False): + """Test casts by pushing items through a channeled queue.""" + + # Not a true global, but capitalized so + # it is clear it is leaking scope into Nested() + QUEUE = eventlet.queue.Queue() + + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + # We use the nested topic so we don't need QUEUE to be a proper + # global, and do not keep state outside this test. + class Nested(object): + @staticmethod + def put_queue(context, value): + LOG.debug("Got value in put_queue: %s", value) + QUEUE.put(value) + + nested = Nested() + conn = self._create_consumer(nested, self.topic_nested, fanout) + value = 42 + + method = (self.rpc.cast, self.rpc.fanout_cast)[fanout] + method(FLAGS, self.context, + self.topic_nested, + {"method": "put_queue", + "args": {"value": value}}) + + try: + # If it does not succeed in 2 seconds, give up and assume + # failure. + result = QUEUE.get(True, 2) + except Exception: + self.assertEqual(value, None) + + conn.close() + self.assertEqual(value, result) + + def test_cast_success(self): + self._test_cast(False) + + def test_fanout_success(self): + self._test_cast(True) + + def test_nested_calls(self): + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + """Test that we can do an rpc.call inside another call.""" + class Nested(object): + @staticmethod + def echo(context, queue, value): + """Calls echo in the passed queue.""" + LOG.debug(_("Nested received %(queue)s, %(value)s") + % locals()) + # TODO(comstud): + # so, it will replay the context and use the same REQID? + # that's bizarre. + ret = self.rpc.call(FLAGS, context, + queue, + {"method": "echo", + "args": {"value": value}}) + LOG.debug(_("Nested return %s"), ret) + return value + + nested = Nested() + conn = self._create_consumer(nested, self.topic_nested) + + value = 42 + result = self.rpc.call(FLAGS, self.context, + self.topic_nested, + {"method": "echo", + "args": {"queue": "test", "value": value}}) + conn.close() + self.assertEqual(value, result) + + def test_call_timeout(self): + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + """Make sure rpc.call will time out.""" + if not self.supports_timeouts: + raise nose.SkipTest(_("RPC backend does not support timeouts")) + + value = 42 + self.assertRaises(rpc_common.Timeout, + self.rpc.call, + FLAGS, self.context, + self.topic, + {"method": "block", + "args": {"value": value}}, timeout=1) + try: + self.rpc.call(FLAGS, self.context, + self.topic, + {"method": "block", + "args": {"value": value}}, + timeout=1) + self.fail("should have thrown Timeout") + except rpc_common.Timeout as exc: + pass + + +class BaseRpcAMQPTestCase(BaseRpcTestCase): + """Base test class for all AMQP-based RPC tests.""" + def test_proxycallback_handles_exceptions(self): + """Make sure exceptions unpacking messages don't cause hangs.""" + if not self.rpc: + raise nose.SkipTest('rpc driver not available.') + + orig_unpack = rpc_amqp.unpack_context + + info = {'unpacked': False} + + def fake_unpack_context(*args, **kwargs): + info['unpacked'] = True + raise test.TestingException('moo') + + self.stubs.Set(rpc_amqp, 'unpack_context', fake_unpack_context) + + value = 41 + self.rpc.cast(FLAGS, self.context, self.topic, + {"method": "echo", "args": {"value": value}}) + + # Wait for the cast to complete. + for x in xrange(50): + if info['unpacked']: + break + greenthread.sleep(0.1) + else: + self.fail("Timeout waiting for message to be consumed") + + # Now see if we get a response even though we raised an + # exception for the cast above. + self.stubs.Set(rpc_amqp, 'unpack_context', orig_unpack) + + value = 42 + result = self.rpc.call(FLAGS, self.context, self.topic, + {"method": "echo", + "args": {"value": value}}) + self.assertEqual(value, result) + + +class TestReceiver(object): + """Simple Proxy class so the consumer has methods to call. + + Uses static methods because we aren't actually storing any state. + + """ + @staticmethod + def echo(context, value): + """Simply returns whatever value is sent in.""" + LOG.debug(_("Received %s"), value) + return value + + @staticmethod + def context(context, value): + """Returns dictionary version of context.""" + LOG.debug(_("Received %s"), context) + return context.to_dict() + + @staticmethod + def multicall_three_nones(context, value): + yield None + yield None + yield None + + @staticmethod + def echo_three_times_yield(context, value): + yield value + yield value + 1 + yield value + 2 + + @staticmethod + def fail(context, value): + """Raises an exception with the value sent in.""" + raise NotImplementedError(value) + + @staticmethod + def fail_converted(context, value): + """Raises an exception with the value sent in.""" + raise exception.ApiError(message=value, code='500') + + @staticmethod + def block(context, value): + time.sleep(2) diff --git a/tests/unit/rpc/test_common.py b/tests/unit/rpc/test_common.py new file mode 100644 index 0000000..73fb733 --- /dev/null +++ b/tests/unit/rpc/test_common.py @@ -0,0 +1,150 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for 'common' functons used through rpc code. +""" + +import json +import logging +import sys +import unittest + +from openstack.common import cfg +from openstack.common import context +from openstack.common import exception +from openstack.common.rpc import amqp as rpc_amqp +from openstack.common.rpc import common as rpc_common +from tests.unit.rpc import common + + +FLAGS = cfg.CONF +LOG = logging.getLogger(__name__) + + +def raise_exception(): + raise Exception("test") + + +class FakeUserDefinedException(Exception): + def __init__(self): + Exception.__init__(self, "Test Message") + + +class RpcCommonTestCase(unittest.TestCase): + def test_serialize_remote_exception(self): + expected = { + 'class': 'Exception', + 'module': 'exceptions', + 'message': 'test', + } + + try: + raise_exception() + except Exception as exc: + failure = rpc_common.serialize_remote_exception(sys.exc_info()) + + failure = json.loads(failure) + #assure the traceback was added + self.assertEqual(expected['class'], failure['class']) + self.assertEqual(expected['module'], failure['module']) + self.assertEqual(expected['message'], failure['message']) + + def test_serialize_remote_custom_exception(self): + def raise_custom_exception(): + raise exception.OpenstackException() + + expected = { + 'class': 'OpenstackException', + 'module': 'openstack.common.exception', + 'message': exception.OpenstackException.message, + } + + try: + raise_custom_exception() + except Exception as exc: + failure = rpc_common.serialize_remote_exception(sys.exc_info()) + + failure = json.loads(failure) + #assure the traceback was added + self.assertEqual(expected['class'], failure['class']) + self.assertEqual(expected['module'], failure['module']) + self.assertEqual(expected['message'], failure['message']) + + def test_deserialize_remote_exception(self): + failure = { + 'class': 'OpenstackException', + 'module': 'openstack.common.exception', + 'message': exception.OpenstackException.message, + 'tb': ['raise OpenstackException'], + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) + self.assertTrue(isinstance(after_exc, exception.OpenstackException)) + self.assertTrue('An unknown' in unicode(after_exc)) + #assure the traceback was added + self.assertTrue('raise OpenstackException' in unicode(after_exc)) + + def test_deserialize_remote_exception_bad_module(self): + failure = { + 'class': 'popen2', + 'module': 'os', + 'kwargs': {'cmd': '/bin/echo failed'}, + 'message': 'foo', + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) + self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) + + def test_deserialize_remote_exception_user_defined_exception(self): + """Ensure a user defined exception can be deserialized.""" + FLAGS.set_override('allowed_rpc_exception_modules', + [self.__class__.__module__]) + failure = { + 'class': 'FakeUserDefinedException', + 'module': self.__class__.__module__, + 'tb': ['raise FakeUserDefinedException'], + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) + self.assertTrue(isinstance(after_exc, FakeUserDefinedException)) + #assure the traceback was added + self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc)) + FLAGS.reset() + + def test_deserialize_remote_exception_cannot_recreate(self): + """Ensure a RemoteError is returned on initialization failure. + + If an exception cannot be recreated with it's original class then a + RemoteError with the exception informations should still be returned. + + """ + FLAGS.set_override('allowed_rpc_exception_modules', + [self.__class__.__module__]) + failure = { + 'class': 'FakeIDontExistException', + 'module': self.__class__.__module__, + 'tb': ['raise FakeIDontExistException'], + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) + self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) + #assure the traceback was added + self.assertTrue('raise FakeIDontExistException' in unicode(after_exc)) + FLAGS.reset() diff --git a/tests/unit/rpc/test_dispatcher.py b/tests/unit/rpc/test_dispatcher.py new file mode 100644 index 0000000..a085567 --- /dev/null +++ b/tests/unit/rpc/test_dispatcher.py @@ -0,0 +1,110 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Unit Tests for rpc.dispatcher +""" + +import unittest + +from openstack.common import context +from openstack.common.rpc import common as rpc_common +from openstack.common.rpc import dispatcher + + +class RpcDispatcherTestCase(unittest.TestCase): + class API1(object): + RPC_API_VERSION = '1.0' + + def __init__(self): + self.test_method_ctxt = None + self.test_method_arg1 = None + + def test_method(self, ctxt, arg1): + self.test_method_ctxt = ctxt + self.test_method_arg1 = arg1 + + class API2(object): + RPC_API_VERSION = '2.1' + + def __init__(self): + self.test_method_ctxt = None + self.test_method_arg1 = None + + def test_method(self, ctxt, arg1): + self.test_method_ctxt = ctxt + self.test_method_arg1 = arg1 + + class API3(object): + RPC_API_VERSION = '3.5' + + def __init__(self): + self.test_method_ctxt = None + self.test_method_arg1 = None + + def test_method(self, ctxt, arg1): + self.test_method_ctxt = ctxt + self.test_method_arg1 = arg1 + + def setUp(self): + self.ctxt = context.RequestContext('fake_user', 'fake_project') + super(RpcDispatcherTestCase, self).setUp() + + def tearDown(self): + super(RpcDispatcherTestCase, self).tearDown() + + def _test_dispatch(self, version, expectations): + v2 = self.API2() + v3 = self.API3() + disp = dispatcher.RpcDispatcher([v2, v3]) + + disp.dispatch(self.ctxt, version, 'test_method', arg1=1) + + self.assertEqual(v2.test_method_ctxt, expectations[0]) + self.assertEqual(v2.test_method_arg1, expectations[1]) + self.assertEqual(v3.test_method_ctxt, expectations[2]) + self.assertEqual(v3.test_method_arg1, expectations[3]) + + def test_dispatch(self): + self._test_dispatch('2.1', (self.ctxt, 1, None, None)) + self._test_dispatch('3.5', (None, None, self.ctxt, 1)) + + def test_dispatch_lower_minor_version(self): + self._test_dispatch('2.0', (self.ctxt, 1, None, None)) + self._test_dispatch('3.1', (None, None, self.ctxt, 1)) + + def test_dispatch_higher_minor_version(self): + self.assertRaises(rpc_common.UnsupportedRpcVersion, + self._test_dispatch, '2.6', (None, None, None, None)) + self.assertRaises(rpc_common.UnsupportedRpcVersion, + self._test_dispatch, '3.6', (None, None, None, None)) + + def test_dispatch_lower_major_version(self): + self.assertRaises(rpc_common.UnsupportedRpcVersion, + self._test_dispatch, '1.0', (None, None, None, None)) + + def test_dispatch_higher_major_version(self): + self.assertRaises(rpc_common.UnsupportedRpcVersion, + self._test_dispatch, '4.0', (None, None, None, None)) + + def test_dispatch_no_version_uses_v1(self): + v1 = self.API1() + disp = dispatcher.RpcDispatcher([v1]) + + disp.dispatch(self.ctxt, None, 'test_method', arg1=1) + + self.assertEqual(v1.test_method_ctxt, self.ctxt) + self.assertEqual(v1.test_method_arg1, 1) diff --git a/tests/unit/rpc/test_fake.py b/tests/unit/rpc/test_fake.py new file mode 100644 index 0000000..8ceac47 --- /dev/null +++ b/tests/unit/rpc/test_fake.py @@ -0,0 +1,32 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for remote procedure calls using fake_impl +""" + +import eventlet +eventlet.monkey_patch() + +from openstack.common.rpc import impl_fake +from tests.unit.rpc import common + + +class RpcFakeTestCase(common.BaseRpcTestCase): + def setUp(self): + self.rpc = impl_fake + super(RpcFakeTestCase, self).setUp() diff --git a/tests/unit/rpc/test_kombu.py b/tests/unit/rpc/test_kombu.py new file mode 100644 index 0000000..ccab4a2 --- /dev/null +++ b/tests/unit/rpc/test_kombu.py @@ -0,0 +1,414 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for remote procedure calls using kombu +""" + +import eventlet +eventlet.monkey_patch() + +import logging +import unittest + +import stubout + +from openstack.common import cfg +from openstack.common import exception +from openstack.common.rpc import amqp as rpc_amqp +from openstack.common.rpc import common as rpc_common +from openstack.common import testutils +from tests.unit.rpc import common + +try: + import kombu + from openstack.common.rpc import impl_kombu +except ImportError: + kombu = None + impl_kombu = None + + +FLAGS = cfg.CONF +LOG = logging.getLogger(__name__) + + +class MyException(Exception): + pass + + +def _raise_exc_stub(stubs, times, obj, method, exc_msg, + exc_class=MyException): + info = {'called': 0} + orig_method = getattr(obj, method) + + def _raise_stub(*args, **kwargs): + info['called'] += 1 + if info['called'] <= times: + raise exc_class(exc_msg) + orig_method(*args, **kwargs) + stubs.Set(obj, method, _raise_stub) + return info + + +class RpcKombuTestCase(common.BaseRpcAMQPTestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + if kombu: + FLAGS.set_override('fake_rabbit', True) + FLAGS.set_override('rpc_response_timeout', 5) + self.rpc = impl_kombu + else: + self.rpc = None + super(RpcKombuTestCase, self).setUp() + + def tearDown(self): + self.stubs.UnsetAll() + self.stubs.SmartUnsetAll() + if kombu: + impl_kombu.cleanup() + FLAGS.reset() + super(RpcKombuTestCase, self).tearDown() + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_reusing_connection(self): + """Test that reusing a connection returns same one.""" + conn_context = self.rpc.create_connection(FLAGS, new=False) + conn1 = conn_context.connection + conn_context.close() + conn_context = self.rpc.create_connection(FLAGS, new=False) + conn2 = conn_context.connection + conn_context.close() + self.assertEqual(conn1, conn2) + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_topic_send_receive(self): + """Test sending to a topic exchange/queue""" + + conn = self.rpc.create_connection(FLAGS) + message = 'topic test message' + + self.received_message = None + + def _callback(message): + self.received_message = message + + conn.declare_topic_consumer('a_topic', _callback) + conn.topic_send('a_topic', message) + conn.consume(limit=1) + conn.close() + + self.assertEqual(self.received_message, message) + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_topic_multiple_queues(self): + """Test sending to a topic exchange with multiple queues""" + + conn = self.rpc.create_connection(FLAGS) + message = 'topic test message' + + self.received_message_1 = None + self.received_message_2 = None + + def _callback1(message): + self.received_message_1 = message + + def _callback2(message): + self.received_message_2 = message + + conn.declare_topic_consumer('a_topic', _callback1, queue_name='queue1') + conn.declare_topic_consumer('a_topic', _callback2, queue_name='queue2') + conn.topic_send('a_topic', message) + conn.consume(limit=2) + conn.close() + + self.assertEqual(self.received_message_1, message) + self.assertEqual(self.received_message_2, message) + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_direct_send_receive(self): + """Test sending to a direct exchange/queue""" + conn = self.rpc.create_connection(FLAGS) + message = 'direct test message' + + self.received_message = None + + def _callback(message): + self.received_message = message + + conn.declare_direct_consumer('a_direct', _callback) + conn.direct_send('a_direct', message) + conn.consume(limit=1) + conn.close() + + self.assertEqual(self.received_message, message) + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_cast_interface_uses_default_options(self): + """Test kombu rpc.cast""" + + ctxt = rpc_common.CommonRpcContext(user='fake_user', + project='fake_project') + + class MyConnection(impl_kombu.Connection): + def __init__(myself, *args, **kwargs): + super(MyConnection, myself).__init__(*args, **kwargs) + self.assertEqual(myself.params, + {'hostname': FLAGS.rabbit_host, + 'userid': FLAGS.rabbit_userid, + 'password': FLAGS.rabbit_password, + 'port': FLAGS.rabbit_port, + 'virtual_host': FLAGS.rabbit_virtual_host, + 'transport': 'memory'}) + + def topic_send(_context, topic, msg): + pass + + MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection) + self.stubs.Set(impl_kombu, 'Connection', MyConnection) + + impl_kombu.cast(FLAGS, ctxt, 'fake_topic', {'msg': 'fake'}) + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_cast_to_server_uses_server_params(self): + """Test kombu rpc.cast""" + + ctxt = rpc_common.CommonRpcContext(user='fake_user', + project='fake_project') + + server_params = {'username': 'fake_username', + 'password': 'fake_password', + 'hostname': 'fake_hostname', + 'port': 31337, + 'virtual_host': 'fake_virtual_host'} + + class MyConnection(impl_kombu.Connection): + def __init__(myself, *args, **kwargs): + super(MyConnection, myself).__init__(*args, **kwargs) + self.assertEqual(myself.params, + {'hostname': server_params['hostname'], + 'userid': server_params['username'], + 'password': server_params['password'], + 'port': server_params['port'], + 'virtual_host': server_params['virtual_host'], + 'transport': 'memory'}) + + def topic_send(_context, topic, msg): + pass + + MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection) + self.stubs.Set(impl_kombu, 'Connection', MyConnection) + + impl_kombu.cast_to_server(FLAGS, ctxt, server_params, + 'fake_topic', {'msg': 'fake'}) + + @testutils.skip_test("kombu memory transport seems buggy with " + "fanout queues as this test passes when " + "you use rabbit (fake_rabbit=False)") + def test_fanout_send_receive(self): + """Test sending to a fanout exchange and consuming from 2 queues""" + + conn = self.rpc.create_connection() + conn2 = self.rpc.create_connection() + message = 'fanout test message' + + self.received_message = None + + def _callback(message): + self.received_message = message + + conn.declare_fanout_consumer('a_fanout', _callback) + conn2.declare_fanout_consumer('a_fanout', _callback) + conn.fanout_send('a_fanout', message) + + conn.consume(limit=1) + conn.close() + self.assertEqual(self.received_message, message) + + self.received_message = None + conn2.consume(limit=1) + conn2.close() + self.assertEqual(self.received_message, message) + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_declare_consumer_errors_will_reconnect(self): + # Test that any exception with 'timeout' in it causes a + # reconnection + info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer, + '__init__', 'foo timeout foo') + + conn = self.rpc.Connection(FLAGS) + result = conn.declare_consumer(self.rpc.DirectConsumer, + 'test_topic', None) + + self.assertEqual(info['called'], 3) + self.assertTrue(isinstance(result, self.rpc.DirectConsumer)) + + # Test that any exception in transport.connection_errors causes + # a reconnection + self.stubs.UnsetAll() + + info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectConsumer, + '__init__', 'meow') + + conn = self.rpc.Connection(FLAGS) + conn.connection_errors = (MyException, ) + + result = conn.declare_consumer(self.rpc.DirectConsumer, + 'test_topic', None) + + self.assertEqual(info['called'], 2) + self.assertTrue(isinstance(result, self.rpc.DirectConsumer)) + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_declare_consumer_ioerrors_will_reconnect(self): + """Test that an IOError exception causes a reconnection""" + info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer, + '__init__', 'Socket closed', exc_class=IOError) + + conn = self.rpc.Connection(FLAGS) + result = conn.declare_consumer(self.rpc.DirectConsumer, + 'test_topic', None) + + self.assertEqual(info['called'], 3) + self.assertTrue(isinstance(result, self.rpc.DirectConsumer)) + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_publishing_errors_will_reconnect(self): + # Test that any exception with 'timeout' in it causes a + # reconnection when declaring the publisher class and when + # calling send() + info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher, + '__init__', 'foo timeout foo') + + conn = self.rpc.Connection(FLAGS) + conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') + + self.assertEqual(info['called'], 3) + self.stubs.UnsetAll() + + info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher, + 'send', 'foo timeout foo') + + conn = self.rpc.Connection(FLAGS) + conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') + + self.assertEqual(info['called'], 3) + + # Test that any exception in transport.connection_errors causes + # a reconnection when declaring the publisher class and when + # calling send() + self.stubs.UnsetAll() + + info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher, + '__init__', 'meow') + + conn = self.rpc.Connection(FLAGS) + conn.connection_errors = (MyException, ) + + conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') + + self.assertEqual(info['called'], 2) + self.stubs.UnsetAll() + + info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher, + 'send', 'meow') + + conn = self.rpc.Connection(FLAGS) + conn.connection_errors = (MyException, ) + + conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') + + self.assertEqual(info['called'], 2) + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_iterconsume_errors_will_reconnect(self): + conn = self.rpc.Connection(FLAGS) + message = 'reconnect test message' + + self.received_message = None + + def _callback(message): + self.received_message = message + + conn.declare_direct_consumer('a_direct', _callback) + conn.direct_send('a_direct', message) + + info = _raise_exc_stub(self.stubs, 1, conn.connection, + 'drain_events', 'foo timeout foo') + conn.consume(limit=1) + conn.close() + + self.assertEqual(self.received_message, message) + # Only called once, because our stub goes away during reconnection + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_call_exception(self): + """Test that exception gets passed back properly. + + rpc.call returns an Exception object. The value of the + exception is converted to a string. + + """ + FLAGS.set_override('allowed_rpc_exception_modules', ['exceptions']) + value = "This is the exception message" + self.assertRaises(NotImplementedError, + self.rpc.call, + FLAGS, + self.context, + 'test', + {"method": "fail", + "args": {"value": value}}) + try: + self.rpc.call(FLAGS, self.context, + 'test', + {"method": "fail", + "args": {"value": value}}) + self.fail("should have thrown Exception") + except NotImplementedError as exc: + self.assertTrue(value in unicode(exc)) + #Traceback should be included in exception message + self.assertTrue('raise NotImplementedError(value)' in unicode(exc)) + + FLAGS.reset() + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_call_converted_exception(self): + """Test that exception gets passed back properly. + + rpc.call returns an Exception object. The value of the + exception is converted to a string. + + """ + value = "This is the exception message" + # The use of ApiError is an arbitrary choice here ... + self.assertRaises(exception.ApiError, + self.rpc.call, + FLAGS, + self.context, + 'test', + {"method": "fail_converted", + "args": {"value": value}}) + try: + self.rpc.call(FLAGS, self.context, + 'test', + {"method": "fail_converted", + "args": {"value": value}}) + self.fail("should have thrown Exception") + except exception.ApiError as exc: + self.assertTrue(value in unicode(exc)) + #Traceback should be included in exception message + self.assertTrue('exception.ApiError' in unicode(exc)) diff --git a/tests/unit/rpc/test_kombu_ssl.py b/tests/unit/rpc/test_kombu_ssl.py new file mode 100644 index 0000000..2aecc3f --- /dev/null +++ b/tests/unit/rpc/test_kombu_ssl.py @@ -0,0 +1,82 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for remote procedure calls using kombu + ssl +""" + +import eventlet +eventlet.monkey_patch() + +import unittest + +from openstack.common import cfg +from openstack.common import testutils + +try: + import kombu + from openstack.common.rpc import impl_kombu +except ImportError: + kombu = None + impl_kombu = None + + +# Flag settings we will ensure get passed to amqplib +SSL_VERSION = "SSLv2" +SSL_CERT = "/tmp/cert.blah.blah" +SSL_CA_CERT = "/tmp/cert.ca.blah.blah" +SSL_KEYFILE = "/tmp/keyfile.blah.blah" + +FLAGS = cfg.CONF + + +class RpcKombuSslTestCase(unittest.TestCase): + + def setUp(self): + super(RpcKombuSslTestCase, self).setUp() + override = { + 'kombu_ssl_keyfile': SSL_KEYFILE, + 'kombu_ssl_ca_certs': SSL_CA_CERT, + 'kombu_ssl_certfile': SSL_CERT, + 'kombu_ssl_version': SSL_VERSION, + 'rabbit_use_ssl': True, + 'fake_rabbit': True, + } + + if kombu: + for k, v in override.iteritems(): + FLAGS.set_override(k, v) + + def tearDown(self): + super(RpcKombuSslTestCase, self).tearDown() + if kombu: + FLAGS.reset() + + @testutils.skip_if(kombu is None, "Test requires kombu") + def test_ssl_on_extended(self): + rpc = impl_kombu + conn = rpc.create_connection(FLAGS, True) + c = conn.connection + #This might be kombu version dependent... + #Since we are now peaking into the internals of kombu... + self.assertTrue(isinstance(c.connection.ssl, dict)) + self.assertEqual(SSL_VERSION, c.connection.ssl.get("ssl_version")) + self.assertEqual(SSL_CERT, c.connection.ssl.get("certfile")) + self.assertEqual(SSL_CA_CERT, c.connection.ssl.get("ca_certs")) + self.assertEqual(SSL_KEYFILE, c.connection.ssl.get("keyfile")) + #That hash then goes into amqplib which then goes + #Into python ssl creation... diff --git a/tests/unit/rpc/test_matchmaker.py b/tests/unit/rpc/test_matchmaker.py new file mode 100644 index 0000000..a38b59c --- /dev/null +++ b/tests/unit/rpc/test_matchmaker.py @@ -0,0 +1,60 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging +import unittest + +from openstack.common.rpc import matchmaker + + +LOG = logging.getLogger(__name__) + + +class _MatchMakerTestCase(unittest.TestCase): + def test_valid_host_matches(self): + queues = self.driver.queues(self.topic) + matched_hosts = map(lambda x: x[1], queues) + + for host in matched_hosts: + self.assertIn(host, self.hosts) + + def test_fanout_host_matches(self): + """For known hosts, see if they're in fanout.""" + queues = self.driver.queues("fanout~" + self.topic) + matched_hosts = map(lambda x: x[1], queues) + + LOG.info("Received result from matchmaker: %s", queues) + for host in self.hosts: + self.assertIn(host, matched_hosts) + + +class MatchMakerFileTestCase(_MatchMakerTestCase): + def setUp(self): + self.topic = "test" + self.hosts = ['hello', 'world', 'foo', 'bar', 'baz'] + ring = { + self.topic: self.hosts + } + self.driver = matchmaker.MatchMakerRing(ring) + super(MatchMakerFileTestCase, self).setUp() + + +class MatchMakerLocalhostTestCase(_MatchMakerTestCase): + def setUp(self): + self.driver = matchmaker.MatchMakerLocalhost() + self.topic = "test" + self.hosts = ['localhost'] + super(MatchMakerLocalhostTestCase, self).setUp() diff --git a/tests/unit/rpc/test_proxy.py b/tests/unit/rpc/test_proxy.py new file mode 100644 index 0000000..1af37c7 --- /dev/null +++ b/tests/unit/rpc/test_proxy.py @@ -0,0 +1,128 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Unit Tests for rpc.proxy +""" + +import copy +import stubout +import unittest + +from openstack.common import context +from openstack.common import rpc +from openstack.common.rpc import proxy + + +class RpcProxyTestCase(unittest.TestCase): + + def setUp(self): + self.stubs = stubout.StubOutForTesting() + super(RpcProxyTestCase, self).setUp() + + def tearDown(self): + self.stubs.UnsetAll() + self.stubs.SmartUnsetAll() + super(RpcProxyTestCase, self).tearDown() + + def _test_rpc_method(self, rpc_method, has_timeout=False, has_retval=False, + server_params=None, supports_topic_override=True): + topic = 'fake_topic' + timeout = 123 + rpc_proxy = proxy.RpcProxy(topic, '1.0') + ctxt = context.RequestContext('fake_user', 'fake_project') + msg = {'method': 'fake_method', 'args': {'x': 'y'}} + expected_msg = {'method': 'fake_method', 'args': {'x': 'y'}, + 'version': '1.0'} + + expected_retval = 'hi' if has_retval else None + + self.fake_args = None + self.fake_kwargs = None + + def _fake_rpc_method(*args, **kwargs): + self.fake_args = args + self.fake_kwargs = kwargs + if has_retval: + return expected_retval + + self.stubs.Set(rpc, rpc_method, _fake_rpc_method) + + args = [ctxt, msg] + if server_params: + args.insert(1, server_params) + + # Base method usage + retval = getattr(rpc_proxy, rpc_method)(*args) + self.assertEqual(retval, expected_retval) + expected_args = [ctxt, topic, expected_msg] + if server_params: + expected_args.insert(1, server_params) + for arg, expected_arg in zip(self.fake_args, expected_args): + self.assertEqual(arg, expected_arg) + + # overriding the version + retval = getattr(rpc_proxy, rpc_method)(*args, version='1.1') + self.assertEqual(retval, expected_retval) + new_msg = copy.deepcopy(expected_msg) + new_msg['version'] = '1.1' + expected_args = [ctxt, topic, new_msg] + if server_params: + expected_args.insert(1, server_params) + for arg, expected_arg in zip(self.fake_args, expected_args): + self.assertEqual(arg, expected_arg) + + if has_timeout: + # set a timeout + retval = getattr(rpc_proxy, rpc_method)(ctxt, msg, timeout=timeout) + self.assertEqual(retval, expected_retval) + expected_args = [ctxt, topic, expected_msg, timeout] + for arg, expected_arg in zip(self.fake_args, expected_args): + self.assertEqual(arg, expected_arg) + + if supports_topic_override: + # set a topic + new_topic = 'foo.bar' + retval = getattr(rpc_proxy, rpc_method)(*args, topic=new_topic) + self.assertEqual(retval, expected_retval) + expected_args = [ctxt, new_topic, expected_msg] + if server_params: + expected_args.insert(1, server_params) + for arg, expected_arg in zip(self.fake_args, expected_args): + self.assertEqual(arg, expected_arg) + + def test_call(self): + self._test_rpc_method('call', has_timeout=True, has_retval=True) + + def test_multicall(self): + self._test_rpc_method('multicall', has_timeout=True, has_retval=True) + + def test_cast(self): + self._test_rpc_method('cast') + + def test_fanout_cast(self): + self._test_rpc_method('fanout_cast', supports_topic_override=False) + + def test_cast_to_server(self): + self._test_rpc_method('cast_to_server', server_params={'blah': 1}) + + def test_fanout_cast_to_server(self): + self._test_rpc_method('fanout_cast_to_server', + server_params={'blah': 1}, supports_topic_override=False) + + def test_make_msg(self): + self.assertEqual(proxy.RpcProxy.make_msg('test_method', a=1, b=2), + {'method': 'test_method', 'args': {'a': 1, 'b': 2}}) diff --git a/tests/unit/rpc/test_qpid.py b/tests/unit/rpc/test_qpid.py new file mode 100644 index 0000000..b753c22 --- /dev/null +++ b/tests/unit/rpc/test_qpid.py @@ -0,0 +1,377 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for remote procedure calls using qpid +""" + +import eventlet +eventlet.monkey_patch() + +import logging +import mox +import stubout +import unittest + +from openstack.common import cfg +from openstack.common import context +from openstack.common.rpc import amqp as rpc_amqp +from openstack.common import testutils + +try: + from openstack.common.rpc import impl_qpid + import qpid +except ImportError: + qpid = None + impl_qpid = None + + +FLAGS = cfg.CONF +LOG = logging.getLogger(__name__) + + +class RpcQpidTestCase(unittest.TestCase): + """ + Exercise the public API of impl_qpid utilizing mox. + + This set of tests utilizes mox to replace the Qpid objects and ensures + that the right operations happen on them when the various public rpc API + calls are exercised. The API calls tested here include: + + nova.rpc.create_connection() + nova.rpc.common.Connection.create_consumer() + nova.rpc.common.Connection.close() + nova.rpc.cast() + nova.rpc.fanout_cast() + nova.rpc.call() + nova.rpc.multicall() + """ + + def setUp(self): + super(RpcQpidTestCase, self).setUp() + + self.stubs = stubout.StubOutForTesting() + self.mox = mox.Mox() + + self.mock_connection = None + self.mock_session = None + self.mock_sender = None + self.mock_receiver = None + + if qpid: + self.orig_connection = qpid.messaging.Connection + self.orig_session = qpid.messaging.Session + self.orig_sender = qpid.messaging.Sender + self.orig_receiver = qpid.messaging.Receiver + qpid.messaging.Connection = lambda *_x, **_y: self.mock_connection + qpid.messaging.Session = lambda *_x, **_y: self.mock_session + qpid.messaging.Sender = lambda *_x, **_y: self.mock_sender + qpid.messaging.Receiver = lambda *_x, **_y: self.mock_receiver + + def tearDown(self): + self.mox.UnsetStubs() + self.stubs.UnsetAll() + self.stubs.SmartUnsetAll() + if qpid: + qpid.messaging.Connection = self.orig_connection + qpid.messaging.Session = self.orig_session + qpid.messaging.Sender = self.orig_sender + qpid.messaging.Receiver = self.orig_receiver + if impl_qpid: + # Need to reset this in case we changed the connection_cls + # in self._setup_to_server_tests() + impl_qpid.Connection.pool.connection_cls = impl_qpid.Connection + + super(RpcQpidTestCase, self).tearDown() + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_create_connection(self): + self.mock_connection = self.mox.CreateMock(self.orig_connection) + self.mock_session = self.mox.CreateMock(self.orig_session) + + self.mock_connection.opened().AndReturn(False) + self.mock_connection.open() + self.mock_connection.session().AndReturn(self.mock_session) + self.mock_connection.close() + + self.mox.ReplayAll() + + connection = impl_qpid.create_connection(FLAGS) + connection.close() + + def _test_create_consumer(self, fanout): + self.mock_connection = self.mox.CreateMock(self.orig_connection) + self.mock_session = self.mox.CreateMock(self.orig_session) + self.mock_receiver = self.mox.CreateMock(self.orig_receiver) + + self.mock_connection.opened().AndReturn(False) + self.mock_connection.open() + self.mock_connection.session().AndReturn(self.mock_session) + if fanout: + # The link name includes a UUID, so match it with a regex. + expected_address = mox.Regex(r'^impl_qpid_test_fanout ; ' + '{"node": {"x-declare": {"auto-delete": true, "durable": ' + 'false, "type": "fanout"}, "type": "topic"}, "create": ' + '"always", "link": {"x-declare": {"auto-delete": true, ' + '"exclusive": true, "durable": false}, "durable": true, ' + '"name": "impl_qpid_test_fanout_.*"}}$') + else: + expected_address = ('nova/impl_qpid_test ; {"node": {"x-declare": ' + '{"auto-delete": true, "durable": true}, "type": "topic"}, ' + '"create": "always", "link": {"x-declare": {"auto-delete": ' + 'true, "exclusive": false, "durable": false}, "durable": ' + 'true, "name": "impl_qpid_test"}}') + self.mock_session.receiver(expected_address).AndReturn( + self.mock_receiver) + self.mock_receiver.capacity = 1 + self.mock_connection.close() + + self.mox.ReplayAll() + + connection = impl_qpid.create_connection(FLAGS) + connection.create_consumer("impl_qpid_test", + lambda *_x, **_y: None, + fanout) + connection.close() + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_create_consumer(self): + self._test_create_consumer(fanout=False) + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_create_consumer_fanout(self): + self._test_create_consumer(fanout=True) + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_create_worker(self): + self.mock_connection = self.mox.CreateMock(self.orig_connection) + self.mock_session = self.mox.CreateMock(self.orig_session) + self.mock_receiver = self.mox.CreateMock(self.orig_receiver) + + self.mock_connection.opened().AndReturn(False) + self.mock_connection.open() + self.mock_connection.session().AndReturn(self.mock_session) + expected_address = ( + 'nova/impl_qpid_test ; {"node": {"x-declare": ' + '{"auto-delete": true, "durable": true}, "type": "topic"}, ' + '"create": "always", "link": {"x-declare": {"auto-delete": ' + 'true, "exclusive": false, "durable": false}, "durable": ' + 'true, "name": "impl.qpid.test.workers"}}') + self.mock_session.receiver(expected_address).AndReturn( + self.mock_receiver) + self.mock_receiver.capacity = 1 + self.mock_connection.close() + + self.mox.ReplayAll() + + connection = impl_qpid.create_connection(FLAGS) + connection.create_worker("impl_qpid_test", + lambda *_x, **_y: None, + 'impl.qpid.test.workers', + ) + connection.close() + + def _test_cast(self, fanout, server_params=None): + self.mock_connection = self.mox.CreateMock(self.orig_connection) + self.mock_session = self.mox.CreateMock(self.orig_session) + self.mock_sender = self.mox.CreateMock(self.orig_sender) + + self.mock_connection.opened().AndReturn(False) + self.mock_connection.open() + + self.mock_connection.session().AndReturn(self.mock_session) + if fanout: + expected_address = ('impl_qpid_test_fanout ; ' + '{"node": {"x-declare": {"auto-delete": true, ' + '"durable": false, "type": "fanout"}, ' + '"type": "topic"}, "create": "always"}') + else: + expected_address = ('nova/impl_qpid_test ; {"node": {"x-declare": ' + '{"auto-delete": true, "durable": false}, "type": "topic"}, ' + '"create": "always"}') + self.mock_session.sender(expected_address).AndReturn(self.mock_sender) + self.mock_sender.send(mox.IgnoreArg()) + if not server_params: + # This is a pooled connection, so instead of closing it, it + # gets reset, which is just creating a new session on the + # connection. + self.mock_session.close() + self.mock_connection.session().AndReturn(self.mock_session) + + self.mox.ReplayAll() + + try: + ctx = context.RequestContext("user", "project") + + args = [FLAGS, ctx, "impl_qpid_test", + {"method": "test_method", "args": {}}] + + if server_params: + args.insert(2, server_params) + if fanout: + method = impl_qpid.fanout_cast_to_server + else: + method = impl_qpid.cast_to_server + else: + if fanout: + method = impl_qpid.fanout_cast + else: + method = impl_qpid.cast + + method(*args) + finally: + while impl_qpid.Connection.pool.free_items: + # Pull the mock connection object out of the connection pool so + # that it doesn't mess up other test cases. + impl_qpid.Connection.pool.get() + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_cast(self): + self._test_cast(fanout=False) + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_fanout_cast(self): + self._test_cast(fanout=True) + + def _setup_to_server_tests(self, server_params): + class MyConnection(impl_qpid.Connection): + def __init__(myself, *args, **kwargs): + super(MyConnection, myself).__init__(*args, **kwargs) + self.assertEqual(myself.connection.username, + server_params['username']) + self.assertEqual(myself.connection.password, + server_params['password']) + self.assertEqual(myself.broker, + server_params['hostname'] + ':' + + str(server_params['port'])) + + MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection) + self.stubs.Set(impl_qpid, 'Connection', MyConnection) + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_cast_to_server(self): + server_params = {'username': 'fake_username', + 'password': 'fake_password', + 'hostname': 'fake_hostname', + 'port': 31337} + self._setup_to_server_tests(server_params) + self._test_cast(fanout=False, server_params=server_params) + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_fanout_cast_to_server(self): + server_params = {'username': 'fake_username', + 'password': 'fake_password', + 'hostname': 'fake_hostname', + 'port': 31337} + self._setup_to_server_tests(server_params) + self._test_cast(fanout=True, server_params=server_params) + + def _test_call(self, multi): + self.mock_connection = self.mox.CreateMock(self.orig_connection) + self.mock_session = self.mox.CreateMock(self.orig_session) + self.mock_sender = self.mox.CreateMock(self.orig_sender) + self.mock_receiver = self.mox.CreateMock(self.orig_receiver) + + self.mock_connection.opened().AndReturn(False) + self.mock_connection.open() + self.mock_connection.session().AndReturn(self.mock_session) + rcv_addr = mox.Regex(r'^.*/.* ; {"node": {"x-declare": {"auto-delete":' + ' true, "durable": true, "type": "direct"}, "type": ' + '"topic"}, "create": "always", "link": {"x-declare": ' + '{"auto-delete": true, "exclusive": true, "durable": ' + 'false}, "durable": true, "name": ".*"}}') + self.mock_session.receiver(rcv_addr).AndReturn(self.mock_receiver) + self.mock_receiver.capacity = 1 + send_addr = ('nova/impl_qpid_test ; {"node": {"x-declare": ' + '{"auto-delete": true, "durable": false}, "type": "topic"}, ' + '"create": "always"}') + self.mock_session.sender(send_addr).AndReturn(self.mock_sender) + self.mock_sender.send(mox.IgnoreArg()) + + self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn( + self.mock_receiver) + self.mock_receiver.fetch().AndReturn(qpid.messaging.Message( + {"result": "foo", "failure": False, "ending": False})) + if multi: + self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn( + self.mock_receiver) + self.mock_receiver.fetch().AndReturn( + qpid.messaging.Message( + {"result": "bar", "failure": False, + "ending": False})) + self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn( + self.mock_receiver) + self.mock_receiver.fetch().AndReturn( + qpid.messaging.Message( + {"result": "baz", "failure": False, + "ending": False})) + self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn( + self.mock_receiver) + self.mock_receiver.fetch().AndReturn(qpid.messaging.Message( + {"failure": False, "ending": True})) + self.mock_session.close() + self.mock_connection.session().AndReturn(self.mock_session) + + self.mox.ReplayAll() + + try: + ctx = context.RequestContext("user", "project") + + if multi: + method = impl_qpid.multicall + else: + method = impl_qpid.call + + res = method(FLAGS, ctx, "impl_qpid_test", + {"method": "test_method", "args": {}}) + + if multi: + self.assertEquals(list(res), ["foo", "bar", "baz"]) + else: + self.assertEquals(res, "foo") + finally: + while impl_qpid.Connection.pool.free_items: + # Pull the mock connection object out of the connection pool so + # that it doesn't mess up other test cases. + impl_qpid.Connection.pool.get() + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_call(self): + self._test_call(multi=False) + + @testutils.skip_if(qpid is None, "Test requires qpid") + def test_multicall(self): + self._test_call(multi=True) + + +# +#from nova.tests.rpc import common +# +# Qpid does not have a handy in-memory transport like kombu, so it's not +# terribly straight forward to take advantage of the common unit tests. +# However, at least at the time of this writing, the common unit tests all pass +# with qpidd running. +# +# class RpcQpidCommonTestCase(common._BaseRpcTestCase): +# def setUp(self): +# self.rpc = impl_qpid +# super(RpcQpidCommonTestCase, self).setUp() +# +# def tearDown(self): +# super(RpcQpidCommonTestCase, self).tearDown() +# diff --git a/tools/pip-requires b/tools/pip-requires index adb2400..3f2d088 100644 --- a/tools/pip-requires +++ b/tools/pip-requires @@ -9,3 +9,4 @@ routes==1.12.3 webtest iso8601>=0.1.4 anyjson==0.2.4 +kombu==1.0.4 |