From 7937144fce54570b2da543663e6ee5e64b1c3cdb Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Fri, 14 Sep 2012 00:21:03 +0000 Subject: Clean up handling of project_only in network_get There was some funky logic for getting networks to work around the project only decorator. This changes the code to match what we actually want which is: In Flat and FlatDHCP mode non-admins should be able to access networks that belong to their project or networks that have no project_id assigned. In VlanManager, project_id=None projects should not be accessible as this means the project hasn't been assigned yet. The assignment is done with an elevated context. This patch adds some logic to model_query to allow None in the project_only filter and makes network_get_all_by_uuids and network_get use it. fixes bug 1048869 Change-Id: I5377cea87dec8e9d0d9cec84e07128c5c6e8dca3 --- nova/db/api.py | 10 +++--- nova/db/sqlalchemy/api.py | 40 +++++++++++------------ nova/network/manager.py | 33 +++++++++++-------- nova/tests/fake_network.py | 6 ++-- nova/tests/image/test_s3.py | 3 +- nova/tests/network/test_manager.py | 66 +++++++++++++++++++------------------- 6 files changed, 82 insertions(+), 76 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index 785944d14..de393287a 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -810,9 +810,9 @@ def network_disassociate(context, network_id): return IMPL.network_disassociate(context, network_id) -def network_get(context, network_id): +def network_get(context, network_id, project_only="allow_none"): """Get a network or raise if it does not exist.""" - return IMPL.network_get(context, network_id) + return IMPL.network_get(context, network_id, project_only=project_only) def network_get_all(context): @@ -820,9 +820,11 @@ def network_get_all(context): return IMPL.network_get_all(context) -def network_get_all_by_uuids(context, network_uuids, project_id=None): +def network_get_all_by_uuids(context, network_uuids, + project_only="allow_none"): """Return networks by ids.""" - return IMPL.network_get_all_by_uuids(context, network_uuids, project_id) + return IMPL.network_get_all_by_uuids(context, network_uuids, + project_only=project_only) # pylint: disable=C0103 diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index f39856cc6..ea8d7cbec 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -188,20 +188,21 @@ def require_aggregate_exists(f): return wrapper -def model_query(context, *args, **kwargs): +def model_query(context, model, *args, **kwargs): """Query helper that accounts for context's `read_deleted` field. :param context: context to query under :param session: if present, the session to use :param read_deleted: if present, overrides context's read_deleted field. :param project_only: if present and context is user-type, then restrict - query to match the context's project_id. + query to match the context's project_id. If set to 'allow_none', + restriction includes project_id = None. """ session = kwargs.get('session') or get_session() read_deleted = kwargs.get('read_deleted') or context.read_deleted - project_only = kwargs.get('project_only') + project_only = kwargs.get('project_only', False) - query = session.query(*args) + query = session.query(model, *args) if read_deleted == 'no': query = query.filter_by(deleted=False) @@ -213,8 +214,12 @@ def model_query(context, *args, **kwargs): raise Exception( _("Unrecognized read_deleted value '%s'") % read_deleted) - if project_only and is_user_context(context): - query = query.filter_by(project_id=context.project_id) + if is_user_context(context) and project_only: + if project_only == 'allow_none': + query = query.filter(or_(model.project_id == context.project_id, + model.project_id == None)) + else: + query = query.filter_by(project_id=context.project_id) return query @@ -2130,9 +2135,9 @@ def network_disassociate(context, network_id): @require_context -def network_get(context, network_id, session=None): +def network_get(context, network_id, session=None, project_only='allow_none'): result = model_query(context, models.Network, session=session, - project_only=True).\ + project_only=project_only).\ filter_by(id=network_id).\ first() @@ -2152,24 +2157,17 @@ def network_get_all(context): return result -@require_admin_context -def network_get_all_by_uuids(context, network_uuids, project_id=None): - project_or_none = or_(models.Network.project_id == project_id, - models.Network.project_id == None) - result = model_query(context, models.Network, read_deleted="no").\ +@require_context +def network_get_all_by_uuids(context, network_uuids, + project_only="allow_none"): + result = model_query(context, models.Network, read_deleted="no", + project_only=project_only).\ filter(models.Network.uuid.in_(network_uuids)).\ - filter(project_or_none).\ all() if not result: raise exception.NoNetworksFound() - #check if host is set to all of the networks - # returned in the result - for network in result: - if network['host'] is None: - raise exception.NetworkHostNotSet(network_id=network['id']) - #check if the result contains all the networks #we are looking for for network_uuid in network_uuids: @@ -2179,7 +2177,7 @@ def network_get_all_by_uuids(context, network_uuids, project_id=None): found = True break if not found: - if project_id: + if project_only: raise exception.NetworkNotFoundForProject( network_uuid=network_uuid, project_id=context.project_id) raise exception.NetworkNotFound(network_id=network_uuid) diff --git a/nova/network/manager.py b/nova/network/manager.py index 1f7a9e3a1..4aeef2f86 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -935,7 +935,7 @@ class NetworkManager(manager.SchedulerDependentManager): # a non-vlan instance should connect to if requested_networks is not None and len(requested_networks) != 0: network_uuids = [uuid for (uuid, fixed_ip) in requested_networks] - networks = self.db.network_get_all_by_uuids(context, network_uuids) + networks = self._get_networks_by_uuids(context, network_uuids) else: try: networks = self.db.network_get_all(context) @@ -1703,10 +1703,12 @@ class NetworkManager(manager.SchedulerDependentManager): instance_uuid=fixed_ip_ref['instance_uuid']) def _get_network_by_id(self, context, network_id): - return self.db.network_get(context, network_id) + return self.db.network_get(context, network_id, + project_only="allow_none") def _get_networks_by_uuids(self, context, network_uuids): - return self.db.network_get_all_by_uuids(context, network_uuids) + return self.db.network_get_all_by_uuids(context, network_uuids, + project_only="allow_none") @wrap_check_policy def get_vifs_by_instance(self, context, instance_id): @@ -1905,10 +1907,6 @@ class FlatDHCPManager(RPCAllocateFixedIP, FloatingIP, NetworkManager): dev = self.driver.get_dev(network) self.driver.update_dhcp(context, dev, network) - def _get_network_by_id(self, context, network_id): - return NetworkManager._get_network_by_id(self, context.elevated(), - network_id) - def _get_network_dict(self, network): """Returns the dict representing necessary and meta network fields""" @@ -1996,15 +1994,26 @@ class VlanManager(RPCAllocateFixedIP, FloatingIP, NetworkManager): network_id = None self.db.network_associate(context, project_id, network_id, force=True) + def _get_network_by_id(self, context, network_id): + # NOTE(vish): Don't allow access to networks with project_id=None as + # these are networksa that haven't been allocated to a + # project yet. + return self.db.network_get(context, network_id, project_only=True) + + def _get_networks_by_uuids(self, context, network_uuids): + # NOTE(vish): Don't allow access to networks with project_id=None as + # these are networksa that haven't been allocated to a + # project yet. + return self.db.network_get_all_by_uuids(context, network_uuids, + project_only=True) + def _get_networks_for_instance(self, context, instance_id, project_id, requested_networks=None): """Determine which networks an instance should connect to.""" # get networks associated with project if requested_networks is not None and len(requested_networks) != 0: network_uuids = [uuid for (uuid, fixed_ip) in requested_networks] - networks = self.db.network_get_all_by_uuids(context, - network_uuids, - project_id) + networks = self._get_networks_by_uuids(context, network_uuids) else: networks = self.db.project_get_networks(context, project_id) return networks @@ -2066,10 +2075,6 @@ class VlanManager(RPCAllocateFixedIP, FloatingIP, NetworkManager): dev = self.driver.get_dev(network) self.driver.update_dhcp(context, dev, network) - def _get_networks_by_uuids(self, context, network_uuids): - return self.db.network_get_all_by_uuids(context, network_uuids, - context.project_id) - def _get_network_dict(self, network): """Returns the dict representing necessary and meta network fields""" diff --git a/nova/tests/fake_network.py b/nova/tests/fake_network.py index 25ec5c070..ef256dec0 100644 --- a/nova/tests/fake_network.py +++ b/nova/tests/fake_network.py @@ -118,7 +118,7 @@ class FakeNetworkManager(network_manager.NetworkManager): fakenet['id'] = 999 return fakenet - def network_get(self, context, network_id): + def network_get(self, context, network_id, project_only="allow_none"): return {'cidr_v6': '2001:db8:69:%x::/64' % network_id} def network_get_by_uuid(self, context, network_uuid): @@ -127,7 +127,7 @@ class FakeNetworkManager(network_manager.NetworkManager): def network_get_all(self, context): raise exception.NoNetworksFound() - def network_get_all_by_uuids(self, context): + def network_get_all_by_uuids(self, context, project_only="allow_none"): raise exception.NoNetworksFound() def network_disassociate(self, context, network_id): @@ -294,7 +294,7 @@ def fake_get_instance_nw_info(stubs, num_networks=1, ips_per_vif=2, 'network': None, 'instance_uuid': 0} - def network_get_fake(context, network_id): + def network_get_fake(context, network_id, project_only='allow_none'): nets = [n for n in networks if n['id'] == network_id] if not nets: raise exception.NetworkNotFound(network_id=network_id) diff --git a/nova/tests/image/test_s3.py b/nova/tests/image/test_s3.py index 5002be16f..3c92ffb2e 100644 --- a/nova/tests/image/test_s3.py +++ b/nova/tests/image/test_s3.py @@ -187,7 +187,8 @@ class TestS3ImageService(test.TestCase): img = self.image_service._s3_create(self.context, metadata) eventlet.sleep() - translated = self.image_service._translate_id_to_uuid(context, img) + translated = self.image_service._translate_id_to_uuid(self.context, + img) uuid = translated['id'] image_service = fake.FakeImageService() updated_image = image_service.update(self.context, uuid, diff --git a/nova/tests/network/test_manager.py b/nova/tests/network/test_manager.py index 31b600b16..8ef37fe95 100644 --- a/nova/tests/network/test_manager.py +++ b/nova/tests/network/test_manager.py @@ -62,7 +62,7 @@ networks = [{'id': 0, 'project_id': 'fake_project', 'vpn_public_address': '192.168.0.2'}, {'id': 1, - 'uuid': "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + 'uuid': 'bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb', 'label': 'test1', 'injected': False, 'multi_host': False, @@ -83,14 +83,14 @@ networks = [{'id': 0, 'vpn_public_address': '192.168.1.2'}] fixed_ips = [{'id': 0, - 'network_id': 0, + 'network_id': FAKEUUID, 'address': '192.168.0.100', 'instance_uuid': 0, 'allocated': False, 'virtual_interface_id': 0, 'floating_ips': []}, {'id': 0, - 'network_id': 1, + 'network_id': 'bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb', 'address': '192.168.1.100', 'instance_uuid': 0, 'allocated': False, @@ -202,10 +202,11 @@ class FlatNetworkTestCase(test.TestCase): requested_networks = [('bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb', '192.168.1.100')] - db.network_get_all_by_uuids(mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks) + db.network_get_all_by_uuids(mox.IgnoreArg(), mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks) db.network_get(mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks[1]) + mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks[1]) ip = fixed_ips[1].copy() ip['instance_uuid'] = None @@ -238,8 +239,8 @@ class FlatNetworkTestCase(test.TestCase): def test_validate_networks_invalid_fixed_ip(self): self.mox.StubOutWithMock(db, 'network_get_all_by_uuids') requested_networks = [(1, "192.168.0.100.1")] - db.network_get_all_by_uuids(mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks) + db.network_get_all_by_uuids(mox.IgnoreArg(), mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks) self.mox.ReplayAll() self.assertRaises(exception.FixedIpInvalid, @@ -250,8 +251,8 @@ class FlatNetworkTestCase(test.TestCase): self.mox.StubOutWithMock(db, 'network_get_all_by_uuids') requested_networks = [(1, "")] - db.network_get_all_by_uuids(mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks) + db.network_get_all_by_uuids(mox.IgnoreArg(), mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks) self.mox.ReplayAll() self.assertRaises(exception.FixedIpInvalid, @@ -262,8 +263,8 @@ class FlatNetworkTestCase(test.TestCase): self.mox.StubOutWithMock(db, 'network_get_all_by_uuids') requested_networks = [(1, None)] - db.network_get_all_by_uuids(mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks) + db.network_get_all_by_uuids(mox.IgnoreArg(), mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks) self.mox.ReplayAll() self.network.validate_networks(self.context, requested_networks) @@ -293,7 +294,8 @@ class FlatNetworkTestCase(test.TestCase): mox.IgnoreArg(), mox.IgnoreArg()).AndReturn('192.168.0.101') db.network_get(mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks[0]) + mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks[0]) db.network_update(mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg()) self.mox.ReplayAll() self.network.add_fixed_ip_to_instance(self.context, 1, HOST, @@ -391,7 +393,8 @@ class FlatNetworkTestCase(test.TestCase): mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(fixedip) db.network_get(mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks[0]) + mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks[0]) db.network_update(mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg()) self.mox.ReplayAll() @@ -491,7 +494,7 @@ class VlanNetworkTestCase(test.TestCase): cidr='192.168.0.1/24', network_size=100) def test_validate_networks(self): - def network_get(_context, network_id): + def network_get(_context, network_id, project_only='allow_none'): return networks[network_id] self.stubs.Set(db, 'network_get', network_get) @@ -500,9 +503,8 @@ class VlanNetworkTestCase(test.TestCase): requested_networks = [("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "192.168.1.100")] - db.network_get_all_by_uuids(mox.IgnoreArg(), - mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks) + db.network_get_all_by_uuids(mox.IgnoreArg(), mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks) fixed_ips[1]['network_id'] = networks[1]['id'] fixed_ips[1]['instance_uuid'] = None @@ -524,9 +526,8 @@ class VlanNetworkTestCase(test.TestCase): def test_validate_networks_invalid_fixed_ip(self): self.mox.StubOutWithMock(db, 'network_get_all_by_uuids') requested_networks = [(1, "192.168.0.100.1")] - db.network_get_all_by_uuids(mox.IgnoreArg(), - mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks) + db.network_get_all_by_uuids(mox.IgnoreArg(), mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks) self.mox.ReplayAll() self.assertRaises(exception.FixedIpInvalid, @@ -537,9 +538,8 @@ class VlanNetworkTestCase(test.TestCase): self.mox.StubOutWithMock(db, 'network_get_all_by_uuids') requested_networks = [(1, "")] - db.network_get_all_by_uuids(mox.IgnoreArg(), - mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks) + db.network_get_all_by_uuids(mox.IgnoreArg(), mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks) self.mox.ReplayAll() self.assertRaises(exception.FixedIpInvalid, @@ -550,9 +550,8 @@ class VlanNetworkTestCase(test.TestCase): self.mox.StubOutWithMock(db, 'network_get_all_by_uuids') requested_networks = [(1, None)] - db.network_get_all_by_uuids(mox.IgnoreArg(), - mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks) + db.network_get_all_by_uuids(mox.IgnoreArg(), mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks) self.mox.ReplayAll() self.network.validate_networks(self.context, requested_networks) @@ -879,7 +878,8 @@ class VlanNetworkTestCase(test.TestCase): mox.IgnoreArg(), mox.IgnoreArg()).AndReturn('192.168.0.101') db.network_get(mox.IgnoreArg(), - mox.IgnoreArg()).AndReturn(networks[0]) + mox.IgnoreArg(), + project_only=mox.IgnoreArg()).AndReturn(networks[0]) self.mox.ReplayAll() self.network.add_fixed_ip_to_instance(self.context, 1, HOST, networks[0]['id']) @@ -888,7 +888,7 @@ class VlanNetworkTestCase(test.TestCase): """Makes sure that we cannot deallocaate or disassociate a public ip of other project""" - def network_get(_context, network_id): + def network_get(_context, network_id, project_only="allow_none"): return networks[network_id] self.stubs.Set(db, 'network_get', network_get) @@ -941,7 +941,7 @@ class VlanNetworkTestCase(test.TestCase): Ensures https://bugs.launchpad.net/nova/+bug/973442 doesn't return""" - def network_get(_context, network_id): + def network_get(_context, network_id, project_only="allow_none"): return networks[network_id] self.stubs.Set(db, 'network_get', network_get) @@ -974,7 +974,7 @@ class VlanNetworkTestCase(test.TestCase): def test_deallocate_fixed_deleted(self): """Verify doesn't deallocate deleted fixed_ip from deleted network""" - def network_get(_context, network_id): + def network_get(_context, network_id, project_only="allow_none"): return networks[network_id] def teardown_network_on_host(_context, network): @@ -1012,7 +1012,7 @@ class VlanNetworkTestCase(test.TestCase): Ensures https://bugs.launchpad.net/nova/+bug/968457 doesn't return""" - def network_get(_context, network_id): + def network_get(_context, network_id, project_only="allow_none"): return networks[network_id] self.stubs.Set(db, 'network_get', network_get) @@ -1037,7 +1037,7 @@ class VlanNetworkTestCase(test.TestCase): def test_fixed_ip_cleanup_fail(self): """Verify IP is not deallocated if the security group refresh fails.""" - def network_get(_context, network_id): + def network_get(_context, network_id, project_only="allow_none"): return networks[network_id] self.stubs.Set(db, 'network_get', network_get) -- cgit