summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--nova/tests/test_driver.py60
-rw-r--r--nova/virt/driver.py11
-rw-r--r--nova/virt/libvirt/driver.py9
3 files changed, 75 insertions, 5 deletions
diff --git a/nova/tests/test_driver.py b/nova/tests/test_driver.py
new file mode 100644
index 000000000..2dee7725f
--- /dev/null
+++ b/nova/tests/test_driver.py
@@ -0,0 +1,60 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright (c) 2013 Citrix Systems, Inc.
+# Copyright 2013 OpenStack LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from nova import test
+from nova.virt import driver
+
+
+class FakeDriver(object):
+ def __init__(self, *args, **kwargs):
+ self.args = args
+ self.kwargs = kwargs
+
+
+class FakeDriver2(FakeDriver):
+ pass
+
+
+class ToDriverRegistryTestCase(test.TestCase):
+
+ def assertDriverInstance(self, inst, class_, *args, **kwargs):
+ self.assertEquals(class_, inst.__class__)
+ self.assertEquals(args, inst.args)
+ self.assertEquals(kwargs, inst.kwargs)
+
+ def test_driver_dict_from_config(self):
+ drvs = driver.driver_dict_from_config(
+ [
+ 'key1=nova.tests.test_driver.FakeDriver',
+ 'key2=nova.tests.test_driver.FakeDriver2',
+ ], 'arg1', 'arg2', param1='value1', param2='value2'
+ )
+
+ self.assertEquals(
+ sorted(['key1', 'key2']),
+ sorted(drvs.keys())
+ )
+
+ self.assertDriverInstance(
+ drvs['key1'],
+ FakeDriver, 'arg1', 'arg2', param1='value1',
+ param2='value2')
+
+ self.assertDriverInstance(
+ drvs['key2'],
+ FakeDriver2, 'arg1', 'arg2', param1='value1',
+ param2='value2')
diff --git a/nova/virt/driver.py b/nova/virt/driver.py
index aa0439e74..f699d2011 100644
--- a/nova/virt/driver.py
+++ b/nova/virt/driver.py
@@ -49,6 +49,17 @@ CONF.register_opts(driver_opts)
LOG = logging.getLogger(__name__)
+def driver_dict_from_config(named_driver_config, *args, **kwargs):
+ driver_registry = dict()
+
+ for driver_str in named_driver_config:
+ driver_type, _sep, driver = driver_str.partition('=')
+ driver_class = importutils.import_class(driver)
+ driver_registry[driver_type] = driver_class(*args, **kwargs)
+
+ return driver_registry
+
+
def block_device_info_get_root(block_device_info):
block_device_info = block_device_info or {}
return block_device_info.get('root_device_name')
diff --git a/nova/virt/libvirt/driver.py b/nova/virt/libvirt/driver.py
index 597aa39a0..86afa1687 100644
--- a/nova/virt/libvirt/driver.py
+++ b/nova/virt/libvirt/driver.py
@@ -284,11 +284,10 @@ class LibvirtDriver(driver.ComputeDriver):
self.virtapi,
get_connection=self._get_connection)
self.vif_driver = importutils.import_object(CONF.libvirt_vif_driver)
- self.volume_drivers = {}
- for driver_str in CONF.libvirt_volume_drivers:
- driver_type, _sep, driver = driver_str.partition('=')
- driver_class = importutils.import_class(driver)
- self.volume_drivers[driver_type] = driver_class(self)
+
+ self.volume_drivers = driver.driver_dict_from_config(
+ CONF.libvirt_volume_drivers, self)
+
self._host_state = None
disk_prefix_map = {"lxc": "", "uml": "ubd", "xen": "sd"}