summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--nova/compute/api.py7
-rw-r--r--nova/db/api.py5
-rw-r--r--nova/db/sqlalchemy/api.py16
-rw-r--r--nova/tests/api/openstack/compute/contrib/test_security_groups.py2
-rw-r--r--nova/tests/db/test_db_api.py30
5 files changed, 44 insertions, 16 deletions
diff --git a/nova/compute/api.py b/nova/compute/api.py
index 0a9b0e67b..70e205dc2 100644
--- a/nova/compute/api.py
+++ b/nova/compute/api.py
@@ -3216,7 +3216,8 @@ class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase):
def trigger_rules_refresh(self, context, id):
"""Called when a rule is added to or removed from a security_group."""
- security_group = self.db.security_group_get(context, id)
+ security_group = self.db.security_group_get(
+ context, id, columns_to_join=['instances'])
for instance in security_group['instances']:
if instance['host'] is not None:
@@ -3242,8 +3243,8 @@ class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase):
security_groups = set()
for rule in security_group_rules:
security_group = self.db.security_group_get(
- context,
- rule['parent_group_id'])
+ context, rule['parent_group_id'],
+ columns_to_join=['instances'])
security_groups.add(security_group)
# ..then we find the instances that are members of these groups..
diff --git a/nova/db/api.py b/nova/db/api.py
index 973be1a26..bd519110c 100644
--- a/nova/db/api.py
+++ b/nova/db/api.py
@@ -1151,9 +1151,10 @@ def security_group_get_all(context):
return IMPL.security_group_get_all(context)
-def security_group_get(context, security_group_id):
+def security_group_get(context, security_group_id, columns_to_join=None):
"""Get security group by its id."""
- return IMPL.security_group_get(context, security_group_id)
+ return IMPL.security_group_get(context, security_group_id,
+ columns_to_join)
def security_group_get_by_name(context, project_id, group_name):
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index fd79ae215..c44f62206 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -3241,13 +3241,17 @@ def security_group_get_all(context):
@require_context
-def security_group_get(context, security_group_id, session=None):
- result = _security_group_get_query(context, session=session,
- project_only=True).\
- filter_by(id=security_group_id).\
- options(joinedload_all('instances')).\
- first()
+def security_group_get(context, security_group_id, columns_to_join=None,
+ session=None):
+ query = _security_group_get_query(context, session=session,
+ project_only=True).\
+ filter_by(id=security_group_id)
+ if columns_to_join is None:
+ columns_to_join = []
+ if 'instances' in columns_to_join:
+ query = query.options(joinedload_all('instances'))
+ result = query.first()
if not result:
raise exception.SecurityGroupNotFound(
security_group_id=security_group_id)
diff --git a/nova/tests/api/openstack/compute/contrib/test_security_groups.py b/nova/tests/api/openstack/compute/contrib/test_security_groups.py
index f1433bd0a..ac3e8885d 100644
--- a/nova/tests/api/openstack/compute/contrib/test_security_groups.py
+++ b/nova/tests/api/openstack/compute/contrib/test_security_groups.py
@@ -716,7 +716,7 @@ class TestSecurityGroupRules(test.TestCase):
db1 = security_group_db(self.sg1)
db2 = security_group_db(self.sg2)
- def return_security_group(context, group_id):
+ def return_security_group(context, group_id, columns_to_join=None):
if group_id == db1['id']:
return db1
if group_id == db2['id']:
diff --git a/nova/tests/db/test_db_api.py b/nova/tests/db/test_db_api.py
index 279481d87..7ba431695 100644
--- a/nova/tests/db/test_db_api.py
+++ b/nova/tests/db/test_db_api.py
@@ -30,6 +30,7 @@ from sqlalchemy.dialects import sqlite
from sqlalchemy import exc
from sqlalchemy.exc import IntegrityError
from sqlalchemy import MetaData
+from sqlalchemy.orm import exc as sqlalchemy_orm_exc
from sqlalchemy.orm import query
from sqlalchemy.sql.expression import select
@@ -1621,19 +1622,40 @@ class SecurityGroupTestCase(test.TestCase, ModelsObjectComparatorMixin):
self.assertRaises(exception.SecurityGroupNotFound,
db.security_group_get,
self.ctxt, security_group1['id'])
- self._assertEqualObjects(db.security_group_get(self.ctxt,
- security_group2['id']),
- security_group2)
+ self._assertEqualObjects(db.security_group_get(
+ self.ctxt, security_group2['id'],
+ columns_to_join=['instances']), security_group2)
def test_security_group_get(self):
security_group1 = self._create_security_group({})
security_group2 = self._create_security_group(
{'name': 'fake_sec_group2'})
real_security_group = db.security_group_get(self.ctxt,
- security_group1['id'])
+ security_group1['id'],
+ columns_to_join=['instances'])
self._assertEqualObjects(security_group1,
real_security_group)
+ def test_security_group_get_no_instances(self):
+ instance = db.instance_create(self.ctxt, {})
+ sid = self._create_security_group({'instances': [instance]})['id']
+
+ session = get_session()
+ self.mox.StubOutWithMock(sqlalchemy_api, 'get_session')
+ sqlalchemy_api.get_session().AndReturn(session)
+ sqlalchemy_api.get_session().AndReturn(session)
+ self.mox.ReplayAll()
+
+ security_group = db.security_group_get(self.ctxt, sid,
+ columns_to_join=['instances'])
+ session.expunge(security_group)
+ self.assertEqual(1, len(security_group['instances']))
+
+ security_group = db.security_group_get(self.ctxt, sid)
+ session.expunge(security_group)
+ self.assertRaises(sqlalchemy_orm_exc.DetachedInstanceError,
+ getattr, security_group, 'instances')
+
def test_security_group_get_not_found_exception(self):
self.assertRaises(exception.SecurityGroupNotFound,
db.security_group_get, self.ctxt, 100500)