summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSandy Walsh <sandy.walsh@rackspace.com>2011-05-27 07:31:29 -0700
committerSandy Walsh <sandy.walsh@rackspace.com>2011-05-27 07:31:29 -0700
commitceb6eee5ddbbd202af80ae32795bbf53d2e9ef49 (patch)
tree9cf90f9db1d5e563c5326c96d2d34e76a879d7a8
parent3f911877a2a9facdf153f173b3fb76a18e44a2ac (diff)
parenta7c36f68793a7db454d344187d4596ebecc8ade0 (diff)
downloadnova-ceb6eee5ddbbd202af80ae32795bbf53d2e9ef49.tar.gz
nova-ceb6eee5ddbbd202af80ae32795bbf53d2e9ef49.tar.xz
nova-ceb6eee5ddbbd202af80ae32795bbf53d2e9ef49.zip
trunk merge
-rw-r--r--nova/fakerabbit.py31
-rw-r--r--nova/rpc.py271
-rw-r--r--nova/service.py60
-rw-r--r--nova/test.py9
-rw-r--r--nova/tests/integrated/integrated_helpers.py5
-rw-r--r--nova/tests/test_cloud.py26
-rw-r--r--nova/tests/test_rpc.py116
-rw-r--r--nova/tests/test_service.py59
-rw-r--r--nova/tests/test_xenapi.py23
-rw-r--r--nova/tests/xenapi/stubs.py26
-rw-r--r--nova/virt/xenapi/fake.py5
-rw-r--r--nova/virt/xenapi/vm_utils.py49
-rw-r--r--nova/virt/xenapi/vmops.py59
-rw-r--r--plugins/xenserver/xenapi/etc/xapi.d/plugins/glance96
14 files changed, 637 insertions, 198 deletions
diff --git a/nova/fakerabbit.py b/nova/fakerabbit.py
index a7dee8caf..e7e9dab77 100644
--- a/nova/fakerabbit.py
+++ b/nova/fakerabbit.py
@@ -31,6 +31,7 @@ LOG = logging.getLogger("nova.fakerabbit")
EXCHANGES = {}
QUEUES = {}
+CONSUMERS = {}
class Message(base.BaseMessage):
@@ -96,17 +97,29 @@ class Backend(base.BaseBackend):
' key %(routing_key)s') % locals())
EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key)
- def declare_consumer(self, queue, callback, *args, **kwargs):
- self.current_queue = queue
- self.current_callback = callback
+ def declare_consumer(self, queue, callback, consumer_tag, *args, **kwargs):
+ global CONSUMERS
+ LOG.debug("Adding consumer %s", consumer_tag)
+ CONSUMERS[consumer_tag] = (queue, callback)
+
+ def cancel(self, consumer_tag):
+ global CONSUMERS
+ LOG.debug("Removing consumer %s", consumer_tag)
+ del CONSUMERS[consumer_tag]
def consume(self, limit=None):
+ global CONSUMERS
+ num = 0
while True:
- item = self.get(self.current_queue)
- if item:
- self.current_callback(item)
- raise StopIteration()
- greenthread.sleep(0)
+ for (queue, callback) in CONSUMERS.itervalues():
+ item = self.get(queue)
+ if item:
+ callback(item)
+ num += 1
+ yield
+ if limit and num == limit:
+ raise StopIteration()
+ greenthread.sleep(0.1)
def get(self, queue, no_ack=False):
global QUEUES
@@ -134,5 +147,7 @@ class Backend(base.BaseBackend):
def reset_all():
global EXCHANGES
global QUEUES
+ global CONSUMERS
EXCHANGES = {}
QUEUES = {}
+ CONSUMERS = {}
diff --git a/nova/rpc.py b/nova/rpc.py
index 2116f22c3..c5277c6a9 100644
--- a/nova/rpc.py
+++ b/nova/rpc.py
@@ -28,12 +28,15 @@ import json
import sys
import time
import traceback
+import types
import uuid
from carrot import connection as carrot_connection
from carrot import messaging
from eventlet import greenpool
-from eventlet import greenthread
+from eventlet import pools
+from eventlet import queue
+import greenlet
from nova import context
from nova import exception
@@ -47,7 +50,10 @@ LOG = logging.getLogger('nova.rpc')
FLAGS = flags.FLAGS
-flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool')
+flags.DEFINE_integer('rpc_thread_pool_size', 1024,
+ 'Size of RPC thread pool')
+flags.DEFINE_integer('rpc_conn_pool_size', 30,
+ 'Size of RPC connection pool')
class Connection(carrot_connection.BrokerConnection):
@@ -90,6 +96,22 @@ class Connection(carrot_connection.BrokerConnection):
return cls.instance()
+class Pool(pools.Pool):
+ """Class that implements a Pool of Connections."""
+
+ # TODO(comstud): Timeout connections not used in a while
+ def create(self):
+ LOG.debug('Creating new connection')
+ return Connection.instance(new=True)
+
+# Create a ConnectionPool to use for RPC calls. We'll order the
+# pool as a stack (LIFO), so that we can potentially loop through and
+# timeout old unused connections at some point
+ConnectionPool = Pool(
+ max_size=FLAGS.rpc_conn_pool_size,
+ order_as_stack=True)
+
+
class Consumer(messaging.Consumer):
"""Consumer base class.
@@ -131,7 +153,9 @@ class Consumer(messaging.Consumer):
self.connection = Connection.recreate()
self.backend = self.connection.create_backend()
self.declare()
- super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks)
+ return super(Consumer, self).fetch(no_ack,
+ auto_ack,
+ enable_callbacks)
if self.failed_connection:
LOG.error(_('Reconnected to queue'))
self.failed_connection = False
@@ -159,13 +183,13 @@ class AdapterConsumer(Consumer):
self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size)
super(AdapterConsumer, self).__init__(connection=connection,
topic=topic)
+ self.register_callback(self.process_data)
- def receive(self, *args, **kwargs):
- self.pool.spawn_n(self._receive, *args, **kwargs)
+ def process_data(self, message_data, message):
+ """Consumer callback to call a method on a proxy object.
- @exception.wrap_exception
- def _receive(self, message_data, message):
- """Magically looks for a method on the proxy object and calls it.
+ Parses the message for validity and fires off a thread to call the
+ proxy object method.
Message data should be a dictionary with two keys:
method: string representing the method to call
@@ -175,8 +199,8 @@ class AdapterConsumer(Consumer):
"""
LOG.debug(_('received %s') % message_data)
- msg_id = message_data.pop('_msg_id', None)
-
+ # This will be popped off in _unpack_context
+ msg_id = message_data.get('_msg_id', None)
ctxt = _unpack_context(message_data)
method = message_data.get('method')
@@ -188,8 +212,17 @@ class AdapterConsumer(Consumer):
# we just log the message and send an error string
# back to the caller
LOG.warn(_('no method for message: %s') % message_data)
- msg_reply(msg_id, _('No method for message: %s') % message_data)
+ if msg_id:
+ msg_reply(msg_id,
+ _('No method for message: %s') % message_data)
return
+ self.pool.spawn_n(self._process_data, msg_id, ctxt, method, args)
+
+ @exception.wrap_exception
+ def _process_data(self, msg_id, ctxt, method, args):
+ """Thread that maigcally looks for a method on the proxy
+ object and calls it.
+ """
node_func = getattr(self.proxy, str(method))
node_args = dict((str(k), v) for k, v in args.iteritems())
@@ -197,7 +230,18 @@ class AdapterConsumer(Consumer):
try:
rval = node_func(context=ctxt, **node_args)
if msg_id:
- msg_reply(msg_id, rval, None)
+ # Check if the result was a generator
+ if isinstance(rval, types.GeneratorType):
+ for x in rval:
+ msg_reply(msg_id, x, None)
+ else:
+ msg_reply(msg_id, rval, None)
+
+ # This final None tells multicall that it is done.
+ msg_reply(msg_id, None, None)
+ elif isinstance(rval, types.GeneratorType):
+ # NOTE(vish): this iterates through the generator
+ list(rval)
except Exception as e:
logging.exception('Exception during message handling')
if msg_id:
@@ -205,11 +249,6 @@ class AdapterConsumer(Consumer):
return
-class Publisher(messaging.Publisher):
- """Publisher base class."""
- pass
-
-
class TopicAdapterConsumer(AdapterConsumer):
"""Consumes messages on a specific topic."""
@@ -242,6 +281,58 @@ class FanoutAdapterConsumer(AdapterConsumer):
topic=topic, proxy=proxy)
+class ConsumerSet(object):
+ """Groups consumers to listen on together on a single connection."""
+
+ def __init__(self, connection, consumer_list):
+ self.consumer_list = set(consumer_list)
+ self.consumer_set = None
+ self.enabled = True
+ self.init(connection)
+
+ def init(self, conn):
+ if not conn:
+ conn = Connection.instance(new=True)
+ if self.consumer_set:
+ self.consumer_set.close()
+ self.consumer_set = messaging.ConsumerSet(conn)
+ for consumer in self.consumer_list:
+ consumer.connection = conn
+ # consumer.backend is set for us
+ self.consumer_set.add_consumer(consumer)
+
+ def reconnect(self):
+ self.init(None)
+
+ def wait(self, limit=None):
+ running = True
+ while running:
+ it = self.consumer_set.iterconsume(limit=limit)
+ if not it:
+ break
+ while True:
+ try:
+ it.next()
+ except StopIteration:
+ return
+ except greenlet.GreenletExit:
+ running = False
+ break
+ except Exception as e:
+ LOG.exception(_("Exception while processing consumer"))
+ self.reconnect()
+ # Break to outer loop
+ break
+
+ def close(self):
+ self.consumer_set.close()
+
+
+class Publisher(messaging.Publisher):
+ """Publisher base class."""
+ pass
+
+
class TopicPublisher(Publisher):
"""Publishes messages on a specific topic."""
@@ -306,16 +397,18 @@ def msg_reply(msg_id, reply=None, failure=None):
LOG.error(_("Returning exception %s to caller"), message)
LOG.error(tb)
failure = (failure[0].__name__, str(failure[1]), tb)
- conn = Connection.instance()
- publisher = DirectPublisher(connection=conn, msg_id=msg_id)
- try:
- publisher.send({'result': reply, 'failure': failure})
- except TypeError:
- publisher.send(
- {'result': dict((k, repr(v))
- for k, v in reply.__dict__.iteritems()),
- 'failure': failure})
- publisher.close()
+
+ with ConnectionPool.item() as conn:
+ publisher = DirectPublisher(connection=conn, msg_id=msg_id)
+ try:
+ publisher.send({'result': reply, 'failure': failure})
+ except TypeError:
+ publisher.send(
+ {'result': dict((k, repr(v))
+ for k, v in reply.__dict__.iteritems()),
+ 'failure': failure})
+
+ publisher.close()
class RemoteError(exception.Error):
@@ -347,8 +440,9 @@ def _unpack_context(msg):
if key.startswith('_context_'):
value = msg.pop(key)
context_dict[key[9:]] = value
+ context_dict['msg_id'] = msg.pop('_msg_id', None)
LOG.debug(_('unpacked context: %s'), context_dict)
- return context.RequestContext.from_dict(context_dict)
+ return RpcContext.from_dict(context_dict)
def _pack_context(msg, context):
@@ -360,70 +454,112 @@ def _pack_context(msg, context):
for args at some point.
"""
- context = dict([('_context_%s' % key, value)
- for (key, value) in context.to_dict().iteritems()])
- msg.update(context)
+ context_d = dict([('_context_%s' % key, value)
+ for (key, value) in context.to_dict().iteritems()])
+ msg.update(context_d)
-def call(context, topic, msg):
- """Sends a message on a topic and wait for a response."""
+class RpcContext(context.RequestContext):
+ def __init__(self, *args, **kwargs):
+ msg_id = kwargs.pop('msg_id', None)
+ self.msg_id = msg_id
+ super(RpcContext, self).__init__(*args, **kwargs)
+
+ def reply(self, *args, **kwargs):
+ msg_reply(self.msg_id, *args, **kwargs)
+
+
+def multicall(context, topic, msg):
+ """Make a call that returns multiple times."""
LOG.debug(_('Making asynchronous call on %s ...'), topic)
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug(_('MSG_ID is %s') % (msg_id))
_pack_context(msg, context)
- class WaitMessage(object):
- def __call__(self, data, message):
- """Acks message and sets result."""
- message.ack()
- if data['failure']:
- self.result = RemoteError(*data['failure'])
- else:
- self.result = data['result']
-
- wait_msg = WaitMessage()
- conn = Connection.instance()
- consumer = DirectConsumer(connection=conn, msg_id=msg_id)
+ con_conn = ConnectionPool.get()
+ consumer = DirectConsumer(connection=con_conn, msg_id=msg_id)
+ wait_msg = MulticallWaiter(consumer)
consumer.register_callback(wait_msg)
- conn = Connection.instance()
- publisher = TopicPublisher(connection=conn, topic=topic)
+ publisher = TopicPublisher(connection=con_conn, topic=topic)
publisher.send(msg)
publisher.close()
- try:
- consumer.wait(limit=1)
- except StopIteration:
- pass
- consumer.close()
- # NOTE(termie): this is a little bit of a change from the original
- # non-eventlet code where returning a Failure
- # instance from a deferred call is very similar to
- # raising an exception
- if isinstance(wait_msg.result, Exception):
- raise wait_msg.result
- return wait_msg.result
+ return wait_msg
+
+
+class MulticallWaiter(object):
+ def __init__(self, consumer):
+ self._consumer = consumer
+ self._results = queue.Queue()
+ self._closed = False
+
+ def close(self):
+ self._closed = True
+ self._consumer.close()
+ ConnectionPool.put(self._consumer.connection)
+
+ def __call__(self, data, message):
+ """Acks message and sets result."""
+ message.ack()
+ if data['failure']:
+ self._results.put(RemoteError(*data['failure']))
+ else:
+ self._results.put(data['result'])
+
+ def __iter__(self):
+ return self.wait()
+
+ def wait(self):
+ while True:
+ rv = None
+ while rv is None and not self._closed:
+ try:
+ rv = self._consumer.fetch(enable_callbacks=True)
+ except Exception:
+ self.close()
+ raise
+ time.sleep(0.01)
+
+ result = self._results.get()
+ if isinstance(result, Exception):
+ self.close()
+ raise result
+ if result == None:
+ self.close()
+ raise StopIteration
+ yield result
+
+
+def call(context, topic, msg):
+ """Sends a message on a topic and wait for a response."""
+ rv = multicall(context, topic, msg)
+ # NOTE(vish): return the last result from the multicall
+ rv = list(rv)
+ if not rv:
+ return
+ return rv[-1]
def cast(context, topic, msg):
"""Sends a message on a topic without waiting for a response."""
LOG.debug(_('Making asynchronous cast on %s...'), topic)
_pack_context(msg, context)
- conn = Connection.instance()
- publisher = TopicPublisher(connection=conn, topic=topic)
- publisher.send(msg)
- publisher.close()
+ with ConnectionPool.item() as conn:
+ publisher = TopicPublisher(connection=conn, topic=topic)
+ publisher.send(msg)
+ publisher.close()
def fanout_cast(context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response."""
LOG.debug(_('Making asynchronous fanout cast...'))
_pack_context(msg, context)
- conn = Connection.instance()
- publisher = FanoutPublisher(topic, connection=conn)
- publisher.send(msg)
- publisher.close()
+ with ConnectionPool.item() as conn:
+ publisher = FanoutPublisher(topic, connection=conn)
+ publisher.send(msg)
+ publisher.close()
def generic_response(message_data, message):
@@ -459,6 +595,7 @@ def send_message(topic, message, wait=True):
if wait:
consumer.wait()
+ consumer.close()
if __name__ == '__main__':
diff --git a/nova/service.py b/nova/service.py
index ab1238c3b..74f9f04d8 100644
--- a/nova/service.py
+++ b/nova/service.py
@@ -19,14 +19,11 @@
"""Generic Node baseclass for all workers that run on hosts."""
+import greenlet
import inspect
import os
-import sys
-import time
-from eventlet import event
from eventlet import greenthread
-from eventlet import greenpool
from nova import context
from nova import db
@@ -91,27 +88,37 @@ class Service(object):
if 'nova-compute' == self.binary:
self.manager.update_available_resource(ctxt)
- conn1 = rpc.Connection.instance(new=True)
- conn2 = rpc.Connection.instance(new=True)
- conn3 = rpc.Connection.instance(new=True)
- if self.report_interval:
- consumer_all = rpc.TopicAdapterConsumer(
- connection=conn1,
- topic=self.topic,
- proxy=self)
- consumer_node = rpc.TopicAdapterConsumer(
- connection=conn2,
- topic='%s.%s' % (self.topic, self.host),
- proxy=self)
- fanout = rpc.FanoutAdapterConsumer(
- connection=conn3,
- topic=self.topic,
- proxy=self)
-
- self.timers.append(consumer_all.attach_to_eventlet())
- self.timers.append(consumer_node.attach_to_eventlet())
- self.timers.append(fanout.attach_to_eventlet())
+ self.conn = rpc.Connection.instance(new=True)
+ logging.debug("Creating Consumer connection for Service %s" %
+ self.topic)
+
+ # Share this same connection for these Consumers
+ consumer_all = rpc.TopicAdapterConsumer(
+ connection=self.conn,
+ topic=self.topic,
+ proxy=self)
+ consumer_node = rpc.TopicAdapterConsumer(
+ connection=self.conn,
+ topic='%s.%s' % (self.topic, self.host),
+ proxy=self)
+ fanout = rpc.FanoutAdapterConsumer(
+ connection=self.conn,
+ topic=self.topic,
+ proxy=self)
+ consumer_set = rpc.ConsumerSet(
+ connection=self.conn,
+ consumer_list=[consumer_all, consumer_node, fanout])
+
+ # Wait forever, processing these consumers
+ def _wait():
+ try:
+ consumer_set.wait()
+ finally:
+ consumer_set.close()
+
+ self.consumer_set_thread = greenthread.spawn(_wait)
+ if self.report_interval:
pulse = utils.LoopingCall(self.report_state)
pulse.start(interval=self.report_interval, now=False)
self.timers.append(pulse)
@@ -174,6 +181,11 @@ class Service(object):
logging.warn(_('Service killed that has no database entry'))
def stop(self):
+ self.consumer_set_thread.kill()
+ try:
+ self.consumer_set_thread.wait()
+ except greenlet.GreenletExit:
+ pass
for x in self.timers:
try:
x.stop()
diff --git a/nova/test.py b/nova/test.py
index 4deb2a175..80b2d0a74 100644
--- a/nova/test.py
+++ b/nova/test.py
@@ -31,17 +31,15 @@ import uuid
import unittest
import mox
-import shutil
import stubout
from eventlet import greenthread
-from nova import context
-from nova import db
from nova import fakerabbit
from nova import flags
from nova import rpc
from nova import service
from nova import wsgi
+from nova.virt import fake
FLAGS = flags.FLAGS
@@ -85,6 +83,7 @@ class TestCase(unittest.TestCase):
self._monkey_patch_attach()
self._monkey_patch_wsgi()
self._original_flags = FLAGS.FlagValuesDict()
+ rpc.ConnectionPool = rpc.Pool(max_size=FLAGS.rpc_conn_pool_size)
def tearDown(self):
"""Runs after each test method to tear down test environment."""
@@ -99,6 +98,10 @@ class TestCase(unittest.TestCase):
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
+ if FLAGS.connection_type == 'fake':
+ if hasattr(fake.FakeConnection, '_instance'):
+ del fake.FakeConnection._instance
+
# Reset any overriden flags
self.reset_flags()
diff --git a/nova/tests/integrated/integrated_helpers.py b/nova/tests/integrated/integrated_helpers.py
index bc98921f0..7f590441e 100644
--- a/nova/tests/integrated/integrated_helpers.py
+++ b/nova/tests/integrated/integrated_helpers.py
@@ -154,10 +154,7 @@ class _IntegratedTestBase(test.TestCase):
# set up services
self.start_service('compute')
self.start_service('volume')
- # NOTE(justinsb): There's a bug here which is eluding me...
- # If we start the network_service, all is good, but then subsequent
- # tests fail: CloudTestCase.test_ajax_console in particular.
- #self.start_service('network')
+ self.start_service('network')
self.start_service('scheduler')
self._start_api_service()
diff --git a/nova/tests/test_cloud.py b/nova/tests/test_cloud.py
index 54c0454de..b64be662e 100644
--- a/nova/tests/test_cloud.py
+++ b/nova/tests/test_cloud.py
@@ -17,13 +17,9 @@
# under the License.
from base64 import b64decode
-import json
from M2Crypto import BIO
from M2Crypto import RSA
import os
-import shutil
-import tempfile
-import time
from eventlet import greenthread
@@ -33,12 +29,10 @@ from nova import db
from nova import flags
from nova import log as logging
from nova import rpc
-from nova import service
from nova import test
from nova import utils
from nova import exception
from nova.auth import manager
-from nova.compute import power_state
from nova.api.ec2 import cloud
from nova.api.ec2 import ec2utils
from nova.image import local
@@ -79,14 +73,21 @@ class CloudTestCase(test.TestCase):
self.stubs.Set(local.LocalImageService, 'show', fake_show)
self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show)
+ # NOTE(vish): set up a manual wait so rpc.cast has a chance to finish
+ rpc_cast = rpc.cast
+
+ def finish_cast(*args, **kwargs):
+ rpc_cast(*args, **kwargs)
+ greenthread.sleep(0.2)
+
+ self.stubs.Set(rpc, 'cast', finish_cast)
+
def tearDown(self):
network_ref = db.project_get_network(self.context,
self.project.id)
db.network_disassociate(self.context, network_ref['id'])
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
- self.compute.kill()
- self.network.kill()
super(CloudTestCase, self).tearDown()
def _create_key(self, name):
@@ -113,7 +114,6 @@ class CloudTestCase(test.TestCase):
self.cloud.describe_addresses(self.context)
self.cloud.release_address(self.context,
public_ip=address)
- greenthread.sleep(0.3)
db.floating_ip_destroy(self.context, address)
def test_associate_disassociate_address(self):
@@ -129,12 +129,10 @@ class CloudTestCase(test.TestCase):
self.cloud.associate_address(self.context,
instance_id=ec2_id,
public_ip=address)
- greenthread.sleep(0.3)
self.cloud.disassociate_address(self.context,
public_ip=address)
self.cloud.release_address(self.context,
public_ip=address)
- greenthread.sleep(0.3)
self.network.deallocate_fixed_ip(self.context, fixed)
db.instance_destroy(self.context, inst['id'])
db.floating_ip_destroy(self.context, address)
@@ -306,31 +304,25 @@ class CloudTestCase(test.TestCase):
'instance_type': instance_type,
'max_count': max_count}
rv = self.cloud.run_instances(self.context, **kwargs)
- greenthread.sleep(0.3)
instance_id = rv['instancesSet'][0]['instanceId']
output = self.cloud.get_console_output(context=self.context,
instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT')
# TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests.
- greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id])
- greenthread.sleep(0.3)
def test_ajax_console(self):
kwargs = {'image_id': 'ami-1'}
rv = self.cloud.run_instances(self.context, **kwargs)
instance_id = rv['instancesSet'][0]['instanceId']
- greenthread.sleep(0.3)
output = self.cloud.get_ajax_console(context=self.context,
instance_id=[instance_id])
self.assertEquals(output['url'],
'%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url)
# TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests.
- greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id])
- greenthread.sleep(0.3)
def test_key_generation(self):
result = self._create_key('test')
diff --git a/nova/tests/test_rpc.py b/nova/tests/test_rpc.py
index 44d7c91eb..ffd748efe 100644
--- a/nova/tests/test_rpc.py
+++ b/nova/tests/test_rpc.py
@@ -31,7 +31,6 @@ LOG = logging.getLogger('nova.tests.rpc')
class RpcTestCase(test.TestCase):
- """Test cases for rpc"""
def setUp(self):
super(RpcTestCase, self).setUp()
self.conn = rpc.Connection.instance(True)
@@ -43,14 +42,55 @@ class RpcTestCase(test.TestCase):
self.context = context.get_admin_context()
def test_call_succeed(self):
- """Get a value through rpc call"""
value = 42
result = rpc.call(self.context, 'test', {"method": "echo",
"args": {"value": value}})
self.assertEqual(value, result)
+ def test_call_succeed_despite_multiple_returns(self):
+ value = 42
+ result = rpc.call(self.context, 'test', {"method": "echo_three_times",
+ "args": {"value": value}})
+ self.assertEqual(value + 2, result)
+
+ def test_call_succeed_despite_multiple_returns_yield(self):
+ value = 42
+ result = rpc.call(self.context, 'test',
+ {"method": "echo_three_times_yield",
+ "args": {"value": value}})
+ self.assertEqual(value + 2, result)
+
+ def test_multicall_succeed_once(self):
+ value = 42
+ result = rpc.multicall(self.context,
+ 'test',
+ {"method": "echo",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ if i > 0:
+ self.fail('should only receive one response')
+ self.assertEqual(value + i, x)
+
+ def test_multicall_succeed_three_times(self):
+ value = 42
+ result = rpc.multicall(self.context,
+ 'test',
+ {"method": "echo_three_times",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ self.assertEqual(value + i, x)
+
+ def test_multicall_succeed_three_times_yield(self):
+ value = 42
+ result = rpc.multicall(self.context,
+ 'test',
+ {"method": "echo_three_times_yield",
+ "args": {"value": value}})
+ for i, x in enumerate(result):
+ self.assertEqual(value + i, x)
+
def test_context_passed(self):
- """Makes sure a context is passed through rpc call"""
+ """Makes sure a context is passed through rpc call."""
value = 42
result = rpc.call(self.context,
'test', {"method": "context",
@@ -58,11 +98,12 @@ class RpcTestCase(test.TestCase):
self.assertEqual(self.context.to_dict(), result)
def test_call_exception(self):
- """Test that exception gets passed back properly
+ """Test that exception gets passed back properly.
rpc.call returns a RemoteError object. The value of the
exception is converted to a string, so we convert it back
to an int in the test.
+
"""
value = 42
self.assertRaises(rpc.RemoteError,
@@ -81,7 +122,7 @@ class RpcTestCase(test.TestCase):
self.assertEqual(int(exc.value), value)
def test_nested_calls(self):
- """Test that we can do an rpc.call inside another call"""
+ """Test that we can do an rpc.call inside another call."""
class Nested(object):
@staticmethod
def echo(context, queue, value):
@@ -108,25 +149,80 @@ class RpcTestCase(test.TestCase):
"value": value}})
self.assertEqual(value, result)
+ def test_connectionpool_single(self):
+ """Test that ConnectionPool recycles a single connection."""
+ conn1 = rpc.ConnectionPool.get()
+ rpc.ConnectionPool.put(conn1)
+ conn2 = rpc.ConnectionPool.get()
+ rpc.ConnectionPool.put(conn2)
+ self.assertEqual(conn1, conn2)
+
+ def test_connectionpool_double(self):
+ """Test that ConnectionPool returns and reuses separate connections.
+
+ When called consecutively we should get separate connections and upon
+ returning them those connections should be reused for future calls
+ before generating a new connection.
+
+ """
+ conn1 = rpc.ConnectionPool.get()
+ conn2 = rpc.ConnectionPool.get()
+
+ self.assertNotEqual(conn1, conn2)
+ rpc.ConnectionPool.put(conn1)
+ rpc.ConnectionPool.put(conn2)
+
+ conn3 = rpc.ConnectionPool.get()
+ conn4 = rpc.ConnectionPool.get()
+ self.assertEqual(conn1, conn3)
+ self.assertEqual(conn2, conn4)
+
+ def test_connectionpool_limit(self):
+ """Test connection pool limit and connection uniqueness."""
+ max_size = FLAGS.rpc_conn_pool_size
+ conns = []
+
+ for i in xrange(max_size):
+ conns.append(rpc.ConnectionPool.get())
+
+ self.assertFalse(rpc.ConnectionPool.free_items)
+ self.assertEqual(rpc.ConnectionPool.current_size,
+ rpc.ConnectionPool.max_size)
+ self.assertEqual(len(set(conns)), max_size)
+
class TestReceiver(object):
- """Simple Proxy class so the consumer has methods to call
+ """Simple Proxy class so the consumer has methods to call.
+
+ Uses static methods because we aren't actually storing any state.
- Uses static methods because we aren't actually storing any state"""
+ """
@staticmethod
def echo(context, value):
- """Simply returns whatever value is sent in"""
+ """Simply returns whatever value is sent in."""
LOG.debug(_("Received %s"), value)
return value
@staticmethod
def context(context, value):
- """Returns dictionary version of context"""
+ """Returns dictionary version of context."""
LOG.debug(_("Received %s"), context)
return context.to_dict()
@staticmethod
+ def echo_three_times(context, value):
+ context.reply(value)
+ context.reply(value + 1)
+ context.reply(value + 2)
+
+ @staticmethod
+ def echo_three_times_yield(context, value):
+ yield value
+ yield value + 1
+ yield value + 2
+
+ @staticmethod
def fail(context, value):
- """Raises an exception with the value sent in"""
+ """Raises an exception with the value sent in."""
raise Exception(value)
diff --git a/nova/tests/test_service.py b/nova/tests/test_service.py
index d48de2057..d1cc8bd61 100644
--- a/nova/tests/test_service.py
+++ b/nova/tests/test_service.py
@@ -106,7 +106,10 @@ class ServiceTestCase(test.TestCase):
# NOTE(vish): Create was moved out of mox replay to make sure that
# the looping calls are created in StartService.
- app = service.Service.create(host=host, binary=binary)
+ app = service.Service.create(host=host, binary=binary, topic=topic)
+
+ self.mox.StubOutWithMock(service.rpc.Connection, 'instance')
+ service.rpc.Connection.instance(new=mox.IgnoreArg())
self.mox.StubOutWithMock(rpc,
'TopicAdapterConsumer',
@@ -114,6 +117,11 @@ class ServiceTestCase(test.TestCase):
self.mox.StubOutWithMock(rpc,
'FanoutAdapterConsumer',
use_mock_anything=True)
+
+ self.mox.StubOutWithMock(rpc,
+ 'ConsumerSet',
+ use_mock_anything=True)
+
rpc.TopicAdapterConsumer(connection=mox.IgnoreArg(),
topic=topic,
proxy=mox.IsA(service.Service)).AndReturn(
@@ -129,9 +137,14 @@ class ServiceTestCase(test.TestCase):
proxy=mox.IsA(service.Service)).AndReturn(
rpc.FanoutAdapterConsumer)
- rpc.TopicAdapterConsumer.attach_to_eventlet()
- rpc.TopicAdapterConsumer.attach_to_eventlet()
- rpc.FanoutAdapterConsumer.attach_to_eventlet()
+ def wait_func(self, limit=None):
+ return None
+
+ mock_cset = self.mox.CreateMock(rpc.ConsumerSet,
+ {'wait': wait_func})
+ rpc.ConsumerSet(connection=mox.IgnoreArg(),
+ consumer_list=mox.IsA(list)).AndReturn(mock_cset)
+ wait_func(mox.IgnoreArg())
service_create = {'host': host,
'binary': binary,
@@ -287,8 +300,42 @@ class ServiceTestCase(test.TestCase):
# Creating mocks
self.mox.StubOutWithMock(service.rpc.Connection, 'instance')
service.rpc.Connection.instance(new=mox.IgnoreArg())
- service.rpc.Connection.instance(new=mox.IgnoreArg())
- service.rpc.Connection.instance(new=mox.IgnoreArg())
+
+ self.mox.StubOutWithMock(rpc,
+ 'TopicAdapterConsumer',
+ use_mock_anything=True)
+ self.mox.StubOutWithMock(rpc,
+ 'FanoutAdapterConsumer',
+ use_mock_anything=True)
+
+ self.mox.StubOutWithMock(rpc,
+ 'ConsumerSet',
+ use_mock_anything=True)
+
+ rpc.TopicAdapterConsumer(connection=mox.IgnoreArg(),
+ topic=topic,
+ proxy=mox.IsA(service.Service)).AndReturn(
+ rpc.TopicAdapterConsumer)
+
+ rpc.TopicAdapterConsumer(connection=mox.IgnoreArg(),
+ topic='%s.%s' % (topic, host),
+ proxy=mox.IsA(service.Service)).AndReturn(
+ rpc.TopicAdapterConsumer)
+
+ rpc.FanoutAdapterConsumer(connection=mox.IgnoreArg(),
+ topic=topic,
+ proxy=mox.IsA(service.Service)).AndReturn(
+ rpc.FanoutAdapterConsumer)
+
+ def wait_func(self, limit=None):
+ return None
+
+ mock_cset = self.mox.CreateMock(rpc.ConsumerSet,
+ {'wait': wait_func})
+ rpc.ConsumerSet(connection=mox.IgnoreArg(),
+ consumer_list=mox.IsA(list)).AndReturn(mock_cset)
+ wait_func(mox.IgnoreArg())
+
self.mox.StubOutWithMock(serv.manager.driver,
'update_available_resource')
serv.manager.driver.update_available_resource(mox.IgnoreArg(), host)
diff --git a/nova/tests/test_xenapi.py b/nova/tests/test_xenapi.py
index be1e35697..18a267896 100644
--- a/nova/tests/test_xenapi.py
+++ b/nova/tests/test_xenapi.py
@@ -395,6 +395,29 @@ class XenAPIVMTestCase(test.TestCase):
os_type="linux")
self.check_vm_params_for_linux()
+ def test_spawn_vhd_glance_swapdisk(self):
+ # Change the default host_call_plugin to one that'll return
+ # a swap disk
+ orig_func = stubs.FakeSessionForVMTests.host_call_plugin
+
+ stubs.FakeSessionForVMTests.host_call_plugin = \
+ stubs.FakeSessionForVMTests.host_call_plugin_swap
+
+ try:
+ # We'll steal the above glance linux test
+ self.test_spawn_vhd_glance_linux()
+ finally:
+ # Make sure to put this back
+ stubs.FakeSessionForVMTests.host_call_plugin = orig_func
+
+ # We should have 2 VBDs.
+ self.assertEqual(len(self.vm['VBDs']), 2)
+ # Now test that we have 1.
+ self.tearDown()
+ self.setUp()
+ self.test_spawn_vhd_glance_linux()
+ self.assertEqual(len(self.vm['VBDs']), 1)
+
def test_spawn_vhd_glance_windows(self):
FLAGS.xenapi_image_service = 'glance'
self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None,
diff --git a/nova/tests/xenapi/stubs.py b/nova/tests/xenapi/stubs.py
index 4833ccb07..35308d95f 100644
--- a/nova/tests/xenapi/stubs.py
+++ b/nova/tests/xenapi/stubs.py
@@ -17,6 +17,7 @@
"""Stubouts, mocks and fixtures for the test suite"""
import eventlet
+import json
from nova.virt import xenapi_conn
from nova.virt.xenapi import fake
from nova.virt.xenapi import volume_utils
@@ -37,7 +38,7 @@ def stubout_instance_snapshot(stubs):
sr_ref=sr_ref, sharable=False)
vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref)
vdi_uuid = vdi_rec['uuid']
- return vdi_uuid
+ return [dict(vdi_type='os', vdi_uuid=vdi_uuid)]
stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image)
@@ -132,11 +133,30 @@ class FakeSessionForVMTests(fake.SessionBase):
def __init__(self, uri):
super(FakeSessionForVMTests, self).__init__(uri)
- def host_call_plugin(self, _1, _2, _3, _4, _5):
+ def host_call_plugin(self, _1, _2, plugin, method, _5):
+ sr_ref = fake.get_all('SR')[0]
+ vdi_ref = fake.create_vdi('', False, sr_ref, False)
+ vdi_rec = fake.get_record('VDI', vdi_ref)
+ if plugin == "glance" and method == "download_vhd":
+ ret_str = json.dumps([dict(vdi_type='os',
+ vdi_uuid=vdi_rec['uuid'])])
+ else:
+ ret_str = vdi_rec['uuid']
+ return '<string>%s</string>' % ret_str
+
+ def host_call_plugin_swap(self, _1, _2, plugin, method, _5):
sr_ref = fake.get_all('SR')[0]
vdi_ref = fake.create_vdi('', False, sr_ref, False)
vdi_rec = fake.get_record('VDI', vdi_ref)
- return '<string>%s</string>' % vdi_rec['uuid']
+ if plugin == "glance" and method == "download_vhd":
+ swap_vdi_ref = fake.create_vdi('', False, sr_ref, False)
+ swap_vdi_rec = fake.get_record('VDI', swap_vdi_ref)
+ ret_str = json.dumps(
+ [dict(vdi_type='os', vdi_uuid=vdi_rec['uuid']),
+ dict(vdi_type='swap', vdi_uuid=swap_vdi_rec['uuid'])])
+ else:
+ ret_str = vdi_rec['uuid']
+ return '<string>%s</string>' % ret_str
def VM_start(self, _1, ref, _2, _3):
vm = fake.get_record('VM', ref)
diff --git a/nova/virt/xenapi/fake.py b/nova/virt/xenapi/fake.py
index e36ef3288..76988b172 100644
--- a/nova/virt/xenapi/fake.py
+++ b/nova/virt/xenapi/fake.py
@@ -159,7 +159,10 @@ def after_VBD_create(vbd_ref, vbd_rec):
vbd_rec['device'] = ''
vm_ref = vbd_rec['VM']
vm_rec = _db_content['VM'][vm_ref]
- vm_rec['VBDs'] = [vbd_ref]
+ if vm_rec.get('VBDs', None):
+ vm_rec['VBDs'].append(vbd_ref)
+ else:
+ vm_rec['VBDs'] = [vbd_ref]
vm_name_label = _db_content['VM'][vm_ref]['name_label']
vbd_rec['vm_name_label'] = vm_name_label
diff --git a/nova/virt/xenapi/vm_utils.py b/nova/virt/xenapi/vm_utils.py
index 9f6cd608c..06ee8ee9b 100644
--- a/nova/virt/xenapi/vm_utils.py
+++ b/nova/virt/xenapi/vm_utils.py
@@ -19,6 +19,7 @@ Helper methods for operations related to the management of VM records and
their attributes like VDIs, VIFs, as well as their lookup functions.
"""
+import json
import os
import pickle
import re
@@ -376,6 +377,9 @@ class VMHelper(HelperBase):
xenapi_image_service = ['glance', 'objectstore']
glance_address = 'address for glance services'
glance_port = 'port for glance services'
+
+ Returns: A single filename if image_type is KERNEL_RAMDISK
+ A list of dictionaries that describe VDIs, otherwise
"""
access = AuthManager().get_access_key(user, project)
@@ -390,6 +394,10 @@ class VMHelper(HelperBase):
@classmethod
def _fetch_image_glance_vhd(cls, session, instance_id, image, access,
image_type):
+ """Tell glance to download an image and put the VHDs into the SR
+
+ Returns: A list of dictionaries that describe VDIs
+ """
LOG.debug(_("Asking xapi to fetch vhd image %(image)s")
% locals())
@@ -408,18 +416,26 @@ class VMHelper(HelperBase):
kwargs = {'params': pickle.dumps(params)}
task = session.async_call_plugin('glance', 'download_vhd', kwargs)
- vdi_uuid = session.wait_for_task(task, instance_id)
+ result = session.wait_for_task(task, instance_id)
+ # 'download_vhd' will return a json encoded string containing
+ # a list of dictionaries describing VDIs. The dictionary will
+ # contain 'vdi_type' and 'vdi_uuid' keys. 'vdi_type' can be
+ # 'os' or 'swap' right now.
+ vdis = json.loads(result)
+ for vdi in vdis:
+ LOG.debug(_("xapi 'download_vhd' returned VDI of "
+ "type '%(vdi_type)s' with UUID '%(vdi_uuid)s'" % vdi))
cls.scan_sr(session, instance_id, sr_ref)
+ # Pull out the UUID of the first VDI
+ vdi_uuid = vdis[0]['vdi_uuid']
# Set the name-label to ease debugging
vdi_ref = session.get_xenapi().VDI.get_by_uuid(vdi_uuid)
- name_label = get_name_label_for_image(image)
- session.get_xenapi().VDI.set_name_label(vdi_ref, name_label)
+ primary_name_label = get_name_label_for_image(image)
+ session.get_xenapi().VDI.set_name_label(vdi_ref, primary_name_label)
- LOG.debug(_("xapi 'download_vhd' returned VDI UUID %(vdi_uuid)s")
- % locals())
- return vdi_uuid
+ return vdis
@classmethod
def _fetch_image_glance_disk(cls, session, instance_id, image, access,
@@ -431,6 +447,8 @@ class VMHelper(HelperBase):
plugin; instead, it streams the disks through domU to the VDI
directly.
+ Returns: A single filename if image_type is KERNEL_RAMDISK
+ A list of dictionaries that describe VDIs, otherwise
"""
# FIXME(sirp): Since the Glance plugin seems to be required for the
# VHD disk, it may be worth using the plugin for both VHD and RAW and
@@ -476,7 +494,8 @@ class VMHelper(HelperBase):
LOG.debug(_("Kernel/Ramdisk VDI %s destroyed"), vdi_ref)
return filename
else:
- return session.get_xenapi().VDI.get_uuid(vdi_ref)
+ vdi_uuid = session.get_xenapi().VDI.get_uuid(vdi_ref)
+ return [dict(vdi_type='os', vdi_uuid=vdi_uuid)]
@classmethod
def determine_disk_image_type(cls, instance):
@@ -535,6 +554,11 @@ class VMHelper(HelperBase):
@classmethod
def _fetch_image_glance(cls, session, instance_id, image, access,
image_type):
+ """Fetch image from glance based on image type.
+
+ Returns: A single filename if image_type is KERNEL_RAMDISK
+ A list of dictionaries that describe VDIs, otherwise
+ """
if image_type == ImageType.DISK_VHD:
return cls._fetch_image_glance_vhd(
session, instance_id, image, access, image_type)
@@ -545,6 +569,11 @@ class VMHelper(HelperBase):
@classmethod
def _fetch_image_objectstore(cls, session, instance_id, image, access,
secret, image_type):
+ """Fetch an image from objectstore.
+
+ Returns: A single filename if image_type is KERNEL_RAMDISK
+ A list of dictionaries that describe VDIs, otherwise
+ """
url = images.image_url(image)
LOG.debug(_("Asking xapi to fetch %(url)s as %(access)s") % locals())
if image_type == ImageType.KERNEL_RAMDISK:
@@ -562,8 +591,10 @@ class VMHelper(HelperBase):
if image_type == ImageType.DISK_RAW:
args['raw'] = 'true'
task = session.async_call_plugin('objectstore', fn, args)
- uuid = session.wait_for_task(task, instance_id)
- return uuid
+ uuid_or_fn = session.wait_for_task(task, instance_id)
+ if image_type != ImageType.KERNEL_RAMDISK:
+ return [dict(vdi_type='os', vdi_uuid=uuid_or_fn)]
+ return uuid_or_fn
@classmethod
def determine_is_pv(cls, session, instance_id, vdi_ref, disk_image_type,
diff --git a/nova/virt/xenapi/vmops.py b/nova/virt/xenapi/vmops.py
index be6ef48ea..6d516ddbc 100644
--- a/nova/virt/xenapi/vmops.py
+++ b/nova/virt/xenapi/vmops.py
@@ -91,7 +91,8 @@ class VMOps(object):
def finish_resize(self, instance, disk_info):
vdi_uuid = self.link_disks(instance, disk_info['base_copy'],
disk_info['cow'])
- vm_ref = self._create_vm(instance, vdi_uuid)
+ vm_ref = self._create_vm(instance,
+ [dict(vdi_type='os', vdi_uuid=vdi_uuid)])
self.resize_instance(instance, vdi_uuid)
self._spawn(instance, vm_ref)
@@ -105,24 +106,25 @@ class VMOps(object):
LOG.debug(_("Starting instance %s"), instance.name)
self._session.call_xenapi('VM.start', vm_ref, False, False)
- def _create_disk(self, instance):
+ def _create_disks(self, instance):
user = AuthManager().get_user(instance.user_id)
project = AuthManager().get_project(instance.project_id)
disk_image_type = VMHelper.determine_disk_image_type(instance)
- vdi_uuid = VMHelper.fetch_image(self._session, instance.id,
- instance.image_id, user, project, disk_image_type)
- return vdi_uuid
+ vdis = VMHelper.fetch_image(self._session,
+ instance.id, instance.image_id, user, project,
+ disk_image_type)
+ return vdis
def spawn(self, instance, network_info=None):
- vdi_uuid = self._create_disk(instance)
- vm_ref = self._create_vm(instance, vdi_uuid, network_info)
+ vdis = self._create_disks(instance)
+ vm_ref = self._create_vm(instance, vdis, network_info)
self._spawn(instance, vm_ref)
def spawn_rescue(self, instance):
"""Spawn a rescue instance."""
self.spawn(instance)
- def _create_vm(self, instance, vdi_uuid, network_info=None):
+ def _create_vm(self, instance, vdis, network_info=None):
"""Create VM instance."""
instance_name = instance.name
vm_ref = VMHelper.lookup(self._session, instance_name)
@@ -141,28 +143,43 @@ class VMOps(object):
user = AuthManager().get_user(instance.user_id)
project = AuthManager().get_project(instance.project_id)
- # Are we building from a pre-existing disk?
- vdi_ref = self._session.call_xenapi('VDI.get_by_uuid', vdi_uuid)
-
disk_image_type = VMHelper.determine_disk_image_type(instance)
kernel = None
if instance.kernel_id:
kernel = VMHelper.fetch_image(self._session, instance.id,
- instance.kernel_id, user, project, ImageType.KERNEL_RAMDISK)
+ instance.kernel_id, user, project,
+ ImageType.KERNEL_RAMDISK)
ramdisk = None
if instance.ramdisk_id:
ramdisk = VMHelper.fetch_image(self._session, instance.id,
- instance.ramdisk_id, user, project, ImageType.KERNEL_RAMDISK)
-
- use_pv_kernel = VMHelper.determine_is_pv(self._session, instance.id,
- vdi_ref, disk_image_type, instance.os_type)
- vm_ref = VMHelper.create_vm(self._session, instance, kernel, ramdisk,
- use_pv_kernel)
-
+ instance.ramdisk_id, user, project,
+ ImageType.KERNEL_RAMDISK)
+
+ # Create the VM ref and attach the first disk
+ first_vdi_ref = self._session.call_xenapi('VDI.get_by_uuid',
+ vdis[0]['vdi_uuid'])
+ use_pv_kernel = VMHelper.determine_is_pv(self._session,
+ instance.id, first_vdi_ref, disk_image_type,
+ instance.os_type)
+ vm_ref = VMHelper.create_vm(self._session, instance,
+ kernel, ramdisk, use_pv_kernel)
VMHelper.create_vbd(session=self._session, vm_ref=vm_ref,
- vdi_ref=vdi_ref, userdevice=0, bootable=True)
+ vdi_ref=first_vdi_ref, userdevice=0, bootable=True)
+
+ # Attach any other disks
+ # userdevice 1 is reserved for rescue
+ userdevice = 2
+ for vdi in vdis[1:]:
+ # vdi['vdi_type'] is either 'os' or 'swap', but we don't
+ # really care what it is right here.
+ vdi_ref = self._session.call_xenapi('VDI.get_by_uuid',
+ vdi['vdi_uuid'])
+ VMHelper.create_vbd(session=self._session, vm_ref=vm_ref,
+ vdi_ref=vdi_ref, userdevice=userdevice,
+ bootable=False)
+ userdevice += 1
# TODO(tr3buchet) - check to make sure we have network info, otherwise
# create it now. This goes away once nova-multi-nic hits.
@@ -172,7 +189,7 @@ class VMOps(object):
# Alter the image before VM start for, e.g. network injection
if FLAGS.xenapi_inject_image:
VMHelper.preconfigure_instance(self._session, instance,
- vdi_ref, network_info)
+ first_vdi_ref, network_info)
self.create_vifs(vm_ref, network_info)
self.inject_network_info(instance, network_info, vm_ref)
diff --git a/plugins/xenserver/xenapi/etc/xapi.d/plugins/glance b/plugins/xenserver/xenapi/etc/xapi.d/plugins/glance
index 4b45671ae..0c00d168b 100644
--- a/plugins/xenserver/xenapi/etc/xapi.d/plugins/glance
+++ b/plugins/xenserver/xenapi/etc/xapi.d/plugins/glance
@@ -22,6 +22,10 @@
#
import httplib
+try:
+ import json
+except ImportError:
+ import simplejson as json
import os
import os.path
import pickle
@@ -87,8 +91,8 @@ def _download_tarball(sr_path, staging_path, image_id, glance_host,
conn.close()
-def _fixup_vhds(sr_path, staging_path, uuid_stack):
- """Fixup the downloaded VHDs before we move them into the SR.
+def _import_vhds(sr_path, staging_path, uuid_stack):
+ """Import the VHDs found in the staging path.
We cannot extract VHDs directly into the SR since they don't yet have
UUIDs, aren't properly associated with each other, and would be subject to
@@ -98,16 +102,25 @@ def _fixup_vhds(sr_path, staging_path, uuid_stack):
To avoid these we problems, we use a staging area to fixup the VHDs before
moving them into the SR. The steps involved are:
- 1. Extracting tarball into staging area
+ 1. Extracting tarball into staging area (done prior to this call)
2. Renaming VHDs to use UUIDs ('snap.vhd' -> 'ffff-aaaa-...vhd')
- 3. Linking the two VHDs together
+ 3. Linking VHDs together if there's a snap.vhd
4. Pseudo-atomically moving the images into the SR. (It's not really
- atomic because it takes place as two os.rename operations; however,
- the chances of an SR.scan occuring between the two rename()
+ atomic because it takes place as multiple os.rename operations;
+ however, the chances of an SR.scan occuring between the rename()s
invocations is so small that we can safely ignore it)
+
+ Returns: A list of VDIs. Each list element is a dictionary containing
+ information about the VHD. Dictionary keys are:
+ 1. "vdi_type" - The type of VDI. Currently they can be "os_disk" or
+ "swap"
+ 2. "vdi_uuid" - The UUID of the VDI
+
+ Example return: [{"vdi_type": "os_disk","vdi_uuid": "ffff-aaa..vhd"},
+ {"vdi_type": "swap","vdi_uuid": "ffff-bbb..vhd"}]
"""
def rename_with_uuid(orig_path):
"""Rename VHD using UUID so that it will be recognized by SR on a
@@ -158,27 +171,59 @@ def _fixup_vhds(sr_path, staging_path, uuid_stack):
"VHD %(path)s is marked as hidden without child" %
locals())
- orig_base_copy_path = os.path.join(staging_path, 'image.vhd')
- if not os.path.exists(orig_base_copy_path):
+ def prepare_if_exists(staging_path, vhd_name, parent_path=None):
+ """
+ Check for existance of a particular VHD in the staging path and
+ preparing it for moving into the SR.
+
+ Returns: Tuple of (Path to move into the SR, VDI_UUID)
+ None, if the vhd_name doesn't exist in the staging path
+
+ If the VHD exists, we will do the following:
+ 1. Rename it with a UUID.
+ 2. If parent_path exists, we'll link up the VHDs.
+ """
+ orig_path = os.path.join(staging_path, vhd_name)
+ if not os.path.exists(orig_path):
+ return None
+ new_path, vdi_uuid = rename_with_uuid(orig_path)
+ if parent_path:
+ # NOTE(sirp): this step is necessary so that an SR scan won't
+ # delete the base_copy out from under us (since it would be
+ # orphaned)
+ link_vhds(new_path, parent_path)
+ return (new_path, vdi_uuid)
+
+ vdi_return_list = []
+ paths_to_move = []
+
+ image_info = prepare_if_exists(staging_path, 'image.vhd')
+ if not image_info:
raise Exception("Invalid image: image.vhd not present")
- base_copy_path, base_copy_uuid = rename_with_uuid(orig_base_copy_path)
-
- vdi_uuid = base_copy_uuid
- orig_snap_path = os.path.join(staging_path, 'snap.vhd')
- if os.path.exists(orig_snap_path):
- snap_path, snap_uuid = rename_with_uuid(orig_snap_path)
- vdi_uuid = snap_uuid
- # NOTE(sirp): this step is necessary so that an SR scan won't
- # delete the base_copy out from under us (since it would be
- # orphaned)
- link_vhds(snap_path, base_copy_path)
- move_into_sr(snap_path)
+ paths_to_move.append(image_info[0])
+
+ snap_info = prepare_if_exists(staging_path, 'snap.vhd',
+ image_info[0])
+ if snap_info:
+ paths_to_move.append(snap_info[0])
+ # We return this snap as the VDI instead of image.vhd
+ vdi_return_list.append(dict(vdi_type="os", vdi_uuid=snap_info[1]))
else:
- assert_vhd_not_hidden(base_copy_path)
+ assert_vhd_not_hidden(image_info[0])
+ # If there's no snap, we return the image.vhd UUID
+ vdi_return_list.append(dict(vdi_type="os", vdi_uuid=image_info[1]))
+
+ swap_info = prepare_if_exists(staging_path, 'swap.vhd')
+ if swap_info:
+ assert_vhd_not_hidden(swap_info[0])
+ paths_to_move.append(swap_info[0])
+ vdi_return_list.append(dict(vdi_type="swap", vdi_uuid=swap_info[1]))
+
+ for path in paths_to_move:
+ move_into_sr(path)
- move_into_sr(base_copy_path)
- return vdi_uuid
+ return vdi_return_list
def _prepare_staging_area_for_upload(sr_path, staging_path, vdi_uuids):
@@ -324,8 +369,9 @@ def download_vhd(session, args):
try:
_download_tarball(sr_path, staging_path, image_id, glance_host,
glance_port)
- vdi_uuid = _fixup_vhds(sr_path, staging_path, uuid_stack)
- return vdi_uuid
+ # Right now, it's easier to return a single string via XenAPI,
+ # so we'll json encode the list of VHDs.
+ return json.dumps(_import_vhds(sr_path, staging_path, uuid_stack))
finally:
_cleanup_staging_area(staging_path)