summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJenkins <jenkins@review.openstack.org>2012-10-27 23:43:33 +0000
committerGerrit Code Review <review@openstack.org>2012-10-27 23:43:33 +0000
commitfe8685cc2646d2f947a890f2013220995c78fa10 (patch)
tree9467ee15d6b0fa9875c65216a48f61a1aa13e756
parent0a6497de3510db28d91834d52e888d28a59e3f78 (diff)
parent21b0bf7eeed72dffa2628dc8ebb7d0c6e371cc1c (diff)
Merge "Make instance_get_all() not require admin context"
-rw-r--r--nova/db/sqlalchemy/api.py8
-rw-r--r--nova/tests/compute/test_compute.py38
2 files changed, 26 insertions, 20 deletions
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index 888bb83a5..da8c559c6 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -1519,7 +1519,7 @@ def _build_instance_get(context, session=None):
options(joinedload('instance_type'))
-@require_admin_context
+@require_context
def instance_get_all(context, columns_to_join=None):
if columns_to_join is None:
columns_to_join = ['info_cache', 'security_groups',
@@ -1527,6 +1527,12 @@ def instance_get_all(context, columns_to_join=None):
query = model_query(context, models.Instance)
for column in columns_to_join:
query = query.options(joinedload(column))
+ if not context.is_admin:
+ # If we're not admin context, add appropriate filter..
+ if context.project_id:
+ query = query.filter_by(project_id=context.project_id)
+ else:
+ query = query.filter_by(user_id=context.user_id)
return query.all()
diff --git a/nova/tests/compute/test_compute.py b/nova/tests/compute/test_compute.py
index 4963edc4f..6a419feea 100644
--- a/nova/tests/compute/test_compute.py
+++ b/nova/tests/compute/test_compute.py
@@ -277,7 +277,7 @@ class ComputeTestCase(BaseTestCase):
try:
self.compute.run_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
instance = instances[0]
self.assertTrue(instance.config_drive)
@@ -292,7 +292,7 @@ class ComputeTestCase(BaseTestCase):
try:
self.compute.run_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
instance = instances[0]
self.assertTrue(instance.config_drive)
@@ -521,7 +521,7 @@ class ComputeTestCase(BaseTestCase):
try:
self.compute.run_instance(self.context, instance=instance,
is_first_time=True)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
instance = instances[0]
self.assertEqual(instance.access_ip_v4, '192.168.1.100')
@@ -535,7 +535,7 @@ class ComputeTestCase(BaseTestCase):
try:
self.compute.run_instance(self.context, instance=instance,
is_first_time=True)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
instance = instances[0]
self.assertFalse(instance.access_ip_v4)
@@ -626,13 +626,13 @@ class ComputeTestCase(BaseTestCase):
self.compute.run_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
LOG.info(_("Running instances: %s"), instances)
self.assertEqual(len(instances), 1)
self.compute.terminate_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
LOG.info(_("After terminating instances: %s"), instances)
self.assertEqual(len(instances), 0)
@@ -644,7 +644,7 @@ class ComputeTestCase(BaseTestCase):
self.compute.run_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
LOG.info(_("Running instances: %s"), instances)
self.assertEqual(len(instances), 1)
@@ -667,7 +667,7 @@ class ComputeTestCase(BaseTestCase):
self.compute.terminate_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
LOG.info(_("After terminating instances: %s"), instances)
self.assertEqual(len(instances), 0)
bdms = db.block_device_mapping_get_all_by_instance(self.context,
@@ -680,7 +680,7 @@ class ComputeTestCase(BaseTestCase):
self.compute.run_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
LOG.info(_("Running instances: %s"), instances)
self.assertEqual(len(instances), 1)
@@ -693,7 +693,7 @@ class ComputeTestCase(BaseTestCase):
self.compute.terminate_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
LOG.info(_("After terminating instances: %s"), instances)
self.assertEqual(len(instances), 0)
@@ -705,7 +705,7 @@ class ComputeTestCase(BaseTestCase):
self.compute.run_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
LOG.info(_("Running instances: %s"), instances)
self.assertEqual(len(instances), 1)
@@ -723,7 +723,7 @@ class ComputeTestCase(BaseTestCase):
except TypeError:
pass
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
LOG.info(_("After terminating instances: %s"), instances)
self.assertEqual(len(instances), 1)
self.assertEqual(instances[0]['task_state'], 'deleting')
@@ -1201,7 +1201,7 @@ class ComputeTestCase(BaseTestCase):
def _assert_state(self, state_dict):
"""Assert state of VM is equal to state passed as parameter"""
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
self.assertEqual(len(instances), 1)
if 'vm_state' in state_dict:
@@ -2368,7 +2368,7 @@ class ComputeTestCase(BaseTestCase):
self.compute.run_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
+ instances = db.instance_get_all(self.context)
LOG.info(_("Running instances: %s"), instances)
self.assertEqual(len(instances), 1)
@@ -2379,7 +2379,7 @@ class ComputeTestCase(BaseTestCase):
ctxt = context.get_admin_context()
self.compute._sync_power_states(ctxt)
- instances = db.instance_get_all(ctxt)
+ instances = db.instance_get_all(self.context)
LOG.info(_("After force-killing instances: %s"), instances)
self.assertEqual(len(instances), 1)
self.assertEqual(task_states.STOPPING, instances[0]['task_state'])
@@ -2937,7 +2937,7 @@ class ComputeAPITestCase(BaseTestCase):
def test_create_instance_with_invalid_security_group_raises(self):
instance_type = instance_types.get_default_instance_type()
- pre_build_len = len(db.instance_get_all(context.get_admin_context()))
+ pre_build_len = len(db.instance_get_all(self.context))
self.assertRaises(exception.SecurityGroupNotFoundForProject,
self.compute_api.create,
self.context,
@@ -2945,7 +2945,7 @@ class ComputeAPITestCase(BaseTestCase):
image_href=None,
security_group=['this_is_a_fake_sec_group'])
self.assertEqual(pre_build_len,
- len(db.instance_get_all(context.get_admin_context())))
+ len(db.instance_get_all(self.context)))
def test_create_with_large_user_data(self):
"""Test an instance type with too much user data."""
@@ -4489,8 +4489,8 @@ class ComputeAPITestCase(BaseTestCase):
params={'architecture': ''}))
try:
self.compute.run_instance(self.context, instance=instance)
- instances = db.instance_get_all(context.get_admin_context())
- instance = instances[0]
+ instance = db.instance_get_by_uuid(self.context,
+ instance['uuid'])
self.assertNotEqual(instance['architecture'], 'Unknown')
finally:
db.instance_destroy(self.context, instance['uuid'])