diff options
-rw-r--r-- | nova/compute/api.py | 36 | ||||
-rw-r--r-- | nova/tests/api/openstack/compute/test_servers.py | 16 |
2 files changed, 35 insertions, 17 deletions
diff --git a/nova/compute/api.py b/nova/compute/api.py index bba6ee1eb..6433f04e0 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -332,6 +332,19 @@ class API(base.Base): LOG.warn(msg) raise exception.InvalidMetadataSize(reason=msg) + def _check_requested_secgroups(self, context, secgroups): + """ + Check if the security group requested exists and belongs to + the project. + """ + for secgroup in secgroups: + # NOTE(sdague): default is handled special + if secgroup == "default": + continue + if not self.security_group_api.get(context, secgroup): + raise exception.SecurityGroupNotFoundForProject( + project_id=context.project_id, security_group_id=secgroup) + def _check_requested_networks(self, context, requested_networks): """ Check if the networks requested belongs to the project @@ -447,7 +460,7 @@ class API(base.Base): image_href, kernel_id, ramdisk_id, min_count, max_count, display_name, display_description, - key_name, key_data, security_group, + key_name, key_data, security_groups, availability_zone, user_data, metadata, injected_files, access_ip_v4, access_ip_v6, @@ -460,8 +473,8 @@ class API(base.Base): if not metadata: metadata = {} - if not security_group: - security_group = 'default' + if not security_groups: + security_groups = ['default'] if not instance_type: instance_type = instance_types.get_default_instance_type() @@ -504,6 +517,7 @@ class API(base.Base): self._check_metadata_properties_quota(context, metadata) self._check_injected_file_quota(context, injected_files) + self._check_requested_secgroups(context, security_groups) self._check_requested_networks(context, requested_networks) if image_href: @@ -597,7 +611,8 @@ class API(base.Base): options = base_options.copy() instance = self.create_db_entry_for_new_instance( context, instance_type, image, options, - security_group, block_device_mapping, num_instances, i) + security_groups, block_device_mapping, + num_instances, i) instances.append(instance) instance_uuids.append(instance['uuid']) @@ -626,7 +641,7 @@ class API(base.Base): 'instance_type': instance_type, 'instance_uuids': instance_uuids, 'block_device_mapping': block_device_mapping, - 'security_group': security_group, + 'security_group': security_groups, } return (instances, request_spec, filter_properties) @@ -857,8 +872,7 @@ class API(base.Base): base_image_ref = base_options['image_ref'] instance['system_metadata']['image_base_image_ref'] = base_image_ref - self.security_group_api.populate_security_groups(instance, - security_groups) + instance['security_groups'] = security_groups return instance @@ -3158,11 +3172,3 @@ class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase): groups = instance.get('security_groups') if groups: return [{'name': group['name']} for group in groups] - - def populate_security_groups(self, instance, security_groups): - # Use 'default' security_group if none specified. - if security_groups is None: - security_groups = ['default'] - elif not isinstance(security_groups, list): - security_groups = [security_groups] - instance['security_groups'] = security_groups diff --git a/nova/tests/api/openstack/compute/test_servers.py b/nova/tests/api/openstack/compute/test_servers.py index 7d41f7d8c..ea4c66493 100644 --- a/nova/tests/api/openstack/compute/test_servers.py +++ b/nova/tests/api/openstack/compute/test_servers.py @@ -2120,15 +2120,27 @@ class ServersControllerCreateTest(test.TestCase): def test_create_instance_with_security_group_enabled(self): self.ext_mgr.extensions = {'os-security-groups': 'fake'} group = 'foo' - params = {'security_groups': [{'name': group}]} old_create = compute_api.API.create + def sec_group_get(ctx, proj, name): + if name == group: + return True + else: + raise exception.SecurityGroupNotFoundForProject( + project_id=proj, security_group_id=name) + def create(*args, **kwargs): self.assertEqual(kwargs['security_group'], [group]) return old_create(*args, **kwargs) + self.stubs.Set(db, 'security_group_get_by_name', sec_group_get) + # negative test + self.assertRaises(webob.exc.HTTPBadRequest, + self._test_create_extra, + {'security_groups': [{'name': 'bogus'}]}) + # positive test - extra assert in create path self.stubs.Set(compute_api.API, 'create', create) - self._test_create_extra(params) + self._test_create_extra({'security_groups': [{'name': group}]}) def test_create_instance_with_security_group_disabled(self): group = 'foo' |