diff options
-rw-r--r-- | nova/api/ec2/__init__.py | 2 | ||||
-rw-r--r-- | nova/api/ec2/cloud.py | 232 | ||||
-rw-r--r-- | nova/auth/manager.py | 2 | ||||
-rw-r--r-- | nova/compute/manager.py | 5 | ||||
-rw-r--r-- | nova/db/api.py | 70 | ||||
-rw-r--r-- | nova/db/sqlalchemy/api.py | 188 | ||||
-rw-r--r-- | nova/db/sqlalchemy/models.py | 64 | ||||
-rw-r--r-- | nova/db/sqlalchemy/session.py | 9 | ||||
-rw-r--r-- | nova/exception.py | 3 | ||||
-rw-r--r-- | nova/network/manager.py | 1 | ||||
-rw-r--r-- | nova/process.py | 2 | ||||
-rw-r--r-- | nova/test.py | 2 | ||||
-rw-r--r-- | nova/tests/api_unittest.py | 188 | ||||
-rw-r--r-- | nova/tests/objectstore_unittest.py | 2 | ||||
-rw-r--r-- | nova/tests/virt_unittest.py | 184 | ||||
-rw-r--r-- | nova/virt/interfaces.template | 1 | ||||
-rw-r--r-- | nova/virt/libvirt.qemu.xml.template | 4 | ||||
-rw-r--r-- | nova/virt/libvirt.uml.xml.template | 4 | ||||
-rw-r--r-- | nova/virt/libvirt_conn.py | 209 | ||||
-rw-r--r-- | run_tests.py | 1 |
20 files changed, 1132 insertions, 41 deletions
diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index 6b538a7f1..6e771f064 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -142,6 +142,8 @@ class Authorizer(wsgi.Middleware): 'CreateKeyPair': ['all'], 'DeleteKeyPair': ['all'], 'DescribeSecurityGroups': ['all'], + 'AuthorizeSecurityGroupIngress': ['netadmin'], + 'RevokeSecurityGroupIngress': ['netadmin'], 'CreateSecurityGroup': ['netadmin'], 'DeleteSecurityGroup': ['netadmin'], 'GetConsoleOutput': ['projectmanager', 'sysadmin'], diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 619c1a4b0..7839dc92c 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -28,6 +28,8 @@ import logging import os import time +import IPy + from nova import crypto from nova import db from nova import exception @@ -43,6 +45,7 @@ from nova.api.ec2 import images FLAGS = flags.FLAGS flags.DECLARE('storage_availability_zone', 'nova.volume.manager') +InvalidInputException = exception.InvalidInputException class QuotaError(exception.ApiError): """Quota Exceeeded""" @@ -127,6 +130,15 @@ class CloudController(object): result[key] = [line] return result + def _trigger_refresh_security_group(self, security_group): + nodes = set([instance['host'] for instance in security_group.instances + if instance['host'] is not None]) + for node in nodes: + rpc.call('%s.%s' % (FLAGS.compute_topic, node), + { "method": "refresh_security_group", + "args": { "context": None, + "security_group_id": security_group.id}}) + def get_metadata(self, address): instance_ref = db.fixed_ip_get_instance(None, address) if instance_ref is None: @@ -246,18 +258,195 @@ class CloudController(object): pass return True - def describe_security_groups(self, context, group_names, **kwargs): - groups = {'securityGroupSet': []} + def describe_security_groups(self, context, group_name=None, **kwargs): + self._ensure_default_security_group(context) + if context.user.is_admin(): + groups = db.security_group_get_all(context) + else: + groups = db.security_group_get_by_project(context, + context.project.id) + groups = [self._format_security_group(context, g) for g in groups] + if not group_name is None: + groups = [g for g in groups if g.name in group_name] + + return {'securityGroupInfo': groups } + + def _format_security_group(self, context, group): + g = {} + g['groupDescription'] = group.description + g['groupName'] = group.name + g['ownerId'] = group.project_id + g['ipPermissions'] = [] + for rule in group.rules: + r = {} + r['ipProtocol'] = rule.protocol + r['fromPort'] = rule.from_port + r['toPort'] = rule.to_port + r['groups'] = [] + r['ipRanges'] = [] + if rule.group_id: + source_group = db.security_group_get(context, rule.group_id) + r['groups'] += [{'groupName': source_group.name, + 'userId': source_group.project_id}] + else: + r['ipRanges'] += [{'cidrIp': rule.cidr}] + g['ipPermissions'] += [r] + return g + + + def _authorize_revoke_rule_args_to_dict(self, context, + to_port=None, from_port=None, + ip_protocol=None, cidr_ip=None, + user_id=None, + source_security_group_name=None, + source_security_group_owner_id=None): + + values = {} + + if source_security_group_name: + source_project_id = self._get_source_project_id(context, + source_security_group_owner_id) + + source_security_group = \ + db.security_group_get_by_name(context, + source_project_id, + source_security_group_name) + values['group_id'] = source_security_group['id'] + elif cidr_ip: + # If this fails, it throws an exception. This is what we want. + IPy.IP(cidr_ip) + values['cidr'] = cidr_ip + else: + values['cidr'] = '0.0.0.0/0' + + if ip_protocol and from_port and to_port: + from_port = int(from_port) + to_port = int(to_port) + ip_protocol = str(ip_protocol) + + if ip_protocol.upper() not in ['TCP','UDP','ICMP']: + raise InvalidInputException('%s is not a valid ipProtocol' % + (ip_protocol,)) + if ((min(from_port, to_port) < -1) or + (max(from_port, to_port) > 65535)): + raise InvalidInputException('Invalid port range') + + values['protocol'] = ip_protocol + values['from_port'] = from_port + values['to_port'] = to_port + else: + # If cidr based filtering, protocol and ports are mandatory + if 'cidr' in values: + return None + + return values - # Stubbed for now to unblock other things. - return groups - def create_security_group(self, context, group_name, **kwargs): + def _security_group_rule_exists(self, security_group, values): + """Indicates whether the specified rule values are already + defined in the given security group. + """ + for rule in security_group.rules: + if 'group_id' in values: + if rule['group_id'] == values['group_id']: + return True + else: + is_duplicate = True + for key in ('cidr', 'from_port', 'to_port', 'protocol'): + if rule[key] != values[key]: + is_duplicate = False + break + if is_duplicate: + return True + return False + + + def revoke_security_group_ingress(self, context, group_name, **kwargs): + self._ensure_default_security_group(context) + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) + + criteria = self._authorize_revoke_rule_args_to_dict(context, **kwargs) + if criteria == None: + raise exception.ApiError("No rule for the specified parameters.") + + for rule in security_group.rules: + match = True + for (k,v) in criteria.iteritems(): + if getattr(rule, k, False) != v: + match = False + if match: + db.security_group_rule_destroy(context, rule['id']) + self._trigger_refresh_security_group(security_group) + return True + raise exception.ApiError("No rule for the specified parameters.") + + # TODO(soren): This has only been tested with Boto as the client. + # Unfortunately, it seems Boto is using an old API + # for these operations, so support for newer API versions + # is sketchy. + def authorize_security_group_ingress(self, context, group_name, **kwargs): + self._ensure_default_security_group(context) + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) + + values = self._authorize_revoke_rule_args_to_dict(context, **kwargs) + values['parent_group_id'] = security_group.id + + if self._security_group_rule_exists(security_group, values): + raise exception.ApiError('This rule already exists in group %s' % + group_name) + + security_group_rule = db.security_group_rule_create(context, values) + + self._trigger_refresh_security_group(security_group) + return True + + def _get_source_project_id(self, context, source_security_group_owner_id): + if source_security_group_owner_id: + # Parse user:project for source group. + source_parts = source_security_group_owner_id.split(':') + + # If no project name specified, assume it's same as user name. + # Since we're looking up by project name, the user name is not + # used here. It's only read for EC2 API compatibility. + if len(source_parts) == 2: + source_project_id = source_parts[1] + else: + source_project_id = source_parts[0] + else: + source_project_id = context.project.id + + return source_project_id + + + def create_security_group(self, context, group_name, group_description): + self._ensure_default_security_group(context) + if db.security_group_exists(context, context.project.id, group_name): + raise exception.ApiError('group %s already exists' % group_name) + + group = {'user_id' : context.user.id, + 'project_id': context.project.id, + 'name': group_name, + 'description': group_description} + group_ref = db.security_group_create(context, group) + + return {'securityGroupSet': [self._format_security_group(context, + group_ref)]} + + def delete_security_group(self, context, group_name, **kwargs): + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) + db.security_group_destroy(context, security_group.id) return True + def get_console_output(self, context, instance_id, **kwargs): # instance_id is passed in as a list of instances ec2_id = instance_id[0] @@ -554,6 +743,18 @@ class CloudController(object): "project_id": context.project.id}}) return db.queue_get_for(context, FLAGS.network_topic, host) + def _ensure_default_security_group(self, context): + try: + db.security_group_get_by_name(context, + context.project.id, + 'default') + except exception.NotFound: + values = { 'name' : 'default', + 'description' : 'default', + 'user_id' : context.user.id, + 'project_id' : context.project.id } + group = db.security_group_create(context, values) + def run_instances(self, context, **kwargs): instance_type = kwargs.get('instance_type', 'm1.small') if instance_type not in INSTANCE_TYPES: @@ -601,8 +802,17 @@ class CloudController(object): kwargs['key_name']) key_data = key_pair_ref['public_key'] - # TODO: Get the real security group of launch in here - security_group = "default" + security_group_arg = kwargs.get('security_group', ["default"]) + if not type(security_group_arg) is list: + security_group_arg = [security_group_arg] + + security_groups = [] + self._ensure_default_security_group(context) + for security_group_name in security_group_arg: + group = db.security_group_get_by_name(context, + context.project.id, + security_group_name) + security_groups.append(group['id']) reservation_id = utils.generate_uid('r') base_options = {} @@ -616,12 +826,12 @@ class CloudController(object): base_options['user_id'] = context.user.id base_options['project_id'] = context.project.id base_options['user_data'] = kwargs.get('user_data', '') - base_options['security_group'] = security_group - base_options['instance_type'] = instance_type + base_options['display_name'] = kwargs.get('display_name') base_options['display_description'] = kwargs.get('display_description') type_data = INSTANCE_TYPES[instance_type] + base_options['instance_type'] = instance_type base_options['memory_mb'] = type_data['memory_mb'] base_options['vcpus'] = type_data['vcpus'] base_options['local_gb'] = type_data['local_gb'] @@ -630,6 +840,10 @@ class CloudController(object): instance_ref = db.instance_create(context, base_options) inst_id = instance_ref['id'] + for security_group_id in security_groups: + db.instance_add_security_group(context, inst_id, + security_group_id) + inst = {} inst['mac_address'] = utils.generate_mac() inst['launch_index'] = num diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 49235c910..58e33969b 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -490,6 +490,7 @@ class AuthManager(object): except: drv.delete_project(project.id) raise + return project def modify_project(self, project, manager_user=None, description=None): @@ -565,6 +566,7 @@ class AuthManager(object): except: logging.exception('Could not destroy network for %s', project) + with self.driver() as drv: drv.delete_project(Project.safe_id(project)) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index 99705d3a9..ef7e9da6f 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -64,6 +64,11 @@ class ComputeManager(manager.Manager): @defer.inlineCallbacks @exception.wrap_exception + def refresh_security_group(self, context, security_group_id, **_kwargs): + yield self.driver.refresh_security_group(security_group_id) + + @defer.inlineCallbacks + @exception.wrap_exception def run_instance(self, context, instance_id, **_kwargs): """Launch a new instance with specified options.""" instance_ref = self.db.instance_get(context, instance_id) diff --git a/nova/db/api.py b/nova/db/api.py index 2f0879c5a..4be8df397 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -304,6 +304,11 @@ def instance_update(context, instance_id, values): return IMPL.instance_update(context, instance_id, values) +def instance_add_security_group(context, instance_id, security_group_id): + """Associate the given security group with the given instance""" + return IMPL.instance_add_security_group(context, instance_id, security_group_id) + + ################### @@ -571,6 +576,71 @@ def volume_update(context, volume_id, values): return IMPL.volume_update(context, volume_id, values) +#################### + + +def security_group_get_all(context): + """Get all security groups""" + return IMPL.security_group_get_all(context) + + +def security_group_get(context, security_group_id): + """Get security group by its internal id""" + return IMPL.security_group_get(context, security_group_id) + + +def security_group_get_by_name(context, project_id, group_name): + """Returns a security group with the specified name from a project""" + return IMPL.security_group_get_by_name(context, project_id, group_name) + + +def security_group_get_by_project(context, project_id): + """Get all security groups belonging to a project""" + return IMPL.security_group_get_by_project(context, project_id) + + +def security_group_get_by_instance(context, instance_id): + """Get security groups to which the instance is assigned""" + return IMPL.security_group_get_by_instance(context, instance_id) + + +def security_group_exists(context, project_id, group_name): + """Indicates if a group name exists in a project""" + return IMPL.security_group_exists(context, project_id, group_name) + + +def security_group_create(context, values): + """Create a new security group""" + return IMPL.security_group_create(context, values) + + +def security_group_destroy(context, security_group_id): + """Deletes a security group""" + return IMPL.security_group_destroy(context, security_group_id) + + +def security_group_destroy_all(context): + """Deletes a security group""" + return IMPL.security_group_destroy_all(context) + + +#################### + + +def security_group_rule_create(context, values): + """Create a new security group""" + return IMPL.security_group_rule_create(context, values) + + +def security_group_rule_get_by_security_group(context, security_group_id): + """Get all rules for a a given security group""" + return IMPL.security_group_rule_get_by_security_group(context, security_group_id) + +def security_group_rule_destroy(context, security_group_rule_id): + """Deletes a security group rule""" + return IMPL.security_group_rule_destroy(context, security_group_rule_id) + + ################### diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 6f1ea7c23..50d802774 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -29,8 +29,11 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import joinedload, joinedload_all -from sqlalchemy.sql import exists, func +from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload_all +from sqlalchemy.sql import exists +from sqlalchemy.sql import func +from sqlalchemy.orm.exc import NoResultFound FLAGS = flags.FLAGS @@ -571,11 +574,13 @@ def instance_get(context, instance_id, session=None): if is_admin_context(context): result = session.query(models.Instance + ).options(joinedload('security_groups') ).filter_by(id=instance_id ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance + ).options(joinedload('security_groups') ).filter_by(project_id=context.project.id ).filter_by(id=instance_id ).filter_by(deleted=False @@ -591,6 +596,7 @@ def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(deleted=can_read_deleted(context) ).all() @@ -600,6 +606,7 @@ def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(deleted=can_read_deleted(context) ).filter_by(user_id=user_id ).all() @@ -612,6 +619,7 @@ def instance_get_all_by_project(context, project_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(project_id=project_id ).filter_by(deleted=can_read_deleted(context) ).all() @@ -624,12 +632,14 @@ def instance_get_all_by_reservation(context, reservation_id): if is_admin_context(context): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(reservation_id=reservation_id ).filter_by(deleted=can_read_deleted(context) ).all() elif is_user_context(context): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(project_id=context.project.id ).filter_by(reservation_id=reservation_id ).filter_by(deleted=False @@ -642,11 +652,13 @@ def instance_get_by_internal_id(context, internal_id): if is_admin_context(context): result = session.query(models.Instance + ).options(joinedload('security_groups') ).filter_by(internal_id=internal_id ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance + ).options(joinedload('security_groups') ).filter_by(project_id=context.project.id ).filter_by(internal_id=internal_id ).filter_by(deleted=False @@ -718,6 +730,18 @@ def instance_update(context, instance_id, values): instance_ref.save(session=session) +def instance_add_security_group(context, instance_id, security_group_id): + """Associate the given security group with the given instance""" + session = get_session() + with session.begin(): + instance_ref = instance_get(context, instance_id, session=session) + security_group_ref = security_group_get(context, + security_group_id, + session=session) + instance_ref.security_groups += [security_group_ref] + instance_ref.save(session=session) + + ################### @@ -1192,6 +1216,7 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): + session = get_session() return session.query(models.Volume ).filter_by(deleted=can_read_deleted(context) ).all() @@ -1282,6 +1307,163 @@ def volume_update(context, volume_id, values): ################### +@require_context +def security_group_get_all(context): + session = get_session() + return session.query(models.SecurityGroup + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload_all('rules') + ).all() + + +@require_context +def security_group_get(context, security_group_id, session=None): + if not session: + session = get_session() + if is_admin_context(context): + result = session.query(models.SecurityGroup + ).filter_by(deleted=can_read_deleted(context), + ).filter_by(id=security_group_id + ).options(joinedload_all('rules') + ).first() + else: + result = session.query(models.SecurityGroup + ).filter_by(deleted=False + ).filter_by(id=security_group_id + ).filter_by(project_id=context.project_id + ).options(joinedload_all('rules') + ).first() + if not result: + raise exception.NotFound("No secuity group with id %s" % + security_group_id) + return result + + +@require_context +def security_group_get_by_name(context, project_id, group_name): + session = get_session() + result = session.query(models.SecurityGroup + ).filter_by(project_id=project_id + ).filter_by(name=group_name + ).filter_by(deleted=False + ).options(joinedload_all('rules') + ).options(joinedload_all('instances') + ).first() + if not result: + raise exception.NotFound( + 'No security group named %s for project: %s' \ + % (group_name, project_id)) + return result + + +@require_context +def security_group_get_by_project(context, project_id): + session = get_session() + return session.query(models.SecurityGroup + ).filter_by(project_id=project_id + ).filter_by(deleted=False + ).options(joinedload_all('rules') + ).all() + + +@require_context +def security_group_get_by_instance(context, instance_id): + session = get_session() + return session.query(models.SecurityGroup + ).filter_by(deleted=False + ).options(joinedload_all('rules') + ).join(models.SecurityGroup.instances + ).filter_by(id=instance_id + ).filter_by(deleted=False + ).all() + + +@require_context +def security_group_exists(context, project_id, group_name): + try: + group = security_group_get_by_name(context, project_id, group_name) + return group != None + except exception.NotFound: + return False + + +@require_context +def security_group_create(context, values): + security_group_ref = models.SecurityGroup() + # FIXME(devcamcar): Unless I do this, rules fails with lazy load exception + # once save() is called. This will get cleaned up in next orm pass. + security_group_ref.rules + for (key, value) in values.iteritems(): + security_group_ref[key] = value + security_group_ref.save() + return security_group_ref + + +@require_context +def security_group_destroy(context, security_group_id): + session = get_session() + with session.begin(): + # TODO(vish): do we have to use sql here? + session.execute('update security_groups set deleted=1 where id=:id', + {'id': security_group_id}) + session.execute('update security_group_rules set deleted=1 ' + 'where group_id=:id', + {'id': security_group_id}) + +@require_context +def security_group_destroy_all(context, session=None): + if not session: + session = get_session() + with session.begin(): + # TODO(vish): do we have to use sql here? + session.execute('update security_groups set deleted=1') + session.execute('update security_group_rules set deleted=1') + + +################### + + +@require_context +def security_group_rule_get(context, security_group_rule_id, session=None): + if not session: + session = get_session() + if is_admin_context(context): + result = session.query(models.SecurityGroupIngressRule + ).filter_by(deleted=can_read_deleted(context) + ).filter_by(id=security_group_rule_id + ).first() + else: + # TODO(vish): Join to group and check for project_id + result = session.query(models.SecurityGroupIngressRule + ).filter_by(deleted=False + ).filter_by(id=security_group_rule_id + ).first() + if not result: + raise exception.NotFound("No secuity group rule with id %s" % + security_group_rule_id) + return result + + +@require_context +def security_group_rule_create(context, values): + security_group_rule_ref = models.SecurityGroupIngressRule() + for (key, value) in values.iteritems(): + security_group_rule_ref[key] = value + security_group_rule_ref.save() + return security_group_rule_ref + +@require_context +def security_group_rule_destroy(context, security_group_rule_id): + session = get_session() + with session.begin(): + security_group_rule = security_group_rule_get(context, + security_group_rule_id, + session=session) + security_group_rule.delete(session=session) + + +################### + @require_admin_context def user_get(context, id, session=None): if not session: @@ -1491,6 +1673,8 @@ def user_add_project_role(context, user_id, project_id, role): ################### + +@require_admin_context def host_get_networks(context, host): session = get_session() with session.begin(): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 9809eb7a7..7dfc39f6f 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -187,7 +187,6 @@ class Instance(BASE, NovaBase): launch_index = Column(Integer) key_name = Column(String(255)) key_data = Column(Text) - security_group = Column(String(255)) state = Column(Integer) state_description = Column(String(255)) @@ -289,10 +288,66 @@ class ExportDevice(BASE, NovaBase): 'ExportDevice.deleted==False)') +class SecurityGroupInstanceAssociation(BASE, NovaBase): + __tablename__ = 'security_group_instance_association' + id = Column(Integer, primary_key=True) + security_group_id = Column(Integer, ForeignKey('security_groups.id')) + instance_id = Column(Integer, ForeignKey('instances.id')) + + +class SecurityGroup(BASE, NovaBase): + """Represents a security group""" + __tablename__ = 'security_groups' + id = Column(Integer, primary_key=True) + + name = Column(String(255)) + description = Column(String(255)) + user_id = Column(String(255)) + project_id = Column(String(255)) + + instances = relationship(Instance, + secondary="security_group_instance_association", + primaryjoin="and_(SecurityGroup.id == SecurityGroupInstanceAssociation.security_group_id," + "SecurityGroup.deleted == False)", + secondaryjoin="and_(SecurityGroupInstanceAssociation.instance_id == Instance.id," + "Instance.deleted == False)", + backref='security_groups') + + @property + def user(self): + return auth.manager.AuthManager().get_user(self.user_id) + + @property + def project(self): + return auth.manager.AuthManager().get_project(self.project_id) + + +class SecurityGroupIngressRule(BASE, NovaBase): + """Represents a rule in a security group""" + __tablename__ = 'security_group_rules' + id = Column(Integer, primary_key=True) + + parent_group_id = Column(Integer, ForeignKey('security_groups.id')) + parent_group = relationship("SecurityGroup", backref="rules", + foreign_keys=parent_group_id, + primaryjoin="and_(SecurityGroupIngressRule.parent_group_id == SecurityGroup.id," + "SecurityGroupIngressRule.deleted == False)") + + protocol = Column(String(5)) # "tcp", "udp", or "icmp" + from_port = Column(Integer) + to_port = Column(Integer) + cidr = Column(String(255)) + + # Note: This is not the parent SecurityGroup. It's SecurityGroup we're + # granting access for. + group_id = Column(Integer, ForeignKey('security_groups.id')) + + class KeyPair(BASE, NovaBase): """Represents a public key pair for ssh""" __tablename__ = 'key_pairs' id = Column(Integer, primary_key=True) + name = Column(String(255)) user_id = Column(String(255)) @@ -461,9 +516,10 @@ class FloatingIp(BASE, NovaBase): def register_models(): """Register Models and create metadata""" from sqlalchemy import create_engine - models = (Service, Instance, Volume, ExportDevice, - FixedIp, FloatingIp, Network, NetworkIndex, - AuthToken, UserProjectAssociation, User, Project) # , Image, Host) + models = (Service, Instance, Volume, ExportDevice, FixedIp, + FloatingIp, Network, NetworkIndex, SecurityGroup, + SecurityGroupIngressRule, SecurityGroupInstanceAssociation, + AuthToken, User, Project) # , Image, Host engine = create_engine(FLAGS.sql_connection, echo=False) for model in models: model.metadata.create_all(engine) diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py index 69a205378..826754f6a 100644 --- a/nova/db/sqlalchemy/session.py +++ b/nova/db/sqlalchemy/session.py @@ -36,7 +36,8 @@ def get_session(autocommit=True, expire_on_commit=False): if not _MAKER: if not _ENGINE: _ENGINE = create_engine(FLAGS.sql_connection, echo=False) - _MAKER = sessionmaker(bind=_ENGINE, - autocommit=autocommit, - expire_on_commit=expire_on_commit) - return _MAKER() + _MAKER = (sessionmaker(bind=_ENGINE, + autocommit=autocommit, + expire_on_commit=expire_on_commit)) + session = _MAKER() + return session diff --git a/nova/exception.py b/nova/exception.py index b8894758f..f157fab2d 100644 --- a/nova/exception.py +++ b/nova/exception.py @@ -69,6 +69,9 @@ class NotEmpty(Error): class Invalid(Error): pass +class InvalidInputException(Error): + pass + def wrap_exception(f): def _wrap(*args, **kw): diff --git a/nova/network/manager.py b/nova/network/manager.py index 9c1846dd9..093a6be9a 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -211,7 +211,6 @@ class FlatManager(NetworkManager): # in the datastore? net = {} net['injected'] = True - net['network_str'] = FLAGS.flat_network_network net['netmask'] = FLAGS.flat_network_netmask net['bridge'] = FLAGS.flat_network_bridge net['gateway'] = FLAGS.flat_network_gateway diff --git a/nova/process.py b/nova/process.py index b3cad894b..13cb90e82 100644 --- a/nova/process.py +++ b/nova/process.py @@ -113,7 +113,7 @@ class BackRelayWithInput(protocol.ProcessProtocol): if self.started_deferred: self.started_deferred.callback(self) if self.process_input: - self.transport.write(self.process_input) + self.transport.write(str(self.process_input)) self.transport.closeStdin() def get_process_output(executable, args=None, env=None, path=None, diff --git a/nova/test.py b/nova/test.py index 1f4b33272..08e1dea2d 100644 --- a/nova/test.py +++ b/nova/test.py @@ -31,6 +31,7 @@ from tornado import ioloop from twisted.internet import defer from twisted.trial import unittest +from nova import db from nova import fakerabbit from nova import flags from nova import rpc @@ -83,6 +84,7 @@ class TrialTestCase(unittest.TestCase): if FLAGS.fake_rabbit: fakerabbit.reset_all() + db.security_group_destroy_all(None) super(TrialTestCase, self).tearDown() diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index c040cdad3..7ab27e000 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -91,6 +91,9 @@ class ApiEc2TestCase(test.BaseTestCase): self.host = '127.0.0.1' self.app = api.API() + + def expect_http(self, host=None, is_secure=False): + """Returns a new EC2 connection""" self.ec2 = boto.connect_ec2( aws_access_key_id='fake', aws_secret_access_key='fake', @@ -100,9 +103,6 @@ class ApiEc2TestCase(test.BaseTestCase): path='/services/Cloud') self.mox.StubOutWithMock(self.ec2, 'new_http_connection') - - def expect_http(self, host=None, is_secure=False): - """Returns a new EC2 connection""" http = FakeHttplibConnection( self.app, '%s:8773' % (self.host), False) # pylint: disable-msg=E1103 @@ -138,3 +138,185 @@ class ApiEc2TestCase(test.BaseTestCase): self.assertEquals(len(results), 1) self.manager.delete_project(project) self.manager.delete_user(user) + + def test_get_all_security_groups(self): + """Test that we can retrieve security groups""" + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + rv = self.ec2.get_all_security_groups() + + self.assertEquals(len(rv), 1) + self.assertEquals(rv[0].name, 'default') + + self.manager.delete_project(project) + self.manager.delete_user(user) + + def test_create_delete_security_group(self): + """Test that we can create a security group""" + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + # At the moment, you need both of these to actually be netadmin + self.manager.add_role('fake', 'netadmin') + project.add_role('fake', 'netadmin') + + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + + self.ec2.create_security_group(security_group_name, 'test group') + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + self.assertEquals(len(rv), 2) + self.assertTrue(security_group_name in [group.name for group in rv]) + + self.expect_http() + self.mox.ReplayAll() + + self.ec2.delete_security_group(security_group_name) + + self.manager.delete_project(project) + self.manager.delete_user(user) + + def test_authorize_revoke_security_group_cidr(self): + """ + Test that we can add and remove CIDR based rules + to a security group + """ + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake') + project = self.manager.create_project('fake', 'fake', 'fake') + + # At the moment, you need both of these to actually be netadmin + self.manager.add_role('fake', 'netadmin') + project.add_role('fake', 'netadmin') + + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + + group = self.ec2.create_security_group(security_group_name, 'test group') + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.authorize('tcp', 80, 81, '0.0.0.0/0') + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + # I don't bother checkng that we actually find it here, + # because the create/delete unit test further up should + # be good enough for that. + for group in rv: + if group.name == security_group_name: + self.assertEquals(len(group.rules), 1) + self.assertEquals(int(group.rules[0].from_port), 80) + self.assertEquals(int(group.rules[0].to_port), 81) + self.assertEquals(len(group.rules[0].grants), 1) + self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0') + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.revoke('tcp', 80, 81, '0.0.0.0/0') + + self.expect_http() + self.mox.ReplayAll() + + self.ec2.delete_security_group(security_group_name) + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + rv = self.ec2.get_all_security_groups() + + self.assertEqual(len(rv), 1) + self.assertEqual(rv[0].name, 'default') + + self.manager.delete_project(project) + self.manager.delete_user(user) + + return + + def test_authorize_revoke_security_group_foreign_group(self): + """ + Test that we can grant and revoke another security group access + to a security group + """ + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + # At the moment, you need both of these to actually be netadmin + self.manager.add_role('fake', 'netadmin') + project.add_role('fake', 'netadmin') + + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + other_security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + + group = self.ec2.create_security_group(security_group_name, 'test group') + + self.expect_http() + self.mox.ReplayAll() + + other_group = self.ec2.create_security_group(other_security_group_name, + 'some other group') + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.authorize(src_group=other_group) + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + + # I don't bother checkng that we actually find it here, + # because the create/delete unit test further up should + # be good enough for that. + for group in rv: + if group.name == security_group_name: + self.assertEquals(len(group.rules), 1) + self.assertEquals(len(group.rules[0].grants), 1) + self.assertEquals(str(group.rules[0].grants[0]), + '%s-%s' % (other_security_group_name, 'fake')) + + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + + for group in rv: + if group.name == security_group_name: + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + group.revoke(src_group=other_group) + + self.expect_http() + self.mox.ReplayAll() + + self.ec2.delete_security_group(security_group_name) + + self.manager.delete_project(project) + self.manager.delete_user(user) + + return diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py index eb2ee0406..872f1ab23 100644 --- a/nova/tests/objectstore_unittest.py +++ b/nova/tests/objectstore_unittest.py @@ -210,7 +210,7 @@ class S3APITestCase(test.TrialTestCase): """Setup users, projects, and start a test server.""" super(S3APITestCase, self).setUp() - FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver', + FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' FLAGS.buckets_path = os.path.join(OSS_TEMPDIR, 'buckets') self.auth_manager = manager.AuthManager() diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 730928f39..684347473 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -14,11 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. -from xml.etree.ElementTree import fromstring as parseXml +from xml.etree.ElementTree import fromstring as xml_to_tree +from xml.dom.minidom import parseString as xml_to_dom +from nova import db from nova import flags from nova import test +from nova.api import context +from nova.api.ec2 import cloud from nova.auth import manager + # Needed to get FLAGS.instances_path defined: from nova.compute import manager as compute_manager from nova.virt import libvirt_conn @@ -33,34 +38,49 @@ class LibvirtConnTestCase(test.TrialTestCase): FLAGS.instances_path = '' def test_get_uri_and_template(self): - instance = { 'name' : 'i-cafebabe', - 'id' : 'i-cafebabe', + ip = '10.11.12.13' + + instance = { 'internal_id' : 1, 'memory_kb' : '1024000', 'basepath' : '/some/path', 'bridge_name' : 'br100', 'mac_address' : '02:12:34:46:56:67', 'vcpus' : 2, 'project_id' : 'fake', - 'ip_address' : '10.11.12.13', 'bridge' : 'br101', 'instance_type' : 'm1.small'} + instance_ref = db.instance_create(None, instance) + network_ref = db.project_get_network(None, self.project.id) + + fixed_ip = { 'address' : ip, + 'network_id' : network_ref['id'] } + + fixed_ip_ref = db.fixed_ip_create(None, fixed_ip) + db.fixed_ip_update(None, ip, { 'allocated' : True, + 'instance_id' : instance_ref['id'] }) + type_uri_map = { 'qemu' : ('qemu:///system', - [(lambda t: t.find('.').tag, 'domain'), - (lambda t: t.find('.').get('type'), 'qemu'), + [(lambda t: t.find('.').get('type'), 'qemu'), (lambda t: t.find('./os/type').text, 'hvm'), (lambda t: t.find('./devices/emulator'), None)]), 'kvm' : ('qemu:///system', - [(lambda t: t.find('.').tag, 'domain'), - (lambda t: t.find('.').get('type'), 'kvm'), + [(lambda t: t.find('.').get('type'), 'kvm'), (lambda t: t.find('./os/type').text, 'hvm'), (lambda t: t.find('./devices/emulator'), None)]), 'uml' : ('uml:///system', - [(lambda t: t.find('.').tag, 'domain'), - (lambda t: t.find('.').get('type'), 'uml'), + [(lambda t: t.find('.').get('type'), 'uml'), (lambda t: t.find('./os/type').text, 'uml')]), } + common_checks = [(lambda t: t.find('.').tag, 'domain'), + (lambda t: \ + t.find('./devices/interface/filterref/parameter') \ + .get('name'), 'IP'), + (lambda t: \ + t.find('./devices/interface/filterref/parameter') \ + .get('value'), '10.11.12.13')] + for (libvirt_type,(expected_uri, checks)) in type_uri_map.iteritems(): FLAGS.libvirt_type = libvirt_type conn = libvirt_conn.LibvirtConnection(True) @@ -68,13 +88,18 @@ class LibvirtConnTestCase(test.TrialTestCase): uri, template = conn.get_uri_and_template() self.assertEquals(uri, expected_uri) - xml = conn.to_xml(instance) - tree = parseXml(xml) + xml = conn.to_xml(instance_ref) + tree = xml_to_tree(xml) for i, (check, expected_result) in enumerate(checks): self.assertEqual(check(tree), expected_result, '%s failed check %d' % (xml, i)) + for i, (check, expected_result) in enumerate(common_checks): + self.assertEqual(check(tree), + expected_result, + '%s failed common check %d' % (xml, i)) + # Deliberately not just assigning this string to FLAGS.libvirt_uri and # checking against that later on. This way we make sure the # implementation doesn't fiddle around with the FLAGS. @@ -90,3 +115,138 @@ class LibvirtConnTestCase(test.TrialTestCase): def tearDown(self): self.manager.delete_project(self.project) self.manager.delete_user(self.user) + +class NWFilterTestCase(test.TrialTestCase): + def setUp(self): + super(NWFilterTestCase, self).setUp() + + class Mock(object): + pass + + self.manager = manager.AuthManager() + self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + self.project = self.manager.create_project('fake', 'fake', 'fake') + self.context = context.APIRequestContext(self.user, self.project) + + self.fake_libvirt_connection = Mock() + + self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection) + + def tearDown(self): + self.manager.delete_project(self.project) + self.manager.delete_user(self.user) + + + def test_cidr_rule_nwfilter_xml(self): + cloud_controller = cloud.CloudController() + cloud_controller.create_security_group(self.context, + 'testgroup', + 'test group description') + cloud_controller.authorize_security_group_ingress(self.context, + 'testgroup', + from_port='80', + to_port='81', + ip_protocol='tcp', + cidr_ip='0.0.0.0/0') + + + security_group = db.security_group_get_by_name(self.context, + 'fake', + 'testgroup') + + xml = self.fw.security_group_to_nwfilter_xml(security_group.id) + + dom = xml_to_dom(xml) + self.assertEqual(dom.firstChild.tagName, 'filter') + + rules = dom.getElementsByTagName('rule') + self.assertEqual(len(rules), 1) + + # It's supposed to allow inbound traffic. + self.assertEqual(rules[0].getAttribute('action'), 'accept') + self.assertEqual(rules[0].getAttribute('direction'), 'in') + + # Must be lower priority than the base filter (which blocks everything) + self.assertTrue(int(rules[0].getAttribute('priority')) < 1000) + + ip_conditions = rules[0].getElementsByTagName('tcp') + self.assertEqual(len(ip_conditions), 1) + self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0') + self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0') + self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80') + self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81') + + + self.teardown_security_group() + + def teardown_security_group(self): + cloud_controller = cloud.CloudController() + cloud_controller.delete_security_group(self.context, 'testgroup') + + + def setup_and_return_security_group(self): + cloud_controller = cloud.CloudController() + cloud_controller.create_security_group(self.context, + 'testgroup', + 'test group description') + cloud_controller.authorize_security_group_ingress(self.context, + 'testgroup', + from_port='80', + to_port='81', + ip_protocol='tcp', + cidr_ip='0.0.0.0/0') + + return db.security_group_get_by_name(self.context, 'fake', 'testgroup') + + def test_creates_base_rule_first(self): + # These come pre-defined by libvirt + self.defined_filters = ['no-mac-spoofing', + 'no-ip-spoofing', + 'no-arp-spoofing', + 'allow-dhcp-server'] + + self.recursive_depends = {} + for f in self.defined_filters: + self.recursive_depends[f] = [] + + def _filterDefineXMLMock(xml): + dom = xml_to_dom(xml) + name = dom.firstChild.getAttribute('name') + self.recursive_depends[name] = [] + for f in dom.getElementsByTagName('filterref'): + ref = f.getAttribute('filter') + self.assertTrue(ref in self.defined_filters, + ('%s referenced filter that does ' + + 'not yet exist: %s') % (name, ref)) + dependencies = [ref] + self.recursive_depends[ref] + self.recursive_depends[name] += dependencies + + self.defined_filters.append(name) + return True + + self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock + + instance_ref = db.instance_create(self.context, + {'user_id': 'fake', + 'project_id': 'fake'}) + inst_id = instance_ref['id'] + + def _ensure_all_called(_): + instance_filter = 'nova-instance-%s' % instance_ref['name'] + secgroup_filter = 'nova-secgroup-%s' % self.security_group['id'] + for required in [secgroup_filter, 'allow-dhcp-server', + 'no-arp-spoofing', 'no-ip-spoofing', + 'no-mac-spoofing']: + self.assertTrue(required in self.recursive_depends[instance_filter], + "Instance's filter does not include %s" % required) + + self.security_group = self.setup_and_return_security_group() + + db.instance_add_security_group(self.context, inst_id, self.security_group.id) + instance = db.instance_get(self.context, inst_id) + + d = self.fw.setup_nwfilters_for_instance(instance) + d.addCallback(_ensure_all_called) + d.addCallback(lambda _:self.teardown_security_group()) + + return d diff --git a/nova/virt/interfaces.template b/nova/virt/interfaces.template index 11df301f6..87b92b84a 100644 --- a/nova/virt/interfaces.template +++ b/nova/virt/interfaces.template @@ -10,7 +10,6 @@ auto eth0 iface eth0 inet static address %(address)s netmask %(netmask)s - network %(network)s broadcast %(broadcast)s gateway %(gateway)s dns-nameservers %(dns)s diff --git a/nova/virt/libvirt.qemu.xml.template b/nova/virt/libvirt.qemu.xml.template index 17bd79b7c..2538b1ade 100644 --- a/nova/virt/libvirt.qemu.xml.template +++ b/nova/virt/libvirt.qemu.xml.template @@ -20,6 +20,10 @@ <source bridge='%(bridge_name)s'/> <mac address='%(mac_address)s'/> <!-- <model type='virtio'/> CANT RUN virtio network right now --> + <filterref filter="nova-instance-%(name)s"> + <parameter name="IP" value="%(ip_address)s" /> + <parameter name="DHCPSERVER" value="%(dhcp_server)s" /> + </filterref> </interface> <serial type="file"> <source path='%(basepath)s/console.log'/> diff --git a/nova/virt/libvirt.uml.xml.template b/nova/virt/libvirt.uml.xml.template index c039d6d90..bb8b47911 100644 --- a/nova/virt/libvirt.uml.xml.template +++ b/nova/virt/libvirt.uml.xml.template @@ -14,6 +14,10 @@ <interface type='bridge'> <source bridge='%(bridge_name)s'/> <mac address='%(mac_address)s'/> + <filterref filter="nova-instance-%(name)s"> + <parameter name="IP" value="%(ip_address)s" /> + <parameter name="DHCPSERVER" value="%(dhcp_server)s" /> + </filterref> </interface> <console type="file"> <source path='%(basepath)s/console.log'/> diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index d868e083c..319f7d2af 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -25,14 +25,17 @@ import logging import os import shutil +import IPy from twisted.internet import defer from twisted.internet import task +from twisted.internet import threads from nova import db from nova import exception from nova import flags from nova import process from nova import utils +#from nova.api import context from nova.auth import manager from nova.compute import disk from nova.compute import instance_types @@ -60,6 +63,9 @@ flags.DEFINE_string('libvirt_uri', '', 'Override the default libvirt URI (which is dependent' ' on libvirt_type)') +flags.DEFINE_bool('allow_project_net_traffic', + True, + 'Whether to allow in project network traffic') def get_connection(read_only): @@ -134,7 +140,7 @@ class LibvirtConnection(object): d.addCallback(lambda _: self._cleanup(instance)) # FIXME: What does this comment mean? # TODO(termie): short-circuit me for tests - # WE'LL save this for when we do shutdown, + # WE'LL save this for when we do shutdown, # instead of destroy - but destroy returns immediately timer = task.LoopingCall(f=None) def _wait_for_shutdown(): @@ -214,6 +220,7 @@ class LibvirtConnection(object): instance['id'], power_state.NOSTATE, 'launching') + yield NWFilterFirewall(self._conn).setup_nwfilters_for_instance(instance) yield self._create_image(instance, xml) yield self._conn.createXML(xml, 0) # TODO(termie): this should actually register @@ -285,7 +292,6 @@ class LibvirtConnection(object): address = db.instance_get_fixed_address(None, inst['id']) with open(FLAGS.injected_network_template) as f: net = f.read() % {'address': address, - 'network': network_ref['network'], 'netmask': network_ref['netmask'], 'gateway': network_ref['gateway'], 'broadcast': network_ref['broadcast'], @@ -317,6 +323,9 @@ class LibvirtConnection(object): network = db.project_get_network(None, instance['project_id']) # FIXME(vish): stick this in db instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']] + ip_address = db.instance_get_fixed_address({}, instance['id']) + # Assume that the gateway also acts as the dhcp server. + dhcp_server = network['gateway'] xml_info = {'type': FLAGS.libvirt_type, 'name': instance['name'], 'basepath': os.path.join(FLAGS.instances_path, @@ -324,7 +333,9 @@ class LibvirtConnection(object): 'memory_kb': instance_type['memory_mb'] * 1024, 'vcpus': instance_type['vcpus'], 'bridge_name': network['bridge'], - 'mac_address': instance['mac_address']} + 'mac_address': instance['mac_address'], + 'ip_address': ip_address, + 'dhcp_server': dhcp_server } libvirt_xml = self.libvirt_xml % xml_info logging.debug('instance %s: finished toXML method', instance['name']) @@ -438,3 +449,195 @@ class LibvirtConnection(object): """ domain = self._conn.lookupByName(instance_name) return domain.interfaceStats(interface) + + + def refresh_security_group(self, security_group_id): + fw = NWFilterFirewall(self._conn) + fw.ensure_security_group_filter(security_group_id) + + +class NWFilterFirewall(object): + """ + This class implements a network filtering mechanism versatile + enough for EC2 style Security Group filtering by leveraging + libvirt's nwfilter. + + First, all instances get a filter ("nova-base-filter") applied. + This filter drops all incoming ipv4 and ipv6 connections. + Outgoing connections are never blocked. + + Second, every security group maps to a nwfilter filter(*). + NWFilters can be updated at runtime and changes are applied + immediately, so changes to security groups can be applied at + runtime (as mandated by the spec). + + Security group rules are named "nova-secgroup-<id>" where <id> + is the internal id of the security group. They're applied only on + hosts that have instances in the security group in question. + + Updates to security groups are done by updating the data model + (in response to API calls) followed by a request sent to all + the nodes with instances in the security group to refresh the + security group. + + Each instance has its own NWFilter, which references the above + mentioned security group NWFilters. This was done because + interfaces can only reference one filter while filters can + reference multiple other filters. This has the added benefit of + actually being able to add and remove security groups from an + instance at run time. This functionality is not exposed anywhere, + though. + + Outstanding questions: + + The name is unique, so would there be any good reason to sync + the uuid across the nodes (by assigning it from the datamodel)? + + + (*) This sentence brought to you by the redundancy department of + redundancy. + """ + + def __init__(self, get_connection): + self._conn = get_connection + + + nova_base_filter = '''<filter name='nova-base' chain='root'> + <uuid>26717364-50cf-42d1-8185-29bf893ab110</uuid> + <filterref filter='no-mac-spoofing'/> + <filterref filter='no-ip-spoofing'/> + <filterref filter='no-arp-spoofing'/> + <filterref filter='allow-dhcp-server'/> + <filterref filter='nova-allow-dhcp-server'/> + <filterref filter='nova-base-ipv4'/> + <filterref filter='nova-base-ipv6'/> + </filter>''' + + nova_dhcp_filter = '''<filter name='nova-allow-dhcp-server' chain='ipv4'> + <uuid>891e4787-e5c0-d59b-cbd6-41bc3c6b36fc</uuid> + <rule action='accept' direction='out' + priority='100'> + <udp srcipaddr='0.0.0.0' + dstipaddr='255.255.255.255' + srcportstart='68' + dstportstart='67'/> + </rule> + <rule action='accept' direction='in' priority='100'> + <udp srcipaddr='$DHCPSERVER' + srcportstart='67' + dstportstart='68'/> + </rule> + </filter>''' + + def nova_base_ipv4_filter(self): + retval = "<filter name='nova-base-ipv4' chain='ipv4'>" + for protocol in ['tcp', 'udp', 'icmp']: + for direction,action,priority in [('out','accept', 399), + ('inout','drop', 400)]: + retval += """<rule action='%s' direction='%s' priority='%d'> + <%s /> + </rule>""" % (action, direction, + priority, protocol) + retval += '</filter>' + return retval + + + def nova_base_ipv6_filter(self): + retval = "<filter name='nova-base-ipv6' chain='ipv6'>" + for protocol in ['tcp', 'udp', 'icmp']: + for direction,action,priority in [('out','accept',399), + ('inout','drop',400)]: + retval += """<rule action='%s' direction='%s' priority='%d'> + <%s-ipv6 /> + </rule>""" % (action, direction, + priority, protocol) + retval += '</filter>' + return retval + + + def nova_project_filter(self, project, net, mask): + retval = "<filter name='nova-project-%s' chain='ipv4'>" % project + for protocol in ['tcp', 'udp', 'icmp']: + retval += """<rule action='accept' direction='in' priority='200'> + <%s srcipaddr='%s' srcipmask='%s' /> + </rule>""" % (protocol, net, mask) + retval += '</filter>' + return retval + + + def _define_filter(self, xml): + if callable(xml): + xml = xml() + d = threads.deferToThread(self._conn.nwfilterDefineXML, xml) + return d + + + @staticmethod + def _get_net_and_mask(cidr): + net = IPy.IP(cidr) + return str(net.net()), str(net.netmask()) + + @defer.inlineCallbacks + def setup_nwfilters_for_instance(self, instance): + """ + Creates an NWFilter for the given instance. In the process, + it makes sure the filters for the security groups as well as + the base filter are all in place. + """ + + yield self._define_filter(self.nova_base_ipv4_filter) + yield self._define_filter(self.nova_base_ipv6_filter) + yield self._define_filter(self.nova_dhcp_filter) + yield self._define_filter(self.nova_base_filter) + + nwfilter_xml = ("<filter name='nova-instance-%s' chain='root'>\n" + + " <filterref filter='nova-base' />\n" + ) % instance['name'] + + if FLAGS.allow_project_net_traffic: + network_ref = db.project_get_network({}, instance['project_id']) + net, mask = self._get_net_and_mask(network_ref['cidr']) + project_filter = self.nova_project_filter(instance['project_id'], + net, mask) + yield self._define_filter(project_filter) + + nwfilter_xml += (" <filterref filter='nova-project-%s' />\n" + ) % instance['project_id'] + + for security_group in instance.security_groups: + yield self.ensure_security_group_filter(security_group['id']) + + nwfilter_xml += (" <filterref filter='nova-secgroup-%d' />\n" + ) % security_group['id'] + nwfilter_xml += "</filter>" + + yield self._define_filter(nwfilter_xml) + return + + def ensure_security_group_filter(self, security_group_id): + return self._define_filter( + self.security_group_to_nwfilter_xml(security_group_id)) + + + def security_group_to_nwfilter_xml(self, security_group_id): + security_group = db.security_group_get({}, security_group_id) + rule_xml = "" + for rule in security_group.rules: + rule_xml += "<rule action='accept' direction='in' priority='300'>" + if rule.cidr: + net, mask = self._get_net_and_mask(rule.cidr) + rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % (rule.protocol, net, mask) + if rule.protocol in ['tcp', 'udp']: + rule_xml += "dstportstart='%s' dstportend='%s' " % \ + (rule.from_port, rule.to_port) + elif rule.protocol == 'icmp': + logging.info('rule.protocol: %r, rule.from_port: %r, rule.to_port: %r' % (rule.protocol, rule.from_port, rule.to_port)) + if rule.from_port != -1: + rule_xml += "type='%s' " % rule.from_port + if rule.to_port != -1: + rule_xml += "code='%s' " % rule.to_port + + rule_xml += '/>\n' + rule_xml += "</rule>\n" + xml = '''<filter name='nova-secgroup-%s' chain='ipv4'>%s</filter>''' % (security_group_id, rule_xml,) + return xml diff --git a/run_tests.py b/run_tests.py index fa1e6f15b..0b27ec6cf 100644 --- a/run_tests.py +++ b/run_tests.py @@ -65,6 +65,7 @@ from nova.tests.service_unittest import * from nova.tests.validator_unittest import * from nova.tests.virt_unittest import * from nova.tests.volume_unittest import * +from nova.tests.virt_unittest import * FLAGS = flags.FLAGS |