diff options
| -rw-r--r-- | nova/fakerabbit.py | 31 | ||||
| -rw-r--r-- | nova/rpc.py | 271 | ||||
| -rw-r--r-- | nova/service.py | 60 | ||||
| -rw-r--r-- | nova/test.py | 9 | ||||
| -rw-r--r-- | nova/tests/integrated/integrated_helpers.py | 5 | ||||
| -rw-r--r-- | nova/tests/test_cloud.py | 26 | ||||
| -rw-r--r-- | nova/tests/test_rpc.py | 116 | ||||
| -rw-r--r-- | nova/tests/test_service.py | 59 | ||||
| -rw-r--r-- | nova/tests/test_xenapi.py | 23 | ||||
| -rw-r--r-- | nova/tests/xenapi/stubs.py | 26 | ||||
| -rw-r--r-- | nova/virt/xenapi/fake.py | 5 | ||||
| -rw-r--r-- | nova/virt/xenapi/vm_utils.py | 49 | ||||
| -rw-r--r-- | nova/virt/xenapi/vmops.py | 59 | ||||
| -rw-r--r-- | plugins/xenserver/xenapi/etc/xapi.d/plugins/glance | 96 |
14 files changed, 637 insertions, 198 deletions
diff --git a/nova/fakerabbit.py b/nova/fakerabbit.py index a7dee8caf..e7e9dab77 100644 --- a/nova/fakerabbit.py +++ b/nova/fakerabbit.py @@ -31,6 +31,7 @@ LOG = logging.getLogger("nova.fakerabbit") EXCHANGES = {} QUEUES = {} +CONSUMERS = {} class Message(base.BaseMessage): @@ -96,17 +97,29 @@ class Backend(base.BaseBackend): ' key %(routing_key)s') % locals()) EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key) - def declare_consumer(self, queue, callback, *args, **kwargs): - self.current_queue = queue - self.current_callback = callback + def declare_consumer(self, queue, callback, consumer_tag, *args, **kwargs): + global CONSUMERS + LOG.debug("Adding consumer %s", consumer_tag) + CONSUMERS[consumer_tag] = (queue, callback) + + def cancel(self, consumer_tag): + global CONSUMERS + LOG.debug("Removing consumer %s", consumer_tag) + del CONSUMERS[consumer_tag] def consume(self, limit=None): + global CONSUMERS + num = 0 while True: - item = self.get(self.current_queue) - if item: - self.current_callback(item) - raise StopIteration() - greenthread.sleep(0) + for (queue, callback) in CONSUMERS.itervalues(): + item = self.get(queue) + if item: + callback(item) + num += 1 + yield + if limit and num == limit: + raise StopIteration() + greenthread.sleep(0.1) def get(self, queue, no_ack=False): global QUEUES @@ -134,5 +147,7 @@ class Backend(base.BaseBackend): def reset_all(): global EXCHANGES global QUEUES + global CONSUMERS EXCHANGES = {} QUEUES = {} + CONSUMERS = {} diff --git a/nova/rpc.py b/nova/rpc.py index 2116f22c3..c5277c6a9 100644 --- a/nova/rpc.py +++ b/nova/rpc.py @@ -28,12 +28,15 @@ import json import sys import time import traceback +import types import uuid from carrot import connection as carrot_connection from carrot import messaging from eventlet import greenpool -from eventlet import greenthread +from eventlet import pools +from eventlet import queue +import greenlet from nova import context from nova import exception @@ -47,7 +50,10 @@ LOG = logging.getLogger('nova.rpc') FLAGS = flags.FLAGS -flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool') +flags.DEFINE_integer('rpc_thread_pool_size', 1024, + 'Size of RPC thread pool') +flags.DEFINE_integer('rpc_conn_pool_size', 30, + 'Size of RPC connection pool') class Connection(carrot_connection.BrokerConnection): @@ -90,6 +96,22 @@ class Connection(carrot_connection.BrokerConnection): return cls.instance() +class Pool(pools.Pool): + """Class that implements a Pool of Connections.""" + + # TODO(comstud): Timeout connections not used in a while + def create(self): + LOG.debug('Creating new connection') + return Connection.instance(new=True) + +# Create a ConnectionPool to use for RPC calls. We'll order the +# pool as a stack (LIFO), so that we can potentially loop through and +# timeout old unused connections at some point +ConnectionPool = Pool( + max_size=FLAGS.rpc_conn_pool_size, + order_as_stack=True) + + class Consumer(messaging.Consumer): """Consumer base class. @@ -131,7 +153,9 @@ class Consumer(messaging.Consumer): self.connection = Connection.recreate() self.backend = self.connection.create_backend() self.declare() - super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks) + return super(Consumer, self).fetch(no_ack, + auto_ack, + enable_callbacks) if self.failed_connection: LOG.error(_('Reconnected to queue')) self.failed_connection = False @@ -159,13 +183,13 @@ class AdapterConsumer(Consumer): self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size) super(AdapterConsumer, self).__init__(connection=connection, topic=topic) + self.register_callback(self.process_data) - def receive(self, *args, **kwargs): - self.pool.spawn_n(self._receive, *args, **kwargs) + def process_data(self, message_data, message): + """Consumer callback to call a method on a proxy object. - @exception.wrap_exception - def _receive(self, message_data, message): - """Magically looks for a method on the proxy object and calls it. + Parses the message for validity and fires off a thread to call the + proxy object method. Message data should be a dictionary with two keys: method: string representing the method to call @@ -175,8 +199,8 @@ class AdapterConsumer(Consumer): """ LOG.debug(_('received %s') % message_data) - msg_id = message_data.pop('_msg_id', None) - + # This will be popped off in _unpack_context + msg_id = message_data.get('_msg_id', None) ctxt = _unpack_context(message_data) method = message_data.get('method') @@ -188,8 +212,17 @@ class AdapterConsumer(Consumer): # we just log the message and send an error string # back to the caller LOG.warn(_('no method for message: %s') % message_data) - msg_reply(msg_id, _('No method for message: %s') % message_data) + if msg_id: + msg_reply(msg_id, + _('No method for message: %s') % message_data) return + self.pool.spawn_n(self._process_data, msg_id, ctxt, method, args) + + @exception.wrap_exception + def _process_data(self, msg_id, ctxt, method, args): + """Thread that maigcally looks for a method on the proxy + object and calls it. + """ node_func = getattr(self.proxy, str(method)) node_args = dict((str(k), v) for k, v in args.iteritems()) @@ -197,7 +230,18 @@ class AdapterConsumer(Consumer): try: rval = node_func(context=ctxt, **node_args) if msg_id: - msg_reply(msg_id, rval, None) + # Check if the result was a generator + if isinstance(rval, types.GeneratorType): + for x in rval: + msg_reply(msg_id, x, None) + else: + msg_reply(msg_id, rval, None) + + # This final None tells multicall that it is done. + msg_reply(msg_id, None, None) + elif isinstance(rval, types.GeneratorType): + # NOTE(vish): this iterates through the generator + list(rval) except Exception as e: logging.exception('Exception during message handling') if msg_id: @@ -205,11 +249,6 @@ class AdapterConsumer(Consumer): return -class Publisher(messaging.Publisher): - """Publisher base class.""" - pass - - class TopicAdapterConsumer(AdapterConsumer): """Consumes messages on a specific topic.""" @@ -242,6 +281,58 @@ class FanoutAdapterConsumer(AdapterConsumer): topic=topic, proxy=proxy) +class ConsumerSet(object): + """Groups consumers to listen on together on a single connection.""" + + def __init__(self, connection, consumer_list): + self.consumer_list = set(consumer_list) + self.consumer_set = None + self.enabled = True + self.init(connection) + + def init(self, conn): + if not conn: + conn = Connection.instance(new=True) + if self.consumer_set: + self.consumer_set.close() + self.consumer_set = messaging.ConsumerSet(conn) + for consumer in self.consumer_list: + consumer.connection = conn + # consumer.backend is set for us + self.consumer_set.add_consumer(consumer) + + def reconnect(self): + self.init(None) + + def wait(self, limit=None): + running = True + while running: + it = self.consumer_set.iterconsume(limit=limit) + if not it: + break + while True: + try: + it.next() + except StopIteration: + return + except greenlet.GreenletExit: + running = False + break + except Exception as e: + LOG.exception(_("Exception while processing consumer")) + self.reconnect() + # Break to outer loop + break + + def close(self): + self.consumer_set.close() + + +class Publisher(messaging.Publisher): + """Publisher base class.""" + pass + + class TopicPublisher(Publisher): """Publishes messages on a specific topic.""" @@ -306,16 +397,18 @@ def msg_reply(msg_id, reply=None, failure=None): LOG.error(_("Returning exception %s to caller"), message) LOG.error(tb) failure = (failure[0].__name__, str(failure[1]), tb) - conn = Connection.instance() - publisher = DirectPublisher(connection=conn, msg_id=msg_id) - try: - publisher.send({'result': reply, 'failure': failure}) - except TypeError: - publisher.send( - {'result': dict((k, repr(v)) - for k, v in reply.__dict__.iteritems()), - 'failure': failure}) - publisher.close() + + with ConnectionPool.item() as conn: + publisher = DirectPublisher(connection=conn, msg_id=msg_id) + try: + publisher.send({'result': reply, 'failure': failure}) + except TypeError: + publisher.send( + {'result': dict((k, repr(v)) + for k, v in reply.__dict__.iteritems()), + 'failure': failure}) + + publisher.close() class RemoteError(exception.Error): @@ -347,8 +440,9 @@ def _unpack_context(msg): if key.startswith('_context_'): value = msg.pop(key) context_dict[key[9:]] = value + context_dict['msg_id'] = msg.pop('_msg_id', None) LOG.debug(_('unpacked context: %s'), context_dict) - return context.RequestContext.from_dict(context_dict) + return RpcContext.from_dict(context_dict) def _pack_context(msg, context): @@ -360,70 +454,112 @@ def _pack_context(msg, context): for args at some point. """ - context = dict([('_context_%s' % key, value) - for (key, value) in context.to_dict().iteritems()]) - msg.update(context) + context_d = dict([('_context_%s' % key, value) + for (key, value) in context.to_dict().iteritems()]) + msg.update(context_d) -def call(context, topic, msg): - """Sends a message on a topic and wait for a response.""" +class RpcContext(context.RequestContext): + def __init__(self, *args, **kwargs): + msg_id = kwargs.pop('msg_id', None) + self.msg_id = msg_id + super(RpcContext, self).__init__(*args, **kwargs) + + def reply(self, *args, **kwargs): + msg_reply(self.msg_id, *args, **kwargs) + + +def multicall(context, topic, msg): + """Make a call that returns multiple times.""" LOG.debug(_('Making asynchronous call on %s ...'), topic) msg_id = uuid.uuid4().hex msg.update({'_msg_id': msg_id}) LOG.debug(_('MSG_ID is %s') % (msg_id)) _pack_context(msg, context) - class WaitMessage(object): - def __call__(self, data, message): - """Acks message and sets result.""" - message.ack() - if data['failure']: - self.result = RemoteError(*data['failure']) - else: - self.result = data['result'] - - wait_msg = WaitMessage() - conn = Connection.instance() - consumer = DirectConsumer(connection=conn, msg_id=msg_id) + con_conn = ConnectionPool.get() + consumer = DirectConsumer(connection=con_conn, msg_id=msg_id) + wait_msg = MulticallWaiter(consumer) consumer.register_callback(wait_msg) - conn = Connection.instance() - publisher = TopicPublisher(connection=conn, topic=topic) + publisher = TopicPublisher(connection=con_conn, topic=topic) publisher.send(msg) publisher.close() - try: - consumer.wait(limit=1) - except StopIteration: - pass - consumer.close() - # NOTE(termie): this is a little bit of a change from the original - # non-eventlet code where returning a Failure - # instance from a deferred call is very similar to - # raising an exception - if isinstance(wait_msg.result, Exception): - raise wait_msg.result - return wait_msg.result + return wait_msg + + +class MulticallWaiter(object): + def __init__(self, consumer): + self._consumer = consumer + self._results = queue.Queue() + self._closed = False + + def close(self): + self._closed = True + self._consumer.close() + ConnectionPool.put(self._consumer.connection) + + def __call__(self, data, message): + """Acks message and sets result.""" + message.ack() + if data['failure']: + self._results.put(RemoteError(*data['failure'])) + else: + self._results.put(data['result']) + + def __iter__(self): + return self.wait() + + def wait(self): + while True: + rv = None + while rv is None and not self._closed: + try: + rv = self._consumer.fetch(enable_callbacks=True) + except Exception: + self.close() + raise + time.sleep(0.01) + + result = self._results.get() + if isinstance(result, Exception): + self.close() + raise result + if result == None: + self.close() + raise StopIteration + yield result + + +def call(context, topic, msg): + """Sends a message on a topic and wait for a response.""" + rv = multicall(context, topic, msg) + # NOTE(vish): return the last result from the multicall + rv = list(rv) + if not rv: + return + return rv[-1] def cast(context, topic, msg): """Sends a message on a topic without waiting for a response.""" LOG.debug(_('Making asynchronous cast on %s...'), topic) _pack_context(msg, context) - conn = Connection.instance() - publisher = TopicPublisher(connection=conn, topic=topic) - publisher.send(msg) - publisher.close() + with ConnectionPool.item() as conn: + publisher = TopicPublisher(connection=conn, topic=topic) + publisher.send(msg) + publisher.close() def fanout_cast(context, topic, msg): """Sends a message on a fanout exchange without waiting for a response.""" LOG.debug(_('Making asynchronous fanout cast...')) _pack_context(msg, context) - conn = Connection.instance() - publisher = FanoutPublisher(topic, connection=conn) - publisher.send(msg) - publisher.close() + with ConnectionPool.item() as conn: + publisher = FanoutPublisher(topic, connection=conn) + publisher.send(msg) + publisher.close() def generic_response(message_data, message): @@ -459,6 +595,7 @@ def send_message(topic, message, wait=True): if wait: consumer.wait() + consumer.close() if __name__ == '__main__': diff --git a/nova/service.py b/nova/service.py index ab1238c3b..74f9f04d8 100644 --- a/nova/service.py +++ b/nova/service.py @@ -19,14 +19,11 @@ """Generic Node baseclass for all workers that run on hosts.""" +import greenlet import inspect import os -import sys -import time -from eventlet import event from eventlet import greenthread -from eventlet import greenpool from nova import context from nova import db @@ -91,27 +88,37 @@ class Service(object): if 'nova-compute' == self.binary: self.manager.update_available_resource(ctxt) - conn1 = rpc.Connection.instance(new=True) - conn2 = rpc.Connection.instance(new=True) - conn3 = rpc.Connection.instance(new=True) - if self.report_interval: - consumer_all = rpc.TopicAdapterConsumer( - connection=conn1, - topic=self.topic, - proxy=self) - consumer_node = rpc.TopicAdapterConsumer( - connection=conn2, - topic='%s.%s' % (self.topic, self.host), - proxy=self) - fanout = rpc.FanoutAdapterConsumer( - connection=conn3, - topic=self.topic, - proxy=self) - - self.timers.append(consumer_all.attach_to_eventlet()) - self.timers.append(consumer_node.attach_to_eventlet()) - self.timers.append(fanout.attach_to_eventlet()) + self.conn = rpc.Connection.instance(new=True) + logging.debug("Creating Consumer connection for Service %s" % + self.topic) + + # Share this same connection for these Consumers + consumer_all = rpc.TopicAdapterConsumer( + connection=self.conn, + topic=self.topic, + proxy=self) + consumer_node = rpc.TopicAdapterConsumer( + connection=self.conn, + topic='%s.%s' % (self.topic, self.host), + proxy=self) + fanout = rpc.FanoutAdapterConsumer( + connection=self.conn, + topic=self.topic, + proxy=self) + consumer_set = rpc.ConsumerSet( + connection=self.conn, + consumer_list=[consumer_all, consumer_node, fanout]) + + # Wait forever, processing these consumers + def _wait(): + try: + consumer_set.wait() + finally: + consumer_set.close() + + self.consumer_set_thread = greenthread.spawn(_wait) + if self.report_interval: pulse = utils.LoopingCall(self.report_state) pulse.start(interval=self.report_interval, now=False) self.timers.append(pulse) @@ -174,6 +181,11 @@ class Service(object): logging.warn(_('Service killed that has no database entry')) def stop(self): + self.consumer_set_thread.kill() + try: + self.consumer_set_thread.wait() + except greenlet.GreenletExit: + pass for x in self.timers: try: x.stop() diff --git a/nova/test.py b/nova/test.py index 4deb2a175..80b2d0a74 100644 --- a/nova/test.py +++ b/nova/test.py @@ -31,17 +31,15 @@ import uuid import unittest import mox -import shutil import stubout from eventlet import greenthread -from nova import context -from nova import db from nova import fakerabbit from nova import flags from nova import rpc from nova import service from nova import wsgi +from nova.virt import fake FLAGS = flags.FLAGS @@ -85,6 +83,7 @@ class TestCase(unittest.TestCase): self._monkey_patch_attach() self._monkey_patch_wsgi() self._original_flags = FLAGS.FlagValuesDict() + rpc.ConnectionPool = rpc.Pool(max_size=FLAGS.rpc_conn_pool_size) def tearDown(self): """Runs after each test method to tear down test environment.""" @@ -99,6 +98,10 @@ class TestCase(unittest.TestCase): if FLAGS.fake_rabbit: fakerabbit.reset_all() + if FLAGS.connection_type == 'fake': + if hasattr(fake.FakeConnection, '_instance'): + del fake.FakeConnection._instance + # Reset any overriden flags self.reset_flags() diff --git a/nova/tests/integrated/integrated_helpers.py b/nova/tests/integrated/integrated_helpers.py index bc98921f0..7f590441e 100644 --- a/nova/tests/integrated/integrated_helpers.py +++ b/nova/tests/integrated/integrated_helpers.py @@ -154,10 +154,7 @@ class _IntegratedTestBase(test.TestCase): # set up services self.start_service('compute') self.start_service('volume') - # NOTE(justinsb): There's a bug here which is eluding me... - # If we start the network_service, all is good, but then subsequent - # tests fail: CloudTestCase.test_ajax_console in particular. - #self.start_service('network') + self.start_service('network') self.start_service('scheduler') self._start_api_service() diff --git a/nova/tests/test_cloud.py b/nova/tests/test_cloud.py index 54c0454de..b64be662e 100644 --- a/nova/tests/test_cloud.py +++ b/nova/tests/test_cloud.py @@ -17,13 +17,9 @@ # under the License. from base64 import b64decode -import json from M2Crypto import BIO from M2Crypto import RSA import os -import shutil -import tempfile -import time from eventlet import greenthread @@ -33,12 +29,10 @@ from nova import db from nova import flags from nova import log as logging from nova import rpc -from nova import service from nova import test from nova import utils from nova import exception from nova.auth import manager -from nova.compute import power_state from nova.api.ec2 import cloud from nova.api.ec2 import ec2utils from nova.image import local @@ -79,14 +73,21 @@ class CloudTestCase(test.TestCase): self.stubs.Set(local.LocalImageService, 'show', fake_show) self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show) + # NOTE(vish): set up a manual wait so rpc.cast has a chance to finish + rpc_cast = rpc.cast + + def finish_cast(*args, **kwargs): + rpc_cast(*args, **kwargs) + greenthread.sleep(0.2) + + self.stubs.Set(rpc, 'cast', finish_cast) + def tearDown(self): network_ref = db.project_get_network(self.context, self.project.id) db.network_disassociate(self.context, network_ref['id']) self.manager.delete_project(self.project) self.manager.delete_user(self.user) - self.compute.kill() - self.network.kill() super(CloudTestCase, self).tearDown() def _create_key(self, name): @@ -113,7 +114,6 @@ class CloudTestCase(test.TestCase): self.cloud.describe_addresses(self.context) self.cloud.release_address(self.context, public_ip=address) - greenthread.sleep(0.3) db.floating_ip_destroy(self.context, address) def test_associate_disassociate_address(self): @@ -129,12 +129,10 @@ class CloudTestCase(test.TestCase): self.cloud.associate_address(self.context, instance_id=ec2_id, public_ip=address) - greenthread.sleep(0.3) self.cloud.disassociate_address(self.context, public_ip=address) self.cloud.release_address(self.context, public_ip=address) - greenthread.sleep(0.3) self.network.deallocate_fixed_ip(self.context, fixed) db.instance_destroy(self.context, inst['id']) db.floating_ip_destroy(self.context, address) @@ -306,31 +304,25 @@ class CloudTestCase(test.TestCase): 'instance_type': instance_type, 'max_count': max_count} rv = self.cloud.run_instances(self.context, **kwargs) - greenthread.sleep(0.3) instance_id = rv['instancesSet'][0]['instanceId'] output = self.cloud.get_console_output(context=self.context, instance_id=[instance_id]) self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT') # TODO(soren): We need this until we can stop polling in the rpc code # for unit tests. - greenthread.sleep(0.3) rv = self.cloud.terminate_instances(self.context, [instance_id]) - greenthread.sleep(0.3) def test_ajax_console(self): kwargs = {'image_id': 'ami-1'} rv = self.cloud.run_instances(self.context, **kwargs) instance_id = rv['instancesSet'][0]['instanceId'] - greenthread.sleep(0.3) output = self.cloud.get_ajax_console(context=self.context, instance_id=[instance_id]) self.assertEquals(output['url'], '%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url) # TODO(soren): We need this until we can stop polling in the rpc code # for unit tests. - greenthread.sleep(0.3) rv = self.cloud.terminate_instances(self.context, [instance_id]) - greenthread.sleep(0.3) def test_key_generation(self): result = self._create_key('test') diff --git a/nova/tests/test_rpc.py b/nova/tests/test_rpc.py index 44d7c91eb..ffd748efe 100644 --- a/nova/tests/test_rpc.py +++ b/nova/tests/test_rpc.py @@ -31,7 +31,6 @@ LOG = logging.getLogger('nova.tests.rpc') class RpcTestCase(test.TestCase): - """Test cases for rpc""" def setUp(self): super(RpcTestCase, self).setUp() self.conn = rpc.Connection.instance(True) @@ -43,14 +42,55 @@ class RpcTestCase(test.TestCase): self.context = context.get_admin_context() def test_call_succeed(self): - """Get a value through rpc call""" value = 42 result = rpc.call(self.context, 'test', {"method": "echo", "args": {"value": value}}) self.assertEqual(value, result) + def test_call_succeed_despite_multiple_returns(self): + value = 42 + result = rpc.call(self.context, 'test', {"method": "echo_three_times", + "args": {"value": value}}) + self.assertEqual(value + 2, result) + + def test_call_succeed_despite_multiple_returns_yield(self): + value = 42 + result = rpc.call(self.context, 'test', + {"method": "echo_three_times_yield", + "args": {"value": value}}) + self.assertEqual(value + 2, result) + + def test_multicall_succeed_once(self): + value = 42 + result = rpc.multicall(self.context, + 'test', + {"method": "echo", + "args": {"value": value}}) + for i, x in enumerate(result): + if i > 0: + self.fail('should only receive one response') + self.assertEqual(value + i, x) + + def test_multicall_succeed_three_times(self): + value = 42 + result = rpc.multicall(self.context, + 'test', + {"method": "echo_three_times", + "args": {"value": value}}) + for i, x in enumerate(result): + self.assertEqual(value + i, x) + + def test_multicall_succeed_three_times_yield(self): + value = 42 + result = rpc.multicall(self.context, + 'test', + {"method": "echo_three_times_yield", + "args": {"value": value}}) + for i, x in enumerate(result): + self.assertEqual(value + i, x) + def test_context_passed(self): - """Makes sure a context is passed through rpc call""" + """Makes sure a context is passed through rpc call.""" value = 42 result = rpc.call(self.context, 'test', {"method": "context", @@ -58,11 +98,12 @@ class RpcTestCase(test.TestCase): self.assertEqual(self.context.to_dict(), result) def test_call_exception(self): - """Test that exception gets passed back properly + """Test that exception gets passed back properly. rpc.call returns a RemoteError object. The value of the exception is converted to a string, so we convert it back to an int in the test. + """ value = 42 self.assertRaises(rpc.RemoteError, @@ -81,7 +122,7 @@ class RpcTestCase(test.TestCase): self.assertEqual(int(exc.value), value) def test_nested_calls(self): - """Test that we can do an rpc.call inside another call""" + """Test that we can do an rpc.call inside another call.""" class Nested(object): @staticmethod def echo(context, queue, value): @@ -108,25 +149,80 @@ class RpcTestCase(test.TestCase): "value": value}}) self.assertEqual(value, result) + def test_connectionpool_single(self): + """Test that ConnectionPool recycles a single connection.""" + conn1 = rpc.ConnectionPool.get() + rpc.ConnectionPool.put(conn1) + conn2 = rpc.ConnectionPool.get() + rpc.ConnectionPool.put(conn2) + self.assertEqual(conn1, conn2) + + def test_connectionpool_double(self): + """Test that ConnectionPool returns and reuses separate connections. + + When called consecutively we should get separate connections and upon + returning them those connections should be reused for future calls + before generating a new connection. + + """ + conn1 = rpc.ConnectionPool.get() + conn2 = rpc.ConnectionPool.get() + + self.assertNotEqual(conn1, conn2) + rpc.ConnectionPool.put(conn1) + rpc.ConnectionPool.put(conn2) + + conn3 = rpc.ConnectionPool.get() + conn4 = rpc.ConnectionPool.get() + self.assertEqual(conn1, conn3) + self.assertEqual(conn2, conn4) + + def test_connectionpool_limit(self): + """Test connection pool limit and connection uniqueness.""" + max_size = FLAGS.rpc_conn_pool_size + conns = [] + + for i in xrange(max_size): + conns.append(rpc.ConnectionPool.get()) + + self.assertFalse(rpc.ConnectionPool.free_items) + self.assertEqual(rpc.ConnectionPool.current_size, + rpc.ConnectionPool.max_size) + self.assertEqual(len(set(conns)), max_size) + class TestReceiver(object): - """Simple Proxy class so the consumer has methods to call + """Simple Proxy class so the consumer has methods to call. + + Uses static methods because we aren't actually storing any state. - Uses static methods because we aren't actually storing any state""" + """ @staticmethod def echo(context, value): - """Simply returns whatever value is sent in""" + """Simply returns whatever value is sent in.""" LOG.debug(_("Received %s"), value) return value @staticmethod def context(context, value): - """Returns dictionary version of context""" + """Returns dictionary version of context.""" LOG.debug(_("Received %s"), context) return context.to_dict() @staticmethod + def echo_three_times(context, value): + context.reply(value) + context.reply(value + 1) + context.reply(value + 2) + + @staticmethod + def echo_three_times_yield(context, value): + yield value + yield value + 1 + yield value + 2 + + @staticmethod def fail(context, value): - """Raises an exception with the value sent in""" + """Raises an exception with the value sent in.""" raise Exception(value) diff --git a/nova/tests/test_service.py b/nova/tests/test_service.py index d48de2057..d1cc8bd61 100644 --- a/nova/tests/test_service.py +++ b/nova/tests/test_service.py @@ -106,7 +106,10 @@ class ServiceTestCase(test.TestCase): # NOTE(vish): Create was moved out of mox replay to make sure that # the looping calls are created in StartService. - app = service.Service.create(host=host, binary=binary) + app = service.Service.create(host=host, binary=binary, topic=topic) + + self.mox.StubOutWithMock(service.rpc.Connection, 'instance') + service.rpc.Connection.instance(new=mox.IgnoreArg()) self.mox.StubOutWithMock(rpc, 'TopicAdapterConsumer', @@ -114,6 +117,11 @@ class ServiceTestCase(test.TestCase): self.mox.StubOutWithMock(rpc, 'FanoutAdapterConsumer', use_mock_anything=True) + + self.mox.StubOutWithMock(rpc, + 'ConsumerSet', + use_mock_anything=True) + rpc.TopicAdapterConsumer(connection=mox.IgnoreArg(), topic=topic, proxy=mox.IsA(service.Service)).AndReturn( @@ -129,9 +137,14 @@ class ServiceTestCase(test.TestCase): proxy=mox.IsA(service.Service)).AndReturn( rpc.FanoutAdapterConsumer) - rpc.TopicAdapterConsumer.attach_to_eventlet() - rpc.TopicAdapterConsumer.attach_to_eventlet() - rpc.FanoutAdapterConsumer.attach_to_eventlet() + def wait_func(self, limit=None): + return None + + mock_cset = self.mox.CreateMock(rpc.ConsumerSet, + {'wait': wait_func}) + rpc.ConsumerSet(connection=mox.IgnoreArg(), + consumer_list=mox.IsA(list)).AndReturn(mock_cset) + wait_func(mox.IgnoreArg()) service_create = {'host': host, 'binary': binary, @@ -287,8 +300,42 @@ class ServiceTestCase(test.TestCase): # Creating mocks self.mox.StubOutWithMock(service.rpc.Connection, 'instance') service.rpc.Connection.instance(new=mox.IgnoreArg()) - service.rpc.Connection.instance(new=mox.IgnoreArg()) - service.rpc.Connection.instance(new=mox.IgnoreArg()) + + self.mox.StubOutWithMock(rpc, + 'TopicAdapterConsumer', + use_mock_anything=True) + self.mox.StubOutWithMock(rpc, + 'FanoutAdapterConsumer', + use_mock_anything=True) + + self.mox.StubOutWithMock(rpc, + 'ConsumerSet', + use_mock_anything=True) + + rpc.TopicAdapterConsumer(connection=mox.IgnoreArg(), + topic=topic, + proxy=mox.IsA(service.Service)).AndReturn( + rpc.TopicAdapterConsumer) + + rpc.TopicAdapterConsumer(connection=mox.IgnoreArg(), + topic='%s.%s' % (topic, host), + proxy=mox.IsA(service.Service)).AndReturn( + rpc.TopicAdapterConsumer) + + rpc.FanoutAdapterConsumer(connection=mox.IgnoreArg(), + topic=topic, + proxy=mox.IsA(service.Service)).AndReturn( + rpc.FanoutAdapterConsumer) + + def wait_func(self, limit=None): + return None + + mock_cset = self.mox.CreateMock(rpc.ConsumerSet, + {'wait': wait_func}) + rpc.ConsumerSet(connection=mox.IgnoreArg(), + consumer_list=mox.IsA(list)).AndReturn(mock_cset) + wait_func(mox.IgnoreArg()) + self.mox.StubOutWithMock(serv.manager.driver, 'update_available_resource') serv.manager.driver.update_available_resource(mox.IgnoreArg(), host) diff --git a/nova/tests/test_xenapi.py b/nova/tests/test_xenapi.py index be1e35697..18a267896 100644 --- a/nova/tests/test_xenapi.py +++ b/nova/tests/test_xenapi.py @@ -395,6 +395,29 @@ class XenAPIVMTestCase(test.TestCase): os_type="linux") self.check_vm_params_for_linux() + def test_spawn_vhd_glance_swapdisk(self): + # Change the default host_call_plugin to one that'll return + # a swap disk + orig_func = stubs.FakeSessionForVMTests.host_call_plugin + + stubs.FakeSessionForVMTests.host_call_plugin = \ + stubs.FakeSessionForVMTests.host_call_plugin_swap + + try: + # We'll steal the above glance linux test + self.test_spawn_vhd_glance_linux() + finally: + # Make sure to put this back + stubs.FakeSessionForVMTests.host_call_plugin = orig_func + + # We should have 2 VBDs. + self.assertEqual(len(self.vm['VBDs']), 2) + # Now test that we have 1. + self.tearDown() + self.setUp() + self.test_spawn_vhd_glance_linux() + self.assertEqual(len(self.vm['VBDs']), 1) + def test_spawn_vhd_glance_windows(self): FLAGS.xenapi_image_service = 'glance' self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None, diff --git a/nova/tests/xenapi/stubs.py b/nova/tests/xenapi/stubs.py index 4833ccb07..35308d95f 100644 --- a/nova/tests/xenapi/stubs.py +++ b/nova/tests/xenapi/stubs.py @@ -17,6 +17,7 @@ """Stubouts, mocks and fixtures for the test suite""" import eventlet +import json from nova.virt import xenapi_conn from nova.virt.xenapi import fake from nova.virt.xenapi import volume_utils @@ -37,7 +38,7 @@ def stubout_instance_snapshot(stubs): sr_ref=sr_ref, sharable=False) vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref) vdi_uuid = vdi_rec['uuid'] - return vdi_uuid + return [dict(vdi_type='os', vdi_uuid=vdi_uuid)] stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image) @@ -132,11 +133,30 @@ class FakeSessionForVMTests(fake.SessionBase): def __init__(self, uri): super(FakeSessionForVMTests, self).__init__(uri) - def host_call_plugin(self, _1, _2, _3, _4, _5): + def host_call_plugin(self, _1, _2, plugin, method, _5): + sr_ref = fake.get_all('SR')[0] + vdi_ref = fake.create_vdi('', False, sr_ref, False) + vdi_rec = fake.get_record('VDI', vdi_ref) + if plugin == "glance" and method == "download_vhd": + ret_str = json.dumps([dict(vdi_type='os', + vdi_uuid=vdi_rec['uuid'])]) + else: + ret_str = vdi_rec['uuid'] + return '<string>%s</string>' % ret_str + + def host_call_plugin_swap(self, _1, _2, plugin, method, _5): sr_ref = fake.get_all('SR')[0] vdi_ref = fake.create_vdi('', False, sr_ref, False) vdi_rec = fake.get_record('VDI', vdi_ref) - return '<string>%s</string>' % vdi_rec['uuid'] + if plugin == "glance" and method == "download_vhd": + swap_vdi_ref = fake.create_vdi('', False, sr_ref, False) + swap_vdi_rec = fake.get_record('VDI', swap_vdi_ref) + ret_str = json.dumps( + [dict(vdi_type='os', vdi_uuid=vdi_rec['uuid']), + dict(vdi_type='swap', vdi_uuid=swap_vdi_rec['uuid'])]) + else: + ret_str = vdi_rec['uuid'] + return '<string>%s</string>' % ret_str def VM_start(self, _1, ref, _2, _3): vm = fake.get_record('VM', ref) diff --git a/nova/virt/xenapi/fake.py b/nova/virt/xenapi/fake.py index e36ef3288..76988b172 100644 --- a/nova/virt/xenapi/fake.py +++ b/nova/virt/xenapi/fake.py @@ -159,7 +159,10 @@ def after_VBD_create(vbd_ref, vbd_rec): vbd_rec['device'] = '' vm_ref = vbd_rec['VM'] vm_rec = _db_content['VM'][vm_ref] - vm_rec['VBDs'] = [vbd_ref] + if vm_rec.get('VBDs', None): + vm_rec['VBDs'].append(vbd_ref) + else: + vm_rec['VBDs'] = [vbd_ref] vm_name_label = _db_content['VM'][vm_ref]['name_label'] vbd_rec['vm_name_label'] = vm_name_label diff --git a/nova/virt/xenapi/vm_utils.py b/nova/virt/xenapi/vm_utils.py index 9f6cd608c..06ee8ee9b 100644 --- a/nova/virt/xenapi/vm_utils.py +++ b/nova/virt/xenapi/vm_utils.py @@ -19,6 +19,7 @@ Helper methods for operations related to the management of VM records and their attributes like VDIs, VIFs, as well as their lookup functions. """ +import json import os import pickle import re @@ -376,6 +377,9 @@ class VMHelper(HelperBase): xenapi_image_service = ['glance', 'objectstore'] glance_address = 'address for glance services' glance_port = 'port for glance services' + + Returns: A single filename if image_type is KERNEL_RAMDISK + A list of dictionaries that describe VDIs, otherwise """ access = AuthManager().get_access_key(user, project) @@ -390,6 +394,10 @@ class VMHelper(HelperBase): @classmethod def _fetch_image_glance_vhd(cls, session, instance_id, image, access, image_type): + """Tell glance to download an image and put the VHDs into the SR + + Returns: A list of dictionaries that describe VDIs + """ LOG.debug(_("Asking xapi to fetch vhd image %(image)s") % locals()) @@ -408,18 +416,26 @@ class VMHelper(HelperBase): kwargs = {'params': pickle.dumps(params)} task = session.async_call_plugin('glance', 'download_vhd', kwargs) - vdi_uuid = session.wait_for_task(task, instance_id) + result = session.wait_for_task(task, instance_id) + # 'download_vhd' will return a json encoded string containing + # a list of dictionaries describing VDIs. The dictionary will + # contain 'vdi_type' and 'vdi_uuid' keys. 'vdi_type' can be + # 'os' or 'swap' right now. + vdis = json.loads(result) + for vdi in vdis: + LOG.debug(_("xapi 'download_vhd' returned VDI of " + "type '%(vdi_type)s' with UUID '%(vdi_uuid)s'" % vdi)) cls.scan_sr(session, instance_id, sr_ref) + # Pull out the UUID of the first VDI + vdi_uuid = vdis[0]['vdi_uuid'] # Set the name-label to ease debugging vdi_ref = session.get_xenapi().VDI.get_by_uuid(vdi_uuid) - name_label = get_name_label_for_image(image) - session.get_xenapi().VDI.set_name_label(vdi_ref, name_label) + primary_name_label = get_name_label_for_image(image) + session.get_xenapi().VDI.set_name_label(vdi_ref, primary_name_label) - LOG.debug(_("xapi 'download_vhd' returned VDI UUID %(vdi_uuid)s") - % locals()) - return vdi_uuid + return vdis @classmethod def _fetch_image_glance_disk(cls, session, instance_id, image, access, @@ -431,6 +447,8 @@ class VMHelper(HelperBase): plugin; instead, it streams the disks through domU to the VDI directly. + Returns: A single filename if image_type is KERNEL_RAMDISK + A list of dictionaries that describe VDIs, otherwise """ # FIXME(sirp): Since the Glance plugin seems to be required for the # VHD disk, it may be worth using the plugin for both VHD and RAW and @@ -476,7 +494,8 @@ class VMHelper(HelperBase): LOG.debug(_("Kernel/Ramdisk VDI %s destroyed"), vdi_ref) return filename else: - return session.get_xenapi().VDI.get_uuid(vdi_ref) + vdi_uuid = session.get_xenapi().VDI.get_uuid(vdi_ref) + return [dict(vdi_type='os', vdi_uuid=vdi_uuid)] @classmethod def determine_disk_image_type(cls, instance): @@ -535,6 +554,11 @@ class VMHelper(HelperBase): @classmethod def _fetch_image_glance(cls, session, instance_id, image, access, image_type): + """Fetch image from glance based on image type. + + Returns: A single filename if image_type is KERNEL_RAMDISK + A list of dictionaries that describe VDIs, otherwise + """ if image_type == ImageType.DISK_VHD: return cls._fetch_image_glance_vhd( session, instance_id, image, access, image_type) @@ -545,6 +569,11 @@ class VMHelper(HelperBase): @classmethod def _fetch_image_objectstore(cls, session, instance_id, image, access, secret, image_type): + """Fetch an image from objectstore. + + Returns: A single filename if image_type is KERNEL_RAMDISK + A list of dictionaries that describe VDIs, otherwise + """ url = images.image_url(image) LOG.debug(_("Asking xapi to fetch %(url)s as %(access)s") % locals()) if image_type == ImageType.KERNEL_RAMDISK: @@ -562,8 +591,10 @@ class VMHelper(HelperBase): if image_type == ImageType.DISK_RAW: args['raw'] = 'true' task = session.async_call_plugin('objectstore', fn, args) - uuid = session.wait_for_task(task, instance_id) - return uuid + uuid_or_fn = session.wait_for_task(task, instance_id) + if image_type != ImageType.KERNEL_RAMDISK: + return [dict(vdi_type='os', vdi_uuid=uuid_or_fn)] + return uuid_or_fn @classmethod def determine_is_pv(cls, session, instance_id, vdi_ref, disk_image_type, diff --git a/nova/virt/xenapi/vmops.py b/nova/virt/xenapi/vmops.py index be6ef48ea..6d516ddbc 100644 --- a/nova/virt/xenapi/vmops.py +++ b/nova/virt/xenapi/vmops.py @@ -91,7 +91,8 @@ class VMOps(object): def finish_resize(self, instance, disk_info): vdi_uuid = self.link_disks(instance, disk_info['base_copy'], disk_info['cow']) - vm_ref = self._create_vm(instance, vdi_uuid) + vm_ref = self._create_vm(instance, + [dict(vdi_type='os', vdi_uuid=vdi_uuid)]) self.resize_instance(instance, vdi_uuid) self._spawn(instance, vm_ref) @@ -105,24 +106,25 @@ class VMOps(object): LOG.debug(_("Starting instance %s"), instance.name) self._session.call_xenapi('VM.start', vm_ref, False, False) - def _create_disk(self, instance): + def _create_disks(self, instance): user = AuthManager().get_user(instance.user_id) project = AuthManager().get_project(instance.project_id) disk_image_type = VMHelper.determine_disk_image_type(instance) - vdi_uuid = VMHelper.fetch_image(self._session, instance.id, - instance.image_id, user, project, disk_image_type) - return vdi_uuid + vdis = VMHelper.fetch_image(self._session, + instance.id, instance.image_id, user, project, + disk_image_type) + return vdis def spawn(self, instance, network_info=None): - vdi_uuid = self._create_disk(instance) - vm_ref = self._create_vm(instance, vdi_uuid, network_info) + vdis = self._create_disks(instance) + vm_ref = self._create_vm(instance, vdis, network_info) self._spawn(instance, vm_ref) def spawn_rescue(self, instance): """Spawn a rescue instance.""" self.spawn(instance) - def _create_vm(self, instance, vdi_uuid, network_info=None): + def _create_vm(self, instance, vdis, network_info=None): """Create VM instance.""" instance_name = instance.name vm_ref = VMHelper.lookup(self._session, instance_name) @@ -141,28 +143,43 @@ class VMOps(object): user = AuthManager().get_user(instance.user_id) project = AuthManager().get_project(instance.project_id) - # Are we building from a pre-existing disk? - vdi_ref = self._session.call_xenapi('VDI.get_by_uuid', vdi_uuid) - disk_image_type = VMHelper.determine_disk_image_type(instance) kernel = None if instance.kernel_id: kernel = VMHelper.fetch_image(self._session, instance.id, - instance.kernel_id, user, project, ImageType.KERNEL_RAMDISK) + instance.kernel_id, user, project, + ImageType.KERNEL_RAMDISK) ramdisk = None if instance.ramdisk_id: ramdisk = VMHelper.fetch_image(self._session, instance.id, - instance.ramdisk_id, user, project, ImageType.KERNEL_RAMDISK) - - use_pv_kernel = VMHelper.determine_is_pv(self._session, instance.id, - vdi_ref, disk_image_type, instance.os_type) - vm_ref = VMHelper.create_vm(self._session, instance, kernel, ramdisk, - use_pv_kernel) - + instance.ramdisk_id, user, project, + ImageType.KERNEL_RAMDISK) + + # Create the VM ref and attach the first disk + first_vdi_ref = self._session.call_xenapi('VDI.get_by_uuid', + vdis[0]['vdi_uuid']) + use_pv_kernel = VMHelper.determine_is_pv(self._session, + instance.id, first_vdi_ref, disk_image_type, + instance.os_type) + vm_ref = VMHelper.create_vm(self._session, instance, + kernel, ramdisk, use_pv_kernel) VMHelper.create_vbd(session=self._session, vm_ref=vm_ref, - vdi_ref=vdi_ref, userdevice=0, bootable=True) + vdi_ref=first_vdi_ref, userdevice=0, bootable=True) + + # Attach any other disks + # userdevice 1 is reserved for rescue + userdevice = 2 + for vdi in vdis[1:]: + # vdi['vdi_type'] is either 'os' or 'swap', but we don't + # really care what it is right here. + vdi_ref = self._session.call_xenapi('VDI.get_by_uuid', + vdi['vdi_uuid']) + VMHelper.create_vbd(session=self._session, vm_ref=vm_ref, + vdi_ref=vdi_ref, userdevice=userdevice, + bootable=False) + userdevice += 1 # TODO(tr3buchet) - check to make sure we have network info, otherwise # create it now. This goes away once nova-multi-nic hits. @@ -172,7 +189,7 @@ class VMOps(object): # Alter the image before VM start for, e.g. network injection if FLAGS.xenapi_inject_image: VMHelper.preconfigure_instance(self._session, instance, - vdi_ref, network_info) + first_vdi_ref, network_info) self.create_vifs(vm_ref, network_info) self.inject_network_info(instance, network_info, vm_ref) diff --git a/plugins/xenserver/xenapi/etc/xapi.d/plugins/glance b/plugins/xenserver/xenapi/etc/xapi.d/plugins/glance index 4b45671ae..0c00d168b 100644 --- a/plugins/xenserver/xenapi/etc/xapi.d/plugins/glance +++ b/plugins/xenserver/xenapi/etc/xapi.d/plugins/glance @@ -22,6 +22,10 @@ # import httplib +try: + import json +except ImportError: + import simplejson as json import os import os.path import pickle @@ -87,8 +91,8 @@ def _download_tarball(sr_path, staging_path, image_id, glance_host, conn.close() -def _fixup_vhds(sr_path, staging_path, uuid_stack): - """Fixup the downloaded VHDs before we move them into the SR. +def _import_vhds(sr_path, staging_path, uuid_stack): + """Import the VHDs found in the staging path. We cannot extract VHDs directly into the SR since they don't yet have UUIDs, aren't properly associated with each other, and would be subject to @@ -98,16 +102,25 @@ def _fixup_vhds(sr_path, staging_path, uuid_stack): To avoid these we problems, we use a staging area to fixup the VHDs before moving them into the SR. The steps involved are: - 1. Extracting tarball into staging area + 1. Extracting tarball into staging area (done prior to this call) 2. Renaming VHDs to use UUIDs ('snap.vhd' -> 'ffff-aaaa-...vhd') - 3. Linking the two VHDs together + 3. Linking VHDs together if there's a snap.vhd 4. Pseudo-atomically moving the images into the SR. (It's not really - atomic because it takes place as two os.rename operations; however, - the chances of an SR.scan occuring between the two rename() + atomic because it takes place as multiple os.rename operations; + however, the chances of an SR.scan occuring between the rename()s invocations is so small that we can safely ignore it) + + Returns: A list of VDIs. Each list element is a dictionary containing + information about the VHD. Dictionary keys are: + 1. "vdi_type" - The type of VDI. Currently they can be "os_disk" or + "swap" + 2. "vdi_uuid" - The UUID of the VDI + + Example return: [{"vdi_type": "os_disk","vdi_uuid": "ffff-aaa..vhd"}, + {"vdi_type": "swap","vdi_uuid": "ffff-bbb..vhd"}] """ def rename_with_uuid(orig_path): """Rename VHD using UUID so that it will be recognized by SR on a @@ -158,27 +171,59 @@ def _fixup_vhds(sr_path, staging_path, uuid_stack): "VHD %(path)s is marked as hidden without child" % locals()) - orig_base_copy_path = os.path.join(staging_path, 'image.vhd') - if not os.path.exists(orig_base_copy_path): + def prepare_if_exists(staging_path, vhd_name, parent_path=None): + """ + Check for existance of a particular VHD in the staging path and + preparing it for moving into the SR. + + Returns: Tuple of (Path to move into the SR, VDI_UUID) + None, if the vhd_name doesn't exist in the staging path + + If the VHD exists, we will do the following: + 1. Rename it with a UUID. + 2. If parent_path exists, we'll link up the VHDs. + """ + orig_path = os.path.join(staging_path, vhd_name) + if not os.path.exists(orig_path): + return None + new_path, vdi_uuid = rename_with_uuid(orig_path) + if parent_path: + # NOTE(sirp): this step is necessary so that an SR scan won't + # delete the base_copy out from under us (since it would be + # orphaned) + link_vhds(new_path, parent_path) + return (new_path, vdi_uuid) + + vdi_return_list = [] + paths_to_move = [] + + image_info = prepare_if_exists(staging_path, 'image.vhd') + if not image_info: raise Exception("Invalid image: image.vhd not present") - base_copy_path, base_copy_uuid = rename_with_uuid(orig_base_copy_path) - - vdi_uuid = base_copy_uuid - orig_snap_path = os.path.join(staging_path, 'snap.vhd') - if os.path.exists(orig_snap_path): - snap_path, snap_uuid = rename_with_uuid(orig_snap_path) - vdi_uuid = snap_uuid - # NOTE(sirp): this step is necessary so that an SR scan won't - # delete the base_copy out from under us (since it would be - # orphaned) - link_vhds(snap_path, base_copy_path) - move_into_sr(snap_path) + paths_to_move.append(image_info[0]) + + snap_info = prepare_if_exists(staging_path, 'snap.vhd', + image_info[0]) + if snap_info: + paths_to_move.append(snap_info[0]) + # We return this snap as the VDI instead of image.vhd + vdi_return_list.append(dict(vdi_type="os", vdi_uuid=snap_info[1])) else: - assert_vhd_not_hidden(base_copy_path) + assert_vhd_not_hidden(image_info[0]) + # If there's no snap, we return the image.vhd UUID + vdi_return_list.append(dict(vdi_type="os", vdi_uuid=image_info[1])) + + swap_info = prepare_if_exists(staging_path, 'swap.vhd') + if swap_info: + assert_vhd_not_hidden(swap_info[0]) + paths_to_move.append(swap_info[0]) + vdi_return_list.append(dict(vdi_type="swap", vdi_uuid=swap_info[1])) + + for path in paths_to_move: + move_into_sr(path) - move_into_sr(base_copy_path) - return vdi_uuid + return vdi_return_list def _prepare_staging_area_for_upload(sr_path, staging_path, vdi_uuids): @@ -324,8 +369,9 @@ def download_vhd(session, args): try: _download_tarball(sr_path, staging_path, image_id, glance_host, glance_port) - vdi_uuid = _fixup_vhds(sr_path, staging_path, uuid_stack) - return vdi_uuid + # Right now, it's easier to return a single string via XenAPI, + # so we'll json encode the list of VHDs. + return json.dumps(_import_vhds(sr_path, staging_path, uuid_stack)) finally: _cleanup_staging_area(staging_path) |
