diff options
author | Jenkins <jenkins@review.openstack.org> | 2013-06-26 15:42:19 +0000 |
---|---|---|
committer | Gerrit Code Review <review@openstack.org> | 2013-06-26 15:42:19 +0000 |
commit | 82b805ce635ddc9ac062115956a8c40b7ad3b85b (patch) | |
tree | b8c45bd2fe761f991f470bb559eaea8c4e935203 | |
parent | 7696c3c11f0de855cbc53cc04ee7d2be07ae3b9c (diff) | |
parent | 90b5796729c3c6db4b0ff225a4fd11bf29467cf3 (diff) | |
download | nova-82b805ce635ddc9ac062115956a8c40b7ad3b85b.tar.gz nova-82b805ce635ddc9ac062115956a8c40b7ad3b85b.tar.xz nova-82b805ce635ddc9ac062115956a8c40b7ad3b85b.zip |
Merge "Refactor db.security_group_get() instance join behavior"
-rw-r--r-- | nova/compute/api.py | 7 | ||||
-rw-r--r-- | nova/db/api.py | 5 | ||||
-rw-r--r-- | nova/db/sqlalchemy/api.py | 16 | ||||
-rw-r--r-- | nova/tests/api/openstack/compute/contrib/test_security_groups.py | 2 | ||||
-rw-r--r-- | nova/tests/db/test_db_api.py | 30 |
5 files changed, 44 insertions, 16 deletions
diff --git a/nova/compute/api.py b/nova/compute/api.py index 0a9b0e67b..70e205dc2 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -3216,7 +3216,8 @@ class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase): def trigger_rules_refresh(self, context, id): """Called when a rule is added to or removed from a security_group.""" - security_group = self.db.security_group_get(context, id) + security_group = self.db.security_group_get( + context, id, columns_to_join=['instances']) for instance in security_group['instances']: if instance['host'] is not None: @@ -3242,8 +3243,8 @@ class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase): security_groups = set() for rule in security_group_rules: security_group = self.db.security_group_get( - context, - rule['parent_group_id']) + context, rule['parent_group_id'], + columns_to_join=['instances']) security_groups.add(security_group) # ..then we find the instances that are members of these groups.. diff --git a/nova/db/api.py b/nova/db/api.py index 973be1a26..bd519110c 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -1151,9 +1151,10 @@ def security_group_get_all(context): return IMPL.security_group_get_all(context) -def security_group_get(context, security_group_id): +def security_group_get(context, security_group_id, columns_to_join=None): """Get security group by its id.""" - return IMPL.security_group_get(context, security_group_id) + return IMPL.security_group_get(context, security_group_id, + columns_to_join) def security_group_get_by_name(context, project_id, group_name): diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index fd79ae215..c44f62206 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -3241,13 +3241,17 @@ def security_group_get_all(context): @require_context -def security_group_get(context, security_group_id, session=None): - result = _security_group_get_query(context, session=session, - project_only=True).\ - filter_by(id=security_group_id).\ - options(joinedload_all('instances')).\ - first() +def security_group_get(context, security_group_id, columns_to_join=None, + session=None): + query = _security_group_get_query(context, session=session, + project_only=True).\ + filter_by(id=security_group_id) + if columns_to_join is None: + columns_to_join = [] + if 'instances' in columns_to_join: + query = query.options(joinedload_all('instances')) + result = query.first() if not result: raise exception.SecurityGroupNotFound( security_group_id=security_group_id) diff --git a/nova/tests/api/openstack/compute/contrib/test_security_groups.py b/nova/tests/api/openstack/compute/contrib/test_security_groups.py index f1433bd0a..ac3e8885d 100644 --- a/nova/tests/api/openstack/compute/contrib/test_security_groups.py +++ b/nova/tests/api/openstack/compute/contrib/test_security_groups.py @@ -716,7 +716,7 @@ class TestSecurityGroupRules(test.TestCase): db1 = security_group_db(self.sg1) db2 = security_group_db(self.sg2) - def return_security_group(context, group_id): + def return_security_group(context, group_id, columns_to_join=None): if group_id == db1['id']: return db1 if group_id == db2['id']: diff --git a/nova/tests/db/test_db_api.py b/nova/tests/db/test_db_api.py index 279481d87..7ba431695 100644 --- a/nova/tests/db/test_db_api.py +++ b/nova/tests/db/test_db_api.py @@ -30,6 +30,7 @@ from sqlalchemy.dialects import sqlite from sqlalchemy import exc from sqlalchemy.exc import IntegrityError from sqlalchemy import MetaData +from sqlalchemy.orm import exc as sqlalchemy_orm_exc from sqlalchemy.orm import query from sqlalchemy.sql.expression import select @@ -1621,19 +1622,40 @@ class SecurityGroupTestCase(test.TestCase, ModelsObjectComparatorMixin): self.assertRaises(exception.SecurityGroupNotFound, db.security_group_get, self.ctxt, security_group1['id']) - self._assertEqualObjects(db.security_group_get(self.ctxt, - security_group2['id']), - security_group2) + self._assertEqualObjects(db.security_group_get( + self.ctxt, security_group2['id'], + columns_to_join=['instances']), security_group2) def test_security_group_get(self): security_group1 = self._create_security_group({}) security_group2 = self._create_security_group( {'name': 'fake_sec_group2'}) real_security_group = db.security_group_get(self.ctxt, - security_group1['id']) + security_group1['id'], + columns_to_join=['instances']) self._assertEqualObjects(security_group1, real_security_group) + def test_security_group_get_no_instances(self): + instance = db.instance_create(self.ctxt, {}) + sid = self._create_security_group({'instances': [instance]})['id'] + + session = get_session() + self.mox.StubOutWithMock(sqlalchemy_api, 'get_session') + sqlalchemy_api.get_session().AndReturn(session) + sqlalchemy_api.get_session().AndReturn(session) + self.mox.ReplayAll() + + security_group = db.security_group_get(self.ctxt, sid, + columns_to_join=['instances']) + session.expunge(security_group) + self.assertEqual(1, len(security_group['instances'])) + + security_group = db.security_group_get(self.ctxt, sid) + session.expunge(security_group) + self.assertRaises(sqlalchemy_orm_exc.DetachedInstanceError, + getattr, security_group, 'instances') + def test_security_group_get_not_found_exception(self): self.assertRaises(exception.SecurityGroupNotFound, db.security_group_get, self.ctxt, 100500) |