# vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. """ AMQP-based RPC. Queues have consumers and publishers. No fan-out support yet. """ import json import logging import sys import time import traceback import uuid from carrot import connection as carrot_connection from carrot import messaging from eventlet import greenthread from nova import context from nova import exception from nova import fakerabbit from nova import flags from nova import utils FLAGS = flags.FLAGS LOG = logging.getLogger('amqplib') LOG.setLevel(logging.DEBUG) class Connection(carrot_connection.BrokerConnection): """Connection instance object""" @classmethod def instance(cls, new=False): """Returns the instance""" if new or not hasattr(cls, '_instance'): params = dict(hostname=FLAGS.rabbit_host, port=FLAGS.rabbit_port, userid=FLAGS.rabbit_userid, password=FLAGS.rabbit_password, virtual_host=FLAGS.rabbit_virtual_host) if FLAGS.fake_rabbit: params['backend_cls'] = fakerabbit.Backend # NOTE(vish): magic is fun! # pylint: disable-msg=W0142 if new: return cls(**params) else: cls._instance = cls(**params) return cls._instance @classmethod def recreate(cls): """Recreates the connection instance This is necessary to recover from some network errors/disconnects""" del cls._instance return cls.instance() class Consumer(messaging.Consumer): """Consumer base class Contains methods for connecting the fetch method to async loops """ def __init__(self, *args, **kwargs): for i in xrange(FLAGS.rabbit_max_retries): if i > 0: time.sleep(FLAGS.rabbit_retry_interval) try: super(Consumer, self).__init__(*args, **kwargs) self.failed_connection = False break except: # Catching all because carrot sucks logging.exception(_("AMQP server on %s:%d is unreachable." " Trying again in %d seconds.") % ( FLAGS.rabbit_host, FLAGS.rabbit_port, FLAGS.rabbit_retry_interval)) self.failed_connection = True if self.failed_connection: logging.exception(_("Unable to connect to AMQP server" " after %d tries. Shutting down.") % FLAGS.rabbit_max_retries) sys.exit(1) def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False): """Wraps the parent fetch with some logic for failed connections""" # TODO(vish): the logic for failed connections and logging should be # refactored into some sort of connection manager object try: if self.failed_connection: # NOTE(vish): connection is defined in the parent class, we can # recreate it as long as we create the backend too # pylint: disable-msg=W0201 self.connection = Connection.recreate() self.backend = self.connection.create_backend() self.declare() super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks) if self.failed_connection: logging.error(_("Reconnected to queue")) self.failed_connection = False # NOTE(vish): This is catching all errors because we really don't # exceptions to be logged 10 times a second if some # persistent failure occurs. except Exception: # pylint: disable-msg=W0703 if not self.failed_connection: logging.exception(_("Failed to fetch message from queue")) self.failed_connection = True def attach_to_eventlet(self): """Only needed for unit tests!""" timer = utils.LoopingCall(self.fetch, enable_callbacks=True) timer.start(0.1) return timer class Publisher(messaging.Publisher): """Publisher base class""" pass class TopicConsumer(Consumer): """Consumes messages on a specific topic""" exchange_type = "topic" def __init__(self, connection=None, topic="broadcast"): self.queue = topic self.routing_key = topic self.exchange = FLAGS.control_exchange self.durable = False super(TopicConsumer, self).__init__(connection=connection) class AdapterConsumer(TopicConsumer): """Calls methods on a proxy object based on method and args""" def __init__(self, connection=None, topic="broadcast", proxy=None): LOG.debug(_('Initing the Adapter Consumer for %s') % (topic)) self.proxy = proxy super(AdapterConsumer, self).__init__(connection=connection, topic=topic) @exception.wrap_exception def receive(self, message_data, message): """Magically looks for a method on the proxy object and calls it Message data should be a dictionary with two keys: method: string representing the method to call args: dictionary of arg: value Example: {'method': 'echo', 'args': {'value': 42}} """ LOG.debug(_('received %s') % (message_data)) msg_id = message_data.pop('_msg_id', None) ctxt = _unpack_context(message_data) method = message_data.get('method') args = message_data.get('args', {}) message.ack() if not method: # NOTE(vish): we may not want to ack here, but that means that bad # messages stay in the queue indefinitely, so for now # we just log the message and send an error string # back to the caller LOG.warn(_('no method for message: %s') % (message_data)) msg_reply(msg_id, _('No method for message: %s') % message_data) return node_func = getattr(self.proxy, str(method)) node_args = dict((str(k), v) for k, v in args.iteritems()) # NOTE(vish): magic is fun! try: rval = node_func(context=ctxt, **node_args) if msg_id: msg_reply(msg_id, rval, None) except Exception as e: if msg_id: msg_reply(msg_id, None, sys.exc_info()) return class TopicPublisher(Publisher): """Publishes messages on a specific topic""" exchange_type = "topic" def __init__(self, connection=None, topic="broadcast"): self.routing_key = topic self.exchange = FLAGS.control_exchange self.durable = False super(TopicPublisher, self).__init__(connection=connection) class DirectConsumer(Consumer): """Consumes messages directly on a channel specified by msg_id""" exchange_type = "direct" def __init__(self, connection=None, msg_id=None): self.queue = msg_id self.routing_key = msg_id self.exchange = msg_id self.auto_delete = True self.exclusive = True super(DirectConsumer, self).__init__(connection=connection) class DirectPublisher(Publisher): """Publishes messages directly on a channel specified by msg_id""" exchange_type = "direct" def __init__(self, connection=None, msg_id=None): self.routing_key = msg_id self.exchange = msg_id self.auto_delete = True super(DirectPublisher, self).__init__(connection=connection) def msg_reply(msg_id, reply=None, failure=None): """Sends a reply or an error on the channel signified by msg_id failure should be a sys.exc_info() tuple. """ if failure: message = str(failure[1]) tb = traceback.format_exception(*failure) logging.error(_("Returning exception %s to caller"), message) logging.error(tb) failure = (failure[0].__name__, str(failure[1]), tb) conn = Connection.instance(True) publisher = DirectPublisher(connection=conn, msg_id=msg_id) try: publisher.send({'result': reply, 'failure': failure}) except TypeError: publisher.send( {'result': dict((k, repr(v)) for k, v in reply.__dict__.iteritems()), 'failure': failure}) publisher.close() class RemoteError(exception.Error): """Signifies that a remote class has raised an exception Containes a string representation of the type of the original exception, the value of the original exception, and the traceback. These are sent to the parent as a joined string so printing the exception contains all of the relevent info.""" def __init__(self, exc_type, value, traceback): self.exc_type = exc_type self.value = value self.traceback = traceback super(RemoteError, self).__init__("%s %s\n%s" % (exc_type, value, traceback)) def _unpack_context(msg): """Unpack context from msg.""" context_dict = {} for key in list(msg.keys()): # NOTE(vish): Some versions of python don't like unicode keys # in kwargs. key = str(key) if key.startswith('_context_'): value = msg.pop(key) context_dict[key[9:]] = value LOG.debug(_('unpacked context: %s'), context_dict) return context.RequestContext.from_dict(context_dict) def _pack_context(msg, context): """Pack context into msg. Values for message keys need to be less than 255 chars, so we pull context out into a bunch of separate keys. If we want to support more arguments in rabbit messages, we may want to do the same for args at some point. """ context = dict([('_context_%s' % key, value) for (key, value) in context.to_dict().iteritems()]) msg.update(context) def call(context, topic, msg): """Sends a message on a topic and wait for a response""" LOG.debug(_("Making asynchronous call...")) msg_id = uuid.uuid4().hex msg.update({'_msg_id': msg_id}) LOG.debug(_("MSG_ID is %s") % (msg_id)) _pack_context(msg, context) class WaitMessage(object): def __call__(self, data, message): """Acks message and sets result.""" message.ack() if data['failure']: self.result = RemoteError(*data['failure']) else: self.result = data['result'] wait_msg = WaitMessage() conn = Connection.instance(True) consumer = DirectConsumer(connection=conn, msg_id=msg_id) consumer.register_callback(wait_msg) conn = Connection.instance() publisher = TopicPublisher(connection=conn, topic=topic) publisher.send(msg) publisher.close() try: consumer.wait(limit=1) except StopIteration: pass consumer.close() # NOTE(termie): this is a little bit of a change from the original # non-eventlet code where returning a Failure # instance from a deferred call is very similar to # raising an exception if isinstance(wait_msg.result, Exception): raise wait_msg.result return wait_msg.result def cast(context, topic, msg): """Sends a message on a topic without waiting for a response""" LOG.debug("Making asynchronous cast...") _pack_context(msg, context) conn = Connection.instance() publisher = TopicPublisher(connection=conn, topic=topic) publisher.send(msg) publisher.close() def generic_response(message_data, message): """Logs a result and exits""" LOG.debug(_('response %s'), message_data) message.ack() sys.exit(0) def send_message(topic, message, wait=True): """Sends a message for testing""" msg_id = uuid.uuid4().hex message.update({'_msg_id': msg_id}) LOG.debug(_('topic is %s'), topic) LOG.debug(_('message %s'), message) if wait: consumer = messaging.Consumer(connection=Connection.instance(), queue=msg_id, exchange=msg_id, auto_delete=True, exchange_type="direct", routing_key=msg_id) consumer.register_callback(generic_response) publisher = messaging.Publisher(connection=Connection.instance(), exchange=FLAGS.control_exchange, durable=False, exchange_type="topic", routing_key=topic) publisher.send(message) publisher.close() if wait: consumer.wait() if __name__ == "__main__": # NOTE(vish): you can send messages from the command line using # topic and a json sting representing a dictionary # for the method send_message(sys.argv[1], json.loads(sys.argv[2]))