summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/unit/rpc/__init__.py15
-rw-r--r--tests/unit/rpc/common.py322
-rw-r--r--tests/unit/rpc/test_common.py150
-rw-r--r--tests/unit/rpc/test_dispatcher.py110
-rw-r--r--tests/unit/rpc/test_fake.py32
-rw-r--r--tests/unit/rpc/test_kombu.py414
-rw-r--r--tests/unit/rpc/test_kombu_ssl.py82
-rw-r--r--tests/unit/rpc/test_matchmaker.py60
-rw-r--r--tests/unit/rpc/test_proxy.py128
-rw-r--r--tests/unit/rpc/test_qpid.py377
10 files changed, 1690 insertions, 0 deletions
diff --git a/tests/unit/rpc/__init__.py b/tests/unit/rpc/__init__.py
new file mode 100644
index 0000000..848908a
--- /dev/null
+++ b/tests/unit/rpc/__init__.py
@@ -0,0 +1,15 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit/rpc/common.py b/tests/unit/rpc/common.py
new file mode 100644
index 0000000..013418d
--- /dev/null
+++ b/tests/unit/rpc/common.py
@@ -0,0 +1,322 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls shared between all implementations
+"""
+
+import logging
+import time
+import unittest
+
+import eventlet
+from eventlet import greenthread
+import nose
+
+from openstack.common import cfg
+from openstack.common import exception
+from openstack.common.gettextutils import _
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common.rpc import common as rpc_common
+from openstack.common.rpc import dispatcher as rpc_dispatcher
+
+
+FLAGS = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+class BaseRpcTestCase(unittest.TestCase):
+ def setUp(self, supports_timeouts=True, topic='test',
+ topic_nested='nested'):
+ super(BaseRpcTestCase, self).setUp()
+ self.topic = topic or self.topic
+ self.topic_nested = topic_nested or self.topic_nested
+ self.supports_timeouts = supports_timeouts
+ self.context = rpc_common.CommonRpcContext(user='fake_user',
+ pw='fake_pw')
+
+ if self.rpc:
+ receiver = TestReceiver()
+ self.conn = self._create_consumer(receiver, self.topic)
+
+ def tearDown(self):
+ if self.rpc:
+ self.conn.close()
+ super(BaseRpcTestCase, self).tearDown()
+
+ def _create_consumer(self, proxy, topic, fanout=False):
+ dispatcher = rpc_dispatcher.RpcDispatcher([proxy])
+ conn = self.rpc.create_connection(FLAGS, True)
+ conn.create_consumer(topic, dispatcher, fanout)
+ conn.consume_in_thread()
+ return conn
+
+ def test_call_succeed(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.call(FLAGS, self.context, self.topic,
+ {"method": "echo", "args": {"value": value}})
+ self.assertEqual(value, result)
+
+ def test_call_succeed_despite_multiple_returns_yield(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.call(FLAGS, self.context, self.topic,
+ {"method": "echo_three_times_yield",
+ "args": {"value": value}})
+ self.assertEqual(value + 2, result)
+
+ def test_multicall_succeed_once(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.multicall(FLAGS, self.context,
+ self.topic,
+ {"method": "echo",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ if i > 0:
+ self.fail('should only receive one response')
+ self.assertEqual(value + i, x)
+
+ def test_multicall_three_nones(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.multicall(FLAGS, self.context,
+ self.topic,
+ {"method": "multicall_three_nones",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ self.assertEqual(x, None)
+ # i should have been 0, 1, and finally 2:
+ self.assertEqual(i, 2)
+
+ def test_multicall_succeed_three_times_yield(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ value = 42
+ result = self.rpc.multicall(FLAGS, self.context,
+ self.topic,
+ {"method": "echo_three_times_yield",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ self.assertEqual(value + i, x)
+
+ def test_context_passed(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ """Makes sure a context is passed through rpc call."""
+ value = 42
+ result = self.rpc.call(FLAGS, self.context,
+ self.topic, {"method": "context",
+ "args": {"value": value}})
+ self.assertEqual(self.context.to_dict(), result)
+
+ def _test_cast(self, fanout=False):
+ """Test casts by pushing items through a channeled queue."""
+
+ # Not a true global, but capitalized so
+ # it is clear it is leaking scope into Nested()
+ QUEUE = eventlet.queue.Queue()
+
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ # We use the nested topic so we don't need QUEUE to be a proper
+ # global, and do not keep state outside this test.
+ class Nested(object):
+ @staticmethod
+ def put_queue(context, value):
+ LOG.debug("Got value in put_queue: %s", value)
+ QUEUE.put(value)
+
+ nested = Nested()
+ conn = self._create_consumer(nested, self.topic_nested, fanout)
+ value = 42
+
+ method = (self.rpc.cast, self.rpc.fanout_cast)[fanout]
+ method(FLAGS, self.context,
+ self.topic_nested,
+ {"method": "put_queue",
+ "args": {"value": value}})
+
+ try:
+ # If it does not succeed in 2 seconds, give up and assume
+ # failure.
+ result = QUEUE.get(True, 2)
+ except Exception:
+ self.assertEqual(value, None)
+
+ conn.close()
+ self.assertEqual(value, result)
+
+ def test_cast_success(self):
+ self._test_cast(False)
+
+ def test_fanout_success(self):
+ self._test_cast(True)
+
+ def test_nested_calls(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ """Test that we can do an rpc.call inside another call."""
+ class Nested(object):
+ @staticmethod
+ def echo(context, queue, value):
+ """Calls echo in the passed queue."""
+ LOG.debug(_("Nested received %(queue)s, %(value)s")
+ % locals())
+ # TODO(comstud):
+ # so, it will replay the context and use the same REQID?
+ # that's bizarre.
+ ret = self.rpc.call(FLAGS, context,
+ queue,
+ {"method": "echo",
+ "args": {"value": value}})
+ LOG.debug(_("Nested return %s"), ret)
+ return value
+
+ nested = Nested()
+ conn = self._create_consumer(nested, self.topic_nested)
+
+ value = 42
+ result = self.rpc.call(FLAGS, self.context,
+ self.topic_nested,
+ {"method": "echo",
+ "args": {"queue": "test", "value": value}})
+ conn.close()
+ self.assertEqual(value, result)
+
+ def test_call_timeout(self):
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ """Make sure rpc.call will time out."""
+ if not self.supports_timeouts:
+ raise nose.SkipTest(_("RPC backend does not support timeouts"))
+
+ value = 42
+ self.assertRaises(rpc_common.Timeout,
+ self.rpc.call,
+ FLAGS, self.context,
+ self.topic,
+ {"method": "block",
+ "args": {"value": value}}, timeout=1)
+ try:
+ self.rpc.call(FLAGS, self.context,
+ self.topic,
+ {"method": "block",
+ "args": {"value": value}},
+ timeout=1)
+ self.fail("should have thrown Timeout")
+ except rpc_common.Timeout as exc:
+ pass
+
+
+class BaseRpcAMQPTestCase(BaseRpcTestCase):
+ """Base test class for all AMQP-based RPC tests."""
+ def test_proxycallback_handles_exceptions(self):
+ """Make sure exceptions unpacking messages don't cause hangs."""
+ if not self.rpc:
+ raise nose.SkipTest('rpc driver not available.')
+
+ orig_unpack = rpc_amqp.unpack_context
+
+ info = {'unpacked': False}
+
+ def fake_unpack_context(*args, **kwargs):
+ info['unpacked'] = True
+ raise test.TestingException('moo')
+
+ self.stubs.Set(rpc_amqp, 'unpack_context', fake_unpack_context)
+
+ value = 41
+ self.rpc.cast(FLAGS, self.context, self.topic,
+ {"method": "echo", "args": {"value": value}})
+
+ # Wait for the cast to complete.
+ for x in xrange(50):
+ if info['unpacked']:
+ break
+ greenthread.sleep(0.1)
+ else:
+ self.fail("Timeout waiting for message to be consumed")
+
+ # Now see if we get a response even though we raised an
+ # exception for the cast above.
+ self.stubs.Set(rpc_amqp, 'unpack_context', orig_unpack)
+
+ value = 42
+ result = self.rpc.call(FLAGS, self.context, self.topic,
+ {"method": "echo",
+ "args": {"value": value}})
+ self.assertEqual(value, result)
+
+
+class TestReceiver(object):
+ """Simple Proxy class so the consumer has methods to call.
+
+ Uses static methods because we aren't actually storing any state.
+
+ """
+ @staticmethod
+ def echo(context, value):
+ """Simply returns whatever value is sent in."""
+ LOG.debug(_("Received %s"), value)
+ return value
+
+ @staticmethod
+ def context(context, value):
+ """Returns dictionary version of context."""
+ LOG.debug(_("Received %s"), context)
+ return context.to_dict()
+
+ @staticmethod
+ def multicall_three_nones(context, value):
+ yield None
+ yield None
+ yield None
+
+ @staticmethod
+ def echo_three_times_yield(context, value):
+ yield value
+ yield value + 1
+ yield value + 2
+
+ @staticmethod
+ def fail(context, value):
+ """Raises an exception with the value sent in."""
+ raise NotImplementedError(value)
+
+ @staticmethod
+ def fail_converted(context, value):
+ """Raises an exception with the value sent in."""
+ raise exception.ApiError(message=value, code='500')
+
+ @staticmethod
+ def block(context, value):
+ time.sleep(2)
diff --git a/tests/unit/rpc/test_common.py b/tests/unit/rpc/test_common.py
new file mode 100644
index 0000000..73fb733
--- /dev/null
+++ b/tests/unit/rpc/test_common.py
@@ -0,0 +1,150 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for 'common' functons used through rpc code.
+"""
+
+import json
+import logging
+import sys
+import unittest
+
+from openstack.common import cfg
+from openstack.common import context
+from openstack.common import exception
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common.rpc import common as rpc_common
+from tests.unit.rpc import common
+
+
+FLAGS = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+def raise_exception():
+ raise Exception("test")
+
+
+class FakeUserDefinedException(Exception):
+ def __init__(self):
+ Exception.__init__(self, "Test Message")
+
+
+class RpcCommonTestCase(unittest.TestCase):
+ def test_serialize_remote_exception(self):
+ expected = {
+ 'class': 'Exception',
+ 'module': 'exceptions',
+ 'message': 'test',
+ }
+
+ try:
+ raise_exception()
+ except Exception as exc:
+ failure = rpc_common.serialize_remote_exception(sys.exc_info())
+
+ failure = json.loads(failure)
+ #assure the traceback was added
+ self.assertEqual(expected['class'], failure['class'])
+ self.assertEqual(expected['module'], failure['module'])
+ self.assertEqual(expected['message'], failure['message'])
+
+ def test_serialize_remote_custom_exception(self):
+ def raise_custom_exception():
+ raise exception.OpenstackException()
+
+ expected = {
+ 'class': 'OpenstackException',
+ 'module': 'openstack.common.exception',
+ 'message': exception.OpenstackException.message,
+ }
+
+ try:
+ raise_custom_exception()
+ except Exception as exc:
+ failure = rpc_common.serialize_remote_exception(sys.exc_info())
+
+ failure = json.loads(failure)
+ #assure the traceback was added
+ self.assertEqual(expected['class'], failure['class'])
+ self.assertEqual(expected['module'], failure['module'])
+ self.assertEqual(expected['message'], failure['message'])
+
+ def test_deserialize_remote_exception(self):
+ failure = {
+ 'class': 'OpenstackException',
+ 'module': 'openstack.common.exception',
+ 'message': exception.OpenstackException.message,
+ 'tb': ['raise OpenstackException'],
+ }
+ serialized = json.dumps(failure)
+
+ after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
+ self.assertTrue(isinstance(after_exc, exception.OpenstackException))
+ self.assertTrue('An unknown' in unicode(after_exc))
+ #assure the traceback was added
+ self.assertTrue('raise OpenstackException' in unicode(after_exc))
+
+ def test_deserialize_remote_exception_bad_module(self):
+ failure = {
+ 'class': 'popen2',
+ 'module': 'os',
+ 'kwargs': {'cmd': '/bin/echo failed'},
+ 'message': 'foo',
+ }
+ serialized = json.dumps(failure)
+
+ after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
+ self.assertTrue(isinstance(after_exc, rpc_common.RemoteError))
+
+ def test_deserialize_remote_exception_user_defined_exception(self):
+ """Ensure a user defined exception can be deserialized."""
+ FLAGS.set_override('allowed_rpc_exception_modules',
+ [self.__class__.__module__])
+ failure = {
+ 'class': 'FakeUserDefinedException',
+ 'module': self.__class__.__module__,
+ 'tb': ['raise FakeUserDefinedException'],
+ }
+ serialized = json.dumps(failure)
+
+ after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
+ self.assertTrue(isinstance(after_exc, FakeUserDefinedException))
+ #assure the traceback was added
+ self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc))
+ FLAGS.reset()
+
+ def test_deserialize_remote_exception_cannot_recreate(self):
+ """Ensure a RemoteError is returned on initialization failure.
+
+ If an exception cannot be recreated with it's original class then a
+ RemoteError with the exception informations should still be returned.
+
+ """
+ FLAGS.set_override('allowed_rpc_exception_modules',
+ [self.__class__.__module__])
+ failure = {
+ 'class': 'FakeIDontExistException',
+ 'module': self.__class__.__module__,
+ 'tb': ['raise FakeIDontExistException'],
+ }
+ serialized = json.dumps(failure)
+
+ after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
+ self.assertTrue(isinstance(after_exc, rpc_common.RemoteError))
+ #assure the traceback was added
+ self.assertTrue('raise FakeIDontExistException' in unicode(after_exc))
+ FLAGS.reset()
diff --git a/tests/unit/rpc/test_dispatcher.py b/tests/unit/rpc/test_dispatcher.py
new file mode 100644
index 0000000..a085567
--- /dev/null
+++ b/tests/unit/rpc/test_dispatcher.py
@@ -0,0 +1,110 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Unit Tests for rpc.dispatcher
+"""
+
+import unittest
+
+from openstack.common import context
+from openstack.common.rpc import common as rpc_common
+from openstack.common.rpc import dispatcher
+
+
+class RpcDispatcherTestCase(unittest.TestCase):
+ class API1(object):
+ RPC_API_VERSION = '1.0'
+
+ def __init__(self):
+ self.test_method_ctxt = None
+ self.test_method_arg1 = None
+
+ def test_method(self, ctxt, arg1):
+ self.test_method_ctxt = ctxt
+ self.test_method_arg1 = arg1
+
+ class API2(object):
+ RPC_API_VERSION = '2.1'
+
+ def __init__(self):
+ self.test_method_ctxt = None
+ self.test_method_arg1 = None
+
+ def test_method(self, ctxt, arg1):
+ self.test_method_ctxt = ctxt
+ self.test_method_arg1 = arg1
+
+ class API3(object):
+ RPC_API_VERSION = '3.5'
+
+ def __init__(self):
+ self.test_method_ctxt = None
+ self.test_method_arg1 = None
+
+ def test_method(self, ctxt, arg1):
+ self.test_method_ctxt = ctxt
+ self.test_method_arg1 = arg1
+
+ def setUp(self):
+ self.ctxt = context.RequestContext('fake_user', 'fake_project')
+ super(RpcDispatcherTestCase, self).setUp()
+
+ def tearDown(self):
+ super(RpcDispatcherTestCase, self).tearDown()
+
+ def _test_dispatch(self, version, expectations):
+ v2 = self.API2()
+ v3 = self.API3()
+ disp = dispatcher.RpcDispatcher([v2, v3])
+
+ disp.dispatch(self.ctxt, version, 'test_method', arg1=1)
+
+ self.assertEqual(v2.test_method_ctxt, expectations[0])
+ self.assertEqual(v2.test_method_arg1, expectations[1])
+ self.assertEqual(v3.test_method_ctxt, expectations[2])
+ self.assertEqual(v3.test_method_arg1, expectations[3])
+
+ def test_dispatch(self):
+ self._test_dispatch('2.1', (self.ctxt, 1, None, None))
+ self._test_dispatch('3.5', (None, None, self.ctxt, 1))
+
+ def test_dispatch_lower_minor_version(self):
+ self._test_dispatch('2.0', (self.ctxt, 1, None, None))
+ self._test_dispatch('3.1', (None, None, self.ctxt, 1))
+
+ def test_dispatch_higher_minor_version(self):
+ self.assertRaises(rpc_common.UnsupportedRpcVersion,
+ self._test_dispatch, '2.6', (None, None, None, None))
+ self.assertRaises(rpc_common.UnsupportedRpcVersion,
+ self._test_dispatch, '3.6', (None, None, None, None))
+
+ def test_dispatch_lower_major_version(self):
+ self.assertRaises(rpc_common.UnsupportedRpcVersion,
+ self._test_dispatch, '1.0', (None, None, None, None))
+
+ def test_dispatch_higher_major_version(self):
+ self.assertRaises(rpc_common.UnsupportedRpcVersion,
+ self._test_dispatch, '4.0', (None, None, None, None))
+
+ def test_dispatch_no_version_uses_v1(self):
+ v1 = self.API1()
+ disp = dispatcher.RpcDispatcher([v1])
+
+ disp.dispatch(self.ctxt, None, 'test_method', arg1=1)
+
+ self.assertEqual(v1.test_method_ctxt, self.ctxt)
+ self.assertEqual(v1.test_method_arg1, 1)
diff --git a/tests/unit/rpc/test_fake.py b/tests/unit/rpc/test_fake.py
new file mode 100644
index 0000000..8ceac47
--- /dev/null
+++ b/tests/unit/rpc/test_fake.py
@@ -0,0 +1,32 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls using fake_impl
+"""
+
+import eventlet
+eventlet.monkey_patch()
+
+from openstack.common.rpc import impl_fake
+from tests.unit.rpc import common
+
+
+class RpcFakeTestCase(common.BaseRpcTestCase):
+ def setUp(self):
+ self.rpc = impl_fake
+ super(RpcFakeTestCase, self).setUp()
diff --git a/tests/unit/rpc/test_kombu.py b/tests/unit/rpc/test_kombu.py
new file mode 100644
index 0000000..ccab4a2
--- /dev/null
+++ b/tests/unit/rpc/test_kombu.py
@@ -0,0 +1,414 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls using kombu
+"""
+
+import eventlet
+eventlet.monkey_patch()
+
+import logging
+import unittest
+
+import stubout
+
+from openstack.common import cfg
+from openstack.common import exception
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common.rpc import common as rpc_common
+from openstack.common import testutils
+from tests.unit.rpc import common
+
+try:
+ import kombu
+ from openstack.common.rpc import impl_kombu
+except ImportError:
+ kombu = None
+ impl_kombu = None
+
+
+FLAGS = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+class MyException(Exception):
+ pass
+
+
+def _raise_exc_stub(stubs, times, obj, method, exc_msg,
+ exc_class=MyException):
+ info = {'called': 0}
+ orig_method = getattr(obj, method)
+
+ def _raise_stub(*args, **kwargs):
+ info['called'] += 1
+ if info['called'] <= times:
+ raise exc_class(exc_msg)
+ orig_method(*args, **kwargs)
+ stubs.Set(obj, method, _raise_stub)
+ return info
+
+
+class RpcKombuTestCase(common.BaseRpcAMQPTestCase):
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ if kombu:
+ FLAGS.set_override('fake_rabbit', True)
+ FLAGS.set_override('rpc_response_timeout', 5)
+ self.rpc = impl_kombu
+ else:
+ self.rpc = None
+ super(RpcKombuTestCase, self).setUp()
+
+ def tearDown(self):
+ self.stubs.UnsetAll()
+ self.stubs.SmartUnsetAll()
+ if kombu:
+ impl_kombu.cleanup()
+ FLAGS.reset()
+ super(RpcKombuTestCase, self).tearDown()
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_reusing_connection(self):
+ """Test that reusing a connection returns same one."""
+ conn_context = self.rpc.create_connection(FLAGS, new=False)
+ conn1 = conn_context.connection
+ conn_context.close()
+ conn_context = self.rpc.create_connection(FLAGS, new=False)
+ conn2 = conn_context.connection
+ conn_context.close()
+ self.assertEqual(conn1, conn2)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_topic_send_receive(self):
+ """Test sending to a topic exchange/queue"""
+
+ conn = self.rpc.create_connection(FLAGS)
+ message = 'topic test message'
+
+ self.received_message = None
+
+ def _callback(message):
+ self.received_message = message
+
+ conn.declare_topic_consumer('a_topic', _callback)
+ conn.topic_send('a_topic', message)
+ conn.consume(limit=1)
+ conn.close()
+
+ self.assertEqual(self.received_message, message)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_topic_multiple_queues(self):
+ """Test sending to a topic exchange with multiple queues"""
+
+ conn = self.rpc.create_connection(FLAGS)
+ message = 'topic test message'
+
+ self.received_message_1 = None
+ self.received_message_2 = None
+
+ def _callback1(message):
+ self.received_message_1 = message
+
+ def _callback2(message):
+ self.received_message_2 = message
+
+ conn.declare_topic_consumer('a_topic', _callback1, queue_name='queue1')
+ conn.declare_topic_consumer('a_topic', _callback2, queue_name='queue2')
+ conn.topic_send('a_topic', message)
+ conn.consume(limit=2)
+ conn.close()
+
+ self.assertEqual(self.received_message_1, message)
+ self.assertEqual(self.received_message_2, message)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_direct_send_receive(self):
+ """Test sending to a direct exchange/queue"""
+ conn = self.rpc.create_connection(FLAGS)
+ message = 'direct test message'
+
+ self.received_message = None
+
+ def _callback(message):
+ self.received_message = message
+
+ conn.declare_direct_consumer('a_direct', _callback)
+ conn.direct_send('a_direct', message)
+ conn.consume(limit=1)
+ conn.close()
+
+ self.assertEqual(self.received_message, message)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_cast_interface_uses_default_options(self):
+ """Test kombu rpc.cast"""
+
+ ctxt = rpc_common.CommonRpcContext(user='fake_user',
+ project='fake_project')
+
+ class MyConnection(impl_kombu.Connection):
+ def __init__(myself, *args, **kwargs):
+ super(MyConnection, myself).__init__(*args, **kwargs)
+ self.assertEqual(myself.params,
+ {'hostname': FLAGS.rabbit_host,
+ 'userid': FLAGS.rabbit_userid,
+ 'password': FLAGS.rabbit_password,
+ 'port': FLAGS.rabbit_port,
+ 'virtual_host': FLAGS.rabbit_virtual_host,
+ 'transport': 'memory'})
+
+ def topic_send(_context, topic, msg):
+ pass
+
+ MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
+ self.stubs.Set(impl_kombu, 'Connection', MyConnection)
+
+ impl_kombu.cast(FLAGS, ctxt, 'fake_topic', {'msg': 'fake'})
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_cast_to_server_uses_server_params(self):
+ """Test kombu rpc.cast"""
+
+ ctxt = rpc_common.CommonRpcContext(user='fake_user',
+ project='fake_project')
+
+ server_params = {'username': 'fake_username',
+ 'password': 'fake_password',
+ 'hostname': 'fake_hostname',
+ 'port': 31337,
+ 'virtual_host': 'fake_virtual_host'}
+
+ class MyConnection(impl_kombu.Connection):
+ def __init__(myself, *args, **kwargs):
+ super(MyConnection, myself).__init__(*args, **kwargs)
+ self.assertEqual(myself.params,
+ {'hostname': server_params['hostname'],
+ 'userid': server_params['username'],
+ 'password': server_params['password'],
+ 'port': server_params['port'],
+ 'virtual_host': server_params['virtual_host'],
+ 'transport': 'memory'})
+
+ def topic_send(_context, topic, msg):
+ pass
+
+ MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
+ self.stubs.Set(impl_kombu, 'Connection', MyConnection)
+
+ impl_kombu.cast_to_server(FLAGS, ctxt, server_params,
+ 'fake_topic', {'msg': 'fake'})
+
+ @testutils.skip_test("kombu memory transport seems buggy with "
+ "fanout queues as this test passes when "
+ "you use rabbit (fake_rabbit=False)")
+ def test_fanout_send_receive(self):
+ """Test sending to a fanout exchange and consuming from 2 queues"""
+
+ conn = self.rpc.create_connection()
+ conn2 = self.rpc.create_connection()
+ message = 'fanout test message'
+
+ self.received_message = None
+
+ def _callback(message):
+ self.received_message = message
+
+ conn.declare_fanout_consumer('a_fanout', _callback)
+ conn2.declare_fanout_consumer('a_fanout', _callback)
+ conn.fanout_send('a_fanout', message)
+
+ conn.consume(limit=1)
+ conn.close()
+ self.assertEqual(self.received_message, message)
+
+ self.received_message = None
+ conn2.consume(limit=1)
+ conn2.close()
+ self.assertEqual(self.received_message, message)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_declare_consumer_errors_will_reconnect(self):
+ # Test that any exception with 'timeout' in it causes a
+ # reconnection
+ info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer,
+ '__init__', 'foo timeout foo')
+
+ conn = self.rpc.Connection(FLAGS)
+ result = conn.declare_consumer(self.rpc.DirectConsumer,
+ 'test_topic', None)
+
+ self.assertEqual(info['called'], 3)
+ self.assertTrue(isinstance(result, self.rpc.DirectConsumer))
+
+ # Test that any exception in transport.connection_errors causes
+ # a reconnection
+ self.stubs.UnsetAll()
+
+ info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectConsumer,
+ '__init__', 'meow')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.connection_errors = (MyException, )
+
+ result = conn.declare_consumer(self.rpc.DirectConsumer,
+ 'test_topic', None)
+
+ self.assertEqual(info['called'], 2)
+ self.assertTrue(isinstance(result, self.rpc.DirectConsumer))
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_declare_consumer_ioerrors_will_reconnect(self):
+ """Test that an IOError exception causes a reconnection"""
+ info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer,
+ '__init__', 'Socket closed', exc_class=IOError)
+
+ conn = self.rpc.Connection(FLAGS)
+ result = conn.declare_consumer(self.rpc.DirectConsumer,
+ 'test_topic', None)
+
+ self.assertEqual(info['called'], 3)
+ self.assertTrue(isinstance(result, self.rpc.DirectConsumer))
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_publishing_errors_will_reconnect(self):
+ # Test that any exception with 'timeout' in it causes a
+ # reconnection when declaring the publisher class and when
+ # calling send()
+ info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher,
+ '__init__', 'foo timeout foo')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
+
+ self.assertEqual(info['called'], 3)
+ self.stubs.UnsetAll()
+
+ info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher,
+ 'send', 'foo timeout foo')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
+
+ self.assertEqual(info['called'], 3)
+
+ # Test that any exception in transport.connection_errors causes
+ # a reconnection when declaring the publisher class and when
+ # calling send()
+ self.stubs.UnsetAll()
+
+ info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher,
+ '__init__', 'meow')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.connection_errors = (MyException, )
+
+ conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
+
+ self.assertEqual(info['called'], 2)
+ self.stubs.UnsetAll()
+
+ info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher,
+ 'send', 'meow')
+
+ conn = self.rpc.Connection(FLAGS)
+ conn.connection_errors = (MyException, )
+
+ conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg')
+
+ self.assertEqual(info['called'], 2)
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_iterconsume_errors_will_reconnect(self):
+ conn = self.rpc.Connection(FLAGS)
+ message = 'reconnect test message'
+
+ self.received_message = None
+
+ def _callback(message):
+ self.received_message = message
+
+ conn.declare_direct_consumer('a_direct', _callback)
+ conn.direct_send('a_direct', message)
+
+ info = _raise_exc_stub(self.stubs, 1, conn.connection,
+ 'drain_events', 'foo timeout foo')
+ conn.consume(limit=1)
+ conn.close()
+
+ self.assertEqual(self.received_message, message)
+ # Only called once, because our stub goes away during reconnection
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_call_exception(self):
+ """Test that exception gets passed back properly.
+
+ rpc.call returns an Exception object. The value of the
+ exception is converted to a string.
+
+ """
+ FLAGS.set_override('allowed_rpc_exception_modules', ['exceptions'])
+ value = "This is the exception message"
+ self.assertRaises(NotImplementedError,
+ self.rpc.call,
+ FLAGS,
+ self.context,
+ 'test',
+ {"method": "fail",
+ "args": {"value": value}})
+ try:
+ self.rpc.call(FLAGS, self.context,
+ 'test',
+ {"method": "fail",
+ "args": {"value": value}})
+ self.fail("should have thrown Exception")
+ except NotImplementedError as exc:
+ self.assertTrue(value in unicode(exc))
+ #Traceback should be included in exception message
+ self.assertTrue('raise NotImplementedError(value)' in unicode(exc))
+
+ FLAGS.reset()
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_call_converted_exception(self):
+ """Test that exception gets passed back properly.
+
+ rpc.call returns an Exception object. The value of the
+ exception is converted to a string.
+
+ """
+ value = "This is the exception message"
+ # The use of ApiError is an arbitrary choice here ...
+ self.assertRaises(exception.ApiError,
+ self.rpc.call,
+ FLAGS,
+ self.context,
+ 'test',
+ {"method": "fail_converted",
+ "args": {"value": value}})
+ try:
+ self.rpc.call(FLAGS, self.context,
+ 'test',
+ {"method": "fail_converted",
+ "args": {"value": value}})
+ self.fail("should have thrown Exception")
+ except exception.ApiError as exc:
+ self.assertTrue(value in unicode(exc))
+ #Traceback should be included in exception message
+ self.assertTrue('exception.ApiError' in unicode(exc))
diff --git a/tests/unit/rpc/test_kombu_ssl.py b/tests/unit/rpc/test_kombu_ssl.py
new file mode 100644
index 0000000..2aecc3f
--- /dev/null
+++ b/tests/unit/rpc/test_kombu_ssl.py
@@ -0,0 +1,82 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls using kombu + ssl
+"""
+
+import eventlet
+eventlet.monkey_patch()
+
+import unittest
+
+from openstack.common import cfg
+from openstack.common import testutils
+
+try:
+ import kombu
+ from openstack.common.rpc import impl_kombu
+except ImportError:
+ kombu = None
+ impl_kombu = None
+
+
+# Flag settings we will ensure get passed to amqplib
+SSL_VERSION = "SSLv2"
+SSL_CERT = "/tmp/cert.blah.blah"
+SSL_CA_CERT = "/tmp/cert.ca.blah.blah"
+SSL_KEYFILE = "/tmp/keyfile.blah.blah"
+
+FLAGS = cfg.CONF
+
+
+class RpcKombuSslTestCase(unittest.TestCase):
+
+ def setUp(self):
+ super(RpcKombuSslTestCase, self).setUp()
+ override = {
+ 'kombu_ssl_keyfile': SSL_KEYFILE,
+ 'kombu_ssl_ca_certs': SSL_CA_CERT,
+ 'kombu_ssl_certfile': SSL_CERT,
+ 'kombu_ssl_version': SSL_VERSION,
+ 'rabbit_use_ssl': True,
+ 'fake_rabbit': True,
+ }
+
+ if kombu:
+ for k, v in override.iteritems():
+ FLAGS.set_override(k, v)
+
+ def tearDown(self):
+ super(RpcKombuSslTestCase, self).tearDown()
+ if kombu:
+ FLAGS.reset()
+
+ @testutils.skip_if(kombu is None, "Test requires kombu")
+ def test_ssl_on_extended(self):
+ rpc = impl_kombu
+ conn = rpc.create_connection(FLAGS, True)
+ c = conn.connection
+ #This might be kombu version dependent...
+ #Since we are now peaking into the internals of kombu...
+ self.assertTrue(isinstance(c.connection.ssl, dict))
+ self.assertEqual(SSL_VERSION, c.connection.ssl.get("ssl_version"))
+ self.assertEqual(SSL_CERT, c.connection.ssl.get("certfile"))
+ self.assertEqual(SSL_CA_CERT, c.connection.ssl.get("ca_certs"))
+ self.assertEqual(SSL_KEYFILE, c.connection.ssl.get("keyfile"))
+ #That hash then goes into amqplib which then goes
+ #Into python ssl creation...
diff --git a/tests/unit/rpc/test_matchmaker.py b/tests/unit/rpc/test_matchmaker.py
new file mode 100644
index 0000000..a38b59c
--- /dev/null
+++ b/tests/unit/rpc/test_matchmaker.py
@@ -0,0 +1,60 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 Cloudscaling Group, Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import logging
+import unittest
+
+from openstack.common.rpc import matchmaker
+
+
+LOG = logging.getLogger(__name__)
+
+
+class _MatchMakerTestCase(unittest.TestCase):
+ def test_valid_host_matches(self):
+ queues = self.driver.queues(self.topic)
+ matched_hosts = map(lambda x: x[1], queues)
+
+ for host in matched_hosts:
+ self.assertIn(host, self.hosts)
+
+ def test_fanout_host_matches(self):
+ """For known hosts, see if they're in fanout."""
+ queues = self.driver.queues("fanout~" + self.topic)
+ matched_hosts = map(lambda x: x[1], queues)
+
+ LOG.info("Received result from matchmaker: %s", queues)
+ for host in self.hosts:
+ self.assertIn(host, matched_hosts)
+
+
+class MatchMakerFileTestCase(_MatchMakerTestCase):
+ def setUp(self):
+ self.topic = "test"
+ self.hosts = ['hello', 'world', 'foo', 'bar', 'baz']
+ ring = {
+ self.topic: self.hosts
+ }
+ self.driver = matchmaker.MatchMakerRing(ring)
+ super(MatchMakerFileTestCase, self).setUp()
+
+
+class MatchMakerLocalhostTestCase(_MatchMakerTestCase):
+ def setUp(self):
+ self.driver = matchmaker.MatchMakerLocalhost()
+ self.topic = "test"
+ self.hosts = ['localhost']
+ super(MatchMakerLocalhostTestCase, self).setUp()
diff --git a/tests/unit/rpc/test_proxy.py b/tests/unit/rpc/test_proxy.py
new file mode 100644
index 0000000..1af37c7
--- /dev/null
+++ b/tests/unit/rpc/test_proxy.py
@@ -0,0 +1,128 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Unit Tests for rpc.proxy
+"""
+
+import copy
+import stubout
+import unittest
+
+from openstack.common import context
+from openstack.common import rpc
+from openstack.common.rpc import proxy
+
+
+class RpcProxyTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ super(RpcProxyTestCase, self).setUp()
+
+ def tearDown(self):
+ self.stubs.UnsetAll()
+ self.stubs.SmartUnsetAll()
+ super(RpcProxyTestCase, self).tearDown()
+
+ def _test_rpc_method(self, rpc_method, has_timeout=False, has_retval=False,
+ server_params=None, supports_topic_override=True):
+ topic = 'fake_topic'
+ timeout = 123
+ rpc_proxy = proxy.RpcProxy(topic, '1.0')
+ ctxt = context.RequestContext('fake_user', 'fake_project')
+ msg = {'method': 'fake_method', 'args': {'x': 'y'}}
+ expected_msg = {'method': 'fake_method', 'args': {'x': 'y'},
+ 'version': '1.0'}
+
+ expected_retval = 'hi' if has_retval else None
+
+ self.fake_args = None
+ self.fake_kwargs = None
+
+ def _fake_rpc_method(*args, **kwargs):
+ self.fake_args = args
+ self.fake_kwargs = kwargs
+ if has_retval:
+ return expected_retval
+
+ self.stubs.Set(rpc, rpc_method, _fake_rpc_method)
+
+ args = [ctxt, msg]
+ if server_params:
+ args.insert(1, server_params)
+
+ # Base method usage
+ retval = getattr(rpc_proxy, rpc_method)(*args)
+ self.assertEqual(retval, expected_retval)
+ expected_args = [ctxt, topic, expected_msg]
+ if server_params:
+ expected_args.insert(1, server_params)
+ for arg, expected_arg in zip(self.fake_args, expected_args):
+ self.assertEqual(arg, expected_arg)
+
+ # overriding the version
+ retval = getattr(rpc_proxy, rpc_method)(*args, version='1.1')
+ self.assertEqual(retval, expected_retval)
+ new_msg = copy.deepcopy(expected_msg)
+ new_msg['version'] = '1.1'
+ expected_args = [ctxt, topic, new_msg]
+ if server_params:
+ expected_args.insert(1, server_params)
+ for arg, expected_arg in zip(self.fake_args, expected_args):
+ self.assertEqual(arg, expected_arg)
+
+ if has_timeout:
+ # set a timeout
+ retval = getattr(rpc_proxy, rpc_method)(ctxt, msg, timeout=timeout)
+ self.assertEqual(retval, expected_retval)
+ expected_args = [ctxt, topic, expected_msg, timeout]
+ for arg, expected_arg in zip(self.fake_args, expected_args):
+ self.assertEqual(arg, expected_arg)
+
+ if supports_topic_override:
+ # set a topic
+ new_topic = 'foo.bar'
+ retval = getattr(rpc_proxy, rpc_method)(*args, topic=new_topic)
+ self.assertEqual(retval, expected_retval)
+ expected_args = [ctxt, new_topic, expected_msg]
+ if server_params:
+ expected_args.insert(1, server_params)
+ for arg, expected_arg in zip(self.fake_args, expected_args):
+ self.assertEqual(arg, expected_arg)
+
+ def test_call(self):
+ self._test_rpc_method('call', has_timeout=True, has_retval=True)
+
+ def test_multicall(self):
+ self._test_rpc_method('multicall', has_timeout=True, has_retval=True)
+
+ def test_cast(self):
+ self._test_rpc_method('cast')
+
+ def test_fanout_cast(self):
+ self._test_rpc_method('fanout_cast', supports_topic_override=False)
+
+ def test_cast_to_server(self):
+ self._test_rpc_method('cast_to_server', server_params={'blah': 1})
+
+ def test_fanout_cast_to_server(self):
+ self._test_rpc_method('fanout_cast_to_server',
+ server_params={'blah': 1}, supports_topic_override=False)
+
+ def test_make_msg(self):
+ self.assertEqual(proxy.RpcProxy.make_msg('test_method', a=1, b=2),
+ {'method': 'test_method', 'args': {'a': 1, 'b': 2}})
diff --git a/tests/unit/rpc/test_qpid.py b/tests/unit/rpc/test_qpid.py
new file mode 100644
index 0000000..b753c22
--- /dev/null
+++ b/tests/unit/rpc/test_qpid.py
@@ -0,0 +1,377 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+# Copyright 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for remote procedure calls using qpid
+"""
+
+import eventlet
+eventlet.monkey_patch()
+
+import logging
+import mox
+import stubout
+import unittest
+
+from openstack.common import cfg
+from openstack.common import context
+from openstack.common.rpc import amqp as rpc_amqp
+from openstack.common import testutils
+
+try:
+ from openstack.common.rpc import impl_qpid
+ import qpid
+except ImportError:
+ qpid = None
+ impl_qpid = None
+
+
+FLAGS = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+class RpcQpidTestCase(unittest.TestCase):
+ """
+ Exercise the public API of impl_qpid utilizing mox.
+
+ This set of tests utilizes mox to replace the Qpid objects and ensures
+ that the right operations happen on them when the various public rpc API
+ calls are exercised. The API calls tested here include:
+
+ nova.rpc.create_connection()
+ nova.rpc.common.Connection.create_consumer()
+ nova.rpc.common.Connection.close()
+ nova.rpc.cast()
+ nova.rpc.fanout_cast()
+ nova.rpc.call()
+ nova.rpc.multicall()
+ """
+
+ def setUp(self):
+ super(RpcQpidTestCase, self).setUp()
+
+ self.stubs = stubout.StubOutForTesting()
+ self.mox = mox.Mox()
+
+ self.mock_connection = None
+ self.mock_session = None
+ self.mock_sender = None
+ self.mock_receiver = None
+
+ if qpid:
+ self.orig_connection = qpid.messaging.Connection
+ self.orig_session = qpid.messaging.Session
+ self.orig_sender = qpid.messaging.Sender
+ self.orig_receiver = qpid.messaging.Receiver
+ qpid.messaging.Connection = lambda *_x, **_y: self.mock_connection
+ qpid.messaging.Session = lambda *_x, **_y: self.mock_session
+ qpid.messaging.Sender = lambda *_x, **_y: self.mock_sender
+ qpid.messaging.Receiver = lambda *_x, **_y: self.mock_receiver
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.stubs.UnsetAll()
+ self.stubs.SmartUnsetAll()
+ if qpid:
+ qpid.messaging.Connection = self.orig_connection
+ qpid.messaging.Session = self.orig_session
+ qpid.messaging.Sender = self.orig_sender
+ qpid.messaging.Receiver = self.orig_receiver
+ if impl_qpid:
+ # Need to reset this in case we changed the connection_cls
+ # in self._setup_to_server_tests()
+ impl_qpid.Connection.pool.connection_cls = impl_qpid.Connection
+
+ super(RpcQpidTestCase, self).tearDown()
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_create_connection(self):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+ self.mock_connection.session().AndReturn(self.mock_session)
+ self.mock_connection.close()
+
+ self.mox.ReplayAll()
+
+ connection = impl_qpid.create_connection(FLAGS)
+ connection.close()
+
+ def _test_create_consumer(self, fanout):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+ self.mock_receiver = self.mox.CreateMock(self.orig_receiver)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+ self.mock_connection.session().AndReturn(self.mock_session)
+ if fanout:
+ # The link name includes a UUID, so match it with a regex.
+ expected_address = mox.Regex(r'^impl_qpid_test_fanout ; '
+ '{"node": {"x-declare": {"auto-delete": true, "durable": '
+ 'false, "type": "fanout"}, "type": "topic"}, "create": '
+ '"always", "link": {"x-declare": {"auto-delete": true, '
+ '"exclusive": true, "durable": false}, "durable": true, '
+ '"name": "impl_qpid_test_fanout_.*"}}$')
+ else:
+ expected_address = ('nova/impl_qpid_test ; {"node": {"x-declare": '
+ '{"auto-delete": true, "durable": true}, "type": "topic"}, '
+ '"create": "always", "link": {"x-declare": {"auto-delete": '
+ 'true, "exclusive": false, "durable": false}, "durable": '
+ 'true, "name": "impl_qpid_test"}}')
+ self.mock_session.receiver(expected_address).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.capacity = 1
+ self.mock_connection.close()
+
+ self.mox.ReplayAll()
+
+ connection = impl_qpid.create_connection(FLAGS)
+ connection.create_consumer("impl_qpid_test",
+ lambda *_x, **_y: None,
+ fanout)
+ connection.close()
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_create_consumer(self):
+ self._test_create_consumer(fanout=False)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_create_consumer_fanout(self):
+ self._test_create_consumer(fanout=True)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_create_worker(self):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+ self.mock_receiver = self.mox.CreateMock(self.orig_receiver)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+ self.mock_connection.session().AndReturn(self.mock_session)
+ expected_address = (
+ 'nova/impl_qpid_test ; {"node": {"x-declare": '
+ '{"auto-delete": true, "durable": true}, "type": "topic"}, '
+ '"create": "always", "link": {"x-declare": {"auto-delete": '
+ 'true, "exclusive": false, "durable": false}, "durable": '
+ 'true, "name": "impl.qpid.test.workers"}}')
+ self.mock_session.receiver(expected_address).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.capacity = 1
+ self.mock_connection.close()
+
+ self.mox.ReplayAll()
+
+ connection = impl_qpid.create_connection(FLAGS)
+ connection.create_worker("impl_qpid_test",
+ lambda *_x, **_y: None,
+ 'impl.qpid.test.workers',
+ )
+ connection.close()
+
+ def _test_cast(self, fanout, server_params=None):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+ self.mock_sender = self.mox.CreateMock(self.orig_sender)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+
+ self.mock_connection.session().AndReturn(self.mock_session)
+ if fanout:
+ expected_address = ('impl_qpid_test_fanout ; '
+ '{"node": {"x-declare": {"auto-delete": true, '
+ '"durable": false, "type": "fanout"}, '
+ '"type": "topic"}, "create": "always"}')
+ else:
+ expected_address = ('nova/impl_qpid_test ; {"node": {"x-declare": '
+ '{"auto-delete": true, "durable": false}, "type": "topic"}, '
+ '"create": "always"}')
+ self.mock_session.sender(expected_address).AndReturn(self.mock_sender)
+ self.mock_sender.send(mox.IgnoreArg())
+ if not server_params:
+ # This is a pooled connection, so instead of closing it, it
+ # gets reset, which is just creating a new session on the
+ # connection.
+ self.mock_session.close()
+ self.mock_connection.session().AndReturn(self.mock_session)
+
+ self.mox.ReplayAll()
+
+ try:
+ ctx = context.RequestContext("user", "project")
+
+ args = [FLAGS, ctx, "impl_qpid_test",
+ {"method": "test_method", "args": {}}]
+
+ if server_params:
+ args.insert(2, server_params)
+ if fanout:
+ method = impl_qpid.fanout_cast_to_server
+ else:
+ method = impl_qpid.cast_to_server
+ else:
+ if fanout:
+ method = impl_qpid.fanout_cast
+ else:
+ method = impl_qpid.cast
+
+ method(*args)
+ finally:
+ while impl_qpid.Connection.pool.free_items:
+ # Pull the mock connection object out of the connection pool so
+ # that it doesn't mess up other test cases.
+ impl_qpid.Connection.pool.get()
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_cast(self):
+ self._test_cast(fanout=False)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_fanout_cast(self):
+ self._test_cast(fanout=True)
+
+ def _setup_to_server_tests(self, server_params):
+ class MyConnection(impl_qpid.Connection):
+ def __init__(myself, *args, **kwargs):
+ super(MyConnection, myself).__init__(*args, **kwargs)
+ self.assertEqual(myself.connection.username,
+ server_params['username'])
+ self.assertEqual(myself.connection.password,
+ server_params['password'])
+ self.assertEqual(myself.broker,
+ server_params['hostname'] + ':' +
+ str(server_params['port']))
+
+ MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection)
+ self.stubs.Set(impl_qpid, 'Connection', MyConnection)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_cast_to_server(self):
+ server_params = {'username': 'fake_username',
+ 'password': 'fake_password',
+ 'hostname': 'fake_hostname',
+ 'port': 31337}
+ self._setup_to_server_tests(server_params)
+ self._test_cast(fanout=False, server_params=server_params)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_fanout_cast_to_server(self):
+ server_params = {'username': 'fake_username',
+ 'password': 'fake_password',
+ 'hostname': 'fake_hostname',
+ 'port': 31337}
+ self._setup_to_server_tests(server_params)
+ self._test_cast(fanout=True, server_params=server_params)
+
+ def _test_call(self, multi):
+ self.mock_connection = self.mox.CreateMock(self.orig_connection)
+ self.mock_session = self.mox.CreateMock(self.orig_session)
+ self.mock_sender = self.mox.CreateMock(self.orig_sender)
+ self.mock_receiver = self.mox.CreateMock(self.orig_receiver)
+
+ self.mock_connection.opened().AndReturn(False)
+ self.mock_connection.open()
+ self.mock_connection.session().AndReturn(self.mock_session)
+ rcv_addr = mox.Regex(r'^.*/.* ; {"node": {"x-declare": {"auto-delete":'
+ ' true, "durable": true, "type": "direct"}, "type": '
+ '"topic"}, "create": "always", "link": {"x-declare": '
+ '{"auto-delete": true, "exclusive": true, "durable": '
+ 'false}, "durable": true, "name": ".*"}}')
+ self.mock_session.receiver(rcv_addr).AndReturn(self.mock_receiver)
+ self.mock_receiver.capacity = 1
+ send_addr = ('nova/impl_qpid_test ; {"node": {"x-declare": '
+ '{"auto-delete": true, "durable": false}, "type": "topic"}, '
+ '"create": "always"}')
+ self.mock_session.sender(send_addr).AndReturn(self.mock_sender)
+ self.mock_sender.send(mox.IgnoreArg())
+
+ self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.fetch().AndReturn(qpid.messaging.Message(
+ {"result": "foo", "failure": False, "ending": False}))
+ if multi:
+ self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.fetch().AndReturn(
+ qpid.messaging.Message(
+ {"result": "bar", "failure": False,
+ "ending": False}))
+ self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.fetch().AndReturn(
+ qpid.messaging.Message(
+ {"result": "baz", "failure": False,
+ "ending": False}))
+ self.mock_session.next_receiver(timeout=mox.IsA(int)).AndReturn(
+ self.mock_receiver)
+ self.mock_receiver.fetch().AndReturn(qpid.messaging.Message(
+ {"failure": False, "ending": True}))
+ self.mock_session.close()
+ self.mock_connection.session().AndReturn(self.mock_session)
+
+ self.mox.ReplayAll()
+
+ try:
+ ctx = context.RequestContext("user", "project")
+
+ if multi:
+ method = impl_qpid.multicall
+ else:
+ method = impl_qpid.call
+
+ res = method(FLAGS, ctx, "impl_qpid_test",
+ {"method": "test_method", "args": {}})
+
+ if multi:
+ self.assertEquals(list(res), ["foo", "bar", "baz"])
+ else:
+ self.assertEquals(res, "foo")
+ finally:
+ while impl_qpid.Connection.pool.free_items:
+ # Pull the mock connection object out of the connection pool so
+ # that it doesn't mess up other test cases.
+ impl_qpid.Connection.pool.get()
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_call(self):
+ self._test_call(multi=False)
+
+ @testutils.skip_if(qpid is None, "Test requires qpid")
+ def test_multicall(self):
+ self._test_call(multi=True)
+
+
+#
+#from nova.tests.rpc import common
+#
+# Qpid does not have a handy in-memory transport like kombu, so it's not
+# terribly straight forward to take advantage of the common unit tests.
+# However, at least at the time of this writing, the common unit tests all pass
+# with qpidd running.
+#
+# class RpcQpidCommonTestCase(common._BaseRpcTestCase):
+# def setUp(self):
+# self.rpc = impl_qpid
+# super(RpcQpidCommonTestCase, self).setUp()
+#
+# def tearDown(self):
+# super(RpcQpidCommonTestCase, self).tearDown()
+#