diff options
-rw-r--r-- | nova/api/ec2/__init__.py | 2 | ||||
-rw-r--r-- | nova/api/ec2/cloud.py | 207 | ||||
-rw-r--r-- | nova/auth/manager.py | 2 | ||||
-rw-r--r-- | nova/compute/manager.py | 5 | ||||
-rw-r--r-- | nova/db/api.py | 70 | ||||
-rw-r--r-- | nova/db/sqlalchemy/api.py | 129 | ||||
-rw-r--r-- | nova/db/sqlalchemy/models.py | 73 | ||||
-rw-r--r-- | nova/db/sqlalchemy/session.py | 9 | ||||
-rw-r--r-- | nova/exception.py | 3 | ||||
-rw-r--r-- | nova/network/manager.py | 1 | ||||
-rw-r--r-- | nova/process.py | 2 | ||||
-rw-r--r-- | nova/test.py | 2 | ||||
-rw-r--r-- | nova/tests/api_unittest.py | 188 | ||||
-rw-r--r-- | nova/tests/objectstore_unittest.py | 2 | ||||
-rw-r--r-- | nova/tests/virt_unittest.py | 142 | ||||
-rw-r--r-- | nova/virt/interfaces.template | 1 | ||||
-rw-r--r-- | nova/virt/libvirt.qemu.xml.template | 4 | ||||
-rw-r--r-- | nova/virt/libvirt.uml.xml.template | 4 | ||||
-rw-r--r-- | nova/virt/libvirt_conn.py | 176 | ||||
-rw-r--r-- | run_tests.py | 1 |
20 files changed, 988 insertions, 35 deletions
diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index 7a958f841..8111ef023 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -142,6 +142,8 @@ class Authorizer(wsgi.Middleware): 'CreateKeyPair': ['all'], 'DeleteKeyPair': ['all'], 'DescribeSecurityGroups': ['all'], + 'AuthorizeSecurityGroupIngress': ['netadmin'], + 'RevokeSecurityGroupIngress': ['netadmin'], 'CreateSecurityGroup': ['netadmin'], 'DeleteSecurityGroup': ['netadmin'], 'GetConsoleOutput': ['projectmanager', 'sysadmin'], diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index d3f54367b..ca3f71036 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -28,6 +28,8 @@ import logging import os import time +import IPy + from nova import crypto from nova import db from nova import exception @@ -42,6 +44,7 @@ from nova.api.ec2 import images FLAGS = flags.FLAGS flags.DECLARE('storage_availability_zone', 'nova.volume.manager') +InvalidInputException = exception.InvalidInputException class QuotaError(exception.ApiError): """Quota Exceeeded""" @@ -112,6 +115,15 @@ class CloudController(object): result[key] = [line] return result + def _trigger_refresh_security_group(self, security_group): + nodes = set([instance['host'] for instance in security_group.instances + if instance['host'] is not None]) + for node in nodes: + rpc.call('%s.%s' % (FLAGS.compute_topic, node), + { "method": "refresh_security_group", + "args": { "context": None, + "security_group_id": security_group.id}}) + def get_metadata(self, address): instance_ref = db.fixed_ip_get_instance(None, address) if instance_ref is None: @@ -231,18 +243,171 @@ class CloudController(object): pass return True - def describe_security_groups(self, context, group_names, **kwargs): - groups = {'securityGroupSet': []} + def describe_security_groups(self, context, group_name=None, **kwargs): + self._ensure_default_security_group(context) + if context.user.is_admin(): + groups = db.security_group_get_all(context) + else: + groups = db.security_group_get_by_project(context, + context.project.id) + groups = [self._format_security_group(context, g) for g in groups] + if not group_name is None: + groups = [g for g in groups if g.name in group_name] + + return {'securityGroupInfo': groups } + + def _format_security_group(self, context, group): + g = {} + g['groupDescription'] = group.description + g['groupName'] = group.name + g['ownerId'] = context.user.id + g['ipPermissions'] = [] + for rule in group.rules: + r = {} + r['ipProtocol'] = rule.protocol + r['fromPort'] = rule.from_port + r['toPort'] = rule.to_port + r['groups'] = [] + r['ipRanges'] = [] + if rule.group_id: + source_group = db.security_group_get(context, rule.group_id) + r['groups'] += [{'groupName': source_group.name, + 'userId': source_group.user_id}] + else: + r['ipRanges'] += [{'cidrIp': rule.cidr}] + g['ipPermissions'] += [r] + return g + + + def _authorize_revoke_rule_args_to_dict(self, context, + to_port=None, from_port=None, + ip_protocol=None, cidr_ip=None, + user_id=None, + source_security_group_name=None, + source_security_group_owner_id=None): + + values = {} + + if source_security_group_name: + source_project_id = self._get_source_project_id(context, + source_security_group_owner_id) + + source_security_group = \ + db.security_group_get_by_name(context, + source_project_id, + source_security_group_name) + values['group_id'] = source_security_group.id + elif cidr_ip: + # If this fails, it throws an exception. This is what we want. + IPy.IP(cidr_ip) + values['cidr'] = cidr_ip + else: + values['cidr'] = '0.0.0.0/0' + + if ip_protocol and from_port and to_port: + from_port = int(from_port) + to_port = int(to_port) + ip_protocol = str(ip_protocol) + + if ip_protocol.upper() not in ['TCP','UDP','ICMP']: + raise InvalidInputException('%s is not a valid ipProtocol' % + (ip_protocol,)) + if ((min(from_port, to_port) < -1) or + (max(from_port, to_port) > 65535)): + raise InvalidInputException('Invalid port range') + + values['protocol'] = ip_protocol + values['from_port'] = from_port + values['to_port'] = to_port + else: + # If cidr based filtering, protocol and ports are mandatory + if 'cidr' in values: + return None + + return values + + def revoke_security_group_ingress(self, context, group_name, **kwargs): + self._ensure_default_security_group(context) + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) + + criteria = self._authorize_revoke_rule_args_to_dict(context, **kwargs) + + for rule in security_group.rules: + for (k,v) in criteria.iteritems(): + if getattr(rule, k, False) != v: + break + # If we make it here, we have a match + db.security_group_rule_destroy(context, rule.id) + + self._trigger_refresh_security_group(security_group) + + return True + + # TODO(soren): Dupe detection. Adding the same rule twice actually + # adds the same rule twice to the rule set, which is + # pointless. + # TODO(soren): This has only been tested with Boto as the client. + # Unfortunately, it seems Boto is using an old API + # for these operations, so support for newer API versions + # is sketchy. + def authorize_security_group_ingress(self, context, group_name, **kwargs): + self._ensure_default_security_group(context) + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) + + values = self._authorize_revoke_rule_args_to_dict(context, **kwargs) + values['parent_group_id'] = security_group.id - # Stubbed for now to unblock other things. - return groups + security_group_rule = db.security_group_rule_create(context, values) + + self._trigger_refresh_security_group(security_group) - def create_security_group(self, context, group_name, **kwargs): return True + def _get_source_project_id(self, context, source_security_group_owner_id): + if source_security_group_owner_id: + # Parse user:project for source group. + source_parts = source_security_group_owner_id.split(':') + + # If no project name specified, assume it's same as user name. + # Since we're looking up by project name, the user name is not + # used here. It's only read for EC2 API compatibility. + if len(source_parts) == 2: + source_project_id = source_parts[1] + else: + source_project_id = source_parts[0] + else: + source_project_id = context.project.id + + return source_project_id + + + def create_security_group(self, context, group_name, group_description): + self._ensure_default_security_group(context) + if db.securitygroup_exists(context, context.project.id, group_name): + raise exception.ApiError('group %s already exists' % group_name) + + group = {'user_id' : context.user.id, + 'project_id': context.project.id, + 'name': group_name, + 'description': group_description} + group_ref = db.security_group_create(context, group) + + return {'securityGroupSet': [self._format_security_group(context, + group_ref)]} + + def delete_security_group(self, context, group_name, **kwargs): + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) + db.security_group_destroy(context, security_group.id) return True + def get_console_output(self, context, instance_id, **kwargs): # instance_id is passed in as a list of instances instance_ref = db.instance_get_by_ec2_id(context, instance_id[0]) @@ -530,6 +695,18 @@ class CloudController(object): "project_id": context.project.id}}) return db.queue_get_for(context, FLAGS.network_topic, host) + def _ensure_default_security_group(self, context): + try: + db.security_group_get_by_name(context, + context.project.id, + 'default') + except exception.NotFound: + values = { 'name' : 'default', + 'description' : 'default', + 'user_id' : context.user.id, + 'project_id' : context.project.id } + group = db.security_group_create({}, values) + def run_instances(self, context, **kwargs): instance_type = kwargs.get('instance_type', 'm1.small') if instance_type not in INSTANCE_TYPES: @@ -577,8 +754,17 @@ class CloudController(object): kwargs['key_name']) key_data = key_pair_ref['public_key'] - # TODO: Get the real security group of launch in here - security_group = "default" + security_group_arg = kwargs.get('security_group', ["default"]) + if not type(security_group_arg) is list: + security_group_arg = [security_group_arg] + + security_groups = [] + self._ensure_default_security_group(context) + for security_group_name in security_group_arg: + group = db.security_group_get_by_name(context, + context.project.id, + security_group_name) + security_groups.append(group['id']) reservation_id = utils.generate_uid('r') base_options = {} @@ -592,7 +778,8 @@ class CloudController(object): base_options['user_id'] = context.user.id base_options['project_id'] = context.project.id base_options['user_data'] = kwargs.get('user_data', '') - base_options['security_group'] = security_group + + type_data = INSTANCE_TYPES[instance_type] base_options['instance_type'] = instance_type base_options['display_name'] = kwargs.get('display_name') base_options['display_description'] = kwargs.get('display_description') @@ -606,6 +793,10 @@ class CloudController(object): instance_ref = db.instance_create(context, base_options) inst_id = instance_ref['id'] + for security_group_id in security_groups: + db.instance_add_security_group(context, inst_id, + security_group_id) + inst = {} inst['mac_address'] = utils.generate_mac() inst['launch_index'] = num diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 0bc12c80f..e2e035d37 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -490,6 +490,7 @@ class AuthManager(object): except: drv.delete_project(project.id) raise + return project def modify_project(self, project, manager_user=None, description=None): @@ -565,6 +566,7 @@ class AuthManager(object): except: logging.exception('Could not destroy network for %s', project) + with self.driver() as drv: drv.delete_project(Project.safe_id(project)) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index f370ede8b..02ac3cb4c 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -64,6 +64,11 @@ class ComputeManager(manager.Manager): @defer.inlineCallbacks @exception.wrap_exception + def refresh_security_group(self, context, security_group_id, **_kwargs): + yield self.driver.refresh_security_group(security_group_id) + + @defer.inlineCallbacks + @exception.wrap_exception def run_instance(self, context, instance_id, **_kwargs): """Launch a new instance with specified options.""" instance_ref = self.db.instance_get(context, instance_id) diff --git a/nova/db/api.py b/nova/db/api.py index b68a0fe8f..9ce3dfb2a 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -304,6 +304,11 @@ def instance_update(context, instance_id, values): return IMPL.instance_update(context, instance_id, values) +def instance_add_security_group(context, instance_id, security_group_id): + """Associate the given security group with the given instance""" + return IMPL.instance_add_security_group(context, instance_id, security_group_id) + + ################### @@ -565,3 +570,68 @@ def volume_update(context, volume_id, values): """ return IMPL.volume_update(context, volume_id, values) + + +#################### + + +def security_group_get_all(context): + """Get all security groups""" + return IMPL.security_group_get_all(context) + + +def security_group_get(context, security_group_id): + """Get security group by its internal id""" + return IMPL.security_group_get(context, security_group_id) + + +def security_group_get_by_name(context, project_id, group_name): + """Returns a security group with the specified name from a project""" + return IMPL.security_group_get_by_name(context, project_id, group_name) + + +def security_group_get_by_project(context, project_id): + """Get all security groups belonging to a project""" + return IMPL.security_group_get_by_project(context, project_id) + + +def security_group_get_by_instance(context, instance_id): + """Get security groups to which the instance is assigned""" + return IMPL.security_group_get_by_instance(context, instance_id) + + +def securitygroup_exists(context, project_id, group_name): + """Indicates if a group name exists in a project""" + return IMPL.security_group_exists(context, project_id, group_name) + + +def security_group_create(context, values): + """Create a new security group""" + return IMPL.security_group_create(context, values) + + +def security_group_destroy(context, security_group_id): + """Deletes a security group""" + return IMPL.security_group_destroy(context, security_group_id) + + +def security_group_destroy_all(context): + """Deletes a security group""" + return IMPL.security_group_destroy_all(context) + + +#################### + + +def security_group_rule_create(context, values): + """Create a new security group""" + return IMPL.security_group_rule_create(context, values) + + +def security_group_rule_get_by_security_group(context, security_group_id): + """Get all rules for a a given security group""" + return IMPL.security_group_rule_get_by_security_group(context, security_group_id) + +def security_group_rule_destroy(context, security_group_rule_id): + """Deletes a security group rule""" + return IMPL.security_group_rule_destroy(context, security_group_rule_id) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 9c3caf9af..013e8ab16 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -29,8 +29,10 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import eagerload from sqlalchemy.orm import joinedload_all -from sqlalchemy.sql import exists, func +from sqlalchemy.sql import exists +from sqlalchemy.sql import func FLAGS = flags.FLAGS @@ -410,7 +412,8 @@ def instance_destroy(_context, instance_id): def instance_get(context, instance_id): - return models.Instance.find(instance_id, deleted=_deleted(context)) + return models.Instance().find(instance_id, deleted=_deleted(context), + options=eagerload('security_groups')) def instance_get_all(context): @@ -511,6 +514,17 @@ def instance_update(_context, instance_id, values): instance_ref.save(session=session) +def instance_add_security_group(context, instance_id, security_group_id): + """Associate the given security group with the given instance""" + session = get_session() + with session.begin(): + instance_ref = models.Instance.find(instance_id, session=session) + security_group_ref = models.SecurityGroup.find(security_group_id, + session=session) + instance_ref.security_groups += [security_group_ref] + instance_ref.save(session=session) + + ################### @@ -925,3 +939,114 @@ def volume_update(_context, volume_id, values): for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) + + +################### + + +def security_group_get_all(_context): + session = get_session() + return session.query(models.SecurityGroup + ).options(eagerload('rules') + ).filter_by(deleted=False + ).all() + + +def security_group_get(_context, security_group_id): + session = get_session() + result = session.query(models.SecurityGroup + ).options(eagerload('rules') + ).get(security_group_id) + if not result: + raise exception.NotFound("No secuity group with id %s" % + security_group_id) + return result + + +def security_group_get_by_name(context, project_id, group_name): + session = get_session() + group_ref = session.query(models.SecurityGroup + ).options(eagerload('rules') + ).options(eagerload('instances') + ).filter_by(project_id=project_id + ).filter_by(name=group_name + ).filter_by(deleted=False + ).first() + if not group_ref: + raise exception.NotFound( + 'No security group named %s for project: %s' \ + % (group_name, project_id)) + return group_ref + + +def security_group_get_by_project(_context, project_id): + session = get_session() + return session.query(models.SecurityGroup + ).options(eagerload('rules') + ).filter_by(project_id=project_id + ).filter_by(deleted=False + ).all() + + +def security_group_get_by_instance(_context, instance_id): + session = get_session() + with session.begin(): + return session.query(models.Instance + ).join(models.Instance.security_groups + ).filter_by(deleted=False + ).all() + + +def security_group_exists(_context, project_id, group_name): + try: + group = security_group_get_by_name(_context, project_id, group_name) + return group != None + except exception.NotFound: + return False + + +def security_group_create(_context, values): + security_group_ref = models.SecurityGroup() + # FIXME(devcamcar): Unless I do this, rules fails with lazy load exception + # once save() is called. This will get cleaned up in next orm pass. + security_group_ref.rules + for (key, value) in values.iteritems(): + security_group_ref[key] = value + security_group_ref.save() + return security_group_ref + + +def security_group_destroy(_context, security_group_id): + session = get_session() + with session.begin(): + # TODO(vish): do we have to use sql here? + session.execute('update security_group set deleted=1 where id=:id', + {'id': security_group_id}) + session.execute('update security_group_rules set deleted=1 ' + 'where group_id=:id', + {'id': security_group_id}) + +def security_group_destroy_all(_context): + session = get_session() + with session.begin(): + # TODO(vish): do we have to use sql here? + session.execute('update security_group set deleted=1') + session.execute('update security_group_rules set deleted=1') + +################### + + +def security_group_rule_create(_context, values): + security_group_rule_ref = models.SecurityGroupIngressRule() + for (key, value) in values.iteritems(): + security_group_rule_ref[key] = value + security_group_rule_ref.save() + return security_group_rule_ref + +def security_group_rule_destroy(_context, security_group_rule_id): + session = get_session() + with session.begin(): + model = models.SecurityGroupIngressRule + security_group_rule = model.find(security_group_rule_id, + session=session) + security_group_rule.delete(session=session) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 01e58b05e..d4caf0b52 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -25,7 +25,7 @@ import datetime # TODO(vish): clean up these imports from sqlalchemy.orm import relationship, backref, exc, object_mapper -from sqlalchemy import Column, Integer, String +from sqlalchemy import Column, Integer, String, Table from sqlalchemy import ForeignKey, DateTime, Boolean, Text from sqlalchemy.ext.declarative import declarative_base @@ -69,15 +69,19 @@ class NovaBase(object): ).count() @classmethod - def find(cls, obj_id, session=None, deleted=False): + def find(cls, obj_id, session=None, deleted=False, options=None): """Find object by id""" if not session: session = get_session() try: - return session.query(cls + query = session.query(cls ).filter_by(id=obj_id - ).filter_by(deleted=deleted - ).one() + ).filter_by(deleted=deleted) + + if options: + query = query.options(options) + + return query.one() except exc.NoResultFound: new_exc = exception.NotFound("No model for id %s" % obj_id) raise new_exc.__class__, new_exc, sys.exc_info()[2] @@ -230,7 +234,6 @@ class Instance(BASE, NovaBase): launch_index = Column(Integer) key_name = Column(String(255)) key_data = Column(Text) - security_group = Column(String(255)) state = Column(Integer) state_description = Column(String(255)) @@ -337,13 +340,65 @@ class ExportDevice(BASE, NovaBase): uselist=False)) +security_group_instance_association = Table('security_group_instance_association', + BASE.metadata, + Column('security_group_id', Integer, + ForeignKey('security_group.id')), + Column('instance_id', Integer, + ForeignKey('instances.id'))) + +class SecurityGroup(BASE, NovaBase): + """Represents a security group""" + __tablename__ = 'security_group' + id = Column(Integer, primary_key=True) + + name = Column(String(255)) + description = Column(String(255)) + user_id = Column(String(255)) + project_id = Column(String(255)) + + instances = relationship(Instance, + secondary=security_group_instance_association, + backref='security_groups') + + @property + def user(self): + return auth.manager.AuthManager().get_user(self.user_id) + + @property + def project(self): + return auth.manager.AuthManager().get_project(self.project_id) + + +class SecurityGroupIngressRule(BASE, NovaBase): + """Represents a rule in a security group""" + __tablename__ = 'security_group_rules' + id = Column(Integer, primary_key=True) + + parent_group_id = Column(Integer, ForeignKey('security_group.id')) + parent_group = relationship("SecurityGroup", backref="rules", + foreign_keys=parent_group_id, + primaryjoin=parent_group_id==SecurityGroup.id) + + protocol = Column(String(5)) # "tcp", "udp", or "icmp" + from_port = Column(Integer) + to_port = Column(Integer) + cidr = Column(String(255)) + + # Note: This is not the parent SecurityGroup. It's SecurityGroup we're + # granting access for. + group_id = Column(Integer, ForeignKey('security_group.id')) + + class KeyPair(BASE, NovaBase): """Represents a public key pair for ssh""" __tablename__ = 'key_pairs' id = Column(Integer, primary_key=True) + name = Column(String(255)) user_id = Column(String(255)) + project_id = Column(String(255)) fingerprint = Column(String(255)) public_key = Column(Text) @@ -484,9 +539,9 @@ class FloatingIp(BASE, NovaBase): def register_models(): """Register Models and create metadata""" from sqlalchemy import create_engine - models = (Service, Instance, Volume, ExportDevice, - FixedIp, FloatingIp, Network, NetworkIndex, - AuthToken) # , Image, Host) + models = (Service, Instance, Volume, ExportDevice, FixedIp, FloatingIp, + Network, NetworkIndex, SecurityGroup, SecurityGroupIngressRule, + AuthToken) # , Image, Host engine = create_engine(FLAGS.sql_connection, echo=False) for model in models: model.metadata.create_all(engine) diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py index 69a205378..826754f6a 100644 --- a/nova/db/sqlalchemy/session.py +++ b/nova/db/sqlalchemy/session.py @@ -36,7 +36,8 @@ def get_session(autocommit=True, expire_on_commit=False): if not _MAKER: if not _ENGINE: _ENGINE = create_engine(FLAGS.sql_connection, echo=False) - _MAKER = sessionmaker(bind=_ENGINE, - autocommit=autocommit, - expire_on_commit=expire_on_commit) - return _MAKER() + _MAKER = (sessionmaker(bind=_ENGINE, + autocommit=autocommit, + expire_on_commit=expire_on_commit)) + session = _MAKER() + return session diff --git a/nova/exception.py b/nova/exception.py index b8894758f..f157fab2d 100644 --- a/nova/exception.py +++ b/nova/exception.py @@ -69,6 +69,9 @@ class NotEmpty(Error): class Invalid(Error): pass +class InvalidInputException(Error): + pass + def wrap_exception(f): def _wrap(*args, **kw): diff --git a/nova/network/manager.py b/nova/network/manager.py index a7126ea4f..8f1924ac9 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -201,7 +201,6 @@ class FlatManager(NetworkManager): # in the datastore? net = {} net['injected'] = True - net['network_str'] = FLAGS.flat_network_network net['netmask'] = FLAGS.flat_network_netmask net['bridge'] = FLAGS.flat_network_bridge net['gateway'] = FLAGS.flat_network_gateway diff --git a/nova/process.py b/nova/process.py index b3cad894b..13cb90e82 100644 --- a/nova/process.py +++ b/nova/process.py @@ -113,7 +113,7 @@ class BackRelayWithInput(protocol.ProcessProtocol): if self.started_deferred: self.started_deferred.callback(self) if self.process_input: - self.transport.write(self.process_input) + self.transport.write(str(self.process_input)) self.transport.closeStdin() def get_process_output(executable, args=None, env=None, path=None, diff --git a/nova/test.py b/nova/test.py index 1f4b33272..08e1dea2d 100644 --- a/nova/test.py +++ b/nova/test.py @@ -31,6 +31,7 @@ from tornado import ioloop from twisted.internet import defer from twisted.trial import unittest +from nova import db from nova import fakerabbit from nova import flags from nova import rpc @@ -83,6 +84,7 @@ class TrialTestCase(unittest.TestCase): if FLAGS.fake_rabbit: fakerabbit.reset_all() + db.security_group_destroy_all(None) super(TrialTestCase, self).tearDown() diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index c040cdad3..7ab27e000 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -91,6 +91,9 @@ class ApiEc2TestCase(test.BaseTestCase): self.host = '127.0.0.1' self.app = api.API() + + def expect_http(self, host=None, is_secure=False): + """Returns a new EC2 connection""" self.ec2 = boto.connect_ec2( aws_access_key_id='fake', aws_secret_access_key='fake', @@ -100,9 +103,6 @@ class ApiEc2TestCase(test.BaseTestCase): path='/services/Cloud') self.mox.StubOutWithMock(self.ec2, 'new_http_connection') - - def expect_http(self, host=None, is_secure=False): - """Returns a new EC2 connection""" http = FakeHttplibConnection( self.app, '%s:8773' % (self.host), False) # pylint: disable-msg=E1103 @@ -138,3 +138,185 @@ class ApiEc2TestCase(test.BaseTestCase): self.assertEquals(len(results), 1) self.manager.delete_project(project) self.manager.delete_user(user) + + def test_get_all_security_groups(self): + """Test that we can retrieve security groups""" + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + rv = self.ec2.get_all_security_groups() + + self.assertEquals(len(rv), 1) + self.assertEquals(rv[0].name, 'default') + + self.manager.delete_project(project) + self.manager.delete_user(user) + + def test_create_delete_security_group(self): + """Test that we can create a security group""" + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + # At the moment, you need both of these to actually be netadmin + self.manager.add_role('fake', 'netadmin') + project.add_role('fake', 'netadmin') + + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + + self.ec2.create_security_group(security_group_name, 'test group') + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + self.assertEquals(len(rv), 2) + self.assertTrue(security_group_name in [group.name for group in rv]) + + self.expect_http() + self.mox.ReplayAll() + + self.ec2.delete_security_group(security_group_name) + + self.manager.delete_project(project) + self.manager.delete_user(user) + + def test_authorize_revoke_security_group_cidr(self): + """ + Test that we can add and remove CIDR based rules + to a security group + """ + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake') + project = self.manager.create_project('fake', 'fake', 'fake') + + # At the moment, you need both of these to actually be netadmin + self.manager.add_role('fake', 'netadmin') + project.add_role('fake', 'netadmin') + + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + + group = self.ec2.create_security_group(security_group_name, 'test group') + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.authorize('tcp', 80, 81, '0.0.0.0/0') + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + # I don't bother checkng that we actually find it here, + # because the create/delete unit test further up should + # be good enough for that. + for group in rv: + if group.name == security_group_name: + self.assertEquals(len(group.rules), 1) + self.assertEquals(int(group.rules[0].from_port), 80) + self.assertEquals(int(group.rules[0].to_port), 81) + self.assertEquals(len(group.rules[0].grants), 1) + self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0') + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.revoke('tcp', 80, 81, '0.0.0.0/0') + + self.expect_http() + self.mox.ReplayAll() + + self.ec2.delete_security_group(security_group_name) + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + rv = self.ec2.get_all_security_groups() + + self.assertEqual(len(rv), 1) + self.assertEqual(rv[0].name, 'default') + + self.manager.delete_project(project) + self.manager.delete_user(user) + + return + + def test_authorize_revoke_security_group_foreign_group(self): + """ + Test that we can grant and revoke another security group access + to a security group + """ + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + # At the moment, you need both of these to actually be netadmin + self.manager.add_role('fake', 'netadmin') + project.add_role('fake', 'netadmin') + + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + other_security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + + group = self.ec2.create_security_group(security_group_name, 'test group') + + self.expect_http() + self.mox.ReplayAll() + + other_group = self.ec2.create_security_group(other_security_group_name, + 'some other group') + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.authorize(src_group=other_group) + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + + # I don't bother checkng that we actually find it here, + # because the create/delete unit test further up should + # be good enough for that. + for group in rv: + if group.name == security_group_name: + self.assertEquals(len(group.rules), 1) + self.assertEquals(len(group.rules[0].grants), 1) + self.assertEquals(str(group.rules[0].grants[0]), + '%s-%s' % (other_security_group_name, 'fake')) + + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + + for group in rv: + if group.name == security_group_name: + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + group.revoke(src_group=other_group) + + self.expect_http() + self.mox.ReplayAll() + + self.ec2.delete_security_group(security_group_name) + + self.manager.delete_project(project) + self.manager.delete_user(user) + + return diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py index 5a599ff3a..1d6b9e826 100644 --- a/nova/tests/objectstore_unittest.py +++ b/nova/tests/objectstore_unittest.py @@ -191,7 +191,7 @@ class S3APITestCase(test.TrialTestCase): """Setup users, projects, and start a test server.""" super(S3APITestCase, self).setUp() - FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver', + FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' FLAGS.buckets_path = os.path.join(OSS_TEMPDIR, 'buckets') self.auth_manager = manager.AuthManager() diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 2aab16809..5e9505374 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -14,23 +14,31 @@ # License for the specific language governing permissions and limitations # under the License. +from xml.dom.minidom import parseString + +from nova import db from nova import flags from nova import test +from nova.api.ec2 import cloud from nova.virt import libvirt_conn FLAGS = flags.FLAGS class LibvirtConnTestCase(test.TrialTestCase): - def test_get_uri_and_template(self): + def bitrot_test_get_uri_and_template(self): class MockDataModel(object): + def __getitem__(self, name): + return self.datamodel[name] + def __init__(self): self.datamodel = { 'name' : 'i-cafebabe', 'memory_kb' : '1024000', 'basepath' : '/some/path', 'bridge_name' : 'br100', 'mac_address' : '02:12:34:46:56:67', - 'vcpus' : 2 } + 'vcpus' : 2, + 'project_id' : None } type_uri_map = { 'qemu' : ('qemu:///system', [lambda s: '<domain type=\'qemu\'>' in s, @@ -53,7 +61,7 @@ class LibvirtConnTestCase(test.TrialTestCase): self.assertEquals(uri, expected_uri) for i, check in enumerate(checks): - xml = conn.toXml(MockDataModel()) + xml = conn.to_xml(MockDataModel()) self.assertTrue(check(xml), '%s failed check %d' % (xml, i)) # Deliberately not just assigning this string to FLAGS.libvirt_uri and @@ -67,3 +75,131 @@ class LibvirtConnTestCase(test.TrialTestCase): uri, template = conn.get_uri_and_template() self.assertEquals(uri, testuri) + +class NWFilterTestCase(test.TrialTestCase): + def setUp(self): + super(NWFilterTestCase, self).setUp() + + class Mock(object): + pass + + self.context = Mock() + self.context.user = Mock() + self.context.user.id = 'fake' + self.context.user.is_superuser = lambda:True + self.context.project = Mock() + self.context.project.id = 'fake' + + self.fake_libvirt_connection = Mock() + + self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection) + + def test_cidr_rule_nwfilter_xml(self): + cloud_controller = cloud.CloudController() + cloud_controller.create_security_group(self.context, + 'testgroup', + 'test group description') + cloud_controller.authorize_security_group_ingress(self.context, + 'testgroup', + from_port='80', + to_port='81', + ip_protocol='tcp', + cidr_ip='0.0.0.0/0') + + + security_group = db.security_group_get_by_name({}, 'fake', 'testgroup') + + xml = self.fw.security_group_to_nwfilter_xml(security_group.id) + + dom = parseString(xml) + self.assertEqual(dom.firstChild.tagName, 'filter') + + rules = dom.getElementsByTagName('rule') + self.assertEqual(len(rules), 1) + + # It's supposed to allow inbound traffic. + self.assertEqual(rules[0].getAttribute('action'), 'accept') + self.assertEqual(rules[0].getAttribute('direction'), 'in') + + # Must be lower priority than the base filter (which blocks everything) + self.assertTrue(int(rules[0].getAttribute('priority')) < 1000) + + ip_conditions = rules[0].getElementsByTagName('tcp') + self.assertEqual(len(ip_conditions), 1) + self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0/0') + self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80') + self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81') + + + self.teardown_security_group() + + def teardown_security_group(self): + cloud_controller = cloud.CloudController() + cloud_controller.delete_security_group(self.context, 'testgroup') + + + def setup_and_return_security_group(self): + cloud_controller = cloud.CloudController() + cloud_controller.create_security_group(self.context, + 'testgroup', + 'test group description') + cloud_controller.authorize_security_group_ingress(self.context, + 'testgroup', + from_port='80', + to_port='81', + ip_protocol='tcp', + cidr_ip='0.0.0.0/0') + + return db.security_group_get_by_name({}, 'fake', 'testgroup') + + def test_creates_base_rule_first(self): + # These come pre-defined by libvirt + self.defined_filters = ['no-mac-spoofing', + 'no-ip-spoofing', + 'no-arp-spoofing', + 'allow-dhcp-server'] + + self.recursive_depends = {} + for f in self.defined_filters: + self.recursive_depends[f] = [] + + def _filterDefineXMLMock(xml): + dom = parseString(xml) + name = dom.firstChild.getAttribute('name') + self.recursive_depends[name] = [] + for f in dom.getElementsByTagName('filterref'): + ref = f.getAttribute('filter') + self.assertTrue(ref in self.defined_filters, + ('%s referenced filter that does ' + + 'not yet exist: %s') % (name, ref)) + dependencies = [ref] + self.recursive_depends[ref] + self.recursive_depends[name] += dependencies + + self.defined_filters.append(name) + return True + + self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock + + instance_ref = db.instance_create({}, {'user_id': 'fake', + 'project_id': 'fake'}) + inst_id = instance_ref['id'] + + def _ensure_all_called(_): + instance_filter = 'nova-instance-%s' % instance_ref['str_id'] + secgroup_filter = 'nova-secgroup-%s' % self.security_group['id'] + for required in [secgroup_filter, 'allow-dhcp-server', + 'no-arp-spoofing', 'no-ip-spoofing', + 'no-mac-spoofing']: + self.assertTrue(required in self.recursive_depends[instance_filter], + "Instance's filter does not include %s" % required) + + self.security_group = self.setup_and_return_security_group() + + db.instance_add_security_group({}, inst_id, self.security_group.id) + instance = db.instance_get({}, inst_id) + + d = self.fw.setup_nwfilters_for_instance(instance) + d.addCallback(_ensure_all_called) + d.addCallback(lambda _:self.teardown_security_group()) + + return d diff --git a/nova/virt/interfaces.template b/nova/virt/interfaces.template index 11df301f6..87b92b84a 100644 --- a/nova/virt/interfaces.template +++ b/nova/virt/interfaces.template @@ -10,7 +10,6 @@ auto eth0 iface eth0 inet static address %(address)s netmask %(netmask)s - network %(network)s broadcast %(broadcast)s gateway %(gateway)s dns-nameservers %(dns)s diff --git a/nova/virt/libvirt.qemu.xml.template b/nova/virt/libvirt.qemu.xml.template index 17bd79b7c..2538b1ade 100644 --- a/nova/virt/libvirt.qemu.xml.template +++ b/nova/virt/libvirt.qemu.xml.template @@ -20,6 +20,10 @@ <source bridge='%(bridge_name)s'/> <mac address='%(mac_address)s'/> <!-- <model type='virtio'/> CANT RUN virtio network right now --> + <filterref filter="nova-instance-%(name)s"> + <parameter name="IP" value="%(ip_address)s" /> + <parameter name="DHCPSERVER" value="%(dhcp_server)s" /> + </filterref> </interface> <serial type="file"> <source path='%(basepath)s/console.log'/> diff --git a/nova/virt/libvirt.uml.xml.template b/nova/virt/libvirt.uml.xml.template index c039d6d90..bb8b47911 100644 --- a/nova/virt/libvirt.uml.xml.template +++ b/nova/virt/libvirt.uml.xml.template @@ -14,6 +14,10 @@ <interface type='bridge'> <source bridge='%(bridge_name)s'/> <mac address='%(mac_address)s'/> + <filterref filter="nova-instance-%(name)s"> + <parameter name="IP" value="%(ip_address)s" /> + <parameter name="DHCPSERVER" value="%(dhcp_server)s" /> + </filterref> </interface> <console type="file"> <source path='%(basepath)s/console.log'/> diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index d868e083c..c86f3ffb7 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -27,6 +27,7 @@ import shutil from twisted.internet import defer from twisted.internet import task +from twisted.internet import threads from nova import db from nova import exception @@ -214,6 +215,7 @@ class LibvirtConnection(object): instance['id'], power_state.NOSTATE, 'launching') + yield NWFilterFirewall(self._conn).setup_nwfilters_for_instance(instance) yield self._create_image(instance, xml) yield self._conn.createXML(xml, 0) # TODO(termie): this should actually register @@ -285,7 +287,6 @@ class LibvirtConnection(object): address = db.instance_get_fixed_address(None, inst['id']) with open(FLAGS.injected_network_template) as f: net = f.read() % {'address': address, - 'network': network_ref['network'], 'netmask': network_ref['netmask'], 'gateway': network_ref['gateway'], 'broadcast': network_ref['broadcast'], @@ -317,6 +318,9 @@ class LibvirtConnection(object): network = db.project_get_network(None, instance['project_id']) # FIXME(vish): stick this in db instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']] + ip_address = db.instance_get_fixed_address({}, instance['id']) + # Assume that the gateway also acts as the dhcp server. + dhcp_server = network['gateway'] xml_info = {'type': FLAGS.libvirt_type, 'name': instance['name'], 'basepath': os.path.join(FLAGS.instances_path, @@ -324,7 +328,9 @@ class LibvirtConnection(object): 'memory_kb': instance_type['memory_mb'] * 1024, 'vcpus': instance_type['vcpus'], 'bridge_name': network['bridge'], - 'mac_address': instance['mac_address']} + 'mac_address': instance['mac_address'], + 'ip_address': ip_address, + 'dhcp_server': dhcp_server } libvirt_xml = self.libvirt_xml % xml_info logging.debug('instance %s: finished toXML method', instance['name']) @@ -438,3 +444,169 @@ class LibvirtConnection(object): """ domain = self._conn.lookupByName(instance_name) return domain.interfaceStats(interface) + + + def refresh_security_group(self, security_group_id): + fw = NWFilterFirewall(self._conn) + fw.ensure_security_group_filter(security_group_id) + + +class NWFilterFirewall(object): + """ + This class implements a network filtering mechanism versatile + enough for EC2 style Security Group filtering by leveraging + libvirt's nwfilter. + + First, all instances get a filter ("nova-base-filter") applied. + This filter drops all incoming ipv4 and ipv6 connections. + Outgoing connections are never blocked. + + Second, every security group maps to a nwfilter filter(*). + NWFilters can be updated at runtime and changes are applied + immediately, so changes to security groups can be applied at + runtime (as mandated by the spec). + + Security group rules are named "nova-secgroup-<id>" where <id> + is the internal id of the security group. They're applied only on + hosts that have instances in the security group in question. + + Updates to security groups are done by updating the data model + (in response to API calls) followed by a request sent to all + the nodes with instances in the security group to refresh the + security group. + + Each instance has its own NWFilter, which references the above + mentioned security group NWFilters. This was done because + interfaces can only reference one filter while filters can + reference multiple other filters. This has the added benefit of + actually being able to add and remove security groups from an + instance at run time. This functionality is not exposed anywhere, + though. + + Outstanding questions: + + The name is unique, so would there be any good reason to sync + the uuid across the nodes (by assigning it from the datamodel)? + + + (*) This sentence brought to you by the redundancy department of + redundancy. + """ + + def __init__(self, get_connection): + self._conn = get_connection + + + nova_base_filter = '''<filter name='nova-base' chain='root'> + <uuid>26717364-50cf-42d1-8185-29bf893ab110</uuid> + <filterref filter='no-mac-spoofing'/> + <filterref filter='no-ip-spoofing'/> + <filterref filter='no-arp-spoofing'/> + <filterref filter='allow-dhcp-server'/> + <filterref filter='nova-allow-dhcp-server'/> + <filterref filter='nova-base-ipv4'/> + <filterref filter='nova-base-ipv6'/> + </filter>''' + + nova_dhcp_filter = '''<filter name='nova-allow-dhcp-server' chain='ipv4'> + <uuid>891e4787-e5c0-d59b-cbd6-41bc3c6b36fc</uuid> + <rule action='accept' direction='out' + priority='100'> + <udp srcipaddr='0.0.0.0' + dstipaddr='255.255.255.255' + srcportstart='68' + dstportstart='67'/> + </rule> + <rule action='accept' direction='in' priority='100'> + <udp srcipaddr='$DHCPSERVER' + srcportstart='67' + dstportstart='68'/> + </rule> + </filter>''' + + def nova_base_ipv4_filter(self): + retval = "<filter name='nova-base-ipv4' chain='ipv4'>" + for protocol in ['tcp', 'udp', 'icmp']: + for direction,action,priority in [('out','accept', 400), + ('in','drop', 399)]: + retval += """<rule action='%s' direction='%s' priority='%d'> + <%s /> + </rule>""" % (action, direction, + priority, protocol) + retval += '</filter>' + return retval + + + def nova_base_ipv6_filter(self): + retval = "<filter name='nova-base-ipv6' chain='ipv6'>" + for protocol in ['tcp', 'udp', 'icmp']: + for direction,action,priority in [('out','accept',400), + ('in','drop',399)]: + retval += """<rule action='%s' direction='%s' priority='%d'> + <%s-ipv6 /> + </rule>""" % (action, direction, + priority, protocol) + retval += '</filter>' + return retval + + + def _define_filter(self, xml): + if callable(xml): + xml = xml() + d = threads.deferToThread(self._conn.nwfilterDefineXML, xml) + return d + + + @defer.inlineCallbacks + def setup_nwfilters_for_instance(self, instance): + """ + Creates an NWFilter for the given instance. In the process, + it makes sure the filters for the security groups as well as + the base filter are all in place. + """ + + yield self._define_filter(self.nova_base_ipv4_filter) + yield self._define_filter(self.nova_base_ipv6_filter) + yield self._define_filter(self.nova_dhcp_filter) + yield self._define_filter(self.nova_base_filter) + + nwfilter_xml = ("<filter name='nova-instance-%s' chain='root'>\n" + + " <filterref filter='nova-base' />\n" + ) % instance['name'] + + for security_group in instance.security_groups: + yield self.ensure_security_group_filter(security_group['id']) + + nwfilter_xml += (" <filterref filter='nova-secgroup-%d' />\n" + ) % security_group['id'] + nwfilter_xml += "</filter>" + + yield self._define_filter(nwfilter_xml) + return + + def ensure_security_group_filter(self, security_group_id): + return self._define_filter( + self.security_group_to_nwfilter_xml(security_group_id)) + + + def security_group_to_nwfilter_xml(self, security_group_id): + security_group = db.security_group_get({}, security_group_id) + rule_xml = "" + for rule in security_group.rules: + rule_xml += "<rule action='accept' direction='in' priority='300'>" + if rule.cidr: + rule_xml += "<%s srcipaddr='%s' " % (rule.protocol, rule.cidr) + if rule.protocol in ['tcp', 'udp']: + rule_xml += "dstportstart='%s' dstportend='%s' " % \ + (rule.from_port, rule.to_port) + elif rule.protocol == 'icmp': + logging.info('rule.protocol: %r, rule.from_port: %r, rule.to_port: %r' % (rule.protocol, rule.from_port, rule.to_port)) + if rule.from_port != -1: + rule_xml += "type='%s' " % rule.from_port + if rule.to_port != -1: + rule_xml += "code='%s' " % rule.to_port + + rule_xml += '/>\n' + rule_xml += "</rule>\n" + xml = '''<filter name='nova-secgroup-%s' chain='ipv4'>%s</filter>''' % (security_group_id, rule_xml,) + return xml diff --git a/run_tests.py b/run_tests.py index 4121f4c06..8bb068ed1 100644 --- a/run_tests.py +++ b/run_tests.py @@ -64,6 +64,7 @@ from nova.tests.scheduler_unittest import * from nova.tests.service_unittest import * from nova.tests.validator_unittest import * from nova.tests.volume_unittest import * +from nova.tests.virt_unittest import * FLAGS = flags.FLAGS |