diff options
| author | Jenkins <jenkins@review.openstack.org> | 2012-06-14 15:43:34 +0000 |
|---|---|---|
| committer | Gerrit Code Review <review@openstack.org> | 2012-06-14 15:43:34 +0000 |
| commit | 2adeb5a76d8376d3506f4c63ec73211bfa1e5cc0 (patch) | |
| tree | 83264d1db655bfa3cd91f4f9f146a5d549e0523b | |
| parent | 81fef25e96b20f69f58044fa341b108edea67d93 (diff) | |
| parent | 123b28cd1a4ffa1e972e29963cb0e6be46b0d7c2 (diff) | |
| download | nova-2adeb5a76d8376d3506f4c63ec73211bfa1e5cc0.tar.gz nova-2adeb5a76d8376d3506f4c63ec73211bfa1e5cc0.tar.xz nova-2adeb5a76d8376d3506f4c63ec73211bfa1e5cc0.zip | |
Merge "Dedupe native and EC2 security group APIs."
| -rw-r--r-- | nova/api/ec2/cloud.py | 376 | ||||
| -rw-r--r-- | nova/api/openstack/compute/contrib/security_groups.py | 428 | ||||
| -rw-r--r-- | nova/compute/api.py | 648 | ||||
| -rw-r--r-- | nova/compute/rpcapi.py | 171 | ||||
| -rw-r--r-- | nova/network/manager.py | 11 | ||||
| -rw-r--r-- | nova/tests/api/ec2/test_cloud.py | 2 | ||||
| -rw-r--r-- | nova/tests/compute/test_compute.py | 16 | ||||
| -rw-r--r-- | nova/tests/compute/test_rpcapi.py | 9 | ||||
| -rw-r--r-- | nova/tests/policy.json | 4 |
9 files changed, 835 insertions, 830 deletions
diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 6f0d605ed..037a84783 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -25,7 +25,6 @@ datastore. import base64 import re import time -import urllib from nova.api.ec2 import ec2utils from nova.api.ec2 import inst_state @@ -41,7 +40,6 @@ from nova.image import s3 from nova import log as logging from nova import network from nova.openstack.common import excutils -from nova.openstack.common import importutils from nova import quota from nova import utils from nova import volume @@ -190,10 +188,11 @@ class CloudController(object): self.image_service = s3.S3ImageService() self.network_api = network.API() self.volume_api = volume.API() + self.security_group_api = CloudSecurityGroupAPI() self.compute_api = compute.API(network_api=self.network_api, - volume_api=self.volume_api) + volume_api=self.volume_api, + security_group_api=self.security_group_api) self.keypair_api = compute.api.KeypairAPI() - self.sgh = importutils.import_object(FLAGS.security_group_handler) def __str__(self): return 'CloudController' @@ -411,25 +410,12 @@ class CloudController(object): def describe_security_groups(self, context, group_name=None, group_id=None, **kwargs): - self.compute_api.ensure_default_security_group(context) - if group_name or group_id: - groups = [] - if group_name: - for name in group_name: - group = db.security_group_get_by_name(context, - context.project_id, - name) - groups.append(group) - if group_id: - for gid in group_id: - group = db.security_group_get(context, gid) - groups.append(group) - elif context.is_admin: - groups = db.security_group_get_all(context) - else: - groups = db.security_group_get_by_project(context, - context.project_id) - groups = [self._format_security_group(context, g) for g in groups] + raw_groups = self.security_group_api.list(context, + group_name, + group_id, + context.project_id) + + groups = [self._format_security_group(context, g) for g in raw_groups] return {'securityGroupInfo': list(sorted(groups, @@ -536,146 +522,51 @@ class CloudController(object): notfound = exception.SecurityGroupNotFound if not source_security_group: raise notfound(security_group_id=source_security_group_name) - values['group_id'] = source_security_group['id'] - elif cidr_ip: - # If this fails, it throws an exception. This is what we want. - cidr_ip = urllib.unquote(cidr_ip).decode() - - if not utils.is_valid_cidr(cidr_ip): - # Raise exception for non-valid address - raise exception.EC2APIError(_("Invalid CIDR")) - - values['cidr'] = cidr_ip - else: - values['cidr'] = '0.0.0.0/0' - - if source_security_group_name: - # Open everything if an explicit port range or type/code are not - # specified, but only if a source group was specified. - ip_proto_upper = ip_protocol.upper() if ip_protocol else '' - if (ip_proto_upper == 'ICMP' and - from_port is None and to_port is None): - from_port = -1 - to_port = -1 - elif (ip_proto_upper in ['TCP', 'UDP'] and from_port is None - and to_port is None): - from_port = 1 - to_port = 65535 - - if ip_protocol and from_port is not None and to_port is not None: - - ip_protocol = str(ip_protocol) - try: - # Verify integer conversions - from_port = int(from_port) - to_port = int(to_port) - except ValueError: - if ip_protocol.upper() == 'ICMP': - raise exception.InvalidInput(reason="Type and" - " Code must be integers for ICMP protocol type") - else: - raise exception.InvalidInput(reason="To and From ports " - "must be integers") - - if ip_protocol.upper() not in ['TCP', 'UDP', 'ICMP']: - raise exception.InvalidIpProtocol(protocol=ip_protocol) - - # Verify that from_port must always be less than - # or equal to to_port - if (ip_protocol.upper() in ['TCP', 'UDP'] and - (from_port > to_port)): - raise exception.InvalidPortRange(from_port=from_port, - to_port=to_port, msg="Former value cannot" - " be greater than the later") - - # Verify valid TCP, UDP port ranges - if (ip_protocol.upper() in ['TCP', 'UDP'] and - (from_port < 1 or to_port > 65535)): - raise exception.InvalidPortRange(from_port=from_port, - to_port=to_port, msg="Valid TCP ports should" - " be between 1-65535") - - # Verify ICMP type and code - if (ip_protocol.upper() == "ICMP" and - (from_port < -1 or from_port > 255 or - to_port < -1 or to_port > 255)): - raise exception.InvalidPortRange(from_port=from_port, - to_port=to_port, msg="For ICMP, the" - " type:code must be valid") - - values['protocol'] = ip_protocol.lower() - values['from_port'] = from_port - values['to_port'] = to_port + group_id = source_security_group['id'] + return self.security_group_api.new_group_ingress_rule( + group_id, ip_protocol, from_port, to_port) else: - # If cidr based filtering, protocol and ports are mandatory - if 'cidr' in values: - return None - - return values - - def _security_group_rule_exists(self, security_group, values): - """Indicates whether the specified rule values are already - defined in the given security group. - """ - for rule in security_group.rules: - is_duplicate = True - keys = ('group_id', 'cidr', 'from_port', 'to_port', 'protocol') - for key in keys: - if rule.get(key) != values.get(key): - is_duplicate = False - break - if is_duplicate: - return rule['id'] - return False + cidr = self.security_group_api.parse_cidr(cidr_ip) + return self.security_group_api.new_cidr_ingress_rule( + cidr, ip_protocol, from_port, to_port) - def revoke_security_group_ingress(self, context, group_name=None, - group_id=None, **kwargs): + def _validate_group_identifier(self, group_name, group_id): if not group_name and not group_id: err = _("Not enough parameters, need group_name or group_id") raise exception.EC2APIError(err) - self.compute_api.ensure_default_security_group(context) - notfound = exception.SecurityGroupNotFound - if group_name: - security_group = db.security_group_get_by_name(context, - context.project_id, - group_name) - if not security_group: - raise notfound(security_group_id=group_name) - if group_id: - security_group = db.security_group_get(context, group_id) - if not security_group: - raise notfound(security_group_id=group_id) - - msg = _("Revoke security group ingress %s") - LOG.audit(msg, security_group['name'], context=context) - prevalues = [] - try: - prevalues = kwargs['ip_permissions'] - except KeyError: - prevalues.append(kwargs) - rule_id = None + + def _validate_rulevalues(self, rulesvalues): + if not rulesvalues: + err = _("%s Not enough parameters to build a valid rule") + raise exception.EC2APIError(err % rulesvalues) + + def revoke_security_group_ingress(self, context, group_name=None, + group_id=None, **kwargs): + self._validate_group_identifier(group_name, group_id) + + security_group = self.security_group_api.get(context, group_name, + group_id) + + prevalues = kwargs.get('ip_permissions', [kwargs]) + rule_ids = [] for values in prevalues: rulesvalues = self._rule_args_to_dict(context, values) - if not rulesvalues: - err = _("%s Not enough parameters to build a valid rule") - raise exception.EC2APIError(err % rulesvalues) - + self._validate_rulevalues(rulesvalues) for values_for_rule in rulesvalues: values_for_rule['parent_group_id'] = security_group.id - rule_id = self._security_group_rule_exists(security_group, - values_for_rule) - if rule_id: - db.security_group_rule_destroy(context, rule_id) - rule_ids.append(rule_id) - if rule_id: - # NOTE(vish): we removed a rule, so refresh - self.compute_api.trigger_security_group_rules_refresh( - context, - security_group_id=security_group['id']) - self.sgh.trigger_security_group_rule_destroy_refresh( - context, rule_ids) + + rule_ids.append(self.security_group_api.rule_exists( + security_group, values_for_rule)) + + rule_ids = [id for id in rule_ids if id] + + if rule_ids: + self.security_group_api.remove_rules(context, security_group, + rule_ids) + return True + raise exception.EC2APIError(_("No rule for the specified parameters.")) # TODO(soren): This has only been tested with Boto as the client. @@ -684,64 +575,27 @@ class CloudController(object): # is sketchy. def authorize_security_group_ingress(self, context, group_name=None, group_id=None, **kwargs): - if not group_name and not group_id: - err = _("Not enough parameters, need group_name or group_id") - raise exception.EC2APIError(err) - self.compute_api.ensure_default_security_group(context) - notfound = exception.SecurityGroupNotFound - if group_name: - security_group = db.security_group_get_by_name(context, - context.project_id, - group_name) - if not security_group: - raise notfound(security_group_id=group_name) - if group_id: - security_group = db.security_group_get(context, group_id) - if not security_group: - raise notfound(security_group_id=group_id) - - msg = _("Authorize security group ingress %s") - LOG.audit(msg, security_group['name'], context=context) - prevalues = [] - try: - prevalues = kwargs['ip_permissions'] - except KeyError: - prevalues.append(kwargs) + self._validate_group_identifier(group_name, group_id) + + security_group = self.security_group_api.get(context, group_name, + group_id) + + prevalues = kwargs.get('ip_permissions', [kwargs]) postvalues = [] for values in prevalues: rulesvalues = self._rule_args_to_dict(context, values) - if not rulesvalues: - err = _("%s Not enough parameters to build a valid rule") - raise exception.EC2APIError(err % rulesvalues) + self._validate_rulevalues(rulesvalues) for values_for_rule in rulesvalues: values_for_rule['parent_group_id'] = security_group.id - if self._security_group_rule_exists(security_group, - values_for_rule): + if self.security_group_api.rule_exists(security_group, + values_for_rule): err = _('%s - This rule already exists in group') raise exception.EC2APIError(err % values_for_rule) postvalues.append(values_for_rule) - count = QUOTAS.count(context, 'security_group_rules', - security_group['id']) - try: - QUOTAS.limit_check(context, security_group_rules=count + 1) - except exception.OverQuota: - msg = _("Quota exceeded, too many security group rules.") - raise exception.EC2APIError(msg) - - rule_ids = [] - for values_for_rule in postvalues: - security_group_rule = db.security_group_rule_create( - context, - values_for_rule) - rule_ids.append(security_group_rule['id']) - if postvalues: - self.compute_api.trigger_security_group_rules_refresh( - context, - security_group_id=security_group['id']) - self.sgh.trigger_security_group_rule_create_refresh( - context, rule_ids) + self.security_group_api.add_rules(context, security_group['id'], + security_group['name'], postvalues) return True raise exception.EC2APIError(_("No rule for the specified parameters.")) @@ -766,64 +620,23 @@ class CloudController(object): def create_security_group(self, context, group_name, group_description): if isinstance(group_name, unicode): group_name = group_name.encode('utf-8') - # TODO(Daviey): LP: #813685 extend beyond group_name checking, and - # probably create a param validator that can be used elsewhere. if FLAGS.ec2_strict_validation: # EC2 specification gives constraints for name and description: # Accepts alphanumeric characters, spaces, dashes, and underscores - err = _("Value (%(value)s) for parameter %(param)s is invalid." - " Content limited to Alphanumeric characters," - " spaces, dashes, and underscores.") - if not re.match('^[a-zA-Z0-9_\- ]+$', group_name): - raise exception.InvalidParameterValue( - err=err % {"value": group_name, - "param": "GroupName"}) - if not re.match('^[a-zA-Z0-9_\- ]+$', group_description): - raise exception.InvalidParameterValue( - err=err % {"value": group_description, - "param": "GroupDescription"}) + allowed = '^[a-zA-Z0-9_\- ]+$' + self.security_group_api.validate_property(group_name, 'name', + allowed) + self.security_group_api.validate_property(group_description, + 'description', allowed) else: # Amazon accepts more symbols. # So, allow POSIX [:print:] characters. - if not re.match(r'^[\x20-\x7E]+$', group_name): - err = _("Value (%(value)s) for parameter %(param)s is invalid." - " Content is limited to characters" - " from the [:print:] class.") - raise exception.InvalidParameterValue( - err=err % {"value": group_name, - "param": "GroupName"}) - - if len(group_name) > 255: - err = _("Value (%s) for parameter GroupName is invalid." - " Length exceeds maximum of 255.") % group_name - raise exception.InvalidParameterValue(err=err) - - LOG.audit(_("Create Security Group %s"), group_name, context=context) - self.compute_api.ensure_default_security_group(context) - if db.security_group_exists(context, context.project_id, group_name): - msg = _('group %s already exists') - raise exception.EC2APIError(msg % group_name) + allowed = r'^[\x20-\x7E]+$' + self.security_group_api.validate_property(group_name, 'name', + allowed) - try: - reservations = QUOTAS.reserve(context, security_groups=1) - except exception.OverQuota: - msg = _("Quota exceeded, too many security groups.") - raise exception.EC2APIError(msg) - - try: - group = {'user_id': context.user_id, - 'project_id': context.project_id, - 'name': group_name, - 'description': group_description} - group_ref = db.security_group_create(context, group) - - self.sgh.trigger_security_group_create_refresh(context, group) - - # Commit the reservation - QUOTAS.commit(context, reservations) - except Exception: - with excutils.save_and_reraise_exception(): - QUOTAS.rollback(context, reservations) + group_ref = self.security_group_api.create(context, group_name, + group_description) return {'securityGroupSet': [self._format_security_group(context, group_ref)]} @@ -833,37 +646,11 @@ class CloudController(object): if not group_name and not group_id: err = _("Not enough parameters, need group_name or group_id") raise exception.EC2APIError(err) - notfound = exception.SecurityGroupNotFound - if group_name: - security_group = db.security_group_get_by_name(context, - context.project_id, - group_name) - if not security_group: - raise notfound(security_group_id=group_name) - elif group_id: - security_group = db.security_group_get(context, group_id) - if not security_group: - raise notfound(security_group_id=group_id) - if db.security_group_in_use(context, security_group.id): - raise exception.InvalidGroup(reason="In Use") - - # Get reservations - try: - reservations = QUOTAS.reserve(context, security_groups=-1) - except Exception: - reservations = None - LOG.exception(_("Failed to update usages deallocating " - "security group")) - LOG.audit(_("Delete security group %s"), group_name, context=context) - db.security_group_destroy(context, security_group.id) + security_group = self.security_group_api.get(context, group_name, + group_id) - self.sgh.trigger_security_group_destroy_refresh(context, - security_group.id) - - # Commit the reservations - if reservations: - QUOTAS.commit(context, reservations) + self.security_group_api.destroy(context, security_group) return True @@ -1719,3 +1506,32 @@ class CloudController(object): self.compute_api.start(context, instance_id=instance_id) return {'imageId': image_id} + + +class CloudSecurityGroupAPI(compute.api.SecurityGroupAPI): + @staticmethod + def raise_invalid_property(msg): + raise exception.InvalidParameterValue(err=msg) + + @staticmethod + def raise_group_already_exists(msg): + raise exception.EC2APIError(message=msg) + + @staticmethod + def raise_invalid_group(msg): + raise exception.InvalidGroup(reason=msg) + + @staticmethod + def raise_invalid_cidr(cidr, decoding_exception=None): + if decoding_exception: + raise decoding_exception + else: + raise exception.EC2APIError(_("Invalid CIDR")) + + @staticmethod + def raise_over_quota(msg): + raise exception.EC2APIError(message=msg) + + @staticmethod + def raise_not_found(msg): + pass diff --git a/nova/api/openstack/compute/contrib/security_groups.py b/nova/api/openstack/compute/contrib/security_groups.py index 4a69d392e..5e81347ec 100644 --- a/nova/api/openstack/compute/contrib/security_groups.py +++ b/nova/api/openstack/compute/contrib/security_groups.py @@ -16,7 +16,6 @@ """The security groups extension.""" -import urllib from xml.dom import minidom import webob @@ -32,14 +31,11 @@ from nova import exception from nova import flags from nova import log as logging from nova.openstack.common import excutils -from nova.openstack.common import importutils -from nova import quota from nova import utils LOG = logging.getLogger(__name__) FLAGS = flags.FLAGS -QUOTAS = quota.QUOTAS authorize = extensions.extension_authorizer('compute', 'security_groups') @@ -182,8 +178,9 @@ class SecurityGroupControllerBase(object): """Base class for Security Group controllers.""" def __init__(self): - self.compute_api = compute.API() - self.sgh = importutils.import_object(FLAGS.security_group_handler) + self.security_group_api = NativeSecurityGroupAPI() + self.compute_api = compute.API( + security_group_api=self.security_group_api) def _format_security_group_rule(self, context, rule): sg_rule = {} @@ -195,7 +192,8 @@ class SecurityGroupControllerBase(object): sg_rule['group'] = {} sg_rule['ip_range'] = {} if rule.group_id: - source_group = db.security_group_get(context, rule.group_id) + source_group = self.security_group_api.get(context, + id=rule.group_id) sg_rule['group'] = {'name': source_group.name, 'tenant_id': source_group.project_id} else: @@ -214,68 +212,65 @@ class SecurityGroupControllerBase(object): context, rule)] return security_group + def _authorize_context(self, req): + context = req.environ['nova.context'] + authorize(context) + return context -class SecurityGroupController(SecurityGroupControllerBase): - """The Security group API controller for the OpenStack API.""" - - def _get_security_group(self, context, id): + def _validate_id(self, id): try: - id = int(id) - security_group = db.security_group_get(context, id) + return int(id) except ValueError: msg = _("Security group id should be integer") raise exc.HTTPBadRequest(explanation=msg) - except exception.NotFound as exp: - raise exc.HTTPNotFound(explanation=unicode(exp)) - return security_group + + def _from_body(self, body, key): + if not body: + raise exc.HTTPUnprocessableEntity() + value = body.get(key, None) + if value is None: + raise exc.HTTPUnprocessableEntity() + return value + + +class SecurityGroupController(SecurityGroupControllerBase): + """The Security group API controller for the OpenStack API.""" @wsgi.serializers(xml=SecurityGroupTemplate) def show(self, req, id): """Return data about the given security group.""" - context = req.environ['nova.context'] - authorize(context) - security_group = self._get_security_group(context, id) + context = self._authorize_context(req) + + id = self._validate_id(id) + + security_group = self.security_group_api.get(context, None, id, + map_exception=True) + return {'security_group': self._format_security_group(context, security_group)} def delete(self, req, id): """Delete a security group.""" - context = req.environ['nova.context'] - authorize(context) - security_group = self._get_security_group(context, id) - if db.security_group_in_use(context, security_group.id): - msg = _("Security group is still in use") - raise exc.HTTPBadRequest(explanation=msg) + context = self._authorize_context(req) - # Get reservations - try: - reservations = QUOTAS.reserve(context, security_groups=-1) - except Exception: - reservations = None - LOG.exception(_("Failed to update usages deallocating " - "security group")) + id = self._validate_id(id) - LOG.audit(_("Delete security group %s"), id, context=context) - db.security_group_destroy(context, security_group.id) - self.sgh.trigger_security_group_destroy_refresh( - context, security_group.id) + security_group = self.security_group_api.get(context, None, id, + map_exception=True) - # Commit the reservations - if reservations: - QUOTAS.commit(context, reservations) + self.security_group_api.destroy(context, security_group) return webob.Response(status_int=202) @wsgi.serializers(xml=SecurityGroupsTemplate) def index(self, req): """Returns a list of security groups""" - context = req.environ['nova.context'] - authorize(context) + context = self._authorize_context(req) + + raw_groups = self.security_group_api.list(context, + project=context.project_id) - self.compute_api.ensure_default_security_group(context) - groups = db.security_group_get_by_project(context, - context.project_id) - limited_list = common.limited(groups, req) + limited_list = common.limited(raw_groups, req) result = [self._format_security_group(context, group) for group in limited_list] @@ -287,110 +282,43 @@ class SecurityGroupController(SecurityGroupControllerBase): @wsgi.deserializers(xml=SecurityGroupXMLDeserializer) def create(self, req, body): """Creates a new security group.""" - context = req.environ['nova.context'] - authorize(context) - if not body: - raise exc.HTTPUnprocessableEntity() + context = self._authorize_context(req) - security_group = body.get('security_group', None) - - if security_group is None: - raise exc.HTTPUnprocessableEntity() + security_group = self._from_body(body, 'security_group') group_name = security_group.get('name', None) group_description = security_group.get('description', None) - self._validate_security_group_property(group_name, "name") - self._validate_security_group_property(group_description, - "description") - group_name = group_name.strip() - group_description = group_description.strip() - - try: - reservations = QUOTAS.reserve(context, security_groups=1) - except exception.OverQuota: - msg = _("Quota exceeded, too many security groups.") - raise exc.HTTPBadRequest(explanation=msg) + self.security_group_api.validate_property(group_name, 'name', None) + self.security_group_api.validate_property(group_description, + 'description', None) - try: - LOG.audit(_("Create Security Group %s"), group_name, - context=context) - self.compute_api.ensure_default_security_group(context) - if db.security_group_exists(context, context.project_id, - group_name): - msg = _('Security group %s already exists') % group_name - raise exc.HTTPBadRequest(explanation=msg) - - group = {'user_id': context.user_id, - 'project_id': context.project_id, - 'name': group_name, - 'description': group_description} - group_ref = db.security_group_create(context, group) - self.sgh.trigger_security_group_create_refresh(context, group) - - # Commit the reservation - QUOTAS.commit(context, reservations) - except Exception: - with excutils.save_and_reraise_exception(): - QUOTAS.rollback(context, reservations) + group_ref = self.security_group_api.create(context, group_name, + group_description) return {'security_group': self._format_security_group(context, group_ref)} - def _validate_security_group_property(self, value, typ): - """ typ will be either 'name' or 'description', - depending on the caller - """ - try: - val = value.strip() - except AttributeError: - msg = _("Security group %s is not a string or unicode") % typ - raise exc.HTTPBadRequest(explanation=msg) - if not val: - msg = _("Security group %s cannot be empty.") % typ - raise exc.HTTPBadRequest(explanation=msg) - if len(val) > 255: - msg = _("Security group %s should not be greater " - "than 255 characters.") % typ - raise exc.HTTPBadRequest(explanation=msg) - class SecurityGroupRulesController(SecurityGroupControllerBase): @wsgi.serializers(xml=SecurityGroupRuleTemplate) @wsgi.deserializers(xml=SecurityGroupRulesXMLDeserializer) def create(self, req, body): - context = req.environ['nova.context'] - authorize(context) + context = self._authorize_context(req) - if not body: - raise exc.HTTPUnprocessableEntity() + sg_rule = self._from_body(body, 'security_group_rule') - if not 'security_group_rule' in body: - raise exc.HTTPUnprocessableEntity() + parent_group_id = self._validate_id(sg_rule.get('parent_group_id', + None)) - self.compute_api.ensure_default_security_group(context) - - sg_rule = body['security_group_rule'] - parent_group_id = sg_rule.get('parent_group_id', None) - try: - parent_group_id = int(parent_group_id) - security_group = db.security_group_get(context, parent_group_id) - except ValueError: - msg = _("Parent group id is not integer") - raise exc.HTTPBadRequest(explanation=msg) - except exception.NotFound as exp: - msg = _("Security group (%s) not found") % parent_group_id - raise exc.HTTPNotFound(explanation=msg) - - msg = _("Authorize security group ingress %s") - LOG.audit(msg, security_group['name'], context=context) + security_group = self.security_group_api.get(context, None, + parent_group_id, map_exception=True) try: values = self._rule_args_to_dict(context, to_port=sg_rule.get('to_port'), from_port=sg_rule.get('from_port'), - parent_group_id=sg_rule.get('parent_group_id'), ip_protocol=sg_rule.get('ip_protocol'), cidr=sg_rule.get('cidr'), group_id=sg_rule.get('group_id')) @@ -398,169 +326,50 @@ class SecurityGroupRulesController(SecurityGroupControllerBase): raise exc.HTTPBadRequest(explanation=unicode(exp)) if values is None: - msg = _("Not enough parameters to build a " - "valid rule.") + msg = _("Not enough parameters to build a valid rule.") raise exc.HTTPBadRequest(explanation=msg) values['parent_group_id'] = security_group.id - if self._security_group_rule_exists(security_group, values): + if self.security_group_api.rule_exists(security_group, values): msg = _('This rule already exists in group %s') % parent_group_id raise exc.HTTPBadRequest(explanation=msg) - count = QUOTAS.count(context, 'security_group_rules', parent_group_id) - try: - QUOTAS.limit_check(context, security_group_rules=count + 1) - except exception.OverQuota: - msg = _("Quota exceeded, too many security group rules.") - raise exc.HTTPBadRequest(explanation=msg) - - security_group_rule = db.security_group_rule_create(context, values) - self.sgh.trigger_security_group_rule_create_refresh( - context, [security_group_rule['id']]) - self.compute_api.trigger_security_group_rules_refresh(context, - security_group_id=security_group['id']) + security_group_rule = self.security_group_api.add_rules( + context, parent_group_id, security_group['name'], [values])[0] return {"security_group_rule": self._format_security_group_rule( context, security_group_rule)} - def _security_group_rule_exists(self, security_group, values): - """Indicates whether the specified rule values are already - defined in the given security group. - """ - for rule in security_group.rules: - is_duplicate = True - keys = ('group_id', 'cidr', 'from_port', 'to_port', 'protocol') - for key in keys: - if rule.get(key) != values.get(key): - is_duplicate = False - break - if is_duplicate: - return True - return False - def _rule_args_to_dict(self, context, to_port=None, from_port=None, - parent_group_id=None, ip_protocol=None, - cidr=None, group_id=None): - values = {} + ip_protocol=None, cidr=None, group_id=None): if group_id is not None: - try: - parent_group_id = int(parent_group_id) - group_id = int(group_id) - except ValueError: - msg = _("Parent or group id is not integer") - raise exception.InvalidInput(reason=msg) - - values['group_id'] = group_id + group_id = self._validate_id(group_id) #check if groupId exists - db.security_group_get(context, group_id) - elif cidr: - # If this fails, it throws an exception. This is what we want. - try: - cidr = urllib.unquote(cidr).decode() - except Exception: - raise exception.InvalidCidr(cidr=cidr) - - if not utils.is_valid_cidr(cidr): - # Raise exception for non-valid address - raise exception.InvalidCidr(cidr=cidr) - - values['cidr'] = cidr + self.security_group_api.get(context, id=group_id) + return self.security_group_api.new_group_ingress_rule( + group_id, ip_protocol, from_port, to_port) else: - values['cidr'] = '0.0.0.0/0' - - if group_id: - # Open everything if an explicit port range or type/code are not - # specified, but only if a source group was specified. - ip_proto_upper = ip_protocol.upper() if ip_protocol else '' - if (ip_proto_upper == 'ICMP' and - from_port is None and to_port is None): - from_port = -1 - to_port = -1 - elif (ip_proto_upper in ['TCP', 'UDP'] and from_port is None - and to_port is None): - from_port = 1 - to_port = 65535 - - if ip_protocol and from_port is not None and to_port is not None: - - ip_protocol = str(ip_protocol) - try: - from_port = int(from_port) - to_port = int(to_port) - except ValueError: - if ip_protocol.upper() == 'ICMP': - raise exception.InvalidInput(reason="Type and" - " Code must be integers for ICMP protocol type") - else: - raise exception.InvalidInput(reason="To and From ports " - "must be integers") - - if ip_protocol.upper() not in ['TCP', 'UDP', 'ICMP']: - raise exception.InvalidIpProtocol(protocol=ip_protocol) - - # Verify that from_port must always be less than - # or equal to to_port - if (ip_protocol.upper() in ['TCP', 'UDP'] and - from_port > to_port): - raise exception.InvalidPortRange(from_port=from_port, - to_port=to_port, msg="Former value cannot" - " be greater than the later") - - # Verify valid TCP, UDP port ranges - if (ip_protocol.upper() in ['TCP', 'UDP'] and - (from_port < 1 or to_port > 65535)): - raise exception.InvalidPortRange(from_port=from_port, - to_port=to_port, msg="Valid TCP ports should" - " be between 1-65535") - - # Verify ICMP type and code - if (ip_protocol.upper() == "ICMP" and - (from_port < -1 or from_port > 255 or - to_port < -1 or to_port > 255)): - raise exception.InvalidPortRange(from_port=from_port, - to_port=to_port, msg="For ICMP, the" - " type:code must be valid") - - values['protocol'] = ip_protocol.lower() - values['from_port'] = from_port - values['to_port'] = to_port - else: - # If cidr based filtering, protocol and ports are mandatory - if 'cidr' in values: - return None - - return values + cidr = self.security_group_api.parse_cidr(cidr) + return self.security_group_api.new_cidr_ingress_rule( + cidr, ip_protocol, from_port, to_port) def delete(self, req, id): - context = req.environ['nova.context'] - authorize(context) + context = self._authorize_context(req) - self.compute_api.ensure_default_security_group(context) - try: - id = int(id) - rule = db.security_group_rule_get(context, id) - except ValueError: - msg = _("Rule id is not integer") - raise exc.HTTPBadRequest(explanation=msg) - except exception.NotFound: - msg = _("Rule (%s) not found") % id - raise exc.HTTPNotFound(explanation=msg) + id = self._validate_id(id) + + rule = self.security_group_api.get_rule(context, id) group_id = rule.parent_group_id - self.compute_api.ensure_default_security_group(context) - security_group = db.security_group_get(context, group_id) - msg = _("Revoke security group ingress %s") - LOG.audit(msg, security_group['name'], context=context) + security_group = self.security_group_api.get(context, None, group_id, + map_exception=True) - db.security_group_rule_destroy(context, rule['id']) - self.sgh.trigger_security_group_rule_destroy_refresh( - context, [rule['id']]) - self.compute_api.trigger_security_group_rules_refresh(context, - security_group_id=security_group['id']) + self.security_group_api.remove_rules(context, security_group, + [rule['id']]) return webob.Response(status_int=202) @@ -570,10 +379,9 @@ class ServerSecurityGroupController(SecurityGroupControllerBase): @wsgi.serializers(xml=SecurityGroupsTemplate) def index(self, req, server_id): """Returns a list of security groups for the given instance.""" - context = req.environ['nova.context'] - authorize(context) + context = self._authorize_context(req) - self.compute_api.ensure_default_security_group(context) + self.security_group_api.ensure_default(context) try: instance = self.compute_api.get(context, server_id) @@ -595,16 +403,13 @@ class ServerSecurityGroupController(SecurityGroupControllerBase): class SecurityGroupActionController(wsgi.Controller): def __init__(self, *args, **kwargs): super(SecurityGroupActionController, self).__init__(*args, **kwargs) - self.compute_api = compute.API() - self.sgh = importutils.import_object(FLAGS.security_group_handler) - - @wsgi.action('addSecurityGroup') - def _addSecurityGroup(self, req, id, body): - context = req.environ['nova.context'] - authorize(context) + self.security_group_api = NativeSecurityGroupAPI() + self.compute_api = compute.API( + security_group_api=self.security_group_api) + def _parse(self, body, action): try: - body = body['addSecurityGroup'] + body = body[action] group_name = body['name'] except TypeError: msg = _("Missing parameter dict") @@ -617,11 +422,12 @@ class SecurityGroupActionController(wsgi.Controller): msg = _("Security group name cannot be empty") raise webob.exc.HTTPBadRequest(explanation=msg) + return group_name + + def _invoke(self, method, context, id, group_name): try: instance = self.compute_api.get(context, id) - self.compute_api.add_security_group(context, instance, group_name) - self.sgh.trigger_instance_add_security_group_refresh( - context, instance, group_name) + method(context, instance, group_name) except exception.SecurityGroupNotFound as exp: raise exc.HTTPNotFound(explanation=unicode(exp)) except exception.InstanceNotFound as exp: @@ -631,39 +437,25 @@ class SecurityGroupActionController(wsgi.Controller): return webob.Response(status_int=202) - @wsgi.action('removeSecurityGroup') - def _removeSecurityGroup(self, req, id, body): + @wsgi.action('addSecurityGroup') + def _addSecurityGroup(self, req, id, body): context = req.environ['nova.context'] authorize(context) - try: - body = body['removeSecurityGroup'] - group_name = body['name'] - except TypeError: - msg = _("Missing parameter dict") - raise webob.exc.HTTPBadRequest(explanation=msg) - except KeyError: - msg = _("Security group not specified") - raise webob.exc.HTTPBadRequest(explanation=msg) + group_name = self._parse(body, 'addSecurityGroup') - if not group_name or group_name.strip() == '': - msg = _("Security group name cannot be empty") - raise webob.exc.HTTPBadRequest(explanation=msg) + return self._invoke(self.security_group_api.add_to_instance, + context, id, group_name) - try: - instance = self.compute_api.get(context, id) - self.compute_api.remove_security_group(context, instance, - group_name) - self.sgh.trigger_instance_remove_security_group_refresh( - context, instance, group_name) - except exception.SecurityGroupNotFound as exp: - raise exc.HTTPNotFound(explanation=unicode(exp)) - except exception.InstanceNotFound as exp: - raise exc.HTTPNotFound(explanation=unicode(exp)) - except exception.Invalid as exp: - raise exc.HTTPBadRequest(explanation=unicode(exp)) + @wsgi.action('removeSecurityGroup') + def _removeSecurityGroup(self, req, id, body): + context = req.environ['nova.context'] + authorize(context) - return webob.Response(status_int=202) + group_name = self._parse(body, 'removeSecurityGroup') + + return self._invoke(self.security_group_api.remove_from_instance, + context, id, group_name) class Security_groups(extensions.ExtensionDescriptor): @@ -698,3 +490,29 @@ class Security_groups(extensions.ExtensionDescriptor): resources.append(res) return resources + + +class NativeSecurityGroupAPI(compute.api.SecurityGroupAPI): + @staticmethod + def raise_invalid_property(msg): + raise exc.HTTPBadRequest(explanation=msg) + + @staticmethod + def raise_group_already_exists(msg): + raise exc.HTTPBadRequest(explanation=msg) + + @staticmethod + def raise_invalid_group(msg): + raise exc.HTTPBadRequest(explanation=msg) + + @staticmethod + def raise_invalid_cidr(cidr, decoding_exception=None): + raise exception.InvalidCidr(cidr=cidr) + + @staticmethod + def raise_over_quota(msg): + raise exc.HTTPBadRequest(explanation=msg) + + @staticmethod + def raise_not_found(msg): + raise exc.HTTPNotFound(explanation=msg) diff --git a/nova/compute/api.py b/nova/compute/api.py index 10bcfe457..922a8bace 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -25,6 +25,7 @@ import functools import re import string import time +import urllib from nova import block_device from nova.compute import aggregate_states @@ -43,6 +44,7 @@ from nova import log as logging from nova import network from nova import notifications from nova.openstack.common import excutils +from nova.openstack.common import importutils from nova.openstack.common import jsonutils import nova.policy from nova import quota @@ -92,17 +94,23 @@ def check_instance_state(vm_state=None, task_state=None): return outer -def wrap_check_policy(func): +def policy_decorator(scope): """Check corresponding policy prior of wrapped method to execution""" - @functools.wraps(func) - def wrapped(self, context, target, *args, **kwargs): - check_policy(context, func.__name__, target) - return func(self, context, target, *args, **kwargs) - return wrapped + def outer(func): + @functools.wraps(func) + def wrapped(self, context, target, *args, **kwargs): + check_policy(context, func.__name__, target, scope) + return func(self, context, target, *args, **kwargs) + return wrapped + return outer + +wrap_check_policy = policy_decorator(scope='compute') +wrap_check_security_groups_policy = policy_decorator( + scope='compute:security_groups') -def check_policy(context, action, target): - _action = 'compute:%s' % action +def check_policy(context, action, target, scope='compute'): + _action = '%s:%s' % (scope, action) nova.policy.enforce(context, _action, target) @@ -110,12 +118,13 @@ class API(base.Base): """API for interacting with the compute manager.""" def __init__(self, image_service=None, network_api=None, volume_api=None, - **kwargs): + security_group_api=None, **kwargs): self.image_service = (image_service or nova.image.get_default_image_service()) self.network_api = network_api or network.API() self.volume_api = volume_api or volume.API() + self.security_group_api = security_group_api or SecurityGroupAPI() self.consoleauth_rpcapi = consoleauth_rpcapi.ConsoleAuthAPI() self.scheduler_rpcapi = scheduler_rpcapi.SchedulerAPI() self.compute_rpcapi = compute_rpcapi.ComputeAPI() @@ -389,7 +398,7 @@ class API(base.Base): kernel_id, ramdisk_id = self._handle_kernel_and_ramdisk( context, kernel_id, ramdisk_id, image, image_service) - self.ensure_default_security_group(context) + self.security_group_api.ensure_default(context) if key_data is None and key_name: key_pair = self.db.key_pair_get(context, context.user_id, key_name) @@ -771,80 +780,6 @@ class API(base.Base): return (inst_ret_list, reservation_id) - def ensure_default_security_group(self, context): - """Ensure that a context has a security group. - - Creates a security group for the security context if it does not - already exist. - - :param context: the security context - """ - try: - self.db.security_group_get_by_name(context, - context.project_id, - 'default') - except exception.NotFound: - values = {'name': 'default', - 'description': 'default', - 'user_id': context.user_id, - 'project_id': context.project_id} - self.db.security_group_create(context, values) - - def trigger_security_group_rules_refresh(self, context, security_group_id): - """Called when a rule is added to or removed from a security_group.""" - - security_group = self.db.security_group_get(context, security_group_id) - - hosts = set() - for instance in security_group['instances']: - if instance['host'] is not None: - hosts.add(instance['host']) - - for host in hosts: - self.compute_rpcapi.refresh_security_group_rules(context, - security_group.id, host=host) - - def trigger_security_group_members_refresh(self, context, group_ids): - """Called when a security group gains a new or loses a member. - - Sends an update request to each compute node for whom this is - relevant. - """ - # First, we get the security group rules that reference these groups as - # the grantee.. - security_group_rules = set() - for group_id in group_ids: - security_group_rules.update( - self.db.security_group_rule_get_by_security_group_grantee( - context, - group_id)) - - # ..then we distill the security groups to which they belong.. - security_groups = set() - for rule in security_group_rules: - security_group = self.db.security_group_get( - context, - rule['parent_group_id']) - security_groups.add(security_group) - - # ..then we find the instances that are members of these groups.. - instances = set() - for security_group in security_groups: - for instance in security_group['instances']: - instances.add(instance) - - # ...then we find the hosts where they live... - hosts = set() - for instance in instances: - if instance['host']: - hosts.add(instance['host']) - - # ...and finally we tell these nodes to refresh their view of this - # particular security group. - for host in hosts: - self.compute_rpcapi.refresh_security_group_members(context, - group_id, host=host) - def trigger_provider_fw_rules_refresh(self, context): """Called when a rule is added/removed from a provider firewall""" @@ -853,81 +788,6 @@ class API(base.Base): for host in hosts: self.compute_rpcapi.refresh_provider_fw_rules(context, host) - def _is_security_group_associated_with_server(self, security_group, - instance_uuid): - """Check if the security group is already associated - with the instance. If Yes, return True. - """ - - if not security_group: - return False - - instances = security_group.get('instances') - if not instances: - return False - - for inst in instances: - if (instance_uuid == inst['uuid']): - return True - - return False - - @wrap_check_policy - def add_security_group(self, context, instance, security_group_name): - """Add security group to the instance""" - security_group = self.db.security_group_get_by_name(context, - context.project_id, - security_group_name) - - instance_uuid = instance['uuid'] - - #check if the security group is associated with the server - if self._is_security_group_associated_with_server(security_group, - instance_uuid): - raise exception.SecurityGroupExistsForInstance( - security_group_id=security_group['id'], - instance_id=instance_uuid) - - #check if the instance is in running state - if instance['power_state'] != power_state.RUNNING: - raise exception.InstanceNotRunning(instance_id=instance_uuid) - - self.db.instance_add_security_group(context.elevated(), - instance_uuid, - security_group['id']) - # NOTE(comstud): No instance_uuid argument to this compute manager - # call - self.compute_rpcapi.refresh_security_group_rules(context, - security_group['id'], host=instance['host']) - - @wrap_check_policy - def remove_security_group(self, context, instance, security_group_name): - """Remove the security group associated with the instance""" - security_group = self.db.security_group_get_by_name(context, - context.project_id, - security_group_name) - - instance_uuid = instance['uuid'] - - #check if the security group is associated with the server - if not self._is_security_group_associated_with_server(security_group, - instance_uuid): - raise exception.SecurityGroupNotExistsForInstance( - security_group_id=security_group['id'], - instance_id=instance_uuid) - - #check if the instance is in running state - if instance['power_state'] != power_state.RUNNING: - raise exception.InstanceNotRunning(instance_id=instance_uuid) - - self.db.instance_remove_security_group(context.elevated(), - instance_uuid, - security_group['id']) - # NOTE(comstud): No instance_uuid argument to this compute manager - # call - self.compute_rpcapi.refresh_security_group_rules(context, - security_group['id'], host=instance['host']) - @wrap_check_policy def update(self, context, instance, **kwargs): """Updates the instance in the datastore. @@ -2065,3 +1925,473 @@ class KeypairAPI(base.Base): 'fingerprint': key_pair['fingerprint'], }) return rval + + +class SecurityGroupAPI(base.Base): + """ + Sub-set of the Compute API related to managing security groups + and security group rules + """ + def __init__(self, **kwargs): + super(SecurityGroupAPI, self).__init__(**kwargs) + self.security_group_rpcapi = compute_rpcapi.SecurityGroupAPI() + self.sgh = importutils.import_object(FLAGS.security_group_handler) + + def validate_property(self, value, property, allowed): + """ + Validate given security group property. + + :param value: the value to validate, as a string or unicode + :param property: the property, either 'name' or 'description' + :param allowed: the range of characters allowed + """ + + try: + val = value.strip() + except AttributeError: + msg = _("Security group %s is not a string or unicode") % property + self.raise_invalid_property(msg) + if not val: + msg = _("Security group %s cannot be empty.") % property + self.raise_invalid_property(msg) + + if allowed and not re.match(allowed, val): + # Some validation to ensure that values match API spec. + # - Alphanumeric characters, spaces, dashes, and underscores. + # TODO(Daviey): LP: #813685 extend beyond group_name checking, and + # probably create a param validator that can be used elsewhere. + msg = (_("Value (%(value)s) for parameter Group%(property)s is " + "invalid. Content limited to '%(allowed)'.") % + dict(value=value, allowed=allowed, + property=property.capitalize())) + self.raise_invalid_property(msg) + if len(val) > 255: + msg = _("Security group %s should not be greater " + "than 255 characters.") % property + self.raise_invalid_property(msg) + + def ensure_default(self, context): + """Ensure that a context has a security group. + + Creates a security group for the security context if it does not + already exist. + + :param context: the security context + """ + try: + self.db.security_group_get_by_name(context, + context.project_id, + 'default') + except exception.NotFound: + values = {'name': 'default', + 'description': 'default', + 'user_id': context.user_id, + 'project_id': context.project_id} + self.db.security_group_create(context, values) + + def create(self, context, name, description): + try: + reservations = QUOTAS.reserve(context, security_groups=1) + except exception.OverQuota: + msg = _("Quota exceeded, too many security groups.") + self.raise_over_quota(msg) + + LOG.audit(_("Create Security Group %s"), name, context=context) + + self.ensure_default(context) + + if self.db.security_group_exists(context, context.project_id, name): + msg = _('Security group %s already exists') % name + self.raise_group_already_exists(msg) + + try: + group = {'user_id': context.user_id, + 'project_id': context.project_id, + 'name': name, + 'description': description} + group_ref = self.db.security_group_create(context, group) + self.sgh.trigger_security_group_create_refresh(context, group) + # Commit the reservation + QUOTAS.commit(context, reservations) + except Exception: + with excutils.save_and_reraise_exception(): + QUOTAS.rollback(context, reservations) + + return group_ref + + def get(self, context, name=None, id=None, map_exception=False): + self.ensure_default(context) + try: + if name: + return self.db.security_group_get_by_name(context, + context.project_id, + name) + elif id: + return self.db.security_group_get(context, id) + except exception.NotFound as exp: + if map_exception: + msg = unicode(exp) + self.raise_not_found(msg) + else: + raise + + def list(self, context, names=None, ids=None, project=None): + self.ensure_default(context) + + groups = [] + if names or ids: + if names: + for name in names: + groups.append(self.db.security_group_get_by_name(context, + project, + name)) + if ids: + for id in ids: + groups.append(self.db.security_group_get(context, id)) + + elif context.is_admin: + groups = self.db.security_group_get_all(context) + + elif project: + groups = self.db.security_group_get_by_project(context, project) + + return groups + + def destroy(self, context, security_group): + if self.db.security_group_in_use(context, security_group.id): + msg = _("Security group is still in use") + self.raise_invalid_group(msg) + + # Get reservations + try: + reservations = QUOTAS.reserve(context, security_groups=-1) + except Exception: + reservations = None + LOG.exception(_("Failed to update usages deallocating " + "security group")) + + LOG.audit(_("Delete security group %s"), security_group.name, + context=context) + self.db.security_group_destroy(context, security_group.id) + + self.sgh.trigger_security_group_destroy_refresh(context, + security_group.id) + + # Commit the reservations + if reservations: + QUOTAS.commit(context, reservations) + + def is_associated_with_server(self, security_group, instance_uuid): + """Check if the security group is already associated + with the instance. If Yes, return True. + """ + + if not security_group: + return False + + instances = security_group.get('instances') + if not instances: + return False + + for inst in instances: + if (instance_uuid == inst['uuid']): + return True + + return False + + @wrap_check_security_groups_policy + def add_to_instance(self, context, instance, security_group_name): + """Add security group to the instance""" + security_group = self.db.security_group_get_by_name(context, + context.project_id, + security_group_name) + + instance_uuid = instance['uuid'] + + #check if the security group is associated with the server + if self.is_associated_with_server(security_group, instance_uuid): + raise exception.SecurityGroupExistsForInstance( + security_group_id=security_group['id'], + instance_id=instance_uuid) + + #check if the instance is in running state + if instance['power_state'] != power_state.RUNNING: + raise exception.InstanceNotRunning(instance_id=instance_uuid) + + self.db.instance_add_security_group(context.elevated(), + instance_uuid, + security_group['id']) + params = {"security_group_id": security_group['id']} + # NOTE(comstud): No instance_uuid argument to this compute manager + # call + self.security_group_rpcapi.refresh_security_group_rules(context, + security_group['id'], host=instance['host']) + + self.trigger_handler('instance_add_security_group', + context, instance, security_group_name) + + @wrap_check_security_groups_policy + def remove_from_instance(self, context, instance, security_group_name): + """Remove the security group associated with the instance""" + security_group = self.db.security_group_get_by_name(context, + context.project_id, + security_group_name) + + instance_uuid = instance['uuid'] + + #check if the security group is associated with the server + if not self.is_associated_with_server(security_group, instance_uuid): + raise exception.SecurityGroupNotExistsForInstance( + security_group_id=security_group['id'], + instance_id=instance_uuid) + + #check if the instance is in running state + if instance['power_state'] != power_state.RUNNING: + raise exception.InstanceNotRunning(instance_id=instance_uuid) + + self.db.instance_remove_security_group(context.elevated(), + instance_uuid, + security_group['id']) + params = {"security_group_id": security_group['id']} + # NOTE(comstud): No instance_uuid argument to this compute manager + # call + self.security_group_rpcapi.refresh_security_group_rules(context, + security_group['id'], host=instance['host']) + + self.trigger_handler('instance_remove_security_group', + context, instance, security_group_name) + + def trigger_handler(self, event, *args): + handle = getattr(self.sgh, 'trigger_%s_refresh' % event) + handle(*args) + + def trigger_rules_refresh(self, context, id): + """Called when a rule is added to or removed from a security_group.""" + + security_group = self.db.security_group_get(context, id) + + hosts = set() + for instance in security_group['instances']: + if instance['host'] is not None: + hosts.add(instance['host']) + + for host in hosts: + self.security_group_rpcapi.refresh_security_group_rules(context, + security_group.id, host=host) + + def trigger_members_refresh(self, context, group_ids): + """Called when a security group gains a new or loses a member. + + Sends an update request to each compute node for whom this is + relevant. + """ + # First, we get the security group rules that reference these groups as + # the grantee.. + security_group_rules = set() + for group_id in group_ids: + security_group_rules.update( + self.db.security_group_rule_get_by_security_group_grantee( + context, + group_id)) + + # ..then we distill the security groups to which they belong.. + security_groups = set() + for rule in security_group_rules: + security_group = self.db.security_group_get( + context, + rule['parent_group_id']) + security_groups.add(security_group) + + # ..then we find the instances that are members of these groups.. + instances = set() + for security_group in security_groups: + for instance in security_group['instances']: + instances.add(instance) + + # ...then we find the hosts where they live... + hosts = set() + for instance in instances: + if instance['host']: + hosts.add(instance['host']) + + # ...and finally we tell these nodes to refresh their view of this + # particular security group. + for host in hosts: + self.security_group_rpcapi.refresh_security_group_members(context, + group_id, host=host) + + def parse_cidr(self, cidr): + if cidr: + try: + cidr = urllib.unquote(cidr).decode() + except Exception as e: + self.raise_invalid_cidr(cidr, e) + + if not utils.is_valid_cidr(cidr): + self.raise_invalid_cidr(cidr) + + return cidr + else: + return '0.0.0.0/0' + + @staticmethod + def new_group_ingress_rule(grantee_group_id, protocol, from_port, + to_port): + return SecurityGroupAPI._new_ingress_rule(protocol, from_port, + to_port, group_id=grantee_group_id) + + @staticmethod + def new_cidr_ingress_rule(grantee_cidr, protocol, from_port, to_port): + return SecurityGroupAPI._new_ingress_rule(protocol, from_port, + to_port, cidr=grantee_cidr) + + @staticmethod + def _new_ingress_rule(ip_protocol, from_port, to_port, + group_id=None, cidr=None): + values = {} + + if group_id: + values['group_id'] = group_id + # Open everything if an explicit port range or type/code are not + # specified, but only if a source group was specified. + ip_proto_upper = ip_protocol.upper() if ip_protocol else '' + if (ip_proto_upper == 'ICMP' and + from_port is None and to_port is None): + from_port = -1 + to_port = -1 + elif (ip_proto_upper in ['TCP', 'UDP'] and from_port is None + and to_port is None): + from_port = 1 + to_port = 65535 + + elif cidr: + values['cidr'] = cidr + + if ip_protocol and from_port is not None and to_port is not None: + + ip_protocol = str(ip_protocol) + try: + # Verify integer conversions + from_port = int(from_port) + to_port = int(to_port) + except ValueError: + if ip_protocol.upper() == 'ICMP': + raise exception.InvalidInput(reason="Type and" + " Code must be integers for ICMP protocol type") + else: + raise exception.InvalidInput(reason="To and From ports " + "must be integers") + + if ip_protocol.upper() not in ['TCP', 'UDP', 'ICMP']: + raise exception.InvalidIpProtocol(protocol=ip_protocol) + + # Verify that from_port must always be less than + # or equal to to_port + if (ip_protocol.upper() in ['TCP', 'UDP'] and + (from_port > to_port)): + raise exception.InvalidPortRange(from_port=from_port, + to_port=to_port, msg="Former value cannot" + " be greater than the later") + + # Verify valid TCP, UDP port ranges + if (ip_protocol.upper() in ['TCP', 'UDP'] and + (from_port < 1 or to_port > 65535)): + raise exception.InvalidPortRange(from_port=from_port, + to_port=to_port, msg="Valid TCP ports should" + " be between 1-65535") + + # Verify ICMP type and code + if (ip_protocol.upper() == "ICMP" and + (from_port < -1 or from_port > 255 or + to_port < -1 or to_port > 255)): + raise exception.InvalidPortRange(from_port=from_port, + to_port=to_port, msg="For ICMP, the" + " type:code must be valid") + + values['protocol'] = ip_protocol + values['from_port'] = from_port + values['to_port'] = to_port + + else: + # If cidr based filtering, protocol and ports are mandatory + if cidr: + return None + + return values + + def rule_exists(self, security_group, values): + """Indicates whether the specified rule values are already + defined in the given security group. + """ + for rule in security_group.rules: + is_duplicate = True + keys = ('group_id', 'cidr', 'from_port', 'to_port', 'protocol') + for key in keys: + if rule.get(key) != values.get(key): + is_duplicate = False + break + if is_duplicate: + return rule.get('id') or True + return False + + def get_rule(self, context, id): + self.ensure_default(context) + try: + return self.db.security_group_rule_get(context, id) + except exception.NotFound: + msg = _("Rule (%s) not found") % id + self.raise_not_found(msg) + + def add_rules(self, context, id, name, vals): + count = QUOTAS.count(context, 'security_group_rules', id) + try: + projected = count + len(vals) + QUOTAS.limit_check(context, security_group_rules=projected) + except exception.OverQuota: + msg = _("Quota exceeded, too many security group rules.") + self.raise_over_quota(msg) + + msg = _("Authorize security group ingress %s") + LOG.audit(msg, name, context=context) + + rules = [self.db.security_group_rule_create(context, v) for v in vals] + + self.trigger_rules_refresh(context, id=id) + self.trigger_handler('security_group_rule_create', context, + [r['id'] for r in rules]) + return rules + + def remove_rules(self, context, security_group, rule_ids): + msg = _("Revoke security group ingress %s") + LOG.audit(msg, security_group['name'], context=context) + + for rule_id in rule_ids: + self.db.security_group_rule_destroy(context, rule_id) + + # NOTE(vish): we removed some rules, so refresh + self.trigger_rules_refresh(context, id=security_group['id']) + self.trigger_handler('security_group_rule_destroy', context, rule_ids) + + @staticmethod + def raise_invalid_property(msg): + raise NotImplementedError() + + @staticmethod + def raise_group_already_exists(msg): + raise NotImplementedError() + + @staticmethod + def raise_invalid_group(msg): + raise NotImplementedError() + + @staticmethod + def raise_invalid_cidr(cidr, decoding_exception=None): + raise NotImplementedError() + + @staticmethod + def raise_over_quota(msg): + raise NotImplementedError() + + @staticmethod + def raise_not_found(msg): + raise NotImplementedError() diff --git a/nova/compute/rpcapi.py b/nova/compute/rpcapi.py index 8c25906c9..e7945c7d4 100644 --- a/nova/compute/rpcapi.py +++ b/nova/compute/rpcapi.py @@ -27,6 +27,27 @@ import nova.rpc.proxy FLAGS = flags.FLAGS +def _compute_topic(topic, ctxt, host, instance): + '''Get the topic to use for a message. + + :param topic: the base topic + :param ctxt: request context + :param host: explicit host to send the message to. + :param instance: If an explicit host was not specified, use + instance['host'] + + :returns: A topic string + ''' + if not host: + if not instance: + raise exception.NovaException(_('No compute host specified')) + host = instance['host'] + if not host: + raise exception.NovaException(_('Unable to find host for ' + 'Instance %s') % instance['uuid']) + return rpc.queue_get_for(ctxt, topic, host) + + class ComputeAPI(nova.rpc.proxy.RpcProxy): '''Client side of the compute rpc API. @@ -41,25 +62,6 @@ class ComputeAPI(nova.rpc.proxy.RpcProxy): super(ComputeAPI, self).__init__(topic=FLAGS.compute_topic, default_version=self.RPC_API_VERSION) - def _compute_topic(self, ctxt, host, instance): - '''Get the topic to use for a message. - - :param ctxt: request context - :param host: explicit host to send the message to. - :param instance: If an explicit host was not specified, use - instance['host'] - - :returns: A topic string - ''' - if not host: - if not instance: - raise exception.NovaException(_('No compute host specified')) - host = instance['host'] - if not host: - raise exception.NovaException(_('Unable to find host for ' - 'Instance %s') % instance['uuid']) - return rpc.queue_get_for(ctxt, self.topic, host) - def add_aggregate_host(self, ctxt, aggregate_id, host_param, host): '''Add aggregate host. @@ -71,90 +73,90 @@ class ComputeAPI(nova.rpc.proxy.RpcProxy): ''' self.cast(ctxt, self.make_msg('add_aggregate_host', aggregate_id=aggregate_id, host=host_param), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def add_fixed_ip_to_instance(self, ctxt, instance, network_id): self.cast(ctxt, self.make_msg('add_fixed_ip_to_instance', instance_uuid=instance['uuid'], network_id=network_id), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def attach_volume(self, ctxt, instance, volume_id, mountpoint): self.cast(ctxt, self.make_msg('attach_volume', instance_uuid=instance['uuid'], volume_id=volume_id, mountpoint=mountpoint), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def check_shared_storage_test_file(self, ctxt, filename, host): return self.call(ctxt, self.make_msg('check_shared_storage_test_file', filename=filename), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def cleanup_shared_storage_test_file(self, ctxt, filename, host): self.cast(ctxt, self.make_msg('cleanup_shared_storage_test_file', filename=filename), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def compare_cpu(self, ctxt, cpu_info, host): return self.call(ctxt, self.make_msg('compare_cpu', cpu_info=cpu_info), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def confirm_resize(self, ctxt, instance, migration_id, host, cast=True): rpc_method = self.cast if cast else self.call return rpc_method(ctxt, self.make_msg('confirm_resize', instance_uuid=instance['uuid'], migration_id=migration_id), - topic=self._compute_topic(ctxt, host, instance)) + topic=_compute_topic(self.topic, ctxt, host, instance)) def create_shared_storage_test_file(self, ctxt, host): return self.call(ctxt, self.make_msg('create_shared_storage_test_file'), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def detach_volume(self, ctxt, instance, volume_id): self.cast(ctxt, self.make_msg('detach_volume', instance_uuid=instance['uuid'], volume_id=volume_id), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def finish_resize(self, ctxt, instance, migration_id, image, disk_info, host): self.cast(ctxt, self.make_msg('finish_resize', instance_uuid=instance['uuid'], migration_id=migration_id, image=image, disk_info=disk_info), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def finish_revert_resize(self, ctxt, instance, migration_id, host): self.cast(ctxt, self.make_msg('finish_revert_resize', instance_uuid=instance['uuid'], migration_id=migration_id), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def get_console_output(self, ctxt, instance, tail_length): return self.call(ctxt, self.make_msg('get_console_output', instance_uuid=instance['uuid'], tail_length=tail_length), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def get_console_pool_info(self, ctxt, console_type, host): return self.call(ctxt, self.make_msg('get_console_pool_info', console_type=console_type), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def get_console_topic(self, ctxt, host): return self.call(ctxt, self.make_msg('get_console_topic'), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def get_diagnostics(self, ctxt, instance): return self.call(ctxt, self.make_msg('get_diagnostics', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def get_instance_disk_info(self, ctxt, instance): return self.call(ctxt, self.make_msg('get_instance_disk_info', instance_name=instance['name']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def get_vnc_console(self, ctxt, instance, console_type): return self.call(ctxt, self.make_msg('get_vnc_console', instance_uuid=instance['uuid'], console_type=console_type), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def host_maintenance_mode(self, ctxt, host_param, mode, host): '''Set host maintenance mode @@ -167,60 +169,61 @@ class ComputeAPI(nova.rpc.proxy.RpcProxy): ''' return self.call(ctxt, self.make_msg('host_maintenance_mode', host=host_param, mode=mode), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def host_power_action(self, ctxt, action, host): + topic = _compute_topic(self.topic, ctxt, host, None) return self.call(ctxt, self.make_msg('host_power_action', - action=action), topic=self._compute_topic(ctxt, host, None)) + action=action), topic) def inject_file(self, ctxt, instance, path, file_contents): self.cast(ctxt, self.make_msg('inject_file', instance_uuid=instance['uuid'], path=path, file_contents=file_contents), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def inject_network_info(self, ctxt, instance): self.cast(ctxt, self.make_msg('inject_network_info', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def lock_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('lock_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def post_live_migration_at_destination(self, ctxt, instance, block_migration, host): return self.call(ctxt, self.make_msg('post_live_migration_at_destination', instance_id=instance['id'], block_migration=block_migration), - self._compute_topic(ctxt, host, None)) + _compute_topic(self.topic, ctxt, host, None)) def pause_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('pause_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def power_off_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('power_off_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def power_on_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('power_on_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def pre_live_migration(self, ctxt, instance, block_migration, disk, host): return self.call(ctxt, self.make_msg('pre_live_migration', instance_id=instance['id'], block_migration=block_migration, - disk=disk), self._compute_topic(ctxt, host, None)) + disk=disk), _compute_topic(self.topic, ctxt, host, None)) def reboot_instance(self, ctxt, instance, reboot_type): self.cast(ctxt, self.make_msg('reboot_instance', instance_uuid=instance['uuid'], reboot_type=reboot_type), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def rebuild_instance(self, ctxt, instance, new_pass, injected_files, image_ref, orig_image_ref): @@ -228,22 +231,22 @@ class ComputeAPI(nova.rpc.proxy.RpcProxy): instance_uuid=instance['uuid'], new_pass=new_pass, injected_files=injected_files, image_ref=image_ref, orig_image_ref=orig_image_ref), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def refresh_provider_fw_rules(self, ctxt, host): self.cast(ctxt, self.make_msg('refresh_provider_fw_rules'), - self._compute_topic(ctxt, host, None)) + _compute_topic(self.topic, ctxt, host, None)) def refresh_security_group_rules(self, ctxt, security_group_id, host): self.cast(ctxt, self.make_msg('refresh_security_group_rules', security_group_id=security_group_id), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def refresh_security_group_members(self, ctxt, security_group_id, host): self.cast(ctxt, self.make_msg('refresh_security_group_members', security_group_id=security_group_id), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def remove_aggregate_host(self, ctxt, aggregate_id, host_param, host): '''Remove aggregate host. @@ -256,57 +259,59 @@ class ComputeAPI(nova.rpc.proxy.RpcProxy): ''' self.cast(ctxt, self.make_msg('remove_aggregate_host', aggregate_id=aggregate_id, host=host_param), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def remove_fixed_ip_from_instance(self, ctxt, instance, address): self.cast(ctxt, self.make_msg('remove_fixed_ip_from_instance', instance_uuid=instance['uuid'], address=address), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def remove_volume_connection(self, ctxt, instance, volume_id, host): return self.call(ctxt, self.make_msg('remove_volume_connection', instance_id=instance['id'], volume_id=volume_id), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def rescue_instance(self, ctxt, instance, rescue_password): self.cast(ctxt, self.make_msg('rescue_instance', instance_uuid=instance['uuid'], rescue_password=rescue_password), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def reset_network(self, ctxt, instance): self.cast(ctxt, self.make_msg('reset_network', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def resize_instance(self, ctxt, instance, migration_id, image): + topic = _compute_topic(self.topic, ctxt, None, instance) self.cast(ctxt, self.make_msg('resize_instance', instance_uuid=instance['uuid'], migration_id=migration_id, - image=image), topic=self._compute_topic(ctxt, None, instance)) + image=image), topic) def resume_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('resume_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def revert_resize(self, ctxt, instance, migration_id, host): self.cast(ctxt, self.make_msg('revert_resize', instance_uuid=instance['uuid'], migration_id=migration_id), - topic=self._compute_topic(ctxt, host, instance)) + topic=_compute_topic(self.topic, ctxt, host, instance)) def rollback_live_migration_at_destination(self, ctxt, instance, host): self.cast(ctxt, self.make_msg('rollback_live_migration_at_destination', instance_id=instance['id']), - topic=self._compute_topic(ctxt, host, None)) + topic=_compute_topic(self.topic, ctxt, host, None)) def set_admin_password(self, ctxt, instance, new_pass): self.cast(ctxt, self.make_msg('set_admin_password', instance_uuid=instance['uuid'], new_pass=new_pass), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def set_host_enabled(self, ctxt, enabled, host): + topic = _compute_topic(self.topic, ctxt, host, None) return self.call(ctxt, self.make_msg('set_host_enabled', - enabled=enabled), topic=self._compute_topic(ctxt, host, None)) + enabled=enabled), topic) def snapshot_instance(self, ctxt, instance, image_id, image_type, backup_type, rotation): @@ -314,40 +319,66 @@ class ComputeAPI(nova.rpc.proxy.RpcProxy): instance_uuid=instance['uuid'], image_id=image_id, image_type=image_type, backup_type=backup_type, rotation=rotation), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def start_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('start_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def stop_instance(self, ctxt, instance, cast=True): rpc_method = self.cast if cast else self.call return rpc_method(ctxt, self.make_msg('stop_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def suspend_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('suspend_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def terminate_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('terminate_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def unlock_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('unlock_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def unpause_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('unpause_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) def unrescue_instance(self, ctxt, instance): self.cast(ctxt, self.make_msg('unrescue_instance', instance_uuid=instance['uuid']), - topic=self._compute_topic(ctxt, None, instance)) + topic=_compute_topic(self.topic, ctxt, None, instance)) + + +class SecurityGroupAPI(nova.rpc.proxy.RpcProxy): + '''Client side of the security group rpc API. + + API version history: + + 1.0 - Initial version. + ''' + + RPC_API_VERSION = '1.0' + + def __init__(self): + super(SecurityGroupAPI, self).__init__(topic=FLAGS.compute_topic, + default_version=self.RPC_API_VERSION) + + def refresh_security_group_rules(self, ctxt, security_group_id, host): + self.cast(ctxt, self.make_msg('refresh_security_group_rules', + security_group_id=security_group_id), + topic=_compute_topic(self.topic, ctxt, host, None)) + + def refresh_security_group_members(self, ctxt, security_group_id, + host): + self.cast(ctxt, self.make_msg('refresh_security_group_members', + security_group_id=security_group_id), + topic=_compute_topic(self.topic, ctxt, host, None)) diff --git a/nova/network/manager.py b/nova/network/manager.py index 1f1580634..72b41b81f 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -764,8 +764,9 @@ class NetworkManager(manager.SchedulerDependentManager): temp = importutils.import_object(FLAGS.floating_ip_dns_manager) self.floating_dns_manager = temp self.network_api = network_api.API() - self.compute_api = compute_api.API() - self.sgh = importutils.import_object(FLAGS.security_group_handler) + self.security_group_api = compute_api.SecurityGroupAPI() + self.compute_api = compute_api.API( + security_group_api=self.security_group_api) # NOTE(tr3buchet: unless manager subclassing NetworkManager has # already imported ipam, import nova ipam here @@ -843,10 +844,10 @@ class NetworkManager(manager.SchedulerDependentManager): instance_ref = self.db.instance_get(admin_context, instance_id) groups = instance_ref['security_groups'] group_ids = [group['id'] for group in groups] - self.compute_api.trigger_security_group_members_refresh(admin_context, - group_ids) - self.sgh.trigger_security_group_members_refresh(admin_context, + self.security_group_api.trigger_members_refresh(admin_context, group_ids) + self.security_group_api.trigger_handler('security_group_members', + admin_context, group_ids) def get_floating_ips_by_fixed_address(self, context, fixed_address): # NOTE(jkoelker) This is just a stub function. Managers supporting diff --git a/nova/tests/api/ec2/test_cloud.py b/nova/tests/api/ec2/test_cloud.py index 5bf20a4e0..afb906f45 100644 --- a/nova/tests/api/ec2/test_cloud.py +++ b/nova/tests/api/ec2/test_cloud.py @@ -295,7 +295,7 @@ class CloudTestCase(test.TestCase): def test_security_group_quota_limit(self): self.flags(quota_security_groups=10) - for i in range(1, 10): + for i in range(1, FLAGS.quota_security_groups + 1): name = 'test name %i' % i descript = 'test description %i' % i create = self.cloud.create_security_group diff --git a/nova/tests/compute/test_compute.py b/nova/tests/compute/test_compute.py index a32e0de05..92e5d193c 100644 --- a/nova/tests/compute/test_compute.py +++ b/nova/tests/compute/test_compute.py @@ -2001,7 +2001,9 @@ class ComputeAPITestCase(BaseTestCase): super(ComputeAPITestCase, self).setUp() self.stubs.Set(nova.network.API, 'get_instance_nw_info', fake_get_nw_info) - self.compute_api = compute.API() + self.security_group_api = compute.api.SecurityGroupAPI() + self.compute_api = compute.API( + security_group_api=self.security_group_api) self.fake_image = { 'id': 1, 'properties': {'kernel_id': 'fake_kernel_id', @@ -3628,12 +3630,12 @@ class ComputeAPITestCase(BaseTestCase): self.compute.run_instance(self.context, instance['uuid']) instance = self.compute_api.get(self.context, instance['uuid']) security_group_name = self._create_group()['name'] - self.compute_api.add_security_group(self.context, - instance, - security_group_name) - self.compute_api.remove_security_group(self.context, - instance, - security_group_name) + self.security_group_api.add_to_instance(self.context, + instance, + security_group_name) + self.security_group_api.remove_from_instance(self.context, + instance, + security_group_name) def test_get_diagnostics(self): instance = self._create_fake_instance() diff --git a/nova/tests/compute/test_rpcapi.py b/nova/tests/compute/test_rpcapi.py index 47fb10645..a0da63918 100644 --- a/nova/tests/compute/test_rpcapi.py +++ b/nova/tests/compute/test_rpcapi.py @@ -44,7 +44,12 @@ class ComputeRpcAPITestCase(test.TestCase): def _test_compute_api(self, method, rpc_method, **kwargs): ctxt = context.RequestContext('fake_user', 'fake_project') - rpcapi = compute_rpcapi.ComputeAPI() + if 'rpcapi_class' in kwargs: + rpcapi_class = kwargs['rpcapi_class'] + del kwargs['rpcapi_class'] + else: + rpcapi_class = compute_rpcapi.ComputeAPI + rpcapi = rpcapi_class() expected_retval = 'foo' if method == 'call' else None expected_msg = rpcapi.make_msg(method, **kwargs) @@ -224,10 +229,12 @@ class ComputeRpcAPITestCase(test.TestCase): def test_refresh_security_group_rules(self): self._test_compute_api('refresh_security_group_rules', 'cast', + rpcapi_class=compute_rpcapi.SecurityGroupAPI, security_group_id='id', host='host') def test_refresh_security_group_members(self): self._test_compute_api('refresh_security_group_members', 'cast', + rpcapi_class=compute_rpcapi.SecurityGroupAPI, security_group_id='id', host='host') def test_remove_aggregate_host(self): diff --git a/nova/tests/policy.json b/nova/tests/policy.json index aa9c79749..8f8ea769c 100644 --- a/nova/tests/policy.json +++ b/nova/tests/policy.json @@ -58,8 +58,8 @@ "compute:snapshot": [], "compute:backup": [], - "compute:add_security_group": [], - "compute:remove_security_group": [], + "compute:security_groups:add_to_instance": [], + "compute:security_groups:remove_from_instance": [], "compute:delete": [], "compute:soft_delete": [], |
