From c9e14d6257f0b488bd892c09d284091c0f612dd7 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Fri, 1 Oct 2010 01:44:17 -0700 Subject: Locked down fixed ips and improved network tests --- nova/db/sqlalchemy/api.py | 98 ++++++++++++++++-------------------------- nova/tests/network_unittest.py | 44 ++++++++++--------- 2 files changed, 60 insertions(+), 82 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index fc5ee2235..860723516 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -78,7 +78,7 @@ def authorize_user_context(context, user_id): raise exception.NotAuthorized() -def use_deleted(context): +def can_read_deleted(context): """Indicates if the context has access to deleted objects.""" if not context: return False @@ -124,7 +124,7 @@ def service_get(context, service_id, session=None): result = session.query(models.Service ).filter_by(id=service_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: @@ -222,9 +222,8 @@ def service_get_by_args(context, host, binary): result = session.query(models.Service ).filter_by(host=host ).filter_by(binary=binary - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() - if not result: raise exception.NotFound('No service for %s, %s' % (host, binary)) @@ -256,7 +255,6 @@ def service_update(context, service_id, values): @require_context def floating_ip_allocate_address(context, host, project_id): authorize_project_context(context, project_id) - session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp @@ -287,7 +285,6 @@ def floating_ip_create(context, values): @require_context def floating_ip_count_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -374,7 +371,6 @@ def floating_ip_get_all_by_host(context, host): @require_context def floating_ip_get_all_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -391,7 +387,7 @@ def floating_ip_get_by_address(context, address, session=None): result = session.query(models.FloatingIp ).filter_by(address=address - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('No fixed ip for address %s' % address) @@ -406,6 +402,7 @@ def floating_ip_get_by_address(context, address, session=None): def fixed_ip_associate(context, address, instance_id): session = get_session() with session.begin(): + instance = instance_get(context, instance_id, session=session) fixed_ip_ref = session.query(models.FixedIp ).filter_by(address=address ).filter_by(deleted=False @@ -416,9 +413,7 @@ def fixed_ip_associate(context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = instance_get(context, - instance_id, - session=session) + fixed_ip_ref.instance = instance session.add(fixed_ip_ref) @@ -472,21 +467,21 @@ def fixed_ip_disassociate(context, address): @require_context def fixed_ip_get_by_address(context, address, session=None): - # TODO(devcamcar): Ensure floating ip belongs to user. - # Only possible if it is associated with an instance. - # May have to use system context for this always. if not session: session = get_session() result = session.query(models.FixedIp ).filter_by(address=address - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).options(joinedload('network') ).options(joinedload('instance') ).first() if not result: raise exception.NotFound('No floating ip for address %s' % address) + if is_user_context(context): + authorize_project_context(context, result.instance.project_id) + return result @@ -562,7 +557,7 @@ def instance_get(context, instance_id, session=None): if is_admin_context(context): result = session.query(models.Instance ).filter_by(id=instance_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -581,7 +576,7 @@ def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -590,7 +585,7 @@ def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).filter_by(user_id=user_id ).all() @@ -603,7 +598,7 @@ def instance_get_all_by_project(context, project_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -615,7 +610,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() elif is_user_context(context): return session.query(models.Instance @@ -633,7 +628,7 @@ def instance_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Instance ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -749,7 +744,7 @@ def key_pair_get(context, user_id, name, session=None): result = session.query(models.KeyPair ).filter_by(user_id=user_id ).filter_by(name=name - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('no keypair for user %s, name %s' % @@ -775,7 +770,7 @@ def key_pair_get_all_by_user(context, user_id): def network_count(context): session = get_session() return session.query(models.Network - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -847,7 +842,7 @@ def network_get(context, network_id, session=None): if is_admin_context(context): result = session.query(models.Network ).filter_by(id=network_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Network @@ -914,7 +909,7 @@ def network_get_index(context, network_id): def network_index_count(context): session = get_session() return session.query(models.NetworkIndex - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -992,7 +987,7 @@ def queue_get_for(_context, topic, physical_node_id): def export_device_count(context): session = get_session() return session.query(models.ExportDevice - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -1038,7 +1033,7 @@ def quota_get(context, project_id, session=None): result = session.query(models.Quota ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('No quota for project_id %s' % project_id) @@ -1167,7 +1162,7 @@ def volume_get(context, volume_id, session=None): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1184,7 +1179,7 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): return session.query(models.Volume - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @require_context @@ -1194,7 +1189,7 @@ def volume_get_all_by_project(context, project_id): session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -1206,7 +1201,7 @@ def volume_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1233,47 +1228,26 @@ def volume_ec2_id_exists(context, ec2_id, session=None): ).one()[0] -@require_context +@require_admin_context def volume_get_instance(context, volume_id): session = get_session() - result = None - - if is_admin_context(context): - result = session.query(models.Volume - ).filter_by(id=volume_id - ).filter_by(deleted=use_deleted(context) - ).options(joinedload('instance') - ).first() - elif is_user_context(context): - result = session.query(models.Volume - ).filter_by(project_id=context.project.id - ).filter_by(deleted=False - ).options(joinedload('instance') - ).first() - else: - raise exception.NotAuthorized() - + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('instance') + ).first() if not result: raise exception.NotFound('Volume %s not found' % ec2_id) return result.instance -@require_context +@require_admin_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() - result = None - - if is_admin_context(context): - result = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - elif is_user_context(context): - result = session.query(models.ExportDevice - ).join(models.Volume - ).filter(models.Volume.project_id==context.project.id - ).filter_by(volume_id=volume_id - ).first() + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() if not result: raise exception.NotFound('No export device found for volume %s' % volume_id) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index e01d7cff9..e601c480c 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -49,7 +49,6 @@ class NetworkTestCase(test.TrialTestCase): self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) - # TODO(devcamcar): Passing project=None is Bad(tm). self.context = context.APIRequestContext(project=None, user=self.user) for i in range(5): name = 'project%s' % i @@ -60,11 +59,9 @@ class NetworkTestCase(test.TrialTestCase): user_context = context.APIRequestContext(project=self.projects[i], user=self.user) self.network.set_network_host(user_context, self.projects[i].id) - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(0) self.instance_id = instance_ref['id'] - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(1) self.instance2_id = instance_ref['id'] def tearDown(self): # pylint: disable-msg=C0103 @@ -77,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase): self.manager.delete_project(project) self.manager.delete_user(self.user) + def _create_instance(self, project_num, mac=None): + if not mac: + mac = utils.generate_mac() + project = self.projects[project_num] + self.context.project = project + return db.instance_create(self.context, + {'project_id': project.id, + 'mac_address': mac}) + def _create_address(self, project_num, instance_id=None): """Create an address in given project num""" if instance_id is None: @@ -84,6 +90,11 @@ class NetworkTestCase(test.TrialTestCase): self.context.project = self.projects[project_num] return self.network.allocate_fixed_ip(self.context, instance_id) + def _deallocate_address(self, project_num, address): + self.context.project = self.projects[project_num] + self.network.deallocate_fixed_ip(self.context, address) + + def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips @@ -134,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(address) lease_ip(address2) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) release_ip(address) self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) # First address release shouldn't affect the second self.assertTrue(is_allocated_in_project(address2, self.projects[1].id)) - self.network.deallocate_fixed_ip(self.context, address2) + self._deallocate_address(1, address2) release_ip(address2) self.assertFalse(is_allocated_in_project(address2, self.projects[1].id)) @@ -152,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(first) instance_ids = [] for i in range(1, 5): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address2 = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address3 = self._create_address(i, instance_ref['id']) lease_ip(address) lease_ip(address2) lease_ip(address3) + self.context.project = self.projects[i] self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) self.assertFalse(is_allocated_in_project(address2, @@ -185,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase): for instance_id in instance_ids: db.instance_destroy(None, instance_id) release_ip(first) - self.network.deallocate_fixed_ip(self.context, first) + self._deallocate_address(0, first) def test_vpn_ip_and_port_looks_valid(self): """Ensure the vpn ip and port are reasonable""" @@ -246,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase): addresses = [] instance_ids = [] for i in range(num_available_ips): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(0) instance_ids.append(instance_ref['id']) address = self._create_address(0, instance_ref['id']) addresses.append(address) -- cgit